stable baselines3では、模倣学習のfeatureがimitationというライブラリに移譲されることになりました。
stable-baselines3.readthedocs.io
これにより、(過渡期である事も要因であるとは思いますが)以前は非常に簡単にできていた模倣学習に一手間必要になりました。
そこで今回は、stable baselines3とimitationを使った模倣学習の実行について備忘録を残しておきたいと思います(これを書いている時点では、ドキュメントがほぼなかったので・・)
エキスパートデータの収集
言わずもかな逆強化学習ではまずエキスパートデータを用意する必要があります。 今回はatariのインベーダー環境に対して、人手で操作してエキスパートデータを生成してみたいと思います
stable baselinesの2系ではstable_baselines.gail.generate_expert_traj
を使うことで簡単に軌跡の記録ができましたが、3系ではimitationのTrajectoryデータ型として記録していきます
1つのエピソードに対して1つのTrajectoryオブジェクトとなり、複数エピソードではTrajectoryのArrayとして保存することになります。
以下は、実装例です
from imitation.data.types import Trajectory from stable_baselines3.common.atari_wrappers import * import gym import pyglet from pyglet.window import key import time import pickle def get_key_state(win, key_handler): key_state = set() win.dispatch_events() for key_code, pressed in key_handler.items(): if pressed: key_state.add(key_code) return key_state def human_expert(_state, win, key_handler): key_state = get_key_state(win, key_handler) action = 0 if key.SPACE in key_state: action = 1 elif key.LEFT in key_state: action = 3 elif key.RIGHT in key_state: action = 4 time.sleep(1.0 / 30.0) return action def main(): record_episodes = 1 ENV_ID = 'SpaceInvaders-v0' env = gym.make(ENV_ID) env.render() win = pyglet.window.Window(width=300, height=100, vsync=False) key_handler = pyglet.window.key.KeyStateHandler() win.push_handlers(key_handler) pyglet.app.platform_event_loop.start() while len(get_key_state(win, key_handler)) == 0: time.sleep(1.0 / 30.0) trajectorys = [] for i in range(0, record_episodes): state = env.reset() actions = [] infos = [] observations = [state] while True: env.render() action = human_expert(state, win, key_handler) state, reward, done, info = env.step(action) actions.append(action) observations.append(state) infos.append(info) if done: ts = Trajectory(obs=np.array(observations), acts=np.array(actions), infos=np.array(infos)) trajectorys.append(ts) break with open("invader_expert.pickle", mode="wb") as f: pickle.dump(trajectorys, f) if __name__ == '__main__': main()
キー入力を受け取る部分の実装はnpakaさんの以下の記事を参考にさせて頂きました
インベーダーが下手すぎてエキスパートとは程遠いデータを収集することができました。
BCによる学習
こちらはimitationのquickstartと同様です。
https://github.com/HumanCompatibleAI/imitation/blob/master/examples/quickstart.py
with open("invader_expert.pickle", "rb") as f: trajectories = pickle.load(f) transitions = rollout.flatten_trajectories(trajectories) ENV_ID = 'SpaceInvaders-v0' venv = util.make_vec_env(ENV_ID, n_envs=2) logger.configure(".BC/") bc_trainer = bc.BC(venv.observation_space, venv.action_space, expert_data=transitions) bc_trainer.train(n_epochs=100) bc_trainer.save_policy('space_invader_policy_v0')
保存したpolicyは以下のようにロードができます
bc_trainer = bc.reconstruct_policy("space_invader_policy_v0")
BC→順強化学習とモデルの実行
まず、BCの学習により得たポリシーを使ってインベーダーをプレイさせてみます。 こちらはポリシーのpredictを使って簡単に実行が可能です。
def main(): ENV_ID = 'SpaceInvaders-v0' env = gym.make(ENV_ID) bc_trainer = bc.reconstruct_policy("space_invader_policy_v0") state = env.reset() while True: env.render() action = bc_trainer.predict(state) state, reward, done, info = env.step(action) if done: break if __name__ == '__main__': main()
次に、このポリシーをベースに順強化学習をさせることも可能です。
エキスパートのデータを渡して、初期の探索を手助けしてあげた後、順強化学習による最適化を進めたい・・というのが直感的なニーズですが、良いのか悪いのかは正直よくわかりません。
こちらは、単純に強化学習モデルクラスの第一引数であるPolicyに先ほどのものを入れて初期化→再学習すればよさそうなんですが、素直に入れると動きません。
そこで、ハック的なテクニックですが以下のようにする必要があります。
class CopyPolicy(ActorCriticPolicy): def __new__(cls, *args, **kwargs): return bc_trainer.policy model = sb3.PPO(CopyPolicy, venv, verbose=0) model.learn(total_timesteps=128000, callback=callback)
このアプローチは以下のissueで言及されています
AIRLによる学習
こちらも、imitationのquickstartに記載されている通りで大丈夫です
https://github.com/HumanCompatibleAI/imitation/blob/master/examples/quickstart.py
最後に.gen_algo.save
でPPOモデルを保存します。
ゲームの実行にはこのモデルが必要になります。
venv = util.make_vec_env(ENV_ID, n_envs=2) logger.configure(".AIRL/") airl_trainer = adversarial.AIRL( venv, expert_data=transitions, expert_batch_size=32, gen_algo=sb3.PPO("MlpPolicy", venv, verbose=1, n_steps=1024), ) airl_trainer.train(total_timesteps=2048) airl_trainer.gen_algo.save("airl_trainer_gen_algo")
モデルの実行
PPO.loadで先ほど保存したモデルをロードしてきます。 後は特に違いはありません。
def main(): ENV_ID = 'SpaceInvaders-v0' env = gym.make(ENV_ID) model = PPO.load("airl_trainer_gen_algo") state = env.reset() while True: env.render() action = model.predict(state) state, reward, done, info = env.step(action) if done: break if __name__ == '__main__': main()