SB3 DQN サンプル - 交通流の強化学習ツール highway-env
強化学習におけるOpenAI Gymを拡張させたツールに、交通流を扱うhighway-envがある。その解説ページにあるSB3 DQNのサンプルについて説明する。
以下は2023/2時点での情報である。
SB3 DQNの意味と前処理
SB3はStable Baselines3というライブラリで、強化学習のエージェント機能を持つ。DQNは深層Q学習の略で、深層強化学習では一般的な手法である。
残念ながらSB3は、2023/1の時点で最新のGymに対応していない。よってバージョンの調整が必要になる。以下はバージョン調整を入れた前処理である。
Q学習は、Q関数と呼ばれる行動後の価値を元に学習を行うものである。DQNはそれを深層化したものである。
# Install environment and agent #!pip install highway-env ##### 修正点、全部バージョン調整 !pip install git+https://github.com/DLR-RM/stable-baselines3 !pip install 'gym==0.25.2' !pip install 'highway_env==1.6.0' # TODO: we use the bleeding edge version because the current stable version does not support the latest gym>=0.21 versions. Revert back to stable at the next SB3 release. #!pip install git+https://github.com/DLR-RM/stable-baselines3 # Environment import gym import highway_env # Agent from stable_baselines3 import DQN # Visualization utils %load_ext tensorboard import sys from tqdm.notebook import trange !pip install tensorboardx gym pyvirtualdisplay !apt-get install -y xvfb python-opengl ffmpeg !git clone https://github.com/eleurent/highway-env.git 2> /dev/null sys.path.insert(0, '/content/highway-env/scripts/') from utils import record_videos, show_videos
上記の「修正点」は私がソースを修正した箇所である。このページに書いてあるソースは、上から順に実行すると2023/1時点のColabで動作するようになっている。
TensorBoard起動
以下の処理でTensorBoardを起動する。TensorBoard上で計算途中の残差などを確認できる。
%tensorboard --logdir "highway_dqn"
学習
以下の処理で学習を行う。model.learnのパラメータを減らせば学習回数を減らして処理時間を短くできる。現在の値だとColab上で17分位かかる。
highway-fast-v0という環境は、流れの速い高速道路での車線変更である。
model = DQN('MlpPolicy', "highway-fast-v0", policy_kwargs=dict(net_arch=[256, 256]), learning_rate=5e-4, buffer_size=15000, learning_starts=200, batch_size=32, gamma=0.8, train_freq=1, gradient_steps=1, target_update_interval=50, exploration_fraction=0.7, verbose=1, tensorboard_log="highway_dqn/") model.learn(int(2e4))
テスト結果
以下の処理でテスト結果を動画で確認できる。実行するごとに結果は変わる。ある程度衝突を回避しているが、2回に1回位は衝突するように見える。
env = gym.make("highway-fast-v0") env = record_videos(env) for episode in trange(3, desc="Test episodes"): obs, done = env.reset(), False while not done: action, _ = model.predict(obs, deterministic=True) obs, reward, done, info = env.step(int(action)) env.close() show_videos()
目次
SB3 DQN サンプル