Skip to content

Commit f09a406

Browse files
[Feature] Actions detached from physics and allow any number of actions (#76)
* First * Amend * Amend * Amend * Amend * Amend * Amend * Amend * Amend * Amend
1 parent 127a9ec commit f09a406

14 files changed

Lines changed: 496 additions & 387 deletions

File tree

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,8 +405,9 @@ To create a fake screen you need to have `Xvfb` installed.
405405
- [ ] Reset any number of dimensions
406406
- [ ] Improve test efficiency and add new tests
407407
- [ ] Implement 1D camera sensor
408-
- [ ] Allow any number of actions
409408
- [ ] Implement 2D birds eye view camera sensor
409+
- [ ] Implement 2D drone dynamics
410+
- [X] Allow any number of actions
410411
- [X] Improve VMAS performance
411412
- [X] Dict obs support in torchrl
412413
- [X] Make TextLine a Geom usable in a scenario

vmas/examples/use_vmas_env.py

Lines changed: 43 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022-2023.
1+
# Copyright (c) 2022-2024.
22
# ProrokLab (https://www.proroklab.org/)
33
# All rights reserved.
44
import random
@@ -11,45 +11,53 @@
1111
from vmas.simulator.utils import save_video
1212

1313

14-
def _get_random_action(agent: Agent, continuous: bool):
14+
def _get_random_action(agent: Agent, continuous: bool, env):
1515
if continuous:
16-
action = torch.zeros(
17-
(agent.batch_dim, 2),
18-
device=agent.device,
19-
dtype=torch.float32,
20-
).uniform_(
21-
-agent.action.u_range,
22-
agent.action.u_range,
23-
)
24-
if agent.u_rot_range > 0:
25-
action = torch.cat(
26-
[
27-
action,
16+
actions = []
17+
for action_index in range(agent.action_size):
18+
actions.append(
19+
torch.zeros(
20+
agent.batch_dim,
21+
device=agent.device,
22+
dtype=torch.float32,
23+
).uniform_(
24+
-agent.action.u_range_tensor[action_index],
25+
agent.action.u_range_tensor[action_index],
26+
)
27+
)
28+
if env.world.dim_c != 0 and not agent.silent:
29+
# If the agent needs to communicate
30+
for _ in range(env.world.dim_c):
31+
actions.append(
2832
torch.zeros(
29-
(agent.batch_dim, 1),
33+
agent.batch_dim,
3034
device=agent.device,
3135
dtype=torch.float32,
3236
).uniform_(
33-
-agent.action.u_rot_range,
34-
agent.action.u_rot_range,
35-
),
36-
],
37-
dim=-1,
38-
)
37+
0,
38+
1,
39+
)
40+
)
41+
action = torch.stack(actions, dim=-1)
3942
else:
4043
action = torch.randint(
41-
low=0, high=5, size=(agent.batch_dim,), device=agent.device
44+
low=0,
45+
high=env.get_agent_action_space(agent).n,
46+
size=(agent.batch_dim,),
47+
device=agent.device,
48+
)
49+
return action
50+
51+
52+
def _get_deterministic_action(agent: Agent, continuous: bool, env):
53+
if continuous:
54+
action = -agent.action.u_range_tensor.expand(env.batch_dim, agent.action_size)
55+
else:
56+
action = (
57+
torch.tensor([1], device=env.device, dtype=torch.long)
58+
.unsqueeze(-1)
59+
.expand(env.batch_dim, 1)
4260
)
43-
if agent.u_rot_range > 0:
44-
action = torch.stack(
45-
[
46-
action,
47-
torch.randint(
48-
low=0, high=3, size=(agent.batch_dim,), device=agent.device
49-
),
50-
],
51-
dim=-1,
52-
)
5361
return action
5462

5563

@@ -85,13 +93,6 @@ def use_vmas_env(
8593
dict_spaces = True # Weather to return obs, rewards, and infos as dictionaries with agent names
8694
# (by default they are lists of len # of agents)
8795

88-
simple_2d_action = (
89-
[0, -1.0] if continuous_actions else [3]
90-
) # Simple action for an agent with 2d actions
91-
simple_3d_action = (
92-
[0, -1.0, 0.1] if continuous_actions else [3, 1]
93-
) # Simple action for an agent with 3d actions (2d forces and torque)
94-
9596
env = make_env(
9697
scenario=scenario_name,
9798
num_envs=num_envs,
@@ -120,12 +121,9 @@ def use_vmas_env(
120121
actions = {} if dict_actions else []
121122
for i, agent in enumerate(env.agents):
122123
if not random_action:
123-
action = torch.tensor(
124-
simple_2d_action if agent.u_rot_range == 0 else simple_3d_action,
125-
device=device,
126-
).repeat(num_envs, 1)
124+
action = _get_deterministic_action(agent, continuous_actions, env)
127125
else:
128-
action = _get_random_action(agent, continuous_actions)
126+
action = _get_random_action(agent, continuous_actions, env)
129127
if dict_actions:
130128
actions.update({agent.name: action})
131129
else:
@@ -158,5 +156,5 @@ def use_vmas_env(
158156
render=True,
159157
save_render=False,
160158
random_action=False,
161-
continuous_actions=True,
159+
continuous_actions=False,
162160
)

vmas/interactive_rendering.py

Lines changed: 52 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022-2023.
1+
# Copyright (c) 2022-2024.
22
# ProrokLab (https://www.proroklab.org/)
33
# All rights reserved.
44
"""
@@ -101,21 +101,14 @@ def _cycle(self):
101101
self.reset = False
102102
total_rew = [0] * self.n_agents
103103

104-
action_list = [
105-
[0.0] * self.env.unwrapped().get_agent_action_size(agent)
106-
for agent in self.agents
107-
]
104+
action_list = [[0.0] * agent.action_size for agent in self.agents]
108105
action_list[self.current_agent_index] = self.u[
109-
: self.env.unwrapped().get_agent_action_size(
110-
self.agents[self.current_agent_index]
111-
)
106+
: self.agents[self.current_agent_index].action_size
112107
]
113108

114109
if self.n_agents > 1 and self.control_two_agents:
115110
action_list[self.current_agent_index2] = self.u2[
116-
: self.env.unwrapped().get_agent_action_size(
117-
self.agents[self.current_agent_index2]
118-
)
111+
: self.agents[self.current_agent_index2].action_size
119112
]
120113
obs, rew, done, info = self.env.step(action_list)
121114

@@ -167,56 +160,60 @@ def _write_values(self, index: int, message: str):
167160
def _key_press(self, k, mod):
168161
from pyglet.window import key
169162

170-
agent_range = self.agents[self.current_agent_index].u_range
171-
agent_rot_range = self.agents[self.current_agent_index].u_rot_range
163+
agent_range = self.agents[self.current_agent_index].action.u_range_tensor
164+
try:
165+
if k == key.LEFT:
166+
self.keys[0] = agent_range[0]
167+
elif k == key.RIGHT:
168+
self.keys[1] = agent_range[0]
169+
elif k == key.DOWN:
170+
self.keys[2] = agent_range[1]
171+
elif k == key.UP:
172+
self.keys[3] = agent_range[1]
173+
elif k == key.M:
174+
self.keys[4] = agent_range[2]
175+
elif k == key.N:
176+
self.keys[5] = agent_range[2]
177+
elif k == key.TAB:
178+
self.current_agent_index = self._increment_selected_agent_index(
179+
self.current_agent_index
180+
)
181+
if self.control_two_agents:
182+
while self.current_agent_index == self.current_agent_index2:
183+
self.current_agent_index = self._increment_selected_agent_index(
184+
self.current_agent_index
185+
)
172186

173-
if k == key.LEFT:
174-
self.keys[0] = agent_range
175-
elif k == key.RIGHT:
176-
self.keys[1] = agent_range
177-
elif k == key.DOWN:
178-
self.keys[2] = agent_range
179-
elif k == key.UP:
180-
self.keys[3] = agent_range
181-
elif k == key.M:
182-
self.keys[4] = agent_rot_range
183-
elif k == key.N:
184-
self.keys[5] = agent_rot_range
185-
elif k == key.TAB:
186-
self.current_agent_index = self._increment_selected_agent_index(
187-
self.current_agent_index
188-
)
189187
if self.control_two_agents:
190-
while self.current_agent_index == self.current_agent_index2:
191-
self.current_agent_index = self._increment_selected_agent_index(
192-
self.current_agent_index
193-
)
194-
195-
if self.control_two_agents:
196-
agent2_range = self.agents[self.current_agent_index2].u_range
197-
agent2_rot_range = self.agents[self.current_agent_index2].u_rot_range
198-
199-
if k == key.A:
200-
self.keys2[0] = agent2_range
201-
elif k == key.D:
202-
self.keys2[1] = agent2_range
203-
elif k == key.S:
204-
self.keys2[2] = agent2_range
205-
elif k == key.W:
206-
self.keys2[3] = agent2_range
207-
elif k == key.E:
208-
self.keys2[4] = agent2_rot_range
209-
elif k == key.Q:
210-
self.keys2[5] = agent2_rot_range
211-
212-
elif k == key.LSHIFT:
213-
self.current_agent_index2 = self._increment_selected_agent_index(
188+
agent2_range = self.agents[
214189
self.current_agent_index2
215-
)
216-
while self.current_agent_index == self.current_agent_index2:
190+
].action.u_range_tensor
191+
192+
if k == key.A:
193+
self.keys2[0] = agent2_range[0]
194+
elif k == key.D:
195+
self.keys2[1] = agent2_range[0]
196+
elif k == key.S:
197+
self.keys2[2] = agent2_range[1]
198+
elif k == key.W:
199+
self.keys2[3] = agent2_range[1]
200+
elif k == key.E:
201+
self.keys2[4] = agent2_range[2]
202+
elif k == key.Q:
203+
self.keys2[5] = agent2_range[2]
204+
205+
elif k == key.LSHIFT:
217206
self.current_agent_index2 = self._increment_selected_agent_index(
218207
self.current_agent_index2
219208
)
209+
while self.current_agent_index == self.current_agent_index2:
210+
self.current_agent_index2 = (
211+
self._increment_selected_agent_index(
212+
self.current_agent_index2
213+
)
214+
)
215+
except IndexError:
216+
print("Action not available")
220217

221218
if k == key.R:
222219
self.reset = True

vmas/make_env.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022-2023.
1+
# Copyright (c) 2022-2024.
22
# ProrokLab (https://www.proroklab.org/)
33
# All rights reserved.
44

@@ -22,6 +22,7 @@ def make_env(
2222
max_steps: Optional[int] = None,
2323
seed: Optional[int] = None,
2424
dict_spaces: bool = False,
25+
multidiscrete_actions: bool = False,
2526
**kwargs,
2627
):
2728
"""
@@ -35,7 +36,10 @@ def make_env(
3536
max_steps: Maximum number of steps in each vectorized environment after which done is returned
3637
seed: seed
3738
dict_spaces: Weather to use dictionary i/o spaces with format {agent_name: tensor}
38-
for obs, rewards, and info instead of tuples.
39+
for obs, rewards, and info instead of tuples.
40+
multidiscrete_actions (bool): Whether to use multidiscrete_actions action spaces when continuous_actions=False.
41+
Otherwise, (default) the action space will be Discrete, and it will be the cartesian product of the
42+
action spaces of an agent.
3943
**kwargs ():
4044
4145
Returns:
@@ -55,6 +59,7 @@ def make_env(
5559
max_steps=max_steps,
5660
seed=seed,
5761
dict_spaces=dict_spaces,
62+
multidiscrete_actions=multidiscrete_actions,
5863
**kwargs,
5964
)
6065

vmas/scenarios/debug/diff_drive.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022-2023.
1+
# Copyright (c) 2022-2024.
22
# ProrokLab (https://www.proroklab.org/)
33
# All rights reserved.
44
import typing
@@ -8,7 +8,8 @@
88

99
from vmas import render_interactively
1010
from vmas.simulator.core import Agent, World
11-
from vmas.simulator.dynamics.diff_drive import DiffDriveDynamics
11+
from vmas.simulator.dynamics.diff_drive import DiffDrive
12+
from vmas.simulator.dynamics.holonomic_with_rot import HolonomicWithRotation
1213
from vmas.simulator.scenario import BaseScenario
1314
from vmas.simulator.utils import Color, ScenarioUtils
1415

@@ -39,16 +40,24 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
3940
world = World(batch_dim, device, substeps=10)
4041

4142
for i in range(self.n_agents):
42-
agent = Agent(
43-
name=f"agent_{i}",
44-
collide=True,
45-
render_action=True,
46-
u_range=1,
47-
u_rot_range=1,
48-
u_rot_multiplier=0.001,
49-
)
5043
if i == 0:
51-
agent.dynamics = DiffDriveDynamics(agent, world, integration="rk4")
44+
agent = Agent(
45+
name=f"diff_drive_{i}",
46+
collide=True,
47+
render_action=True,
48+
u_range=[1, 1],
49+
u_multiplier=[1, 0.001],
50+
dynamics=DiffDrive(world, integration="rk4"),
51+
)
52+
else:
53+
agent = Agent(
54+
name=f"holo_rot_{i}",
55+
collide=True,
56+
render_action=True,
57+
u_range=[1, 1, 1],
58+
u_multiplier=[1, 1, 0.001],
59+
dynamics=HolonomicWithRotation(),
60+
)
5261

5362
world.add_agent(agent)
5463

@@ -64,12 +73,6 @@ def reset_world_at(self, env_index: int = None):
6473
y_bounds=(-1, 1),
6574
)
6675

67-
def process_action(self, agent: Agent):
68-
try:
69-
agent.dynamics.process_force()
70-
except AttributeError:
71-
pass
72-
7376
def reward(self, agent: Agent):
7477
return torch.zeros(self.world.batch_dim)
7578

0 commit comments

Comments
 (0)