Skip to content

Commit b434185

Browse files
committed
[Scenario] Update transport to never spawn package on goal
1 parent f09a406 commit b434185

1 file changed

Lines changed: 57 additions & 69 deletions

File tree

vmas/scenarios/transport.py

Lines changed: 57 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022-2023.
1+
# Copyright (c) 2022-2024.
22
# ProrokLab (https://www.proroklab.org/)
33
# All rights reserved.
44

@@ -8,7 +8,7 @@
88
from vmas.simulator.core import Agent, Box, Landmark, Sphere, World
99
from vmas.simulator.heuristic_policy import BaseHeuristicPolicy
1010
from vmas.simulator.scenario import BaseScenario
11-
from vmas.simulator.utils import Color
11+
from vmas.simulator.utils import Color, ScenarioUtils
1212

1313

1414
class Scenario(BaseScenario):
@@ -20,12 +20,25 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
2020
self.package_mass = kwargs.get("package_mass", 50)
2121

2222
self.shaping_factor = 100
23+
self.world_semidim = 1
24+
self.agent_radius = 0.03
2325

2426
# Make world
25-
world = World(batch_dim, device)
27+
world = World(
28+
batch_dim,
29+
device,
30+
x_semidim=self.world_semidim
31+
+ 2 * self.agent_radius
32+
+ max(self.package_length, self.package_width),
33+
y_semidim=self.world_semidim
34+
+ 2 * self.agent_radius
35+
+ max(self.package_length, self.package_width),
36+
)
2637
# Add agents
2738
for i in range(n_agents):
28-
agent = Agent(name=f"agent_{i}", shape=Sphere(0.03), u_multiplier=0.6)
39+
agent = Agent(
40+
name=f"agent_{i}", shape=Sphere(self.agent_radius), u_multiplier=0.6
41+
)
2942
world.add_agent(agent)
3043
# Add landmarks
3144
goal = Landmark(
@@ -52,35 +65,50 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
5265
return world
5366

5467
def reset_world_at(self, env_index: int = None):
68+
# Random pos between -1 and 1
69+
ScenarioUtils.spawn_entities_randomly(
70+
self.world.agents,
71+
self.world,
72+
env_index,
73+
min_dist_between_entities=self.agent_radius * 2,
74+
x_bounds=(
75+
-self.world_semidim,
76+
self.world_semidim,
77+
),
78+
y_bounds=(
79+
-self.world_semidim,
80+
self.world_semidim,
81+
),
82+
)
83+
agent_occupied_positions = torch.stack(
84+
[agent.state.pos for agent in self.world.agents], dim=1
85+
)
86+
if env_index is not None:
87+
agent_occupied_positions = agent_occupied_positions[env_index].unsqueeze(0)
88+
5589
goal = self.world.landmarks[0]
56-
goal.set_pos(
57-
torch.zeros(
58-
(1, self.world.dim_p)
59-
if env_index is not None
60-
else (self.world.batch_dim, self.world.dim_p),
61-
device=self.world.device,
62-
dtype=torch.float32,
63-
).uniform_(
64-
-1.0,
65-
1.0,
90+
ScenarioUtils.spawn_entities_randomly(
91+
[goal] + self.packages,
92+
self.world,
93+
env_index,
94+
min_dist_between_entities=max(
95+
package.shape.circumscribed_radius() + goal.shape.radius + 0.01
96+
for package in self.packages
6697
),
67-
batch_index=env_index,
98+
x_bounds=(
99+
-self.world_semidim,
100+
self.world_semidim,
101+
),
102+
y_bounds=(
103+
-self.world_semidim,
104+
self.world_semidim,
105+
),
106+
occupied_positions=agent_occupied_positions,
68107
)
108+
69109
for i, package in enumerate(self.packages):
70-
package.set_pos(
71-
torch.zeros(
72-
(1, self.world.dim_p)
73-
if env_index is not None
74-
else (self.world.batch_dim, self.world.dim_p),
75-
device=self.world.device,
76-
dtype=torch.float32,
77-
).uniform_(
78-
-1.0,
79-
1.0,
80-
),
81-
batch_index=env_index,
82-
)
83110
package.on_goal = self.world.is_overlapping(package, package.goal)
111+
84112
if env_index is None:
85113
package.global_shaping = (
86114
torch.linalg.vector_norm(
@@ -95,38 +123,6 @@ def reset_world_at(self, env_index: int = None):
95123
)
96124
* self.shaping_factor
97125
)
98-
for i, agent in enumerate(self.world.agents):
99-
# Random pos between -1 and 1
100-
agent.set_pos(
101-
torch.zeros(
102-
(1, self.world.dim_p)
103-
if env_index is not None
104-
else (self.world.batch_dim, self.world.dim_p),
105-
device=self.world.device,
106-
dtype=torch.float32,
107-
).uniform_(
108-
-1.0,
109-
1.0,
110-
),
111-
batch_index=env_index,
112-
)
113-
for package in self.packages:
114-
while self.world.is_overlapping(
115-
agent, package, env_index=env_index
116-
).any():
117-
agent.set_pos(
118-
torch.zeros(
119-
(1, self.world.dim_p)
120-
if env_index is not None
121-
else (self.world.batch_dim, self.world.dim_p),
122-
device=self.world.device,
123-
dtype=torch.float32,
124-
).uniform_(
125-
-1.0,
126-
1.0,
127-
),
128-
batch_index=env_index,
129-
)
130126

131127
def reward(self, agent: Agent):
132128
is_first = agent == self.world.agents[0]
@@ -343,12 +339,4 @@ def get_action(
343339

344340

345341
if __name__ == "__main__":
346-
render_interactively(
347-
__file__,
348-
control_two_agents=True,
349-
n_agents=4,
350-
n_packages=1,
351-
package_width=0.15,
352-
package_length=0.15,
353-
package_mass=50,
354-
)
342+
render_interactively(__file__, control_two_agents=True)

0 commit comments

Comments
 (0)