TF-Agentsによる強化学習
TensorFlowがバージョンアップして強化学習用のライブラリ TF-Agentsが使えるようになったようだ。fastaiは強化深層学習はサポートしないそうなので、RLがしたいときにはこれを使えば良い。
tensorflow/agents
TF-Agents is a library for Reinforcement Learning in TensorFlgithub.com
TF-Agents is a library for Reinforcement Learning in TensorFlgithub.com
ただプログラムはあまり綺麗ではなく、Pythonのバージョンも2のようだ。ChainerもRLに力を入れているようなので、比較して良い方を使うべきだろう。
SAC(Soft Actor Critic)などの新しめの手法も実装しているようで、色々比較して使いたいときにはTF-Agentsが良いだろう。
まずは色々インポートして(省略)から、定数パラメータを準備する。
env_name = 'CartPole-v0' # @param
num_iterations = 20000 # @param
initial_collect_steps = 1000 # @param
collect_steps_per_iteration = 1 # @param
replay_buffer_capacity = 100000 # @param
fc_layer_params = (100,)
batch_size = 64 # @param
learning_rate = 1e-3 # @param
log_interval = 200 # @param
num_eval_episodes = 10 # @param
eval_interval = 1000 # @param
Gymを使って倒立振子の環境を準備する。
env = suite_gym.load(env_name)
環境を観測するには以下のようにしてobservation属性をみる。
env.time_step_spec().observation
>>>
Observation Spec:
BoundedArraySpec(shape=(4,), dtype=dtype('float32'), name=None, minimum=[-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], maximum=[4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38])
観測は、台車の現在位置と速度、振子の角度と速度である。
行動は0(左方向に動かす)と1(右方向に動かす)である。
env.action_spec()
>>>
Action Spec:
BoundedArraySpec(shape=(), dtype=dtype('int64'), name=None, minimum=0, maximum=1)
Q値(状態と行動の組を与えると価値を返す関数)はニューラルネットであり、観測と行動と層パラメータを与えて構築する。
q_net = q_network.QNetwork(
train_env.observation_spec(),
train_env.action_spec(),
fc_layer_params=fc_layer_params)
最適化はAdamとし、エージェントを生成し、初期化しておく。
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
train_step_counter = tf.compat.v2.Variable(0)
tf_agent = dqn_agent.DqnAgent(
train_env.time_step_spec(),
train_env.action_spec(),
q_network=q_net,
optimizer=optimizer,
td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
train_step_counter=train_step_counter)
tf_agent.initialize()
次いで、方策を定義しておく。
eval_policy = tf_agent.policy
collect_policy = tf_agent.collect_policy
random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
train_env.action_spec())
評価値(価値関数)を計算する関数を準備する。
def compute_avg_return(environment, policy, num_episodes=10):
total_return = 0.0
for _ in range(num_episodes):
time_step = environment.reset()
episode_return = 0.0
while not time_step.is_last():
action_step = policy.action(time_step)
time_step = environment.step(action_step.action)
episode_return += time_step.reward
total_return += episode_return
avg_return = total_return / num_episodes
return avg_return.numpy()[0]
compute_avg_return(eval_env, random_policy, num_eval_episodes)
モダンな強化学習はリプレイバッファを利用する。
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
data_spec=tf_agent.collect_data_spec,
batch_size=train_env.batch_size,
max_length=replay_buffer_capacity)
最初にランダム方策を実行してリプレイバッファに溜めておく。これを用いて学習する。
def collect_step(environment, policy):
time_step = environment.current_time_step()
action_step = policy.action(time_step)
next_time_step = environment.step(action_step.action)
traj = trajectory.from_transition(time_step, action_step, next_time_step)
# Add trajectory to the replay buffer
replay_buffer.add_batch(traj)
for _ in range(initial_collect_steps):
collect_step(train_env, random_policy)
dataset = replay_buffer.as_dataset(
num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3)
iterator = iter(dataset)
最後にエージェントを訓練する。
tf_agent.train = common.function(tf_agent.train)
# Reset the train step
tf_agent.train_step_counter.assign(0)
# Evaluate the agent's policy once before training.
avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes)
returns = [avg_return]
for _ in range(num_iterations):
# Collect one step using collect_policy and save to the replay buffer.
collect_step(train_env, tf_agent.collect_policy)
# Sample a batch of data from the buffer and update the agent's network.
experience, unused_info = next(iterator)
train_loss = tf_agent.train(experience)
step = tf_agent.train_step_counter.numpy()
if step % log_interval == 0:
print('step = {0}: loss = {1}'.format(step, train_loss.loss))
if step % eval_interval == 0:
avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes)
print('step = {0}: Average Return = {1}'.format(step, avg_return))
returns.append(avg_return)
5分くらい回すと冒頭に示したように訓練が行われ、振子が倒れないようになる。