GANからWasserstein GANへ

generative adversarial network(GAN)からWasserstein generative adversarial network(WGAN)への道の整理をします。 こちらを参考にしました:

目次

Kullback–Leibler Divergence (KL divergence) と Jensen–Shannon Divergence (JS divergence)

まず、確率密度関数の類似度をはかる2つの指標を導入します。

Kullback–Leibler Divergence

2つの確率密度関数p(x)q(x)を考えます。KL divergenceはpがqからどれだけ異なるか、をはかる指標です。

KL Divergence \displaystyle{
D _ {KL}(p||q) = \int _ x p(x) \log \frac{p(x)}{q(x)} dx
}

KL divergenceの性質

  • KL divergenceはpとqに対して非対称(D _ {KL}(p||q) \neq D _ {KL}(q||p))です。すなわち、距離として使うことはできません。
  • p(x)がほぼ0で、q(x)が0でない場所ではqの影響が無視されます。

Jensen–Shannon Divergence

JS divergenceは2つの確率密度関数の類似度をはかるもう一つの指標です。また、範囲は[0,1]です。

JS Divergence \displaystyle{
D _ {JS}(p||q) = \frac{1}{2} D _ {KL} \left( p||\frac{p+q}{2} \right) +  \frac{1}{2} D _ {KL} \left( q||\frac{p+q}{2} \right)
}

JS divergenceはpとqに関して対称です。GANではこちらのJS divergenceによって p _ {generator}(x)p _ {data}(x)の類似度を測ります。

GAN

GANは、現実のデータ集合が与えられたとき、それらに似たデータを生成することを目指します。

GANは2つのモデルからできています。

  • Discriminator(D) : Discriminatorは入力されたサンプルが現実のデータかどうかを識別する、2値分類器です。
  • Generator (G) : Generatorはノイズz \sim p(z)を入力として受け取り、人工的なデータを出力します。その際、現実のデータの分布と似た分布を学習します。つまり、Discriminatorを騙すような(人工的なデータではあるが、現実のデータだと識別させるような)データを生成することを目指します。

これらの2つのモデルが互いを見抜く・騙すように訓練されて、十分学習が進めばGeneratorが現実のデータと見分けがつかないようなデータを生成できるようになる、というわけです。欲しいのは良いGeneratorです。

f:id:daiki-yosky:20190424180631p:plain
GANの概観

ここで、

  • p _ z : ノイズzの分布(一様分布を使うことが多いです)
  • p _ g : Generatorが生成するデータの分布
  • p _ r : 現実のデータの分布
  • D(x) : Discriminatorが、入力されたデータxを実際のデータだと判断する確率
  • G(x) : Generatorが、入力されたノイズzから生成するデータ

とします。

GANの目的関数

まず、Discriminatorは現実のデータを正しく本物だと識別してほしいです。つまり、 $$ \mathbb{E} _ {x \sim p _ r(x)} \left[ \log D(x) \right] $$ を最大化したいです。一方で、Generatorが生成したデータ G(z)を正しく偽物だと識別して欲しいので、 $$ \mathbb{E} _ {z \sim p _ z(z)} \left[ \log \left( 1 - D(G(z) \right) \right] $$ を最大化して欲しいです。

次に、Generatorに関しては生成したデータをDiscriminatorが本物だと誤分類させたいので、 $$ \mathbb{E} _ {z \sim p _ z(z)} \left[ \log \left( 1 - D(G(z) \right) \right] $$ を最小化したいです。

これらを組み合わせると、以下のようなmin-max lossになります。

Loss of GAN $$ \min_G \max_D L(D,G) = \mathbb{E} _ {x \sim p _ r(x)} \left[ \log D(x) \right] + \mathbb{E} _ {z \sim p _ z(z)} \left[ \log \left( 1 - D(G(z) \right) \right] \\ = \mathbb{E} _ {x \sim p _ r(x)} \left[ \log D(x) \right] + \mathbb{E} _ {x \sim p _ g(x)} \left[ \log \left( 1 - D(x) \right) \right] \tag{1}$$

密度比推定との関連

Discriminatorの学習は、密度比推定と深い関係があります。密度比とは、2つの確率密度関数( p _ r(x)p _ g(x))の比で、

$$ r(x) = \frac{p _ r(x)}{p _ g(x)} $$

です。

密度比を推定する方法

現実のデータ集合に仮にラベル+1を割り当て、Generatorが生成したデータに仮にラベル-1を割り当てることにします。 この時、ラベルがgivenという条件下のもとでデータの分布を表すことができて、

$$ p _ r(x) = p (x | y = +1) $$ $$ p _ g(x) = p (x | y = -1) $$

です。

密度比は、ベイズの定理から

\displaystyle{
r(x) = \frac{p _ r(x)}{p _ g(x)} \\
= \frac{p (x | y = +1)} {p _ g(x) = p (x | y = -1) } \\
= \frac {p(y=-1)p (x , y = +1)} {p(y=+1)p (x , y = -1)}  \\
= \frac {p(y=-1)p (y = +1|x)} {p(y=+1)p (y = -1|x)}
}

となります。p (y = +1|x)p (y = -1|x)は任意の2値分類器で求めることができて、それはまさにDiscriminatorです。\frac{p(y=-1)}{p(y=-1)}はデータ数の比で近似出来ます。 Discriminatorの損失にはBinary Cross Entropyを用いればよくて、それを変形すると(1)の目的関数になります。 つまり、結果としてDiscriminatorの学習はp _ r(x)p _ g(x)の密度比を推定するように行われることになります。


似たようなことは以下の論文にも記述されています。

[1610.02920] Generative Adversarial Nets from a Density Ratio Estimation Perspective

こちらはDiscriminatorが密度比推定を行なっていることに注目し、f-divergenceを最小化するGANを提案しています。

Discriminatorの最適解

先ほどの目的関数を最大化するDiscriminatorの最適解をまず求めてみます。 L(G,D)は期待値の部分を書き直せば

$$ L(G,D) = \int \left( p _ {r} (x) \log(D(x)) + p _ {g}(x) \log(1 - D(x)) \right) dx $$

とかけます。今我々の興味はL(G,Dを最大化するようなD(x)なので、

$$ \hat{x}=D(x), A = p _ r(x), B = p _ {g}(x) $$とおきます。

すると、

$$ f(\hat{x}) = A\log \hat{x} + B \log (1- \hat{x}) $$ とかけて、\hat{x}について微分すれば

$$ \frac{d f(\hat{x})}{d\hat{x}} = \frac{A-(A+B)\hat{x}} {\hat{x} (1- \hat{x})} $$ となります。これを0とおくと、最適なD(x)

$$ D^{\ast}(x) = \frac{A}{A+B} = \frac{ p _ r(x)} { p _ r(x)+p _ {g}(x)} $$

になります。

さらに、Generatorが最適に学習すれば、p _ gp _ rに近しいものになり、p _ g = p _ rのような状況では D^{\ast}(x) = \frac{1}{2} になります。これは、完璧なGeneratorができれば、Discriminatorはもはや機能しなくなる、ということです。

What is global optimal?

DiscriminatorとGeneratorが最適な学習をするとp _ g = p _ rD^{\ast}(x) = \frac{1}{2}になることは上で確認しました。 この時、GAN のlossは、

$$ L(G^{\ast}, D^{\ast}) = \int \left( p _ {r} (x) \log(D^{\ast}(x)) + p _ {g}(x) \log(1 - D^{\ast}(x)) \right) dx \tag{2}\\ = \log \frac{1}{2} \int p _ {r} (x) dx + \log \frac{1}{2} \int p _ {g} (x) dx = -2 \log 2 $$

なお(2)は  L(G, D^{\ast})に対応します。

GANの目的関数が意味すること

p _ rp _ gの間のJS divergenceは、

\displaystyle{
D _ {JS}(p _ r||p _ g) = \frac{1}{2} D _ {KL} \left( p _ r||\frac{p _ r+p _ g}{2} \right) +  \frac{1}{2} D _ {KL} \left( p _ g||\frac{p _ r+p _ g}{2} \right) \\
= \frac{1}{2} \left(  \int p _ r(x) \log \frac{2p _ r(x)}{p _ r(x)+p _ g(x) }dx    \right)   +   \frac{1}{2} \left(  \int p _ g(x) \log \frac{2p _ g(x)}{p _ r(x)+p _ g(x) }dx    \right) \\
= \frac{1}{2} \left( \int p _ r(x) \log 2 dx +  \int p _ r(x) \log \frac{p _ r(x)}{p _ r(x)+p _ g(x) }dx    \right)   +   \frac{1}{2} \left(   \int p _ g(x) \log 2 dx + \int p _ g(x) \log \frac{p _ g(x)}{p _ r(x)+p _ g(x) }dx    \right) \\
= \frac{1}{2} \left( \log 4 + L(G, D^{\ast}) \right)
}

と変形できて、

 L(G, D^{\ast}) = 2  D _ {JS}(p _ r||p _ g) - 2 \log 2

と表せます。 この式から、Discriminatorが最適である時、GANの目的関数 L(G, D^{\ast})p _ dp _ gの間のJS divergenceを定量化します。なお、Generatorが最適である時、JS divergenceは0になって、 L(G^{\ast}, D^{\ast})=-2\log 2と一致します。

GANの問題点

  • ナッシュ均衡を達成するのが困難

  • low dimensional supports

  • 勾配消失

  • mode collapse

  • 適切な評価指標が存在しない

Wasserstein GAN (WGAN)

Wasserstein distance

Wasserstein distanceとは、JS divergenceと同じように2つの確率密度関数の距離をはかる指標です。Wasserstein distanceはEarth Mover's distanceとも呼ばれ、短くEM distanceと呼ばれることもあります。

Wasserstein distanceは、ある確率密度関数を動かしてもう一つの確率密度関数に一致させるときの最小コストです。 以下では、確率密度を「土」として表現し、「土」の最適な輸送としてWasserstein distanceを考えます。

2つの確率密度関数 p _ r p _ gのWasserstein distanceは以下のように与えられます。

Wasserstein distance $$ W(p _ r, p _ g) = \inf _ {\gamma \sim \Pi(p _ r, p _ g)} \mathbb{E} _ {(x,y) \sim \gamma} \left[ ||x - y || \right] $$

infは下限で、wasserstein distanceを求めること自体が最適化問題になっています。

\gamma(x,y)p _ rある地点xからp _ gある地点yに動かす土の量です。正確には地点xから、全土の量\int p _ r(x) dxのうちどれだけを地点yへ輸送するか、という量です。 土を動かし、 p _ r p _ gに一致させることから、直ちに

\displaystyle{
\sum _ {x} \gamma(x,y) = p _ g(y)
}

が成り立ちます。(地点yへ動かされた土の量をxについて和をとると動かし終わった土の量p _ g(y)と一致するはず)

逆に、

\displaystyle{
\sum _ {y} \gamma(x,y) = p _ r(x)
}

も成り立ちます。(地点xから動かされた土の量をyについて和を取るともともとxにあった土の量p _ r(x)と一致するはず)

土の量に動かす距離||x-y||をかけることでコスト\gamma(x,y)||x - y||を算出します。 全てのx,yについてコストの平均をとると、

\displaystyle{
\sum _ {x,y} \gamma(x,y) ||x - y|| = \mathbb{E} _ {x,y \sim \gamma} ||x-y||
}

候補となる土の動かし方戦略\gammaのうち、総コストがもっとも小さいものをとればwasserstein distanceが求まります。

Wasserstein GAN がJS divergenceとKL divergenceよりも良い理由

確率密度関数が低次元かつ2つの確率密度関数に重なりが場合でもWasserstein distanceはより滑らかな表現を提供してくれます。 例えば、以下のような2つの2次元の確率密度PQを考えます。Pのx成分は0に固定し、y成分は[0,1]の一様分布に従います。一方でQのx成分は\thetaに固定しyは[0,1]の一様分布に従います。

f:id:daiki-yosky:20190424220931p:plain
PとQの概観

\theta \neq 0の時

  • \displaystyle{D _ {KL}(P||Q) = \int _ {x=0, y \sim U(0,1)} P \log \frac{P}{Q} dxdy =  \infty}
  • \displaystyle{D _ {KL}(Q||P) = \int _ {x=\theta, y \sim U(0,1)} Q \log \frac{Q}{P} dxdy =  \infty}
  • \displaystyle{D _ {JS}(P,Q) =  \frac{1}{2} D _ {KL} \left( P||\frac{P+Q}{2} \right) +  \frac{1}{2} D _ {KL} \left( Q||\frac{Q+P}{2} \right)  = \log 2}
  • \displaystyle{ W(Q||P) = |\theta|}

一方\theta = 0の時、PとQはx=0で完全に重なっていて、

  • \displaystyle{D _ {KL}(P||Q) ={D _ {KL}(Q||P)} = {D _ {JS}(P,Q)}} = 0
  • \displaystyle{ W(Q||P) = 0 = |\theta|}

このように、KL divergenceは2つの確率密度に重なりがない場合\inftyに発散してしまいます。 JS divergenceは\theta=0で突然ジャンプし、微分不可能になってしまいます。 Wasserstein distanceは\thetaの変化に対して滑らかで、勾配降下法で学習する場合に安定すると考えられます。

GANの損失としてのWasserstein distance

Wasserstein distanceはKantorovich-Rubinstein双対性を使って、

$$ W(p _ r, p _ g) = \frac{1}{K} \sup _ {||f|| _ {L} \leq K} \mathbb{E} _ {x \sim p _ {r}} [f(x)] - \mathbb{E} _ {x \sim p _ {g}} [f(x)] $$

と変換することができます。

Lipschitz 連続性

Wasserstein distanceのfには、||f|| _ {L} \leq Kという制約がついています。つまりfはK-リプシッツ連続である必要があります。 関数f:\mathbb{R} \to \mathbb{R}は以下の条件を満たす時にK-リプシッツ連続です。

ある定数K \geq 0が存在して、全てのx _ {1}, x _ {2} \in \mathbb{R}に対して、 $$ |f(x _ {1} - f(x _ {2})| \leq K |x _ {1} - x _ {2}| $$ これは直感的には、任意の区間で傾きがある値Kで抑えられるということを意味します。(Kはリプシッツ定数と呼ばれます)

任意の場所で微分可能な関数はリプシッツ連続です。なぜなら\frac{|f(x _ {1} - f(x _ {2})|}{|x _ {1} - x _ {2}|}にはboundが存在するからです。 しかし、リプシッツ連続だからと言って任意の場所で微分可能である訳ではありません。例えば、f(x)=|x|は原点で微分不可能です。

Wasserstein loss

f _ wがパラメータwをもつK-リプシッツ関数とします。Wasserstein GANでは、Discriminatorは良いf _ wを求めます。WGANの損失としては p_r (現実のデータの分布)とp _ g (Generatorが生むデータの分布)間のWasserstein distanceを採用します。つまり、学習が進むにつれてGeneratorは現実のデータの分布に近いデータの分布を出力できるようになります。

Loss of Wasserstein GAN $$ L(p _ {r}, p _ {g}) = W(p _ {r}, p _ {g}) = \max _ {w \in W} \mathbb{E} _ {x \sim p _ {r}} [f _ {w}(x)] - \mathbb{E} _ {z \sim p _ {z}} [f _ {w}(g _ {\theta}(z))] $$

\inf\maxで近似されています。

WGAN全体としては、こちらのLossを最小化することを目指します。

ここで重要なのが、f _ wのK-リプシッツ性を維持する方法です。簡単かつ強力な方法として、重みwを更新した後、wを[-0.01, 0.01]といった小さな範囲でクリップします。 それにより、パラメータ空間Wは小さくなり、f _ wの傾きはboundで抑えられます。WGANの著者らは、clipingよりも良いK-リプシッツ性を維持する方法があるはずだ、とも述べています。

Wasserstein GANの学習

Wasserstein lossのGeneratorのパラメータ  \theta に関する微分は、

$$ \frac{\partial}{\partial \theta} L(p _ {r}, p _ {g}) = \frac{\partial}{\partial \theta} - \mathbb{E} _ {z \sim p _ {z}} [f _ {w}(g _ {\theta}(z))] $$

であり、こちらはサンプル近似によって

$$ \frac{\partial}{\partial \theta} - \mathbb{E} _ {z \sim p _ {z}} [f _ {w}(g _ {\theta}(z))] = \frac{1}{M} \sum_{m=1}^{M} \frac{\partial}{\partial \theta} - f _ {w}(g _ {\theta}(z_m)) $$

と近似できます。 Mはバッチサイズです。

よって、WGAN全体の学習は

  1. Discriminatorのパラメータ{\bf w}に関して、WGANのLossを微分し、Wasserstein distanceの良い近似を求めるように{\bf w}を更新する

  2. Discriminatorのパラメータ{\bf w}のリプシッツ連続性を保つため、クリッピングを行う

  3. Generatorのパラメータ  \theta に関して、Lossを微分し、Wasserstein distanceを小さくするように \theta を更新する

以上を繰り返します。