JavaCVで機械学習する

投稿者: 加藤 淳
投稿日:
カテゴリ: programming

OpenCV には機械学習の実装がいくつか(単純ベイズ、k 近傍、SVM、決定木…)用意されていて、画像処理に限らず汎用目的で便利に使うことができます。実装ごとにクラス化されていて、学習(train)→ 予想(predict)という似たようなメソッドが用意されているため、学習結果がイマイチなら他に切り替える、ということも比較的簡単にできるようです。

この便利な機能は、Java 版ラッパーである JavaCV にもポーティングされています。そこで、とりあえず 5 次元ベクトル群を 2 つのクラスに分類するコードを書いてみました。下の画像では、5 次元中 2 次元を可視化しています。他のパラメタで実行してみた画面キャプチャやコード全体は記事末尾にあります。JavaCV のインストールなどに関しては前の記事「OpenCV を Java から使う」をどうぞ。

準備

// Parameters for the test
int numBins = 5;
int numSamples = 50;
int visualizationSize = 512;

numBins で次元、numSamples で学習に使うサンプルの数を決めます。visualizationSize は画面に表示するウィンドウの大きさです。

// Initialize the classifier.
opencv_ml.CvKNearest classifier = new opencv_ml.CvKNearest();
int k = 1;

まずは、k-Nearest Neighbor 法を使ってみることにします。

学習させる

// Initialize the training set.
opencv_core.CvMat trainingDataMat = CvMat.create(
    numSamples,
    numBins,
    opencv_core.CV_32FC1);
opencv_core.CvMat labelsMat = CvMat.create(
    numSamples,
    1,
    opencv_core.CV_32FC1);

行列 trainingDataMat にサンプルを、labelsMat にサンプルの分類を表す値を入れていきます。

double[] values = new double[numBins];
for (int i = 0; i < numSamples; i ++) {
    createTrainingData(values);

    // Be cautious, this shouldn't be put(i, values)!!
    trainingDataMat.put(i\*numBins, values);

    float category = categorize(values);
    labelsMat.put(i, category);
}

createTrainingData(double[] values)はランダムなサンプルの値を生成するメソッド、categorize(double[] values)は与えられたサンプルの分類(ground truth)を与えるメソッドです。

注意すべき点として、CvMat 型の put メソッドは第一引数に行列中で値を代入したい場所のインデックスを与えるのですが、例えば 5 行 10 列の行列に 10 列の配列データを代入するとして、3 行目に代入したいなら put(2, data)ではいけません。put(2 * 5, data)です。行列とはいっても中では普通の一次元配列として持っているようで、インデックスにはその一次元配列中での位置を指定しないといけないんですね。これで半日はまりました。

// Train the classifier.
try {
    classifier.train(trainingDataMat, labelsMat, null, false, k, false);
} catch (Exception e) {
    return;
}

CvKNearest::trainに書いてあるように引数を与えて、学習を行います。

新しいサンプルを分類する

// Predict and visualize the results for the test set.
opencv_core.CvMat predictionMat = CvMat.create( 1, numBins, opencv_core.CV_32FC1);
// 中略
createTrainingData(values);
// 中略
predictionMat.put(0, values);
// 中略
float label = classifier.find_nearest( predictionMat, k, null, null, null, null);

学習後は、新しいサンプルに対してCvKNearest::find_nearestを呼んで、既存のサンプルのなかで近いもの k 個の多数決で分類の予想値を取得できます。

他のアルゴリズムも試す

一種類できれば他のアルゴリズムに差し換えるのも簡単です。

// Initialize the classifier.
opencv_ml.CvSVM classifier = new opencv_ml.CvSVM();
opencv_ml.CvSVMParams params = new opencv_ml.CvSVMParams();
params.svm_type(opencv_ml.CvSVM.C_SVC);
params.kernel_type(opencv_ml.CvSVM.LINEAR);
params.term_crit(opencv_core.cvTermCriteria( opencv_core.CV_TERMCRIT_ITER, 1000, 0.001));
//(中略)
classifier.train(trainingDataMat, labelsMat, null, null, params);
//(中略)
opencv_core.CvMat resultMat = CvMat.create( visualizationSize \* visualizationSize, 1, opencv_core.CV_32FC1);
float result = classifier.predict(predictionMat, resultMat);
//(中略)
float label = (float) resultMat.get(x + y \* visualizationSize);

例えば、該当箇所をこのように書き換えれば SVM での機械学習ができます。k-Nearest よりも線引きがはっきりするアルゴリズムなので、実際、分布がはっきり分かれていますね。また、学習サンプルの個数があまり多くないので、線引きの傾きが ground truth と比べるとときどきズレるのも確認できました。

もっと試す

numBins を 2 にして 2 次元のデータで試しました。こっちのほうがアルゴリズムの違いが分かりやすいですね!

最近傍のサンプルを参照してクラス分けを行う k-Nearest Neighbor 法に…

ぱっくり線引きを行う SVM。

numSamples を増やしていくと、計算時間の差が歴然です。k-Nearest Neighbor 法は学習に全く時間がかからない代わりに分別にすごく時間がかかりますが、SVM は学習に少しだけ時間がかかるけど分別はすぐですね。

ソースコード全文

import java.awt.Color;
import java.awt.Graphics2D;
import java.awt.RenderingHints;
import java.awt.event.MouseEvent;
import java.awt.event.MouseAdapter;
import java.awt.image.BufferedImage;

import javax.swing.SwingUtilities;

import com.googlecode.javacv.CanvasFrame;
import com.googlecode.javacv.cpp.opencv_core;
import com.googlecode.javacv.cpp.opencv_core.CvMat;
import com.googlecode.javacv.cpp.opencv_ml;

public class MachineLearningKNN {
    public static void main(String[] args) throws Exception {
        new MachineLearningKNN();
    }

    // Parameters for the test
    int numBins = 5;
    int numSamples = 50;
    int visualizationSize = 512;

    CanvasFrame trainingFrame;

    public MachineLearningKNN() {
        SwingUtilities.invokeLater(new Runnable() {
            public void run() {
                trainingFrame = new CanvasFrame("Machine learning");
                trainingFrame.setDefaultCloseOperation(CanvasFrame.EXIT_ON_CLOSE);
                trainingFrame.getCanvas().addMouseListener(new MouseAdapter() {
                    @Override
                    public void mouseClicked(MouseEvent arg0) { update(); }
                });
                update();
            }
        });
    }

    private void update() {

        // Prepare the window.
        BufferedImage image = new BufferedImage(
                visualizationSize, visualizationSize, BufferedImage.TYPE_4BYTE_ABGR);
        Graphics2D g2 = image.createGraphics();
        g2.setRenderingHint(
                RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_OFF);
        g2.setColor(Color.white);
        g2.fillRect(0, 0, image.getWidth(), image.getHeight());
        g2.dispose();
        trainingFrame.showImage(image);

        // Initialize the classifier.
        opencv_ml.CvKNearest classifier = new opencv_ml.CvKNearest();
        int k = 1;

        // Initialize the training set.
        opencv_core.CvMat trainingDataMat = CvMat.create(
                numSamples,
                numBins,
                opencv_core.CV_32FC1);
        opencv_core.CvMat labelsMat = CvMat.create(
                numSamples,
                1,
                opencv_core.CV_32FC1);
        double[] values = new double[numBins];
        for (int i = 0; i < numSamples; i ++) {
            createTrainingData(values);

            // Be cautious, this shouldn't be put(i, values)!!
            trainingDataMat.put(i*numBins, values);

            float category = categorize(values);
            labelsMat.put(i, category);
        }

        // Train the classifier.
        try {
            classifier.train(trainingDataMat, labelsMat, null, false, k, false);
        } catch (Exception e) {
            return;
        }

        // Predict and visualize the results for the test set.
        opencv_core.CvMat predictionMat = CvMat.create(
                1,
                numBins,
                opencv_core.CV_32FC1);
        Color green = new Color(25, 106, 115);
        Color orange = new Color(235, 162, 85);
        g2 = image.createGraphics();
        for (int x = 0; x < visualizationSize; x ++) {
            int ymax = numBins > 1 ? visualizationSize : 1;
            for (int y = 0; y < ymax; y ++) {
                createTrainingData(values);
                values[0] = (double) x / visualizationSize;
                if (numBins > 1) {
                    values[1] = (double) y / visualizationSize;
                }
                predictionMat.put(0, values);

                float label = classifier.find_nearest(
                        predictionMat, k, null, null,
                        null, null);
                g2.setColor(label == 1 ? green : orange);

                int sy, ey;
                if (numBins > 1) {
                    sy = ey = y;
                } else {
                    sy = 0; ey = visualizationSize - 1;
                }
                g2.drawLine(x, sy, x, ey);
            }
        }

        for (int i = 0; i < numSamples; i ++) {
            float label = (float) labelsMat.get(i);
            g2.setColor(label == 1 ? Color.white : Color.black);
            int x = (int) (trainingDataMat.get(i, 0) * visualizationSize);
            int y;
            if (numBins > 1) {
                y = (int) (trainingDataMat.get(i, 1) * visualizationSize);
            } else {
                y = visualizationSize / 2;
            }
            g2.fillOval(x, y, 5, 5);
        }
        g2.dispose();
        trainingFrame.showImage(image);
    }

    private void createTrainingData(double[] values) {
        for (int j = 0; j < numBins; j ++)
            values[j] = Math.random();
    }

    private float categorize(double[] params) {
        if (params.length == 1) return params[0] < 0.5 ? -1 : 1;
        return params[0] < params[1] ? -1 : 1;
    }
}