AsaHP
AI・深層学習

SB3 DQN サンプル - 交通流の強化学習ツール highway-env

強化学習におけるOpenAI Gymを拡張させたツールに、交通流を扱うhighway-envがある。その解説ページにあるSB3 DQNのサンプルについて説明する。

以下は2023/2時点での情報である。

SB3 DQNの意味と前処理

SB3はStable Baselines3というライブラリで、強化学習のエージェント機能を持つ。DQNは深層Q学習の略で、深層強化学習では一般的な手法である。

残念ながらSB3は、2023/1の時点で最新のGymに対応していない。よってバージョンの調整が必要になる。以下はバージョン調整を入れた前処理である。

Q学習は、Q関数と呼ばれる行動後の価値を元に学習を行うものである。DQNはそれを深層化したものである。

# Install environment and agent
#!pip install highway-env
##### 修正点、全部バージョン調整
!pip install git+https://github.com/DLR-RM/stable-baselines3
!pip install 'gym==0.25.2'
!pip install 'highway_env==1.6.0'
# TODO: we use the bleeding edge version because the current stable version does not support the latest gym>=0.21 versions. Revert back to stable at the next SB3 release.
#!pip install git+https://github.com/DLR-RM/stable-baselines3

# Environment
import gym
import highway_env

# Agent
from stable_baselines3 import DQN

# Visualization utils
%load_ext tensorboard
import sys
from tqdm.notebook import trange
!pip install tensorboardx gym pyvirtualdisplay
!apt-get install -y xvfb python-opengl ffmpeg
!git clone https://github.com/eleurent/highway-env.git 2> /dev/null
sys.path.insert(0, '/content/highway-env/scripts/')
from utils import record_videos, show_videos

上記の「修正点」は私がソースを修正した箇所である。このページに書いてあるソースは、上から順に実行すると2023/1時点のColabで動作するようになっている。

TensorBoard起動

以下の処理でTensorBoardを起動する。TensorBoard上で計算途中の残差などを確認できる。

%tensorboard --logdir "highway_dqn"

学習

以下の処理で学習を行う。model.learnのパラメータを減らせば学習回数を減らして処理時間を短くできる。現在の値だとColab上で17分位かかる。

highway-fast-v0という環境は、流れの速い高速道路での車線変更である。

model = DQN('MlpPolicy', "highway-fast-v0",
                policy_kwargs=dict(net_arch=[256, 256]),
                learning_rate=5e-4,
                buffer_size=15000,
                learning_starts=200,
                batch_size=32,
                gamma=0.8,
                train_freq=1,
                gradient_steps=1,
                target_update_interval=50,
                exploration_fraction=0.7,
                verbose=1,
                tensorboard_log="highway_dqn/")
model.learn(int(2e4))

テスト結果

以下の処理でテスト結果を動画で確認できる。実行するごとに結果は変わる。ある程度衝突を回避しているが、2回に1回位は衝突するように見える。

env = gym.make("highway-fast-v0")
env = record_videos(env)
for episode in trange(3, desc="Test episodes"):
    obs, done = env.reset(), False
    while not done:
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, done, info = env.step(int(action))
env.close()
show_videos()

目次

背景と概要

Getting Started

SB3 DQN サンプル

Highway Planning サンプル

Parking HER サンプル

Social Attention DQN サンプル

利用論文

AI・深層学習