JKになりたい

何か書きたいことを書きます。主にWeb方面の技術系記事が多いかも。

PPOにまつわる備忘録

何の記事か

最近、HuggingFaceの強化学習チュートリアルをみてたんですよ。

で、PPOのこの最後の目的関数の意味がわからなかったんですよね。

これ。

 L_t^{CLIP+VF+S}(\theta) = \hat{E_t} \Big[ L_t^{CLIP}(\theta) - c_1 L_t^{VF}(\theta) + c_2 S\big[\pi_{\theta}(s_t) \big] \Big] (1)

各項は「クリップされた代理方策目的関数 - 価値関数の目的関数 + エントロピーボーナス」となっています。

本記事はこれを理解するための備忘録です。ただ、多分色々解釈間違ってるんでご指摘いただけると嬉しいです。


なぜ1つの目的関数でActorとCriticを更新できるの?

一般的なActorCriticでは、

Actorは  \Delta_{\theta} J(\theta)= E\big[ \sum_{t=0}^{T}  (R_t + \gamma V_w (S_t + 1) - V_w(S_t)) \Delta_\theta  \log{\pi_{\theta}}(A_t | S_t)   \big] (2)

Criticは \Delta_w =  R_t + \gamma V_w(S_{t+1}) - V_w(S_t) (3)

で最適化しましょう、となっていました。それぞれ目的関数が定義され、それぞれ最適化します。 (Criticで使われている上式はAdvantage Functionと呼ばれ、この記事の後でも取り扱います)

しかし(1)は、(2)(3)が一つの目的関数に結合されているような見た目になっていますね。

これで問題ないのは、偏微分される対象がActorとCriticで異なるため、片方の項について勾配がゼロになるためです。

更に、同じにしておくことでActorとCriticのネットワークを共有するような実装であっても目的関数の定義がそのまま使えるという実装上のメリットがあります。

ActorとCriticのネットワークを共有したいときって?

例えば、ピクセルを直接Stateとして入力に用いる際に前段のレイヤーとしてCNNを用意することはよくあります。

このとき、CNNのパラメータをActorとCriticでそれぞれ学習させたとしても似たようなものになると考えられます。そしたら、一緒でいいじゃん、みたいな。

実際、A3CではActorとCriticでネットワークを共有するような実装になっています。

pylessons.com

これで精度が担保されるのかはよくわからないですが、少なくとも最適化するべきパラメータが減る分、計算効率は格段に改善できそうですね。

目的関数が干渉しあって精度が悪化してしまう問題は存在するので、一長一短はあります。このあたりPPGと呼ばれる改善案も提案されていたりするので、興味のある方はこの論文を読むのが良いかも。

https://arxiv.org/pdf/2009.04416.pdf

目的関数の各項について

まず最初に。

E[-]は経験的に推定された期待値で、実際は複数のエピソードからなるミニバッチの平均値をとることを意味します。 (各式にハットがついてるのは、それが推定であることを示しています)

また、この目的関数は最大化することが目標ですが、実装上ではこれに-1をかけて最小化問題として解きます。(通常DNNライブラリはSGDしか対応してないので)

(1) クリップされた代理方策目的関数

 L^{CLIP}(\theta) =\hat E_t  \big[ min(r_t(\theta)\hat A_t, clip(r_t(\theta), 1-\epsilon, 1+\epsilon)\hat A_t) \big]

ここで、Aは方策勾配法でよく用いられる「ベースライン」と呼ばれるアイデアを元にしたもので、アドバンテージ関数と呼ばれています。

 A_t = V_t^{target} - V_{\omega_{old}}(s_t)

式の通り、結果が推定値と比較してどの程度良かったかor悪かったかを表現しています。 こうすることで、分散が抑えられます。分散が抑えられるということは、サンプル効率が良くなるということですね。

rは比率関数と呼ばれ下記のように定義されます。

 r_t(\theta) = \frac{\pi_\theta (a_t | s_t)}{\pi_{\theta_{old}}(a_t | s_t)}

この値が1より大きい場合、現在のアクションが実行される可能性が高くなりますし、0~1であるならば、(古いポリシーと比較して)可能性が下がります。

比率にすることで、急激な更新を抑える役割があります。アドバンテージ関数同様、こちらもサンプル効率を向上させるための工夫ですね。

TRPOではKLダイバージェンスを利用してましたが、よりシンプルになりました。

ただし、前のポリシーと比較して、今のポリシーの方が可能性がすごく高くなる場合、大幅な勾配となってしまう恐れがあります。

最後に、これを防ぐために「クリップ」を行います。

で、最初の式になるわけですね。 原論文では比率関数は0.8 ~ 1.2の間にクリップされます。

 L^{CLIP}(\theta) =\hat E_t  \big[ min(r_t(\theta)\hat A_t, clip(r_t(\theta), 1-\epsilon, 1+\epsilon)\hat A_t) \big]

クリップされた場合は、 (1 - \epsilon)\hat A_tもしくは (1 + \epsilon)\hat A_tとなるため、勾配はゼロになります。(=方策は更新されない)

まとめると、PPOのこの部分は

(1) ポリシーのパフォーマンスを低めに見積り、慎重に更新する

(2) 破壊的な大きな勾配での更新を避ける

ための工夫といえそうです。

(2) 価値関数の目的関数

価値観数の更新にまつわる部分です。以下式の添字ωがCriticのネットワークのパラメータですね。

 L^{V} = \hat E_t \big[ (V_{\omega}(s_t) - V_t^{target})^{2}  \big]

ミニバッチに含まれる複数の学習例を平均した二乗誤差を最小化することで学習が進みます。

こちらは特に言及することもないかと。

(3) エントロピーボーナス

最初はエントロピーボーナスってなんやねん!状態でした。

これは、(Action空間が離散空間である場合に限りますが)アクションをサンプリングする際の確率を平等に割り当てる方向に報酬を与えることにより、探索範囲を広げるための項です。

 L^{S}(\theta) = \hat{E_t} \Big[ c_2 S\big[\pi_{\theta}(s_t) \big] \Big]

つまり、あるStateに対してAction空間の正規化された確率推定値(0.25, 0.25, 0.25, 0.25)と(0.8, 0.05, 0.05, 0.1)を比較すると、前者に大きな報酬を与えたい、ということです。

これを実現するために、シャノンのエントロピー(平均情報量)を用います。

情報量の話は言及しません。シャノンのエントロピーの定義式は以下の通り。

 H(X) = - \sum_i p_i \log{p_i}

これを使うと、不確実性が高いときに大きな報酬を与えることができますね。

 L^{S}(\theta) = \hat{E_t} \Big[ - c_2 \sum_m^{M} \pi_{\theta}(a_m | s_t) \log{\pi_{\theta}(a_m | s_t)} \big] \Big]

※ MはAction空間の次元数になります

参考

https://fse.studenttheses.ub.rug.nl/25709/1/mAI_2021_BickD.pdf

The 37 Implementation Details of Proximal Policy Optimization · The ICLR Blog Track

第6回 今更だけど基礎から強化学習を勉強する PPO編 #Python - Qiita

A3CでCartPole (強化学習) - どこから見てもメンダコ