Giter Club home page Giter Club logo

deep-reinforcement-learning-book's Introduction

Deep-Reinforcement-Learning-Book

書籍「つくりながら学ぶ!深層強化学習」、著者:株式会社電通国際情報サービス 小川雄太郎、出版社: マイナビ出版 (2018/6/28) のサポートリポジトリです。

最下部にFAQを追記しました(2019年3月24日最新)

最下部に正誤表を記載しております(2019年1月31日最新)。

図 ブロック崩しを攻略(A2Cを使用し、GPU1枚で3時間の学習後)

図 迷路をランダムに移動

図 迷路を強化学習

図 迷路内の各位置の価値を学習

図 倒立振子を制御

正誤表

[1] 初版:p. 46

パラメータθの更新量の式において、符号がマイナスであるべき部分がプラスになっていました。これに伴い以下3点の修正をお願いします。

[1-1] p. 46 式(2行目)
⊿θ_{s, a_j} = {N(s_i, a_j) + P(s_i, a_j) N(s_i, a)} / T
↓
⊿θ_{s, a_j} = {N(s_i, a_j) - P(s_i, a_j) N(s_i, a)} / T


[1-2] p. 47 コード(上段)
delta_theta[i, j] = (N_ij + pi[i, j] * N_i) / T
↓
delta_theta[i, j] = (N_ij - pi[i, j] * N_i) / T


[1-3] p. 48 コード
stop_epsilon = 10**-8
↓
stop_epsilon = 10**-4

[2] 4.3「PyTorchで手書き数字画像の分類課題MNISTを実装」

p.109、「mnist = fetch_mldata('MNIST original') が実行できない問題」への対処法。 該当ファイル:4_3_PyTorch_MNIST.ipynbを修正しました。

[3] 2.3「方策反復法の実装」

p.46、●方策反復法に従い方策を更新する、の2つの数式内の変数の添え字を訂正いたします。

⊿θ_{s, a_j}
↓
⊿θ_{s_i, a_j}

FAQ

7.4 A2C実装(前半)のクラスNet実装内の、value, actor_output = self(x)の動作について

7.4 A2Cの実装(前半)のClass Netの実装(初版ではp.217-219)において、

def act(self, x):、def get_value(self, x):、def evaluate_actions(self, x, action):

などの関数定義内で

value, actor_output = self(x)

というコードがあります。この部分のself(x)の解説補足です。

このself(x)は同じクラスNetのforward関数を実行しています。

def forward(self, x):のreturn を見ると

return critic_output, actor_output

となっていますが、これらが、self(x)のreturn値である、valueとactor_outputに対応しています。

なぜself(x)でforward()が実行されるのか補足します。

まずself()は自分自身を意味しますので、self(x)とはNet(x)を示します。

そしてクラスNetはnn.Moduleクラスを継承しています。

nn.Moduleクラスは__call__()というメソッドを持っており、その中で、forward()が実行されるように指定されています。

この__call__()というメソッドは、Pythonの一般的なメソッドです。

そのクラスのオブジェクトが具体的な関数を指定されずに呼び出されたときに動作する関数です。

よって、Net(x)の具体的なオブジェクトがあったとします。

例えば、

net = Net(n_in=4, n_mid=32, n_out=2)

でnetというクラスNetのオブジェクトができます。

ここで

net(x)

と、具体的な関数を指定せずに実行すると

netの__call__(x)が実行され、

この__call__()のなかにある、

net.forward(x)

が実行されることになります。

つまり、self(x) → Net(x) → Netの__call__(x) → Net.forward(x)

という関係になっています。

そのためself(x)によってforward(x)が実行されています。

deep-reinforcement-learning-book's People

Contributors

ogawa-yutaro avatar yutaroogawa avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

deep-reinforcement-learning-book's Issues

7_breakout_learning.ipynbでエラー

「7_breakout_learning.ipynb」の実行中に次のエラーが発生しました。
なんとなくmultiprocessingに関するエラーかなと思っています。

ちなみにライブラリのバージョンは次のとおりです。
gym==0.17.1
matplotlib==2.2.5
JSAnimation==0.1
pyglet==1.5.0
torch==0.4.1

---------------------------------------------------------------------------
ConnectionResetError                      Traceback (most recent call last)
<ipython-input-15-ed82eff94473> in <module>()
      1 # 実行
      2 breakout_env = Environment()
----> 3 breakout_env.run()

4 frames
/usr/lib/python3.7/multiprocessing/connection.py in _recv(self, size, read)
    377         remaining = size
    378         while remaining > 0:
--> 379             chunk = read(handle, remaining)
    380             n = len(chunk)
    381             if n == 0:

ConnectionResetError: [Errno 104] Connection reset by peer

3.2中,for step in range(0, 200): frames.append(env.render(mode='rgb_array')) 此句代码运行的时候出现错误TypeError: render() got an unexpected keyword argument 'mode',不知道要怎么改

TypeError Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_14388/2358979688.py in
6
7 for step in range(0, 200):
----> 8 frames.append(env.render(mode='rgb_array')) # framesに各時刻の画像を追加していく
9 action = np.random.choice(2) # 0(カートを左に押す), 1(カートを右に押す)をランダムに返す
10 observation, reward, done, info = env.step(action) # actionを実行する

D:\Anaconda\anaconda3\lib\site-packages\gym\core.py in render(self, *args, **kwargs)
327 ) -> Optional[Union[RenderFrame, List[RenderFrame]]]:
328 """Renders the environment."""
--> 329 return self.env.render(*args, **kwargs)
330
331 def close(self):

D:\Anaconda\anaconda3\lib\site-packages\gym\wrappers\order_enforcing.py in render(self, *args, **kwargs)
49 "set disable_render_order_enforcing=True on the OrderEnforcer wrapper."
50 )
---> 51 return self.env.render(*args, **kwargs)
52
53 @Property

D:\Anaconda\anaconda3\lib\site-packages\gym\wrappers\env_checker.py in render(self, *args, **kwargs)
51 if self.checked_render is False:
52 self.checked_render = True
---> 53 return env_render_passive_checker(self.env, *args, **kwargs)
54 else:
55 return self.env.render(*args, **kwargs)

D:\Anaconda\anaconda3\lib\site-packages\gym\utils\passive_env_checker.py in env_render_passive_checker(env, *args, **kwargs)
314 )
315
--> 316 result = env.render(*args, **kwargs)
317
318 # TODO: Check that the result is correct

TypeError: render() got an unexpected keyword argument 'mode'

ライセンス表記について

今この教科書を読んでいるのですが、勉強過程で自分用に整理した参考実装を公開したいと思っております。何らかのOSSライセンスを明示して頂けると幸いです。

7章P.233の、current_obsの上書きでエラーが出ます

下の画像のようなエラーが発生したので、

current_obs[:, :-1] = current_obs[:, 1:] # 0~2番目に1~3番目を上書き

この部分を下のように書き換えたところ

current_obs[:, :-1] = current_obs[:, 1:].clone() # 0~2番目に1~3番目を上書き

正常に動作しました。参考まで。

方策勾配法での表記について質問です

教科書を買わせていただきました。とてもわかりやすくスムーズに学習が進んでおります。ありがとうございます。
一点質問なのですが
p.46の方策勾配法についての記述で、2箇所ある
⊿θ_{s, a_j}

⊿θ_{s_i, a_j}
ではないのかと思ったのですがどうでしょうか、?

HTML(anim.to_jshtml())によるエラー

本書にて学習させていただいております。
とても読みやすかったのですが、サンプルコードを実施するにあたり
下記問題がありましたので、ご報告させていただきます。

■以下該当箇所

2.2 迷路内をランダムに探索させる

# 初期化関数とフレームごとの描画関数を用いて動画を作成する
anim = animation.FuncAnimation(fig, animate, init_func=init, frames=len(
    state_history), interval=200, repeat=False)

エラーメッセージ

HTML(anim.to_jshtml())
---> 26 HTML(anim.to_jshtml())

AttributeError: 'FuncAnimation' object has no attribute 'to_jshtml'

以上よろしくお願いいたします。

python str format

完全に寝ぼけて誤ったIssueを作ってしまいました…Close致します🙇

P.S. 著作、躓くこと無く読了できました。人にも薦めております。

第五章中,使用pytorch实现DQN时,显示ValueError: unknown file extension: .mp4

在运行第五章的代码时,显示无法正常识别.mp4文件。
第三章的倒立摆实验没有使用pytorch,可以正常生成视频,但是将第三章的代码放在第5章的环境中运行也不能正常生成视频。
由此推断视频的无法生成和pytorch环境无法识别视频有关,请问各位学者该问题有什么办法解决吗?

KeyError Traceback (most recent call last)
D:\Anaconda\anaconda3\envs\pytorch\lib\site-packages\PIL\Image.py in save(self, fp, format, **params)
2214 try:
-> 2215 format = EXTENSION[ext]
2216 except KeyError as e:

KeyError: '.mp4'

ValueError: unknown file extension: .mp4

第四章中,用了作者新改的代码,后还是报了错误,请问这是什么原因呀

File D:\Anaconda\anaconda3\lib\site-packages\pandas\core\indexes\base.py:3631, in Index.get_loc(self, key, method, tolerance)
3629 return self._engine.get_loc(casted_key)
3630 except KeyError as err:
-> 3631 raise KeyError(key) from err
3632 except TypeError:
3633 # If we have a listlike key, _check_indexing_error will raise
3634 # InvalidIndexError. Otherwise we fall through and re-raise
3635 # the TypeError.
3636 self._check_indexing_error(key)

KeyError: 0

fetch_mldataによるエラー

4章でmnistをダウンロードするためにsklearn のfetch_mldataを使っているようですが,mldata.orgが落ちているためか使うことができません.また復旧の見込みもないようです.(scikit-learn/scikit-learn#8588)
代わりにsklearn 0.20で追加された
fetch_openmlを使って(https://scikit-learn.org/stable/modules/generated/sklearn.datasets.fetch_openml.html#sklearn.datasets.fetch_openml)

from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1,)

とするか,
scikit-learn/scikit-learn#8588 (comment)
などの対処法を使うように書いた方が良いのではないかと思います.

該当箇所
https://github.com/YutaroOgawa/Deep-Reinforcement-Learning-Book/blob/master/program/4_3_PyTorch_MNIST.ipynb


# 手書き数字の画像データMNISTをダウンロード

from sklearn.datasets import fetch_mldata

mnist = fetch_mldata('MNIST original', data_home=".")  # data_homeは保存先を指定します

3_2_try_CartPole.ipynb にて TypeError

お世話になっております.
当方の環境で 3_2_try_CartPole.ipynb を実行すると,下記のような TypeError が出て困っています.

環境

  • MacBook Pro(2019), OSX 10.15.5
  • conda 4.8.3
  • gym 0.17.2
  • matplotlib 2.2.5
  • JSAnimation 0.1
  • pyglet 1.2.4
  • ffmpeg 4.2.2

エラー

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-3-649a9cc1e6eb> in <module>
      6 
      7 for step in range(0, 200):
----> 8     frames.append(env.render(mode='rgb_array'))  # framesに各時刻の画像を追加していく
      9     action = np.random.choice(2)  # 0(カートを左に押す), 1(カートを右に押す)をランダムに返す
     10     observation, reward, done, info = env.step(action)  # actionを実行する

/opt/anaconda3/lib/python3.7/site-packages/gym/core.py in render(self, mode, **kwargs)
    231 
    232     def render(self, mode='human', **kwargs):
--> 233         return self.env.render(mode, **kwargs)
    234 
    235     def close(self):

/opt/anaconda3/lib/python3.7/site-packages/gym/envs/classic_control/cartpole.py in render(self, mode)
    211         self.poletrans.set_rotation(-x[2])
    212 
--> 213         return self.viewer.render(return_rgb_array=mode == 'rgb_array')
    214 
    215     def close(self):

/opt/anaconda3/lib/python3.7/site-packages/gym/envs/classic_control/rendering.py in render(self, return_rgb_array)
    113             buffer = pyglet.image.get_buffer_manager().get_color_buffer()
    114             image_data = buffer.get_image_data()
--> 115             arr = np.frombuffer(image_data.get_data(), dtype=np.uint8)
    116             # In https://github.com/openai/gym-http-api/issues/2, we
    117             # discovered that someone using Xmonad on Arch was having

TypeError: get_data() missing 2 required positional arguments: 'format' and 'pitch'

5.3 & 5.4にdeprecation warningが出ます。

macOSで実行してますが、

# main クラス
cartpole_env = Environment()
cartpole_env.run()

を実行すると、以下のようなwarningが出ます。

/Users/distiller/project/conda/conda-bld/pytorch_1573049287641/work/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.

エントロピー項の計算について

A2C の損失関数は、Criticの損失を最小化し、方策損失とそのエントロピーを最大化するので、方策損失とそのエントロピーにはそれぞれマイナスを掛けて、損失全体は、
value_loss_coef - action_gain - entropy
という形になっていると思います.。しかし、Net#evaluate_actions()の中ではエントロピーには既にマイナスを掛けています。これは対数の計算によりマイナス(第4象限)になった値をひっくり返してプラスに変えるためでしょうか?

QテーブルからDQNのニューラルネットワークに変換する部分がよく分かりません。

小川先生、

いつも、丁寧にご返事いただき、ありがとうございます。
先生の本によると、Qテーブルからニューラルネットワークにした際は、P124にあるように、入力層には、状態を、出力層には、P123にあるように、各行動の行動価値関数が入り、バックプロパゲーションで、誤差関数を観測すると読みました。
いつも、ややかけ離れたような話で恐縮なのですが、以下のプログラムでは、tensorflowとtf_agentsを使っているのですが、
MyQNetworkクラスの入力に、observation_spec, action_spec,がはいっており、観測値と行動が入っているようなのですが、出力がどこで、同フィードバックがかかっているのかよくわからず、このオセロのプログラムの著者や、ネットの様々なところで質問を投げているのですが、どうにも分からない状態です。こちら、オセロのトレーニングをするためのプログラムなのですが、Qテーブルをニューラルネットに投げた際、入力が何で、出力が何なのか、お分かりになるでしょうか。
いつも、本の内容と少し違うことを質問してしまって大変恐縮です。しかし、このあたりがクリアにならないと、DQNのプログラムを書くことは難しいのではないかとも考えていて、藁をもすがる思いで質問させていただいております。
無理であれば、しかたがないので諦めます。
下記に、問題のコード部分を書きます。(tf_agentが独自に持つ関数があって、そこが不明なのかもしれません)

#####ここから

'''
リバーシプログラム:エージェント学習プログラム(CNN,DQNを利用)
Copyright(c) 2020 Koji Makino and Hiromitsu Nishizaki All Rights Reserved.
'''
import tensorflow as tf
from tensorflow import keras

from tf_agents.environments import gym_wrapper, py_environment, tf_py_environment
from tf_agents.agents.dqn import dqn_agent
from tf_agents.networks import network
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.policies import policy_saver
from tf_agents.trajectories import time_step as ts
from tf_agents.trajectories import trajectory, policy_step as ps
from tf_agents.specs import array_spec
from tf_agents.utils import common, nest_utils

import numpy as np
import random
import copy

SIZE = 4 # 盤面のサイズ SIZE*SIZE
NONE = 0 # 盤面のある座標にある石:なし
BLACK = 1# 盤面のある座標にある石:黒
WHITE = 2# 盤面のある座標にある石:白
STONE = {NONE:" ", BLACK:"●", WHITE:"○"}# 石の表示用
ROWLABEL = {'a':1, 'b':2, 'c':3, 'd':4, 'e':5, 'f':6, 'g':7, 'h':8} #ボードの横軸ラベル
N2L = ['', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] # 横軸ラベルの逆引き用
REWARD_WIN = 1 # 勝ったときの報酬
REWARD_LOSE = -1 # 負けたときの報酬
# 2次元のボード上での隣接8方向の定義(左から,上,右上,右,右下,下,左下,左,左上)
DIR = ((-1,0), (-1,1), (0,1), (1,1), (1,0), (1, -1), (0,-1), (-1,-1))
#シミュレータークラス
class Board(py_environment.PyEnvironment):  
  def __init__(self):
    super(Board, self).__init__()  
    self._observation_spec = array_spec.BoundedArraySpec(
      shape=(SIZE,SIZE,1), dtype=np.float32, minimum=0, maximum=2
    )
    self._action_spec = array_spec.BoundedArraySpec(
      shape=(), dtype=np.int32, minimum=0, maximum=SIZE*SIZE-1
    )
    self.reset()
  def observation_spec(self):
    return self._observation_spec
  def action_spec(self):
    return self._action_spec
#ボードの初期化
  def _reset(self):
    self.board = np.zeros((SIZE, SIZE, 1), dtype=np.float32) # 全ての石をクリア.ボードは2次元配列(i, j)で定義する.
    mid = SIZE // 2 # 真ん中の基準ポジション
    # 初期4つの石を配置
    self.board[mid, mid] = WHITE
    self.board[mid-1, mid-1] = WHITE
    self.board[mid-1, mid] = BLACK
    self.board[mid, mid-1] = BLACK
    self.winner = NONE # 勝者
    self.turn = random.choice([BLACK,WHITE])
    self.game_end = False # ゲーム終了チェックフラグ
    self.pss = 0 # パスチェック用フラグ.双方がパスをするとゲーム終了
    self.nofb = 0 # ボード上の黒石の数
    self.nofw = 0 # ボード上の白石の数
    self.available_pos = self.search_positions() # self.turnの石が置ける場所のリスト

    time_step = ts.restart(self.board)
    return nest_utils.batch_nested_array(time_step)
#行動による状態変化(石を置く&リバース処理)
  def _step(self, pos):
    pos = nest_utils.unbatch_nested_array(pos)
    pos = divmod(pos, SIZE)     #一次元座標を二次元に変換
    if self.is_available(pos):
      self.board[pos[0], pos[1]] = self.turn
      self.do_reverse(pos) # リバース
    self.end_check()#終了したかチェック
    time_step = ts.transition(self.board, reward=0, discount=1)
    return nest_utils.batch_nested_array(time_step)
#ターンチェンジ
  def change_turn(self, role=None):
    if role is NONE:
      role = random.choice([WHITE,BLACK])
    if role is None or role != self.turn:
      self.turn = WHITE if self.turn == BLACK else BLACK
      self.available_pos = self.search_positions() # 石が置ける場所を探索しておく
#ランダムに石を置く場所を決める(ε-greedy用)
  def random_action(self):
    if len(self.available_pos) > 0:
      pos = random.choice(self.available_pos) # 置く場所をランダムに決める
      pos = pos[0] * SIZE + pos[1] # 1次元座標に変換(NNの教師データは1次元でないといけない)
      return pos
    return False # 置く場所なし
#リバース処理
  def do_reverse(self, pos):
    for di, dj in DIR:
      opp = BLACK if self.turn == WHITE else WHITE # 対戦相手の石
      boardcopy = self.board.copy() # 一旦ボードをコピーする(copyを使わないと参照渡しになるので注意)
      i = pos[0]
      j = pos[1]
      flag = False # 挟み判定用フラグ
      while 0 <= i < SIZE and 0 <= j < SIZE: # (i,j)座標が盤面内に収まっている間繰り返す
        i += di # i座標(縦)をずらす
        j += dj # j座標(横)をずらす
        if 0 <= i < SIZE and 0 <= j < SIZE and boardcopy[i,j] == opp:  # 盤面に収まっており,かつ相手の石だったら
          flag = True
          boardcopy[i,j] = self.turn # 自分の石にひっくり返す
        elif not(0 <= i < SIZE and 0 <= j < SIZE) or (flag == False and boardcopy[i,j] != opp):
          break
        elif boardcopy[i,j] == self.turn and flag == True: # 自分と同じ色の石がくれば挟んでいるのでリバース処理を確定
          self.board = boardcopy.copy() # ボードを更新
          break

#石が置ける場所をリストアップする.石が置ける場所がなければ「パス」となる
  def search_positions(self):
    pos = []
    emp = np.where(self.board == 0) # 石が置かれていない場所を取得
    for i in range(emp[0].size): # 石が置かれていない全ての座標に対して
      p = (emp[0][i], emp[1][i]) # (i,j)座標に変換
      if self.is_available(p):
        pos.append(p) # 石が置ける場所の座標リストの生成
        #print(pos)
    return pos
#石が置けるかをチェックする
  def is_available(self, pos):
    if self.board[pos[0], pos[1]] != NONE: # 既に石が置いてあれば,置けない
      return False
    opp = BLACK if self.turn == WHITE else WHITE
    for di, dj in DIR: # 8方向の挟み(リバースできるか)チェック
      #print(di,dj)
      i = pos[0]
      j = pos[1]
      flag = False # 挟み判定用フラグ
      while 0 <= i < SIZE and 0 <= j < SIZE: # (i,j)座標が盤面内に収まっている間繰り返す
        i += di # i座標(縦)をずらす
        j += dj # j座標(横)をずらす
        if 0 <= i < SIZE and 0 <= j < SIZE and self.board[i,j] == opp: #盤面に収まっており,かつ相手の石だったら
          flag = True
        elif not(0 <= i < SIZE and 0 <= j < SIZE) or (flag == False and self.board[i,j] != opp) or self.board[i,j] == NONE:        
          break
        elif self.board[i,j] == self.turn and flag == True: # 自分と同じ色の石          
          return True
    return False
    
#ゲーム終了チェック
  def end_check(self):
    if np.count_nonzero(self.board) == SIZE * SIZE or self.pss == 2: # ボードに全て石が埋まるか,双方がパスがしたら
      self.game_end = True
      self.nofb = len(np.where(self.board==BLACK)[0])
      self.nofw = len(np.where(self.board==WHITE)[0])
      if self.nofb > self.nofw:
        self.winner = BLACK
      elif self.nofb < self.nofw:
        self.winner = WHITE
      else:
        self.winner = NONE
#ボードの表示(人間との対戦用)
  def show_board(self):
    print('  ', end='')      
    for i in range(1, SIZE + 1):
      print(f' {N2L[i]}', end='') # 横軸ラベル表示
    print('')
    for i in range(0, SIZE):
      print(f'{i+1:2d} ', end='')
      for j in range(0, SIZE):
        print(f'{STONE[int(self.board[i][j])]} ', end='') 
      print('')
#パスしたときの処理  
  def add_pass(self):
    self.pss += 1
#パスした後の処理  
  def clear_pass(self):
    self.pss = 0
  
  @property
  def batched(self):
    return True

  @property
  def batch_size(self):
    return 1
#ネットワークの設定
class MyQNetwork(network.Network):
  def __init__(self, observation_spec, action_spec, n_hidden_channels=256, name='QNetwork'):
    super(MyQNetwork,self).__init__(
      input_tensor_spec=observation_spec, 
      state_spec=(), 
      name=name
    )
    n_action = action_spec.maximum - action_spec.minimum + 1
    self.model = keras.Sequential(
      [
        keras.layers.Conv2D(4, 2, 1, activation='relu'),
        keras.layers.Conv2D(8, 2, 1, activation='relu'),
        keras.layers.Conv2D(16, 2, 1, activation='relu'),
        keras.layers.Dense(256, kernel_initializer='he_normal'),
        keras.layers.Flatten(),
        keras.layers.Dense(n_action, kernel_initializer='he_normal'),
      ]
    )
  def call(self, observation, step_type=None, network_state=(), training=True):
    observation = tf.cast(observation, tf.float32)
    actions = self.model(observation, training=training)
    return actions, network_state
#ランダム行動を行うときのポリシー
def random_policy_step(random_action_function):
  random_act = random_action_function()
  if random_act is not False:
    return ps.PolicyStep(
          action=tf.constant([random_act]),
          state=(),
          info=()
        )
  else:
    raise Exception("No position avaliable.")

def main():
#環境の設定
  env_py = Board()
  env = tf_py_environment.TFPyEnvironment(env_py)
#黒と白の2つを宣言するために先に宣言
  primary_network = {}
  agent = {}
  replay_buffer = {}
  iterator = {}
  policy = {}
  tf_policy_saver = {}

  n_step_update = 1
  for role in [BLACK, WHITE]:#黒と白のそれぞれの設定
#ネットワークの設定
    primary_network[role] = MyQNetwork(env.observation_spec(), env.action_spec())
#    print("obs is",env.observation_spec())
#    print("env is",env.action_spec())
#エージェントの設定
    agent[role] = dqn_agent.DqnAgent(
      env.time_step_spec(),
      env.action_spec(),
      q_network = primary_network[role],
      optimizer = keras.optimizers.Adam(learning_rate=1e-3),
      n_step_update = n_step_update,
      target_update_period=100,#0,
      gamma=0.99,
      train_step_counter = tf.Variable(0),
      epsilon_greedy = 0.0,
    )
    agent[role].initialize()
    agent[role].train = common.function(agent[role].train)
#行動の設定
    policy[role] = agent[role].collect_policy
#データの保存の設定
    replay_buffer[role] = tf_uniform_replay_buffer.TFUniformReplayBuffer(
      data_spec=agent[role].collect_data_spec,
      batch_size=env.batch_size,
      max_length=10**6,
    )
    dataset = replay_buffer[role].as_dataset(
        num_parallel_calls=tf.data.experimental.AUTOTUNE,
        sample_batch_size=16,
        num_steps=n_step_update+1,
      ).prefetch(tf.data.experimental.AUTOTUNE)
    iterator[role] = iter(dataset)
#ポリシーの保存設定
    tf_policy_saver[role] = policy_saver.PolicySaver(agent[role].policy)

  num_episodes = 100#0
  decay_episodes = 70#0
  epsilon = np.concatenate( [np.linspace(start=1.0, stop=0.1, num=decay_episodes),0.1 * np.ones(shape=(num_episodes-decay_episodes,)),],0)

  action_step_counter = 0
  replay_start_size = 100#0

  winner_counter = {BLACK:0, WHITE:0, NONE:0}#黒と白の勝った回数と引き分けの回数
  episode_average_loss = {BLACK:[], WHITE:[]}#黒と白の平均loss

  for episode in range(1, num_episodes + 1):
    policy[WHITE]._epsilon = epsilon[episode-1]#ε-greedy法用
    policy[BLACK]._epsilon = epsilon[episode-1]
    env.reset()

    rewards = {BLACK:0, WHITE:0}# 報酬リセット
    previous_time_step = {BLACK:None, WHITE:None}
    previous_policy_step = {BLACK:None, WHITE:None}

    while not env.game_end: # ゲームが終わるまで繰り返す
      if not env.available_pos:# 石が置けない場合はパス
        env.add_pass()
        env.end_check()
      else:# 石を置く処理
        current_time_step = env.current_time_step()
        while True: # 置ける場所が見つかるまで繰り返す
          if previous_time_step[env.turn] is None:#1手目は学習データを作らない
            pass
          else:
            previous_step_reward = tf.constant([rewards[env.turn],],dtype=tf.float32)
            current_time_step = current_time_step._replace(reward=previous_step_reward)

            traj = trajectory.from_transition( previous_time_step[env.turn], previous_policy_step[env.turn], current_time_step )#データの生成
            replay_buffer[env.turn].add_batch( traj )#データの保存

            if action_step_counter >= 2*replay_start_size:#事前データ作成用
              experience, _ = next(iterator[env.turn])
              loss_info = agent[env.turn].train(experience=experience)#学習
              episode_average_loss[env.turn].append(loss_info.loss.numpy())
            else:
              action_step_counter += 1
          if random.random() < epsilon[episode-1]:#ε-greedy法によるランダム動作
            policy_step = random_policy_step(env.random_action)#設定したランダムポリシー
          else:
            policy_step = policy[env.turn].action(current_time_step)#状態から行動の決定

          previous_time_step[env.turn] = current_time_step#1つ前の状態の保存
          previous_policy_step[env.turn] = policy_step#1つ前の行動の保存

          pos = policy_step.action.numpy()[0]

          pos = divmod(pos, SIZE) # 座標を2次元(i,j)に変換

          if env.is_available(pos):
            rewards[env.turn] = 0
            break
          else:
            rewards[env.turn] = REWARD_LOSE # 石が置けない場所であれば負の報酬                    
        
        env.step(policy_step.action)# 石を配置
        env.clear_pass() # 石が配置できた場合にはパスフラグをリセットしておく(双方が連続パスするとゲーム終了する)

      if env.game_end:#ゲーム終了時の処理
        if env.winner == BLACK:#黒が勝った場合
          rewards[BLACK] = REWARD_WIN  # 黒の勝ち報酬
          rewards[WHITE] = REWARD_LOSE # 白の負け報酬
          winner_counter[BLACK] += 1
        elif env.winner == WHITE:#白が勝った場合
          rewards[BLACK] = REWARD_LOSE
          rewards[WHITE] = REWARD_WIN
          
          winner_counter[WHITE] += 1
        else:#引き分けの場合
          winner_counter[NONE] += 1
        #エピソードを終了して学習
        final_time_step = env.current_time_step()#最後の状態の呼び出し
        for role in [WHITE, BLACK]:
          final_time_step = final_time_step._replace(step_type = tf.constant([2], dtype=tf.int32), reward = tf.constant([rewards[role]], dtype=tf.float32), )#最後の状態の報酬の変更
          traj = trajectory.from_transition( previous_time_step[role], previous_policy_step[role], final_time_step )#データの生成
          replay_buffer[role].add_batch( traj )#事前データ作成用
          if action_step_counter >= 2*replay_start_size:
            experience, _ = next(iterator[role])
            loss_info = agent[role].train(experience=experience)
            episode_average_loss[role].append(loss_info.loss.numpy())
      else:        
        env.change_turn()

    # 学習の進捗表示 (100エピソードごと)
    if episode % 100 == 0:      
      print(f'==== Episode {episode}: black win {winner_counter[BLACK]}, white win {winner_counter[WHITE]}, draw {winner_counter[NONE]} ====')
      if len(episode_average_loss[BLACK]) == 0:
        episode_average_loss[BLACK].append(0)
      print(f'<BLACK> AL: {np.mean(episode_average_loss[BLACK]):.4f}, PE:{policy[BLACK]._epsilon:.6f}')
      if len(episode_average_loss[WHITE]) == 0:
        episode_average_loss[WHITE].append(0)
      print(f'<WHITE> AL:{np.mean(episode_average_loss[WHITE]):.4f}, PE:{policy[WHITE]._epsilon:.6f}')
      # カウンタ変数の初期化      
      winner_counter = {BLACK:0, WHITE:0, NONE:0}
      episode_average_loss = {WHITE:[], BLACK:[]}

    if episode % (num_episodes//10) == 0:
      tf_policy_saver[BLACK].save(f"policy_black_{episode}")
      tf_policy_saver[WHITE].save(f"policy_white_{episode}")

if __name__ == '__main__':
  main()
####ここまで

入力ファイルを表形式で持った強化学習のアルゴリズムについて

小川先生

本の内容と直接関係がなく、大変恐縮なのですが、例えばDQNを報酬、行動、状態などをエクセルのような表形式で持たせて、DQNを行うようなPythonのアルゴリズムをご存じでしょうか。今のデータがそのような形になっており、いろいろと検索をしたり、聞いて回っているのですが、なかなか良い回答が得られず、恥ずかしながら、こちらで先生ならどのように考えられるか、少しご相談したいと思いました。
本と関係ないので、分からないということであれば、クローズで大丈夫です。
何卒、よろしくお願い申し上げます。

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.