GAN(敵対的生成ネットワーク)のアルゴリズムを丁寧に解説!

こんにちは!
IT企業に勤めて、約2年間でデータサイエンティストになったごぼちゃん(@XB37q)です!

このコラムでは、GAN(敵対的生成ネットワーク)のアルゴリズムを紹介しています。
ディープラーニングを中心とした現代のAI技術は、予測や分類だけでなく、新たなデータを生成するといったことも可能になってきています。
生成するアルゴリズムの代表であるGANについて、アルゴリズムの紹介していきます!

GANの概要

GAN(Generative Adversarial Network)は敵対的生成ネットワークと呼ばれ、ディープラーニングを使用した教師無し学習です。
Generator(生成器)とDiscriminator(識別器)の2つのネットワークから構成され、両者のネットワークが互いに対立しながら学習を実施するアルゴリズムです。

Generator(生成器)は、潜在変数から本物データに似た偽データ(画像や音声など)を生成します。
Discriminator(識別器)は、本物データとGenerator(生成器)が生成した偽データの2種類のデータに対して、本物データかどうかの確率を出力します。
GANでは、2つのネットワークの学習を進め、最終的にGeneratorが本物データと区別がつかない偽物データを生成できるようにすることを目指します。

GANの使われ方

GANから発展した様々なアルゴリズムが考えられ、それらのアルゴリズムを使用して、画像の生成や変換、演算などを行うことが可能です。

高解像度の画像生成

画像を生成することによって、少ない画像データを水増しします。

例)医療系の画像など入手しにくいものをGANで生成して訓練データとして用いる

文章からの画像起こし(StackGAN)

絵の特徴を文章で語っただけで画像に変換します。

画像の演算

生成したイメージをベクトル演算します。

例)「眼鏡をかけた男」と「眼鏡をかけない男」の差分を求め、それを「眼鏡をかけない女」にベクトル加算すると「眼鏡をかけた女」が生み出される

ビデオ予測

画像をもとに、その数秒先までの動画を予測します。

例)自動運転車のモニターに着けることで、歩行者や自転車の動きを予測する

識別モデルと生成モデル

識別モデルとは

識別モデル
識別モデル

識別モデルはクラスの境界を学習するモデルです。
与えられたデータxが生成される確率は考慮せずに、境界面のみを学習します。
データxが与えられると、そのxがクラスyに属する確率が分かります。
上記の数理モデルから、クラスyに属する確率が1番高いyを予測値として採用します。

生成モデルとは

生成モデル
生成モデル

データを学習して、あるクラスに属する擬似的なデータを作ることが可能です。
生成モデルの学習の目標は、単純な確率分布で表現される潜在変数から、ある傾向を持ったデータへの変換を実施することです。
潜在変数とは直接観測できない値であり、データはある確率分布に従っている潜在変数から生成されているものと考えます。
全てのデータを表現しており、潜在変数を決定するとデータを生成することが可能です。
データはある数理モデルから生成されたと考え、その生成過程を数理的にモデル化したもの
数理モデルは確率分布によって表される
データxがクラスyに属する確率p(y|x)をベイズの定理を使用してモデル化する
p(y|x)をモデル化するのではなく、ベイズの定理を使用してp(x|y)p(y)をモデル化する
学習時(分類時)には分母p(x)は不要なため、分母を無視したものをモデル化する
p(x|y):データxがクラスyから生成される確率、この関数を使用してデータを生成する
p(y):データ全体でのクラスの割合(クラス1のデータが60%ならp(y)=0.6)

GANの学習方法

学習の全体像

GANの学習構造
GANの学習構造

GANでは、GeneratorとDiscriminatorが交互に学習することによって、本物のデータに近いデータを生成します。

Generatorの目的は、Discriminatorが本物データと間違う生成データを生成することです。
Discriminatorの目的は、本物データと生成データを見分けることです。
GANの学習は2つのネットワークを交互に行うが、どちらも先に学習させてもよいのが特徴です。
また、損失関数はともに、交差エントロピー(2つの分布を使用する関数)を用い、パラメータの更新は確率的勾配降下法で行います。

Discriminator(識別器)の学習

Discriminator(識別器)の学習
Discriminator(識別器)の学習

本物データと生成データを見分けるように学習します。
「Generatorが生成した生成データ」と「本物データ」を学習の入力データとして使用します。
Discriminatorの学習では、Generatorは生成のみを行い、学習は行いません。

Generator(生成器)の学習

Generator(生成器)の学習
Generator(生成器)の学習

Discriminatorが本物データと間違う生成データを生成するように学習します。
潜在空間からランダムサンプリングした「潜在変数」を学習の入力データとします。
Generatorの学習では、Discriminatorは識別のみを行い、学習は行いません。

損失関数

GANの損失関数
GANの損失関数

損失関数とは、モデルの精度を評価するための予測と実績値のずれの大きさを表す関数です。

Generatorの学習の目的は、偽物データを本物データとして、Discriminatorをだますことのため、損失関数を小さくすることが重要です。
そのため、D(G(z))=1であることが望ましいです。
本物データと似ているデータを生成できるようになると、 Discriminator がうまく判別できなくなるため、D(G(z))が大きくなり、log(1- D(G(z)))は小さくなります。

Discriminatorの学習の目的は、Generatorの出力と本物データを正しく見分ける事のため、損失関数を大きくすることが重要です。
そのため、D(G(z))=0 かつD(x)=1であることが望ましいです。
本物データだと判別可能な場合はD(x)が大きくなり、logD(x)も大きくなるため、偽物データだと判別可能な場合はD(G(z))が小さくなり、log(1- D(G(z)))は大きくなります。

また、損失関数はLogをとることにより、マイナス値に近づくほど、急激に小さくなるため学習スピードが速くなります。
交差エントロピーが分類問題やニューラルネットワークで使用されることが多い理由と同じです。正解データは1、偽物データ0は無視できます。

学習時の注意点

GANの学習を行う場合、注意する点が2点存在します。

mode collapse現象
mode collapse現象

1つ目は、学習の不安定さです。
GANが持つ構成の複雑さから発生し、生成されるデータに偏りが生じて、1種類のデータしか生成しなくなる「mode collapse」現象があります。
「mode collapse」発生の予見は難しく、回避するにはひたすら試行錯誤を重ね、パラメータやネットワーク構成の見直しが必要になります。
GeneratorがDiscriminatorが苦手とする特定のパターンを集中的に生成する方向に学習してしまう(Discriminatorが苦手を学習した場合、Generatorはもう一方のパターンを集中的に生成することでだまそうとする)ことにより、「mode collapse」が発生してしまいます。

生成データに対する評価の難しさ
生成データに対する評価の難しさ

2つ目は、生成データの評価が難しさです。
分類問題や将来の予測が対象であれば、定量的にその学習結果を評価することは可能ですが、生成モデルの場合は、生成された類似データを定量的な評価指標を用いて評価すること(何をもって「よい類似」とみなすか)は難しいです。
そのため学習データに精通した者が、実際に生成された画像を見て学習結果を評価することが望ましい精度検証となります。

まとめ

  • GAN(Generative Adversarial Network)
    • 敵対的生成ネットワーク
    • ディープラーニングを使用した教師無し学習
    • Generator(生成器)とDiscriminator(識別器)の2つのネットワークから構成される
    • 2つのネットワークを互いに対立させながら学習を実施する
    • Generatorが本物データと区別がつかない偽物データを生成できるようにすることを目指す
  • GANの注意点
    • 学習が不安定
    • 生成データに対する評価が難し
  • Generator(生成器)
    • 潜在変数から本物データに似た偽データ(画像や音声など)を生成する
    • D(G(z))=1であることが望ましい
    • Discriminatorが本物データと間違う生成データを生成するように学習
  • Discriminator(識別器)
    • 本物データとGenerator(生成器)が生成した偽データの2種類のデータに対して、本物データかどうかの確率を出力する
    • 本物データと生成データを見分けるように学習
    • D(G(z))=0 かつD(x)=1であることが望ましい