@@ -282,12 +282,21 @@ def rot(self, rot: Tensor):
282282 self ._rot = rot .to (self ._device )
283283
284284 def _reset (self , env_index : typing .Optional [int ]):
285- for attr in [self .pos , self .rot , self .vel , self .ang_vel ]:
285+ for attr_name in ["pos" , "rot" , "vel" , "ang_vel" ]:
286+ attr = self .__getattribute__ (attr_name )
286287 if attr is not None :
287288 if env_index is None :
288- attr [:] = 0.0
289+ self . __setattr__ ( attr_name , torch . zeros_like ( attr ))
289290 else :
290- attr [env_index ] = 0.0
291+ self .__setattr__ (
292+ attr_name , TorchUtils .where_from_index (env_index , 0 , attr )
293+ )
294+
295+ def zero_grad (self ):
296+ for attr_name in ["pos" , "rot" , "vel" , "ang_vel" ]:
297+ attr = self .__getattribute__ (attr_name )
298+ if attr is not None :
299+ self .__setattr__ (attr_name , attr .detach ())
291300
292301 def _spawn (self , dim_c : int , dim_p : int ):
293302 self .pos = torch .zeros (
@@ -363,14 +372,25 @@ def torque(self, value):
363372
364373 @override (EntityState )
365374 def _reset (self , env_index : typing .Optional [int ]):
366- for attr in [self .c , self .force , self .torque ]:
375+ for attr_name in ["c" , "force" , "torque" ]:
376+ attr = self .__getattribute__ (attr_name )
367377 if attr is not None :
368378 if env_index is None :
369- attr [:] = 0.0
379+ self . __setattr__ ( attr_name , torch . zeros_like ( attr ))
370380 else :
371- attr [env_index ] = 0.0
381+ self .__setattr__ (
382+ attr_name , TorchUtils .where_from_index (env_index , 0 , attr )
383+ )
372384 super ()._reset (env_index )
373385
386+ @override (EntityState )
387+ def zero_grad (self ):
388+ for attr_name in ["c" , "force" , "torque" ]:
389+ attr = self .__getattribute__ (attr_name )
390+ if attr is not None :
391+ self .__setattr__ (attr_name , attr .detach ())
392+ super ().zero_grad ()
393+
374394 @override (EntityState )
375395 def _spawn (self , dim_c : int , dim_p : int ):
376396 if dim_c > 0 :
@@ -492,12 +512,21 @@ def _to_tensor(self, value):
492512 )
493513
494514 def _reset (self , env_index : typing .Optional [int ]):
495- for attr in [self .u , self .c ]:
515+ for attr_name in ["u" , "c" ]:
516+ attr = self .__getattribute__ (attr_name )
496517 if attr is not None :
497518 if env_index is None :
498- attr [:] = 0.0
519+ self . __setattr__ ( attr_name , torch . zeros_like ( attr ))
499520 else :
500- attr [env_index ] = 0.0
521+ self .__setattr__ (
522+ attr_name , TorchUtils .where_from_index (env_index , 0 , attr )
523+ )
524+
525+ def zero_grad (self ):
526+ for attr_name in ["u" , "c" ]:
527+ attr = self .__getattribute__ (attr_name )
528+ if attr is not None :
529+ self .__setattr__ (attr_name , attr .detach ())
501530
502531
503532# properties and state of physical world entity
@@ -690,6 +719,9 @@ def _spawn(self, dim_c: int, dim_p: int):
690719 def _reset (self , env_index : int ):
691720 self .state ._reset (env_index )
692721
722+ def zero_grad (self ):
723+ self .state .zero_grad ()
724+
693725 def set_pos (self , pos : Tensor , batch_index : int ):
694726 self ._set_state_property (EntityState .pos , self .state , pos , batch_index )
695727
@@ -988,6 +1020,11 @@ def _reset(self, env_index: int):
9881020 self .dynamics .reset (env_index )
9891021 super ()._reset (env_index )
9901022
1023+ def zero_grad (self ):
1024+ self .action .zero_grad ()
1025+ self .dynamics .zero_grad ()
1026+ super ().zero_grad ()
1027+
9911028 @override (Entity )
9921029 def to (self , device : torch .device ):
9931030 super ().to (device )
@@ -1112,6 +1149,10 @@ def reset(self, env_index: int):
11121149 for e in self .entities :
11131150 e ._reset (env_index )
11141151
1152+ def zero_grad (self ):
1153+ for e in self .entities :
1154+ e .zero_grad ()
1155+
11151156 @property
11161157 def agents (self ) -> List [Agent ]:
11171158 return self ._agents
0 commit comments