1- # Copyright (c) 2022-2023 .
1+ # Copyright (c) 2022-2024 .
22# ProrokLab (https://www.proroklab.org/)
33# All rights reserved.
44
88from vmas .simulator .core import Agent , Box , Landmark , Sphere , World
99from vmas .simulator .heuristic_policy import BaseHeuristicPolicy
1010from vmas .simulator .scenario import BaseScenario
11- from vmas .simulator .utils import Color
11+ from vmas .simulator .utils import Color , ScenarioUtils
1212
1313
1414class Scenario (BaseScenario ):
@@ -20,12 +20,25 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
2020 self .package_mass = kwargs .get ("package_mass" , 50 )
2121
2222 self .shaping_factor = 100
23+ self .world_semidim = 1
24+ self .agent_radius = 0.03
2325
2426 # Make world
25- world = World (batch_dim , device )
27+ world = World (
28+ batch_dim ,
29+ device ,
30+ x_semidim = self .world_semidim
31+ + 2 * self .agent_radius
32+ + max (self .package_length , self .package_width ),
33+ y_semidim = self .world_semidim
34+ + 2 * self .agent_radius
35+ + max (self .package_length , self .package_width ),
36+ )
2637 # Add agents
2738 for i in range (n_agents ):
28- agent = Agent (name = f"agent_{ i } " , shape = Sphere (0.03 ), u_multiplier = 0.6 )
39+ agent = Agent (
40+ name = f"agent_{ i } " , shape = Sphere (self .agent_radius ), u_multiplier = 0.6
41+ )
2942 world .add_agent (agent )
3043 # Add landmarks
3144 goal = Landmark (
@@ -52,35 +65,50 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
5265 return world
5366
5467 def reset_world_at (self , env_index : int = None ):
68+ # Random pos between -1 and 1
69+ ScenarioUtils .spawn_entities_randomly (
70+ self .world .agents ,
71+ self .world ,
72+ env_index ,
73+ min_dist_between_entities = self .agent_radius * 2 ,
74+ x_bounds = (
75+ - self .world_semidim ,
76+ self .world_semidim ,
77+ ),
78+ y_bounds = (
79+ - self .world_semidim ,
80+ self .world_semidim ,
81+ ),
82+ )
83+ agent_occupied_positions = torch .stack (
84+ [agent .state .pos for agent in self .world .agents ], dim = 1
85+ )
86+ if env_index is not None :
87+ agent_occupied_positions = agent_occupied_positions [env_index ].unsqueeze (0 )
88+
5589 goal = self .world .landmarks [0 ]
56- goal .set_pos (
57- torch .zeros (
58- (1 , self .world .dim_p )
59- if env_index is not None
60- else (self .world .batch_dim , self .world .dim_p ),
61- device = self .world .device ,
62- dtype = torch .float32 ,
63- ).uniform_ (
64- - 1.0 ,
65- 1.0 ,
90+ ScenarioUtils .spawn_entities_randomly (
91+ [goal ] + self .packages ,
92+ self .world ,
93+ env_index ,
94+ min_dist_between_entities = max (
95+ package .shape .circumscribed_radius () + goal .shape .radius + 0.01
96+ for package in self .packages
6697 ),
67- batch_index = env_index ,
98+ x_bounds = (
99+ - self .world_semidim ,
100+ self .world_semidim ,
101+ ),
102+ y_bounds = (
103+ - self .world_semidim ,
104+ self .world_semidim ,
105+ ),
106+ occupied_positions = agent_occupied_positions ,
68107 )
108+
69109 for i , package in enumerate (self .packages ):
70- package .set_pos (
71- torch .zeros (
72- (1 , self .world .dim_p )
73- if env_index is not None
74- else (self .world .batch_dim , self .world .dim_p ),
75- device = self .world .device ,
76- dtype = torch .float32 ,
77- ).uniform_ (
78- - 1.0 ,
79- 1.0 ,
80- ),
81- batch_index = env_index ,
82- )
83110 package .on_goal = self .world .is_overlapping (package , package .goal )
111+
84112 if env_index is None :
85113 package .global_shaping = (
86114 torch .linalg .vector_norm (
@@ -95,38 +123,6 @@ def reset_world_at(self, env_index: int = None):
95123 )
96124 * self .shaping_factor
97125 )
98- for i , agent in enumerate (self .world .agents ):
99- # Random pos between -1 and 1
100- agent .set_pos (
101- torch .zeros (
102- (1 , self .world .dim_p )
103- if env_index is not None
104- else (self .world .batch_dim , self .world .dim_p ),
105- device = self .world .device ,
106- dtype = torch .float32 ,
107- ).uniform_ (
108- - 1.0 ,
109- 1.0 ,
110- ),
111- batch_index = env_index ,
112- )
113- for package in self .packages :
114- while self .world .is_overlapping (
115- agent , package , env_index = env_index
116- ).any ():
117- agent .set_pos (
118- torch .zeros (
119- (1 , self .world .dim_p )
120- if env_index is not None
121- else (self .world .batch_dim , self .world .dim_p ),
122- device = self .world .device ,
123- dtype = torch .float32 ,
124- ).uniform_ (
125- - 1.0 ,
126- 1.0 ,
127- ),
128- batch_index = env_index ,
129- )
130126
131127 def reward (self , agent : Agent ):
132128 is_first = agent == self .world .agents [0 ]
@@ -343,12 +339,4 @@ def get_action(
343339
344340
345341if __name__ == "__main__" :
346- render_interactively (
347- __file__ ,
348- control_two_agents = True ,
349- n_agents = 4 ,
350- n_packages = 1 ,
351- package_width = 0.15 ,
352- package_length = 0.15 ,
353- package_mass = 50 ,
354- )
342+ render_interactively (__file__ , control_two_agents = True )
0 commit comments