11# %%
22import os
33from collections import OrderedDict , deque , namedtuple
4- from typing import List , Tuple
4+ from typing import Iterator , List , Tuple
55
66import gym
77import numpy as np
@@ -99,7 +99,7 @@ def __init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None:
9999 self .buffer = buffer
100100 self .sample_size = sample_size
101101
102- def __iter__ (self ) -> Tuple :
102+ def __iter__ (self ) -> Iterator [ Tuple ] :
103103 states , actions , rewards , dones , new_states = self .buffer .sample (self .sample_size )
104104 for i in range (len (dones )):
105105 yield states [i ], actions [i ], rewards [i ], dones [i ], new_states [i ]
@@ -247,7 +247,7 @@ def populate(self, steps: int = 1000) -> None:
247247 Args:
248248 steps: number of random steps to populate the buffer with
249249 """
250- for i in range (steps ):
250+ for _ in range (steps ):
251251 self .agent .play_step (self .net , epsilon = 1.0 )
252252
253253 def forward (self , x : Tensor ) -> Tensor :
@@ -273,7 +273,7 @@ def dqn_mse_loss(self, batch: Tuple[Tensor, Tensor]) -> Tensor:
273273 """
274274 states , actions , rewards , dones , next_states = batch
275275
276- state_action_values = self .net (states ).gather (1 , actions .unsqueeze (- 1 )).squeeze (- 1 )
276+ state_action_values = self .net (states ).gather (1 , actions .long (). unsqueeze (- 1 )).squeeze (- 1 )
277277
278278 with torch .no_grad ():
279279 next_state_values = self .target_net (next_states ).max (1 )[0 ]
@@ -284,6 +284,11 @@ def dqn_mse_loss(self, batch: Tuple[Tensor, Tensor]) -> Tensor:
284284
285285 return nn .MSELoss ()(state_action_values , expected_state_action_values )
286286
287+ def get_epsilon (self , start : int , end : int , frames : int ) -> float :
288+ if self .global_step > frames :
289+ return end
290+ return start - (self .global_step / frames ) * (start - end )
291+
287292 def training_step (self , batch : Tuple [Tensor , Tensor ], nb_batch ) -> OrderedDict :
288293 """Carries out a single step through the environment to update the replay buffer. Then calculates loss
289294 based on the minibatch recieved.
@@ -296,14 +301,13 @@ def training_step(self, batch: Tuple[Tensor, Tensor], nb_batch) -> OrderedDict:
296301 Training loss and log metrics
297302 """
298303 device = self .get_device (batch )
299- epsilon = max (
300- self .hparams .eps_end ,
301- self .hparams .eps_start - self .global_step + 1 / self .hparams .eps_last_frame ,
302- )
304+ epsilon = self .get_epsilon (self .hparams .eps_start , self .hparams .eps_end , self .hparams .eps_last_frame )
305+ self .log ("epsilon" , epsilon )
303306
304307 # step through environment with agent
305308 reward , done = self .agent .play_step (self .net , epsilon , device )
306309 self .episode_reward += reward
310+ self .log ("episode reward" , self .episode_reward )
307311
308312 # calculates training loss
309313 loss = self .dqn_mse_loss (batch )
0 commit comments