Skip to content

Commit 8a193b5

Browse files
committed
[Differentiable] Further remove inplace operations and introduce world.zero_grad()
1 parent d588548 commit 8a193b5

4 files changed

Lines changed: 65 additions & 10 deletions

File tree

vmas/simulator/core.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -282,12 +282,21 @@ def rot(self, rot: Tensor):
282282
self._rot = rot.to(self._device)
283283

284284
def _reset(self, env_index: typing.Optional[int]):
285-
for attr in [self.pos, self.rot, self.vel, self.ang_vel]:
285+
for attr_name in ["pos", "rot", "vel", "ang_vel"]:
286+
attr = self.__getattribute__(attr_name)
286287
if attr is not None:
287288
if env_index is None:
288-
attr[:] = 0.0
289+
self.__setattr__(attr_name, torch.zeros_like(attr))
289290
else:
290-
attr[env_index] = 0.0
291+
self.__setattr__(
292+
attr_name, TorchUtils.where_from_index(env_index, 0, attr)
293+
)
294+
295+
def zero_grad(self):
296+
for attr_name in ["pos", "rot", "vel", "ang_vel"]:
297+
attr = self.__getattribute__(attr_name)
298+
if attr is not None:
299+
self.__setattr__(attr_name, attr.detach())
291300

292301
def _spawn(self, dim_c: int, dim_p: int):
293302
self.pos = torch.zeros(
@@ -363,14 +372,25 @@ def torque(self, value):
363372

364373
@override(EntityState)
365374
def _reset(self, env_index: typing.Optional[int]):
366-
for attr in [self.c, self.force, self.torque]:
375+
for attr_name in ["c", "force", "torque"]:
376+
attr = self.__getattribute__(attr_name)
367377
if attr is not None:
368378
if env_index is None:
369-
attr[:] = 0.0
379+
self.__setattr__(attr_name, torch.zeros_like(attr))
370380
else:
371-
attr[env_index] = 0.0
381+
self.__setattr__(
382+
attr_name, TorchUtils.where_from_index(env_index, 0, attr)
383+
)
372384
super()._reset(env_index)
373385

386+
@override(EntityState)
387+
def zero_grad(self):
388+
for attr_name in ["c", "force", "torque"]:
389+
attr = self.__getattribute__(attr_name)
390+
if attr is not None:
391+
self.__setattr__(attr_name, attr.detach())
392+
super().zero_grad()
393+
374394
@override(EntityState)
375395
def _spawn(self, dim_c: int, dim_p: int):
376396
if dim_c > 0:
@@ -492,12 +512,21 @@ def _to_tensor(self, value):
492512
)
493513

494514
def _reset(self, env_index: typing.Optional[int]):
495-
for attr in [self.u, self.c]:
515+
for attr_name in ["u", "c"]:
516+
attr = self.__getattribute__(attr_name)
496517
if attr is not None:
497518
if env_index is None:
498-
attr[:] = 0.0
519+
self.__setattr__(attr_name, torch.zeros_like(attr))
499520
else:
500-
attr[env_index] = 0.0
521+
self.__setattr__(
522+
attr_name, TorchUtils.where_from_index(env_index, 0, attr)
523+
)
524+
525+
def zero_grad(self):
526+
for attr_name in ["u", "c"]:
527+
attr = self.__getattribute__(attr_name)
528+
if attr is not None:
529+
self.__setattr__(attr_name, attr.detach())
501530

502531

503532
# properties and state of physical world entity
@@ -690,6 +719,9 @@ def _spawn(self, dim_c: int, dim_p: int):
690719
def _reset(self, env_index: int):
691720
self.state._reset(env_index)
692721

722+
def zero_grad(self):
723+
self.state.zero_grad()
724+
693725
def set_pos(self, pos: Tensor, batch_index: int):
694726
self._set_state_property(EntityState.pos, self.state, pos, batch_index)
695727

@@ -988,6 +1020,11 @@ def _reset(self, env_index: int):
9881020
self.dynamics.reset(env_index)
9891021
super()._reset(env_index)
9901022

1023+
def zero_grad(self):
1024+
self.action.zero_grad()
1025+
self.dynamics.zero_grad()
1026+
super().zero_grad()
1027+
9911028
@override(Entity)
9921029
def to(self, device: torch.device):
9931030
super().to(device)
@@ -1112,6 +1149,10 @@ def reset(self, env_index: int):
11121149
for e in self.entities:
11131150
e._reset(env_index)
11141151

1152+
def zero_grad(self):
1153+
for e in self.entities:
1154+
e.zero_grad()
1155+
11151156
@property
11161157
def agents(self) -> List[Agent]:
11171158
return self._agents

vmas/simulator/dynamics/common.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ def __init__(
1717
def reset(self, index: Union[Tensor, int] = None):
1818
pass
1919

20+
def zero_grad(self):
21+
pass
22+
2023
@property
2124
def agent(self):
2225
if self._agent is None:

vmas/simulator/dynamics/drone.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,12 @@ def reset(self, index: Union[Tensor, int] = None):
4949
device=self.world.device,
5050
)
5151
else:
52-
self.drone_state[index] = 0.0
52+
self.drone_state = vmas.TorchUtils.where_from_index(
53+
index, 0.0, self.drone_state
54+
)
55+
56+
def zero_grad(self):
57+
self.drone_state = self.drone_state.detach()
5358

5459
def needs_reset(self) -> Tensor:
5560
# Constraint roll and pitch within +-30 degrees

vmas/simulator/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,12 @@ def recursive_require_grad_(value: Union[Dict[str, Tensor], Tensor, List[Tensor]
218218
for val in value:
219219
TorchUtils.recursive_require_grad_(val)
220220

221+
@staticmethod
222+
def where_from_index(env_index, new_value, old_value):
223+
mask = torch.zeros_like(old_value, dtype=torch.bool, device=old_value.device)
224+
mask[env_index] = True
225+
return torch.where(mask, new_value, old_value)
226+
221227

222228
class ScenarioUtils:
223229
@staticmethod

0 commit comments

Comments
 (0)