FLINTERS Engineer's Blog

FLINTERSのエンジニアによる技術ブログ

ScalaでDeeplearning4jを使い自動運転で峠を攻める!

こんにちは。菅野です。

最近、AIとか機械学習とかが話題ですね。
AIに仕事を奪われる職業がどうとかの記事もよく見かけます。
このブログ記事もAIが書いてくれたら良いのにと思っている今日この頃です。

…でも思ってるだけでは仕事を奪ってくれないので、やっぱり何かしら自分で作るしか無さそう。
という訳で、今回はJavaディープラーニングが出来るDeeplearning4jを使って機械学習を試します!

プロジェクトD

さて、何を作りましょう?
最終的には私の仕事を勝手にやってくれるAIを作りたいです。
でも、はじめは簡単なものから少しずつ作っていこうと思います。
よくディープラーニングでネタにされるのは手書き文字の識別ですが、正直面白くないので道路上を自動運転するAIを作ります!

嘘です。
いきなり作るのは無理があるので、画像の道路が左カーブなのか、右カーブなのか、あるいは直線なのかを分類し画面に表示して自動運転する夢を見るアプリを作ろうと思います。
目指せ公道最速伝説。

Deeplearning4jとは

Deeplearning4jは名前の通りJava用のディープラーニングのライブラリです。 deeplearning4j.org

ディープラーニングとは、多層構造のニューラルネットワークを使用した機械学習の分野の名前です。
ニューラルネットワークは人間の脳の神経回路を参考にした、機械学習を行うための数学モデルです。
Deeplearning4jの公式サイトでも解説がされていて、なんと日本語化もされているので一読の価値があります!

機械学習分野だとPythonC++等が主流ですが、JavaScalaでアプリ作ってる人たちにとってはJavaアプリから学習したモデルをサクッと使えると便利ですよね。
Deeplearning4jはそんな人たち向けのライブラリだと思います。

他のメリットとしては、直接Sparkをサポートしていて分散環境上で簡単に動かせることが挙げられます。
ということはAmazon EMR上ですぐに動かせるということです!

また、CUDAを使いGPUを活用することが出来ます。
私はRadeon派なので手元のマシンでは使えないのですが!
でも、OpenCLにも対応することが予定されているので対応されたら私にとっては更にメリットが増えます。

いざ、データ収集の旅へ

今回行う道路形状の分類は機械学習の中でも「教師あり学習」と呼ばれる手法で行います。
当然、お手本となるデータが必要です。

仕方なく(?)、バイクに乗ってツーリングに出かけます。

f:id:zakknak:20170604225614j:plain ロケ地:ビーナスライン(長野県)

途中いろいろあるわけですが、割愛してまたPC上での作業に戻ります。

プログラムに読み込ませる用の教師データの作成に移ります。
まず、「0」「-1」「1」という名前でフォルダを3つ作ります。(名前自体に意味は無い)
f:id:zakknak:20170604232941p:plain

撮影した動画から一定間隔ごとにスナップショットを取り、
左カーブの画像を「-1」のフォルダ、右カーブの画像を「1」のフォルダ、直線の画像を「0」のフォルダに振り分けます。

-1
f:id:zakknak:20170604233041p:plain

1
f:id:zakknak:20170604233055p:plain

0
f:id:zakknak:20170604233105p:plain

全部で約14000枚の画像を用意してみました。
後に行う学習が速くなるように、予めグレースケール画像にしています。

さぁ、学習だ

教師用データが揃ったので、これから機械に学習してもらいます。

Deeplearning4jを読み込む

まず、Deeplearning4j(以下DL4J)を読み込むところからです。
私はScalaを使いたいのでbuild.sbtに依存ライブラリを書きます。

DL4Jは高速にベクトル演算を行うためのライブラリであるND4Jに依存しているのですが、
そのライブラリは演算を行うためにネイティブのライブラリを使用しています。

私が持っている中で最速のマシンはゲーム用であるWin機なので、今回はWindows環境でND4Jを動作させました。
やはりネイティブのdllを使用しているのが厄介で、DL4Jの現時点の最新バージョンは0.8.0なのですが
上手くネイティブのライブラリの依存が解決できなかったので今回は0.7.0を使用します。

ちなみにネイティブ部分以外にも悩ましいところがあり、DL4Jは0.4、0.6、0.7のバージョンアップで色々とインターフェースが変わっていて結構ヤンチャな開発が活発なライブラリです。

最終的な依存の指定は以下のとおりです。見慣れないオプションが必要になりますが、
ネイティブのdllはjarに含まれているので別途用意したりする必要はありません。そのまま読み込んでくれます!

javaOptions in run += """-Djava.library.path="""""
javaOptions in run += """-Djavacpp.platform="windows-x86_64""""

classpathTypes += "maven-plugin"

libraryDependencies ++= Seq(
  "org.deeplearning4j" % "deeplearning4j-core" % "0.7.0",
  "org.deeplearning4j" %% "deeplearning4j-ui" % "0.7.0",
  "org.nd4j" % "nd4j-native" % "0.7.0" classifier "" classifier "windows-x86_64",
  "org.bytedeco" % "javacv-platform" % "1.3.2"
)

教師データを学習する

DL4Jが使えるようになったら、学習の準備を始めます。

今回は画像認識なので「畳み込みニューラルネットワーク」というものを使用したいと思います。 deeplearning4j.org DL4Jの公式サイトに畳み込みネットワークの解説があるので、詳細はそちらを参照してください。

一からニューラルネットワークを設計するのはちょっと厳しいので、DL4Jのサンプルにある動物の識別のコードにある「AlexNet」と呼ばれる有名なネットワークの実装を参考に作ります。 github.com

DL4Jの各サンプルコードを継ぎ接ぎしたような感じになってしまったのですが、教師データを読み込んで学習するコードは以下のようになりました。

object BuildModel extends App {

  val nChannels = 1 // 画像のチャネル数(グレースケールなので1)
  val outputNum = 3 // ニューラルネットの出力の数(左カーブ、直線、右カーブの3つ)
  val batchSize = 128 // 教師データを一括処理する数
  val nEpochs = 10 // 繰り返し学習数
  val iterations = 2 // 一回の学習で行うイテレーション
  val seed = 123 // プログラムを再実行する場合にRandomが同じ値で固定されるようにするためのシード(適当な固定値)

  // 読み込んだ画像のサイズ(実サイズと違うならリサイズされて読み込まれる)
  val width = 100
  val height = 100

  val randNumGen = new Random(seed)

  // datavec:DL4J向けの教師データ読み込みライブラリ。
  // ParentPathLabelGeneratorは親のディレクトリの名前をラベルにしてデータをグループ分けする。
  // 今回は-1,0,1のラベルが付いたデータとしてImageRecordReaderを使ってそれぞれが読み込まれる。
  val labelGen = new ParentPathLabelGenerator()
  val fileSplit = new FileSplit(dataPath.toFile, BaseImageLoader.ALLOWED_FORMATS, randNumGen)
  val pathFilter = new BalancedPathFilter(randNumGen, BaseImageLoader.ALLOWED_FORMATS, labelGen)
  val Array(testInput, trainInput) = fileSplit.sample(pathFilter, 10, 90) // 教師用データの中から訓練データとテスト用データに分割する

  val trainDataReader = new ImageRecordReader(height, width, nChannels, labelGen)
  trainDataReader.initialize(trainInput)
  val trainData = new RecordReaderDataSetIterator(trainDataReader, batchSize)

  val testDataReader = new ImageRecordReader(height, width, nChannels, labelGen)
  testDataReader.initialize(testInput)
  val testData = new RecordReaderDataSetIterator(testDataReader, batchSize)


  val nonZeroBias = 1 // ネットワークのバイアス
  val dropOut = 0.5 // より学習を促進させるための係数

  // AlexNet
  val conf = new NeuralNetConfiguration.Builder()
    .seed(seed)
    .weightInit(WeightInit.DISTRIBUTION)
    .dist(new NormalDistribution(0.0, 0.01))
    .activation("relu")
    .updater(Updater.NESTEROVS)
    .iterations(iterations)
    .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
    .learningRate(1e-2)
    .biasLearningRate(1e-2 * 2)
    .learningRateDecayPolicy(LearningRatePolicy.Step)
    .lrPolicyDecayRate(0.1)
    .lrPolicySteps(100000)
    .regularization(true)
    .l2(5 * 1e-4)
    .momentum(0.9)
    .miniBatch(false)
    .list()
    .layer(0, new ConvolutionLayer.Builder(Array(11, 11), Array(4, 4), Array(3, 3)).name("cnn1").nIn(nChannels).nOut(96).biasInit(0).build())
    .layer(1, new LocalResponseNormalization.Builder().name("lrn1").build())
    .layer(2, new SubsamplingLayer.Builder(Array(3, 3), Array(2, 2)).name("maxpool1").build())
    .layer(3, new ConvolutionLayer.Builder(Array(5, 5), Array(1, 1), Array(2, 2)).name("cnn2").nOut(256).biasInit(nonZeroBias).build())
    .layer(4, new LocalResponseNormalization.Builder().name("lrn2").build())
    .layer(5, new SubsamplingLayer.Builder(Array(3, 3), Array(2, 2)).name("maxpool2").build())
    .layer(6, new ConvolutionLayer.Builder(Array(3, 3), Array(1, 1), Array(1, 1)).name("cnn3").nOut(384).biasInit(0).build())
    .layer(7, new ConvolutionLayer.Builder(Array(3, 3), Array(1, 1), Array(1, 1)).name("cnn4").nOut(384).biasInit(nonZeroBias).build())
    .layer(8, new ConvolutionLayer.Builder(Array(3, 3), Array(1, 1), Array(1, 1)).name("cnn5").nOut(256).biasInit(nonZeroBias).build())
    .layer(9, new SubsamplingLayer.Builder(Array(3, 3), Array(2, 2)).name("maxpool3").build())
    .layer(10, new DenseLayer.Builder().name("ffn1").nOut(4096).biasInit(nonZeroBias).dropOut(dropOut).dist(new GaussianDistribution(0, 0.005)).build())
    .layer(11, new DenseLayer.Builder().name("ffn2").nOut(4096).biasInit(nonZeroBias).dropOut(dropOut).dist(new GaussianDistribution(0, 0.005)).build())
    .layer(12, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
      .name("output")
      .nOut(outputNum)
      .activation("softmax")
      .build())
    .backprop(true)
    .pretrain(false)
    .setInputType(InputType.convolutional(height, width, nChannels))
    .build()

  val model = new MultiLayerNetwork(conf)
  model.init()

  // DL4J UIを使って学習状態をWeb画面から見られるようにする
  val uiServer = UIServer.getInstance()
  val statsStorage = new InMemoryStatsStorage()
  uiServer.attach(statsStorage)
  model.setListeners(new StatsListener(statsStorage))

  0 to nEpochs foreach { i =>
    model.fit(trainData)

    println(s"*** Completed epoch $i ***")
    println("Evaluate model....")
    val eval = new Evaluation(outputNum)
    while(testData.hasNext()){
      val ds = testData.next()
      val output = model.output(ds.getFeatureMatrix(), false)
      eval.eval(ds.getLabels(), output)
    }
    println(eval.stats())
    testData.reset();
  }

  // 学習済みのモデルをファイルに保存
  ModelSerializer.writeModel(model, new File(modelPath), false)
  println("saved")

}

情報量が多いと学習に時間がかかるので、入力画像は100x100のグレースケール画像にしています。
判定には十分だと思います。
出力パターンは左、直線、右の3つです。

このプログラムを実行すると、後は勝手にDL4Jがネットワークの学習を始めます。

deeplearning4j.org 前提知識がないと何のこっちゃとなりますが、リンクにあるニューラルネットワークの図の「input layer」が画像にあたり、「hidden layer」で判定されて「output layer(今回は-1,0,1の3つ)」のどれかが出力される仕組みで、
「hidden layer」にあたる箇所で

  • -1の画像が入力されたら-1が出力される
  • 0の画像が入力されたら0が出力される
  • 1の画像が入力されたら1が出力される

となるように、ひたすらコスト関数と呼ばれるものを調整する作業をします。

それと、途中で訓練データとテストデータを分けていますが、
このテストデータはその時点での学習の成果を確認するために使うデータです。
これを使って現在の学習状態での認識の精度を確認します。

学習はひたすら時間がかかります。その学習状態はDL4J UIで簡単に確認できます。
DL4J UIにはPlay frameworkが使われていて、localhostの9000番ポートにアクセスすれば見ることが出来ます。

f:id:zakknak:20170605011755p:plain そこまで嬉しくないですが、日本語化出来ます(笑)
モデルスコアが減少していくのが正常な状態です。


6時間後

学習完了

先程のプログラムは6時間で最後までいきました。
本来はどこで学習を終えるかは学習状態を見て判断するのですが、今回は雰囲気的に出来てそうであればそれで良しとします。

f:id:zakknak:20170605012835p:plain 6時間後はこのような感じです。もう少し効率よく学習できるのかもしれないです。

学習過程での精度は以下のようになりました。
(このラベル出力では0が左、1が直線、2が右になってしまっています…。)

学習の初期ではテストデータを入力して正解とは左右逆と判定する事案が5回発生しています。
これでは路上には出れません。

*** Completed epoch 3 ***
Evaluate model....

Examples labeled as 0 classified by model as 0: 309 times
Examples labeled as 0 classified by model as 1: 17 times
Examples labeled as 0 classified by model as 2: 3 times
Examples labeled as 1 classified by model as 0: 87 times
Examples labeled as 1 classified by model as 1: 207 times
Examples labeled as 1 classified by model as 2: 36 times
Examples labeled as 2 classified by model as 0: 2 times
Examples labeled as 2 classified by model as 1: 36 times
Examples labeled as 2 classified by model as 2: 292 times


==========================Scores========================================
 Accuracy:        0.817
 Precision:       0.8182
 Recall:          0.8171
 F1 Score:        0.8177
========================================================================

後半ではかなり改善しています。

*** Completed epoch 9 ***
Evaluate model....

Examples labeled as 0 classified by model as 0: 289 times
Examples labeled as 0 classified by model as 1: 39 times
Examples labeled as 0 classified by model as 2: 1 times
Examples labeled as 1 classified by model as 0: 28 times
Examples labeled as 1 classified by model as 1: 263 times
Examples labeled as 1 classified by model as 2: 39 times
Examples labeled as 2 classified by model as 1: 19 times
Examples labeled as 2 classified by model as 2: 311 times


==========================Scores========================================
 Accuracy:        0.8726
 Precision:       0.8723
 Recall:          0.8726
 F1 Score:        0.8725
========================================================================

最後にはテストデータでは左右逆に判定される事が無くなりました。

*** Completed epoch 10 ***
Evaluate model....

Examples labeled as 0 classified by model as 0: 295 times
Examples labeled as 0 classified by model as 1: 34 times
Examples labeled as 1 classified by model as 0: 26 times
Examples labeled as 1 classified by model as 1: 260 times
Examples labeled as 1 classified by model as 2: 44 times
Examples labeled as 2 classified by model as 1: 19 times
Examples labeled as 2 classified by model as 2: 311 times


==========================Scores========================================
 Accuracy:        0.8756
 Precision:       0.8752
 Recall:          0.8757
 F1 Score:        0.8754
========================================================================

とりあえずそれっぽい傾向を示す学習モデルが出来上がったので、路上に出ようと思います。

ちなみに機械学習に興味を持ち基礎から勉強したい場合は、Courseraで受講できるAndrew Ng先生の機械学習のコースはオススメです。 www.coursera.org

オレオレAIによる自動運転(風)

学習モデルは出来上がりましたが、実際にコイツに命を預けるのは不安しか無いので、
動画からスナップショットを切り出してその画像を判定させるアプリを作りました。

学習済みモデルの使い方

学習済みモデルを読み込んで、判定したい画像を判定させるやり方は以下のとおりです。

val model = ModelSerializer.restoreMultiLayerNetwork(modelFile) // 作成したファイルを読み込み
val loader = new NativeImageLoader(100, 100, 1) // 入力画像は学習したときと同じ100x100のグレースケール画像にする
val imageDataArr = loader.asMatrix(inputFile) // 画像を行列に変換してロード
val output = model.output(imageDataArr)  // 判定
// 今回は要素3の配列として結果が返ってくる。左、直線、右のそれぞれの出力値で、一番大きい値がニューラルネットワークとしての出力値になる。
// result: 0=左, 1=直線, 2=右
val result = Nd4j.argMax(output).getInt(0)

これを応用して作ったアプリが以下です。
走行風景の動画から0.5秒ごとにスナップショットを取り、学習したモデルを使って判定して結果を表示します。

余談ですが、ScalaFX便利です。JavaFXのラッパーですが、Javaで組み立てたり、fxmlなんか使うよりもすごく楽にGUIが作れます。

object AutoPilot extends JFXApp {

  // 動画の用意
  val mediaSrc = """movie.mp4"""
  val path = Paths.get(mediaSrc)
  val media = new Media(path.toUri.toString)
  val mediaPlayer = new MediaPlayer(media) {
    volume = 0.3
    onStalled = println("stalled")
  }

  val mediaView = new MediaView(mediaPlayer) {
    preserveRatio = true
    fitWidth = 1920 / 2
  }

  // モデルの用意
  val model = ModelSerializer.restoreMultiLayerNetwork(new File(modelPath))

  // 縮小用
  val param = new SnapshotParameters {
    transform = new Scale(100d / 1920 * 2, 100d / 1080 * 2)
  }
  val scaledWidth = 100
  val scaledHeight = 100
  val image = new WritableImage(scaledWidth, scaledHeight)
  val colorConvert = new ColorConvertOp(ColorSpace.getInstance(ColorSpace.CS_GRAY), null)

  val scheduler = new Timer()
  scheduler.schedule(new TimerTask {
    def run(): Unit = {
      Platform.runLater {
        mediaView.snapshot(param, image)
        val convertedImage = colorConvert.filter(SwingFXUtils.fromFXImage(image, null), new BufferedImage(scaledWidth, scaledHeight, BufferedImage.TYPE_BYTE_GRAY))
        SwingFXUtils.toFXImage(convertedImage, image)
        imageView.setImage(image)

        val baos = new ByteArrayOutputStream()
        ImageIO.write(convertedImage, "jpg", baos)
        val is = new ByteArrayInputStream(baos.toByteArray())
        val loader = new NativeImageLoader(100, 100, 1)
        val imageDataArr = loader.asMatrix(is)
        val output = model.output(imageDataArr)

        val result = Nd4j.argMax(output).getInt(0)
        val position = result - 1
        positionText.text = position.toString
        bar.data = ObservableBuffer(
          XYChart.Series[Number, String](
            "",
            ObservableBuffer(XYChart.Data[Number, String](position, ""))
          )
        )
      }
    }
  }, 5000, 500)

  override def stopApp() = {
    scheduler.cancel()
  }

  val timeText = new Text()
  val positionText = new Text()

  val playButton = new Button(text = "Play/Stop") {
    onAction = (e: ActionEvent) =>
      if (mediaPlayer.status.value == PLAYING) {
        mediaPlayer.pause()
      } else {
        mediaPlayer.play()
        mediaPlayer.seek(Duration.valueOf("18m"))
      }
  }

  val imageView = new ImageView() // 縮小画像確認用

  // 左右の位置表示用のグラフを描画する
  val axis = new NumberAxis {
    lowerBound = -1
    upperBound = 1
    autoRanging = false
    tickUnit = 1
  }
  val category = new CategoryAxis()
  val axisValue = ObservableBuffer(XYChart.Data[Number, String](0, ""))
  val bar = new BarChart(axis, category) {
    title = "Control"
    prefHeight = 20
    data = ObservableBuffer(
      XYChart.Series[Number, String](
        "",
        axisValue
      )
    )
  }

  // メインウインドウ
 stage = new PrimaryStage {
    title = "AutoPilot"
    width = 1000
    height = 1000
    scene = new Scene {
      stylesheets.add(this.getClass.getClassLoader.getResource("style.css").toString)
      content = new VBox {
        children = Seq(
          timeText,
          mediaView,
          bar,
          playButton,
          positionText,
          imageView
        )
      }
    }
  }
  
}

自動運転?

さて、ドキドキしながら動画を再生した結果は?

View post on imgur.com
i.imgur.com

判定結果が見やすいように、左右はオレンジのバーで表示しています。
それっぽい動きをしています! 途中、対向車に突っ込もうとしてるけど!

左下の表示は、実際にニューラルネットワークに入力している画像です。

View post on imgur.com
i.imgur.com 続いてこちらは教師データに多い、木の間を走る区間です。
ピッタリ意図したとおりの出力結果になっています。

View post on imgur.com
i.imgur.com こちらは教師データに含めなかった区間です。
今までの学習を元に、未知のデータについてもわりと分類できています。

View post on imgur.com
i.imgur.com 別の道でも試してみました。
画像そのものではなく道路と言うものの特徴(ってどんなのだ?)を学習しているため、
問題なく左右の判定ができています。途中、民家に突っ込もうとしてるけど!

それなりに判定できている箇所を載せましたが、全体では7,8割くらいは意図した通りの出力になりました。
たまに崖下ダイブを企てたり、交差点で左右に迷いだしたりする仮免以下の運転ですが、とりあえず作ったものでもそれなりに動くので正直驚きました。

おわり

Deeplearning4jを使ってとりあえず画像認識する手順を紹介しました。
機械学習は敷居が低くなくて、中でもDeeplearning4jは手軽じゃない方なのでとっつきにくいですが、
自分で作ったモデルが勝手に判定をする様子は見てて面白いです!

この分野は考えればさまざまなことに活用できると思うので、何かしらの機械学習ライブラリ、
もしくはH2O(https://www.h2o.ai/)のようなプログラミングが要らないアプリが使えるように勉強するのは良いかもしれません!

ここまで読んでいただいてありがとうございます。