Kaggleでも使われる!弱学習器を組み合わせるアンサンブル学習

こんにちは!
IT企業に勤めて、約2年間でデータサイエンティストになったごぼちゃん(@XB37q)です!

このコラムでは、アンサンブル学習と呼ばれるAI手法を紹介します!
精度を向上させるために考えられた少し人間味のあるAIのため、考えた人は面白い発想だなと感じています。

アンサンブル学習とは

アンサンブル学習とは
アンサンブル学習とは

アンサンブル学習とは、弱学習器と呼ばれる分析手法を複数個組み合わせてAIモデルを構築する方法です。

「3人集まれば文殊の知恵」とあるように、1個の手法では弱くてもたくさんの分析手法を組み合わせれば強くなるだろうという考え方です!

弱学習器とは

弱学習器とは
弱学習器とは

弱学習器とは、単体では精度があまり出ないAIの分析手法のことです。
弱学習器と呼ぶ分析手法の定義は存在しません。
一般的には、構造がシンプルであり、計算量が少ない分析手法が弱学習器として呼ばれることが多いです。

決定木分析や単純なニューラルネットワークが弱学習器の代表です。

アンサンブル学習 バギング

バギング
バギング

アンサンブル学習には、バギングと呼ばれる方法があります。
バギングとは、並列に弱学習器を構築し、それらを合わせた結果を出力する学習方法です。
バギング(bagging)は、「bootstrap(ブートストラップ)」 と 「aggregating(集約)」から成り立っています。

バギング bootstrap(ブートストラップ)

ブートストラップ(bootstrap)
ブートストラップ(bootstrap)

ブートストラップとは、N個のデータから重複を許してランダムにn回データを選ぶことです。
少しだけ違うデータから、少しだけ違う弱学習器を作ることが可能です。

構築した弱学習器の違いは、使っているデータだけです。
そのため、弱学習器同士に相関が高くなることが起きる可能性が高く、あまり精度が向上しない可能性もあると言われています。
例えば、Nで100のデータのうち、ランダムに95回のデータを選んで作った決定木同士は相関が高くなる可能性が高いということです。

バギング aggregating(集約)

集約(Aggregating)
集約(Aggregating)


そして、それぞれの弱学習器の結果を集約して、結果を出力します。
並列に弱学習器を構築するため、モデル構築をするスピードは速いです。
しかし、精度の面ではブースティングより劣ることが多いと言われています。

ランダムフォレスト

弱学習器に決定木分析を使ったバギングにランダムフォレストが存在します。
上記バギングの特徴を持ちながら、分析に使用する説明変数をランダムに選定することで、過学習を防ごうとしています。

アンサンブル学習 ブースティング

ブースティング
ブースティング

アンサンブル学習には、ブースティングと呼ばれる方法があります。
ブースティングとは、直列に弱学習器を構築し、それらを合わせた結果を出力する分析手法です。
構築した弱学習器の弱点を克服するように、次の弱学習器を構築するため、バギングより精度が良いと言われています。

しかし、並列に弱学習器を構築しないため、学習に時間がかかると言われています。
また、弱学習器の数を増やしすぎると過学習する傾向にあります。

ブースティングの学習手順

ブースティングの学習の流れ
ブースティングの学習の流れ

ブースティングは誤差を学習するようにモデルを構築すると紹介しましたが、正確にはモデルを構築する度に、学習データに変化を与えています。
つまり、データに変化を与えることが、モデルをレベルアップして構築することになります。

モデルを構築していく手順は大きく区別して6つに区別できます。
5つ目の手順の「データの重みを更新する」ことが、「データに変化を与える」ということになります。

  1. 重みを初期化する
  2. モデルの数を設定する
  3. 誤差が最小になるようにモデルを構築する
  4. モデルの信頼度を計算する
  5. データの重みを更新する
  6. モデルの完成

ブースティングでは、この手順の3~5を繰り返しながら、学習を行います。

それぞれ数式を交えながら、紹介していきましょう!

初期値の設定

初期値の設定
初期値の設定

まずは、全てのデータを同様に扱うため、重みを初期化します。(1)

次に、モデルを学習させる回数、つまり構築するモデルの数を設定します。(2)

誤差が最小になるようにモデルを構築する

誤差が最小になるようにモデルを構築する
誤差が最小になるようにモデルを構築する

まずは誤差が最小になるように、1つの学習モデルを構築します。

I(y_m(x_i)≠t_i)は、0か1を出力する関数であり、正しく分類できた場合は0、正しく分類できなかった場合は1を出力します。
つまり、正しく分類できなかったデータが多い場合は、誤差E_Mが大きくなります。

モデルの信頼度を計算する

モデルの信頼度を計算する
モデルの信頼度を計算する

次に、モデルの信頼度を計算します。
信頼度は先ほどの誤差E_Mを使用して計算を行います。

モデルの信頼度が、モデルの予測値をどれだけ信頼するかといった値です。

データの重みを更新する

データの重みを更新する
データの重みを更新する

モデルの構築と信頼度の計算が終わった後に、データの重みを更新します。
ここで、正しく分類できたデータは重みを更新せず、正しく分類できなかったデータに対して重みを更新します。
つまり、誤差を無くすようにモデルを構築していくために、データに変化を加えていくのです。

そして、(3)~(5)を繰り返して、モデルを構築していきます。

モデルの完成

ブースティングの完成
ブースティングの完成

モデルをすべて作成し終えた段階で、ブースティングの学習は終了です。
複数のモデルを合わせた最終的な予測値をY_mとすると、Y_mはsin(全てのモデルの信頼度×モデルの予測値)になります。

また、sinは0か1を出力する関数であり、以下を表します。

  • sinは、0より大きければ+1
  • sinは、0より小さければ-1

この式からもわかるとおり、信頼度が高いモデルの予測値が重視される傾向にあります。

XGBoostとLightGBM

ブースティングを使うためのpythonライブラリには、XGBoostとLightGBMが有名です。

XGBoostは決定木の葉を同時に展開していく手法です。
LightGBMは1つの幹を完璧に展開していくため、早く構築できる手法です。

詳しくは他コラムで紹介しています!

まとめ

この投稿では、アンサンブル学習と呼ばれるAI手法を紹介しました!

  • アンサンブル学習
    • 弱学習器を組み合わせる方法
  • 弱学習器
    • 単体では精度が高くなりにくい分析手法
    • 決定木分析、ニューラルネットワークが代表
  • バギング
    • 並列に弱学習器を構築
    • 学習速度が速い
    • 精度が低い
  • ブースティング
    • 直列に弱学習器を構築
    • 学習速度が遅い
    • 精度が高いが、過学習する傾向がある