less than 1 minute read

変分ベイズの手順について、調べるたびに分かった気になっては、また必要になったら調べるということを繰り返しているので、ここにメモしておく。

細かい議論というよりは雰囲気。

問題設定

次のような確率モデルを考える。

\[x \sim p_\theta(x|z)p(z)\]

ただし、各変数の意味は以下の通りである。

  • データ\(x\) 。観測されている。
  • 潜在変数\(z\) 。値は未知。

データの観測分布

\[p_\theta (x|z)\]

は、パラメトリックなモデルで、未知パラメータ\(\theta\)はデータから推定したい。また、潜在変数の分布 \(p(z)\)も未知の分布でデータから決定したいが、こちらは(当面)ノンパラメトリックに一般の分布全体から選ぶ。

最尤推定

\(z\)は観測されていないので、データの尤度は\(z\)について積分して以下で与えられる。

\[p_\theta(x) = \int p_\theta(x|z) p(z) dz\]

これを、パラメータ\(\theta\)と確率分布\(p(z)\)を動かして最大化したい。

以下、天下りにEMアルゴリズムを導出する。

下限の評価

まず、jensenの不等式(凹関数を引数の平均値で評価すると、凹関数の値の平均値よりも大きいという性質。logが凹関数に注意)から

\[\log\ p_\theta(x) \\ = \log \int p_\theta(x|z) p(z) dz \\ \ge \int p(z) \log\ p_\theta(x|z) dz\]

が言える。この値を変分下限と呼ぶ。

EMアルゴリズムでは、対数尤度の代わりに変分下限を最大化する。元の\(\log p_\theta(x)\)を最大化しようとすると、対数を取る前に確率の値そのものを積分する必要があり扱いずらそう。一方で、変分下限の式では積分の対象が「確率の対数」になっている。 だいたいの場合、確率分布は\(\mathrm{定数}\ \exp(\mathrm{何か分かりやすい式})\)になっているので、確率の対数を取ると扱いやすくなることが多い。

下限の最大化

今、仮に\(z\)の分布\(p(z)\)の候補が何か与えられているとする(適当でよい)。もし、その\(p(z)\)に対してこの変分下限の積分が解析的にできると、\(\theta\)についての最大化を行うプログラムが書ける。例えば\(z\)が離散変数で積分が和に落ちる場合などは簡単に扱うことができる。 このステップが、EMアルゴリズムのMステップである。

今、適当に選んだ分布\(p(z)\)について下限を最大化して\(\theta\)を得たが、これだけでは\(p(z)\)の選択に恣意性が残る。これを改善したい。

下限と対数尤度の差

先の下限と真の対数尤度の差を計算すると

\[\log\ p_\theta(x) - \int p(z) \log\ p_\theta(x|z) dz \\ = \log\ p_\theta(x)\int p(z) dz - \int p(z) \log\ p_\theta(x|z) dz \\ = \int p(z) \log\ p_\theta(x) dz - \int p(z) \log\ p_\theta(x|z) dz \\ = \int p(z) \log\ p_\theta(x) dz - \int p(z) \log\ \frac{p_\theta(x, z)}{p(z)} dz \\ = \int p(z) \log\ p_\theta(x) dz - \int p(z) \log\ p_\theta(x, z) dz + \int p(z) \log p(z) dz \\ = \int p(z) \log\ \frac{1}{p_\theta(z|x)} dz + \int p(z) \log p(z) dz \\ = \int p(z) \log\ \frac{p(z)}{p_\theta(z|x)} dz\]

となる。無理矢理感もあるが、最後にKL情報量が出てくる。つまり以下が成り立つ。

\[\log\ p_\theta(x) - \int p(z) \log\ p_\theta(x|z) dz = \mathrm{KL}[p_\theta(z|x) || p(z)]\]

下限の改善

上の評価から、対数尤度と変分下限の食い違いはKL情報量で与えられている。KL情報量が最小になるのは両分布が一致するときなので、

\[p(z) := p_\theta(z|x)\]

によって、\(p(z)\)を更新する。これがEステップである。

式の中で陽に表現されていないが、対数尤度\(\log\ p_\theta(x)\)やその変分下限は実は\(z\)の事前分布\(p(z)\)に依存している。つまり、\(p(z)\)を更新すると、それに対する最適な\(\theta\)も変化する。つまり、再度Mステップを実施する余地がある。

というわけで、以下、M→E→M→・・と繰り返すことができる(実際には適当なところでやめる)。

想像

上の議論は「対数尤度を最大化したい」というモチベーションが背景にあるストーリーだが、実際には、変分下限が大事な量なのだと思われる。おそらく、単純に任意の事前分布\(p(z)\)を用いて対数尤度を最大化すると、値はいくらでも大きくなりうるのではなかろうか(つまり過適合する)。

変分下限は以下のように分解できる。

\[\int p(z) \log\ p_\theta(x|z) dz = \log\ p_\theta(x) - \mathrm{KL}[p_\theta(z|x) || p(z)]\]

つまり、第二項のKLダイバージェンスが罰則項のように効いている。これによって、\(p(z)\)が変なところにいかないように歯止めが利くのだろう。

VAE

VAEでは、損失関数として前に上げた変分下限を用いる。

EMアルゴリズムとは異なり、\(p(z)\)は適当な次元の標準正規分布で固定する。代わりに、

\[p_\theta(x|z), \ p_\phi(z|x)\]

を適当な深さのニューラルネットで表現できる範囲に制限し、変分下限を最大化する重みをバックプロパゲーションで求める。

VAEの実装

データ\(x\)を入力として、適当な次元の平均・分散ベクトルを計算するネットワークを定義する(エンコーダー)。エンコーダーからの出力と正規乱数から、潜在変数の事後分布からの実現値を得る。これを入力として、データを再現するネットワークを定義する(デコーダー)。

デコーダーの出力と入力\(x\)から尤度が計算できる。エンコーダーの出力(平均・分散ベクトル)からKLダイバージェンスが計算できる。これらによって、損失関数が計算できる。

VAEの良いところ

エンコーダーの出力(平均値)ではなく、正規乱数によってそこから少し乱された値がデコーダーの入力となる。これによって、出力が近くなるべき入力が潜在変数の空間でも近くなるようにエンコーダーが学習される。

エンコーダーの出力(平均と分散)と標準正規分布とのKLダイバージェンスを小さくするように学習される。これによって、潜在変数の空間の原点近くにエンコードの結果を集めて、できるだけ空間を満たすようにエンコーダーが学習される。

タグ:

カテゴリー:

更新日時: