Skip to content

Commit bd4e106

Browse files
committed
[Notebook] Update example notebook
1 parent b1112e9 commit bd4e106

1 file changed

Lines changed: 104 additions & 109 deletions

File tree

notebooks/VMAS_Use_vmas_environment.ipynb

Lines changed: 104 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -29,35 +29,16 @@
2929
"id": "0NsC_EwfCF5I"
3030
}
3131
},
32-
{
33-
"cell_type": "code",
34-
"execution_count": null,
35-
"metadata": {
36-
"id": "cP9ijqwvIXGd",
37-
"cellView": "form"
38-
},
39-
"outputs": [],
40-
"source": [
41-
"#@title\n",
42-
"! git clone https://github.com/proroklab/VectorizedMultiAgentSimulator.git"
43-
]
44-
},
4532
{
4633
"cell_type": "code",
4734
"source": [
4835
"#@title\n",
36+
"! git clone https://github.com/proroklab/VectorizedMultiAgentSimulator.git\n",
4937
"%cd /content/VectorizedMultiAgentSimulator\n",
50-
"\n",
51-
"!pip install -r requirements.txt\n",
52-
"!apt-get update\n",
53-
"!apt-get install -y x11-utils \n",
54-
"!apt-get install -y xvfb\n",
55-
"!apt-get install -y imagemagick\n",
5638
"!pip install -e ."
5739
],
5840
"metadata": {
59-
"id": "zjnXLxaOMLuv",
60-
"cellView": "form"
41+
"id": "zjnXLxaOMLuv"
6142
},
6243
"execution_count": null,
6344
"outputs": []
@@ -66,14 +47,15 @@
6647
"cell_type": "code",
6748
"source": [
6849
"#@title\n",
50+
"!sudo apt-get update\n",
51+
"!sudo apt-get install python3-opengl xvfb\n",
6952
"!pip install pyvirtualdisplay\n",
7053
"import pyvirtualdisplay\n",
7154
"display = pyvirtualdisplay.Display(visible=False, size=(1400, 900))\n",
7255
"display.start()"
7356
],
7457
"metadata": {
75-
"id": "5wilTW60cNr4",
76-
"cellView": "form"
58+
"id": "1ZpWFjvHOpZJ"
7759
},
7860
"execution_count": null,
7961
"outputs": []
@@ -90,91 +72,97 @@
9072
{
9173
"cell_type": "code",
9274
"source": [
93-
"# Copyright (c) 2022.\n",
94-
"# ProrokLab (https://www.proroklab.org/)\n",
95-
"# All rights reserved.\n",
96-
"\n",
75+
"from vmas.simulator.scenario import BaseScenario\n",
76+
"from typing import Union\n",
9777
"import time\n",
98-
"import random\n",
9978
"import torch\n",
100-
"from PIL import Image\n",
101-
"\n",
10279
"from vmas import make_env\n",
10380
"\n",
104-
"scenario_name = \"waterfall\"\n",
105-
"\n",
106-
"# Scenario specific variables\n",
107-
"n_agents = 4\n",
81+
"def use_vmas_env(\n",
82+
" render: bool = False,\n",
83+
" save_render: bool = False,\n",
84+
" num_envs: int = 32,\n",
85+
" n_steps: int = 100,\n",
86+
" device: str = \"cpu\",\n",
87+
" scenario: Union[str, BaseScenario]= \"waterfall\",\n",
88+
" n_agents: int = 4,\n",
89+
" continuous_actions: bool = True,\n",
90+
" **kwargs\n",
91+
"):\n",
92+
" \"\"\"Example function to use a vmas environment\n",
93+
"\n",
94+
" Args:\n",
95+
" continuous_actions (bool): Whether the agents have continuous or discrete actions\n",
96+
" n_agents (int): Number of agents\n",
97+
" scenario (str): Name of scenario\n",
98+
" device (str): Torch device to use\n",
99+
" render (bool): Whether to render the scenario\n",
100+
" save_render (bool): Whether to save render of the scenario\n",
101+
" num_envs (int): Number of vectorized environments\n",
102+
" n_steps (int): Number of steps before returning done\n",
103+
"\n",
104+
" Returns:\n",
105+
"\n",
106+
" \"\"\"\n",
107+
" assert not (save_render and not render), \"To save the video you have to render it\"\n",
108+
"\n",
109+
" simple_2d_action = (\n",
110+
" [0, -1.0] if continuous_actions else [3]\n",
111+
" ) # Simple action for an agent with 2d actions\n",
112+
"\n",
113+
" scenario_name = scenario if isinstance(scenario,str) else scenario.__class__.__name__\n",
114+
"\n",
115+
" env = make_env(\n",
116+
" scenario=scenario,\n",
117+
" num_envs=num_envs,\n",
118+
" device=device,\n",
119+
" continuous_actions=continuous_actions,\n",
120+
" wrapper=None,\n",
121+
" seed=None,\n",
122+
" # Environment specific variables\n",
123+
" n_agents=n_agents,\n",
124+
" **kwargs\n",
125+
" )\n",
126+
"\n",
127+
" frame_list = [] # For creating a gif\n",
128+
" init_time = time.time()\n",
129+
" step = 0\n",
130+
"\n",
131+
" for s in range(n_steps):\n",
132+
" step += 1\n",
133+
" print(f\"Step {step}\")\n",
134+
"\n",
135+
" actions = []\n",
136+
" for i, agent in enumerate(env.agents):\n",
137+
" action = torch.tensor(\n",
138+
" simple_2d_action,\n",
139+
" device=device,\n",
140+
" ).repeat(num_envs, 1)\n",
108141
"\n",
109-
"num_envs = 32 # Number of vectorized environments\n",
110-
"continuous_actions = False\n",
111-
"device = \"cpu\" # or cuda or any other torch device\n",
112-
"n_steps = 100 # Number of steps before returning done\n",
113-
"dict_spaces = True # Weather to return obs, rewards, and infos as dictionaries with agent names (by default they are lists of len # of agents)\n",
114-
"\n",
115-
"simple_2d_action = (\n",
116-
" [0, 0.5] if continuous_actions else [3]\n",
117-
") # Simple action tell each agent to go down\n",
118-
"\n",
119-
"env = make_env(\n",
120-
" scenario=scenario_name,\n",
121-
" num_envs=num_envs,\n",
122-
" device=device,\n",
123-
" continuous_actions=continuous_actions,\n",
124-
" dict_spaces=dict_spaces,\n",
125-
" wrapper=None,\n",
126-
" seed=None,\n",
127-
" # Environment specific variables\n",
128-
" n_agents=n_agents,\n",
129-
")\n",
130-
"\n",
131-
"frame_list = [] # For creating a gif\n",
132-
"init_time = time.time()\n",
133-
"step = 0\n",
134-
"\n",
135-
"for s in range(n_steps):\n",
136-
" step += 1\n",
137-
" print(f\"Step {step}\")\n",
138-
"\n",
139-
" # VMAS actions can be either a list of tensors (one per agent)\n",
140-
" # or a dict of tensors (one entry per agent with its name as key)\n",
141-
" # Both action inputs can be used independently of what type of space its chosen\n",
142-
" dict_actions = random.choice([True, False])\n",
143-
"\n",
144-
" actions = {} if dict_actions else []\n",
145-
" for i, agent in enumerate(env.agents):\n",
146-
" action = torch.tensor(\n",
147-
" simple_2d_action,\n",
148-
" device=device,\n",
149-
" ).repeat(num_envs, 1)\n",
150-
" if dict_actions:\n",
151-
" actions.update({agent.name: action})\n",
152-
" else:\n",
153142
" actions.append(action)\n",
154143
"\n",
155-
" obs, rews, dones, info = env.step(actions)\n",
156-
"\n",
157-
" frame_list.append(\n",
158-
" Image.fromarray(env.render(mode=\"rgb_array\", agent_index_focus=None))\n",
159-
" ) # Can give the camera an agent index to focus on\n",
160-
"\n",
161-
"gif_name = scenario_name + \".gif\"\n",
162-
"\n",
163-
"# Produce a gif\n",
164-
"frame_list[0].save(\n",
165-
" gif_name,\n",
166-
" save_all=True,\n",
167-
" append_images=frame_list[1:],\n",
168-
" duration=3,\n",
169-
" loop=0,\n",
170-
")\n",
171-
"\n",
172-
"\n",
173-
"total_time = time.time() - init_time\n",
174-
"print(\n",
175-
" f\"It took: {total_time}s for {n_steps} steps of {num_envs} parallel environments on device {device} \"\n",
176-
" f\"for {scenario_name} scenario.\"\n",
177-
")"
144+
" obs, rews, dones, info = env.step(actions)\n",
145+
"\n",
146+
" if render:\n",
147+
" frame = env.render(\n",
148+
" mode=\"rgb_array\" if save_render else \"human\",\n",
149+
" agent_index_focus=None, # Can give the camera an agent index to focus on\n",
150+
" visualize_when_rgb=True,\n",
151+
" )\n",
152+
" if save_render:\n",
153+
" frame_list.append(frame)\n",
154+
"\n",
155+
" total_time = time.time() - init_time\n",
156+
" print(\n",
157+
" f\"It took: {total_time}s for {n_steps} steps of {num_envs} parallel environments on device {device} \"\n",
158+
" f\"for {scenario_name} scenario.\"\n",
159+
" )\n",
160+
"\n",
161+
" if render and save_render:\n",
162+
" from moviepy.editor import ImageSequenceClip\n",
163+
" fps=30\n",
164+
" clip = ImageSequenceClip(frame_list, fps=fps)\n",
165+
" clip.write_gif(f'{scenario_name}.gif', fps=fps)"
178166
],
179167
"metadata": {
180168
"id": "2Ol4AFeRQ3Ma"
@@ -185,26 +173,33 @@
185173
{
186174
"cell_type": "code",
187175
"source": [
188-
"from IPython.display import Image\n",
189-
"Image(open(f'{scenario_name}.gif','rb').read())"
176+
"scenario_name=\"waterfall\"\n",
177+
"use_vmas_env(\n",
178+
" scenario=scenario_name,\n",
179+
" render=True,\n",
180+
" save_render=True,\n",
181+
" num_envs=32,\n",
182+
" n_steps=150,\n",
183+
" device=\"cuda\",\n",
184+
" continuous_actions=True,\n",
185+
" # Environment specific variables\n",
186+
" n_agents=4,\n",
187+
")"
190188
],
191189
"metadata": {
192-
"id": "UPRa91hMPU1n"
190+
"id": "3cskWki-O8Ul"
193191
},
194192
"execution_count": null,
195193
"outputs": []
196194
},
197195
{
198196
"cell_type": "code",
199197
"source": [
200-
"import os\n",
201-
"# Requires imagemagick to be installed to convert the gif in faster format\n",
202-
"os.system(f\"convert -delay 1x30 -loop 0 {gif_name} {scenario_name}_fast.gif\")\n",
203198
"from IPython.display import Image\n",
204-
"Image(open(f'{scenario_name}_fast.gif','rb').read())"
199+
"Image(f'{scenario_name}.gif')"
205200
],
206201
"metadata": {
207-
"id": "BohliLebMOJB"
202+
"id": "UPRa91hMPU1n"
208203
},
209204
"execution_count": null,
210205
"outputs": []

0 commit comments

Comments
 (0)