@@ -169,7 +169,8 @@ def __init__(self, task, time_limit=float('inf'), random_state=None,
169169 n_sub_steps = None ,
170170 raise_exception_on_physics_error = True ,
171171 strip_singleton_obs_buffer_dim = False ,
172- delayed_observation_padding = ObservationPadding .ZERO ):
172+ delayed_observation_padding = ObservationPadding .ZERO ,
173+ legacy_step : bool = True ):
173174 """Initializes an instance of `_CommonEnvironment`.
174175
175176 Args:
@@ -194,6 +195,9 @@ def __init__(self, task, time_limit=float('inf'), random_state=None,
194195 observables. If `ZERO` then the buffer is initially filled with zeroes.
195196 If `INITIAL_VALUE` then the buffer is initially filled with the first
196197 observation values.
198+ legacy_step: If True, steps the state with up-to-date position and
199+ velocity dependent fields. See Page 6 of
200+ https://arxiv.org/abs/2006.12983 for more information.
197201 """
198202 if not isinstance (delayed_observation_padding , ObservationPadding ):
199203 raise ValueError (
@@ -210,6 +214,7 @@ def __init__(self, task, time_limit=float('inf'), random_state=None,
210214 self ._raise_exception_on_physics_error = raise_exception_on_physics_error
211215 self ._strip_singleton_obs_buffer_dim = strip_singleton_obs_buffer_dim
212216 self ._delayed_observation_padding = delayed_observation_padding
217+ self ._legacy_step = legacy_step
213218
214219 if n_sub_steps is not None :
215220 warnings .simplefilter ('once' , DeprecationWarning )
@@ -248,6 +253,7 @@ def _recompile_physics(self):
248253 self ._physics .free ()
249254 self ._physics = mjcf .Physics .from_mjcf_model (
250255 self ._task .root_entity .mjcf_model )
256+ self ._physics .legacy_step = self ._legacy_step
251257
252258 def _make_observation_updater (self ):
253259 pad_with_initial_value = (
@@ -291,7 +297,8 @@ def __init__(self, task, time_limit=float('inf'), random_state=None,
291297 raise_exception_on_physics_error = True ,
292298 strip_singleton_obs_buffer_dim = False ,
293299 max_reset_attempts = 1 ,
294- delayed_observation_padding = ObservationPadding .ZERO ):
300+ delayed_observation_padding = ObservationPadding .ZERO ,
301+ legacy_step : bool = True ):
295302 """Initializes an instance of `Environment`.
296303
297304 Args:
@@ -320,6 +327,8 @@ def __init__(self, task, time_limit=float('inf'), random_state=None,
320327 observables. If `ZERO` then the buffer is initially filled with zeroes.
321328 If `INITIAL_VALUE` then the buffer is initially filled with the first
322329 observation values.
330+ legacy_step: If True, steps the state with up-to-date position and
331+ velocity dependent fields.
323332 """
324333 super ().__init__ (
325334 task = task ,
@@ -328,7 +337,8 @@ def __init__(self, task, time_limit=float('inf'), random_state=None,
328337 n_sub_steps = n_sub_steps ,
329338 raise_exception_on_physics_error = raise_exception_on_physics_error ,
330339 strip_singleton_obs_buffer_dim = strip_singleton_obs_buffer_dim ,
331- delayed_observation_padding = delayed_observation_padding )
340+ delayed_observation_padding = delayed_observation_padding ,
341+ legacy_step = legacy_step )
332342 self ._max_reset_attempts = max_reset_attempts
333343 self ._reset_next_step = True
334344
0 commit comments