Skip to content

Commit 461bed9

Browse files
kevinzakkacopybara-github
authored andcommitted
Add option to step the physics with mj_step instead of mj_step2->mj_step1.
PiperOrigin-RevId: 514728526 Change-Id: Icfac463419f1f9365b90ce36d4d96b0d5311ec73
1 parent 1cd211b commit 461bed9

3 files changed

Lines changed: 42 additions & 18 deletions

File tree

dm_control/composer/environment.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@ def __init__(self, task, time_limit=float('inf'), random_state=None,
169169
n_sub_steps=None,
170170
raise_exception_on_physics_error=True,
171171
strip_singleton_obs_buffer_dim=False,
172-
delayed_observation_padding=ObservationPadding.ZERO):
172+
delayed_observation_padding=ObservationPadding.ZERO,
173+
legacy_step: bool = True):
173174
"""Initializes an instance of `_CommonEnvironment`.
174175
175176
Args:
@@ -194,6 +195,9 @@ def __init__(self, task, time_limit=float('inf'), random_state=None,
194195
observables. If `ZERO` then the buffer is initially filled with zeroes.
195196
If `INITIAL_VALUE` then the buffer is initially filled with the first
196197
observation values.
198+
legacy_step: If True, steps the state with up-to-date position and
199+
velocity dependent fields. See Page 6 of
200+
https://arxiv.org/abs/2006.12983 for more information.
197201
"""
198202
if not isinstance(delayed_observation_padding, ObservationPadding):
199203
raise ValueError(
@@ -210,6 +214,7 @@ def __init__(self, task, time_limit=float('inf'), random_state=None,
210214
self._raise_exception_on_physics_error = raise_exception_on_physics_error
211215
self._strip_singleton_obs_buffer_dim = strip_singleton_obs_buffer_dim
212216
self._delayed_observation_padding = delayed_observation_padding
217+
self._legacy_step = legacy_step
213218

214219
if n_sub_steps is not None:
215220
warnings.simplefilter('once', DeprecationWarning)
@@ -248,6 +253,7 @@ def _recompile_physics(self):
248253
self._physics.free()
249254
self._physics = mjcf.Physics.from_mjcf_model(
250255
self._task.root_entity.mjcf_model)
256+
self._physics.legacy_step = self._legacy_step
251257

252258
def _make_observation_updater(self):
253259
pad_with_initial_value = (
@@ -291,7 +297,8 @@ def __init__(self, task, time_limit=float('inf'), random_state=None,
291297
raise_exception_on_physics_error=True,
292298
strip_singleton_obs_buffer_dim=False,
293299
max_reset_attempts=1,
294-
delayed_observation_padding=ObservationPadding.ZERO):
300+
delayed_observation_padding=ObservationPadding.ZERO,
301+
legacy_step: bool = True):
295302
"""Initializes an instance of `Environment`.
296303
297304
Args:
@@ -320,6 +327,8 @@ def __init__(self, task, time_limit=float('inf'), random_state=None,
320327
observables. If `ZERO` then the buffer is initially filled with zeroes.
321328
If `INITIAL_VALUE` then the buffer is initially filled with the first
322329
observation values.
330+
legacy_step: If True, steps the state with up-to-date position and
331+
velocity dependent fields.
323332
"""
324333
super().__init__(
325334
task=task,
@@ -328,7 +337,8 @@ def __init__(self, task, time_limit=float('inf'), random_state=None,
328337
n_sub_steps=n_sub_steps,
329338
raise_exception_on_physics_error=raise_exception_on_physics_error,
330339
strip_singleton_obs_buffer_dim=strip_singleton_obs_buffer_dim,
331-
delayed_observation_padding=delayed_observation_padding)
340+
delayed_observation_padding=delayed_observation_padding,
341+
legacy_step=legacy_step)
332342
self._max_reset_attempts = max_reset_attempts
333343
self._reset_next_step = True
334344

dm_control/mujoco/engine.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -144,30 +144,37 @@ def set_control(self, control):
144144
"""
145145
np.copyto(self.data.ctrl, control)
146146

147-
def step(self, nstep=1):
148-
"""Advances physics with up-to-date position and velocity dependent fields.
149-
150-
Args:
151-
nstep: Optional integer, number of steps to take.
152-
153-
The actuation can be updated by calling the `set_control` function first.
154-
"""
147+
def _step_with_up_to_date_position_velocity(self, nstep: int = 1) -> None:
148+
"""Physics step with up-to-date position and velocity dependent fields."""
155149
# In the case of Euler integration we assume mj_step1 has already been
156150
# called for this state, finish the step with mj_step2 and then update all
157151
# position and velocity related fields with mj_step1. This ensures that
158152
# (most of) mjData is in sync with qpos and qvel. In the case of non-Euler
159153
# integrators (e.g. RK4) an additional mj_step1 must be called after the
160154
# last mj_step to ensure mjData syncing.
155+
if self.model.opt.integrator != mujoco.mjtIntegrator.mjINT_RK4.value:
156+
mujoco.mj_step2(self.model.ptr, self.data.ptr)
157+
if nstep > 1:
158+
mujoco.mj_step(self.model.ptr, self.data.ptr, nstep-1)
159+
else:
160+
mujoco.mj_step(self.model.ptr, self.data.ptr, nstep)
161+
162+
mujoco.mj_step1(self.model.ptr, self.data.ptr)
163+
164+
def step(self, nstep: int = 1) -> None:
165+
"""Advances the physics state by `nstep`s.
166+
167+
Args:
168+
nstep: Optional integer, number of steps to take.
169+
170+
The actuation can be updated by calling the `set_control` function first.
171+
"""
161172
with self.check_invalid_state():
162-
if self.model.opt.integrator != mujoco.mjtIntegrator.mjINT_RK4.value:
163-
mujoco.mj_step2(self.model.ptr, self.data.ptr)
164-
if nstep > 1:
165-
mujoco.mj_step(self.model.ptr, self.data.ptr, nstep-1)
173+
if self.legacy_step:
174+
self._step_with_up_to_date_position_velocity(nstep)
166175
else:
167176
mujoco.mj_step(self.model.ptr, self.data.ptr, nstep)
168177

169-
mujoco.mj_step1(self.model.ptr, self.data.ptr)
170-
171178
def render(
172179
self,
173180
height=240,

dm_control/rl/control.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ def __init__(self,
3434
time_limit=float('inf'),
3535
control_timestep=None,
3636
n_sub_steps=None,
37-
flat_observation=False):
37+
flat_observation=False,
38+
legacy_step: bool = True):
3839
"""Initializes a new `Environment`.
3940
4041
Args:
@@ -48,12 +49,16 @@ def __init__(self,
4849
`control_timestep` is not specified.
4950
flat_observation: If True, observations will be flattened and concatenated
5051
into a single numpy array.
52+
legacy_step: If True, steps the state with up-to-date position and
53+
velocity dependent fields. See Page 6 of
54+
https://arxiv.org/abs/2006.12983 for more information.
5155
5256
Raises:
5357
ValueError: If both `n_sub_steps` and `control_timestep` are supplied.
5458
"""
5559
self._task = task
5660
self._physics = physics
61+
self._physics.legacy_step = legacy_step
5762
self._flat_observation = flat_observation
5863

5964
if n_sub_steps is not None and control_timestep is not None:
@@ -201,6 +206,8 @@ def _spec_from_observation(observation):
201206
class Physics(metaclass=abc.ABCMeta):
202207
"""Simulates a physical environment."""
203208

209+
legacy_step: bool = True
210+
204211
@abc.abstractmethod
205212
def step(self, n_sub_steps=1):
206213
"""Updates the simulation state.

0 commit comments

Comments
 (0)