Skip to content

Commit 99f562c

Browse files
[Notbooks] Refresh notebooks (#90)
* amend * amend * amend * amend
1 parent ce242ab commit 99f562c

3 files changed

Lines changed: 37 additions & 35 deletions

File tree

notebooks/VMAS_RLlib.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@
4444
"! git clone https://github.com/proroklab/VectorizedMultiAgentSimulator.git\n",
4545
"%cd /content/VectorizedMultiAgentSimulator\n",
4646
"!pip install -e .\n",
47-
"!pip install \"ray[rllib]\"==2.2 wandb"
47+
"!pip install \"ray[rllib]\"==2.2 wandb\n",
48+
"!pip install \"pydantic<2\" numpy==1.23.5"
4849
]
4950
},
5051
{

notebooks/VMAS_Use_vmas_environment.ipynb

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -77,38 +77,43 @@
7777
"import time\n",
7878
"import torch\n",
7979
"from vmas import make_env\n",
80+
"from vmas.simulator.core import Agent\n",
81+
"\n",
82+
"def _get_deterministic_action(agent: Agent, continuous: bool, env):\n",
83+
" if continuous:\n",
84+
" action = -agent.action.u_range_tensor.expand(env.batch_dim, agent.action_size)\n",
85+
" else:\n",
86+
" action = (\n",
87+
" torch.tensor([1], device=env.device, dtype=torch.long)\n",
88+
" .unsqueeze(-1)\n",
89+
" .expand(env.batch_dim, 1)\n",
90+
" )\n",
91+
" return action.clone()\n",
8092
"\n",
8193
"def use_vmas_env(\n",
82-
" render: bool = False,\n",
83-
" save_render: bool = False,\n",
84-
" num_envs: int = 32,\n",
85-
" n_steps: int = 100,\n",
86-
" device: str = \"cpu\",\n",
87-
" scenario: Union[str, BaseScenario]= \"waterfall\",\n",
88-
" n_agents: int = 4,\n",
89-
" continuous_actions: bool = True,\n",
94+
" render: bool,\n",
95+
" num_envs: int,\n",
96+
" n_steps: int,\n",
97+
" device: str,\n",
98+
" scenario: Union[str, BaseScenario],\n",
99+
" continuous_actions: bool,\n",
100+
" random_action: bool,\n",
90101
" **kwargs\n",
91102
"):\n",
92-
" \"\"\"Example function to use a vmas environment\n",
103+
" \"\"\"Example function to use a vmas environment.\n",
104+
" \n",
105+
" This is a simplification of the function in `vmas.examples.use_vmas_env.py`.\n",
93106
"\n",
94107
" Args:\n",
95108
" continuous_actions (bool): Whether the agents have continuous or discrete actions\n",
96-
" n_agents (int): Number of agents\n",
97109
" scenario (str): Name of scenario\n",
98110
" device (str): Torch device to use\n",
99111
" render (bool): Whether to render the scenario\n",
100-
" save_render (bool): Whether to save render of the scenario\n",
101112
" num_envs (int): Number of vectorized environments\n",
102113
" n_steps (int): Number of steps before returning done\n",
103-
"\n",
104-
" Returns:\n",
114+
" random_action (bool): Use random actions or have all agents perform the down action\n",
105115
"\n",
106116
" \"\"\"\n",
107-
" assert not (save_render and not render), \"To save the video you have to render it\"\n",
108-
"\n",
109-
" simple_2d_action = (\n",
110-
" [0, -1.0] if continuous_actions else [3]\n",
111-
" ) # Simple action for an agent with 2d actions\n",
112117
"\n",
113118
" scenario_name = scenario if isinstance(scenario,str) else scenario.__class__.__name__\n",
114119
"\n",
@@ -117,10 +122,8 @@
117122
" num_envs=num_envs,\n",
118123
" device=device,\n",
119124
" continuous_actions=continuous_actions,\n",
120-
" wrapper=None,\n",
121-
" seed=None,\n",
125+
" seed=0,\n",
122126
" # Environment specific variables\n",
123-
" n_agents=n_agents,\n",
124127
" **kwargs\n",
125128
" )\n",
126129
"\n",
@@ -134,31 +137,29 @@
134137
"\n",
135138
" actions = []\n",
136139
" for i, agent in enumerate(env.agents):\n",
137-
" action = torch.tensor(\n",
138-
" simple_2d_action,\n",
139-
" device=device,\n",
140-
" ).repeat(num_envs, 1)\n",
140+
" if not random_action:\n",
141+
" action = _get_deterministic_action(agent, continuous_actions, env)\n",
142+
" else:\n",
143+
" action = env.get_random_action(agent)\n",
141144
"\n",
142145
" actions.append(action)\n",
143146
"\n",
144147
" obs, rews, dones, info = env.step(actions)\n",
145148
"\n",
146149
" if render:\n",
147150
" frame = env.render(\n",
148-
" mode=\"rgb_array\" if save_render else \"human\",\n",
151+
" mode=\"rgb_array\",\n",
149152
" agent_index_focus=None, # Can give the camera an agent index to focus on\n",
150-
" visualize_when_rgb=True,\n",
151153
" )\n",
152-
" if save_render:\n",
153-
" frame_list.append(frame)\n",
154+
" frame_list.append(frame)\n",
154155
"\n",
155156
" total_time = time.time() - init_time\n",
156157
" print(\n",
157158
" f\"It took: {total_time}s for {n_steps} steps of {num_envs} parallel environments on device {device} \"\n",
158159
" f\"for {scenario_name} scenario.\"\n",
159160
" )\n",
160161
"\n",
161-
" if render and save_render:\n",
162+
" if render:\n",
162163
" from moviepy.editor import ImageSequenceClip\n",
163164
" fps=30\n",
164165
" clip = ImageSequenceClip(frame_list, fps=fps)\n",
@@ -177,11 +178,11 @@
177178
"use_vmas_env(\n",
178179
" scenario=scenario_name,\n",
179180
" render=True,\n",
180-
" save_render=True,\n",
181181
" num_envs=32,\n",
182-
" n_steps=150,\n",
182+
" n_steps=100,\n",
183183
" device=\"cuda\",\n",
184-
" continuous_actions=True,\n",
184+
" continuous_actions=False,\n",
185+
" random_action=False,\n",
185186
" # Environment specific variables\n",
186187
" n_agents=4,\n",
187188
")"

vmas/examples/use_vmas_env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def _get_deterministic_action(agent: Agent, continuous: bool, env):
2020
.unsqueeze(-1)
2121
.expand(env.batch_dim, 1)
2222
)
23-
return action
23+
return action.clone()
2424

2525

2626
def use_vmas_env(

0 commit comments

Comments
 (0)