|
29 | 29 | "id": "0NsC_EwfCF5I" |
30 | 30 | } |
31 | 31 | }, |
32 | | - { |
33 | | - "cell_type": "code", |
34 | | - "execution_count": null, |
35 | | - "metadata": { |
36 | | - "id": "cP9ijqwvIXGd", |
37 | | - "cellView": "form" |
38 | | - }, |
39 | | - "outputs": [], |
40 | | - "source": [ |
41 | | - "#@title\n", |
42 | | - "! git clone https://github.com/proroklab/VectorizedMultiAgentSimulator.git" |
43 | | - ] |
44 | | - }, |
45 | 32 | { |
46 | 33 | "cell_type": "code", |
47 | 34 | "source": [ |
48 | 35 | "#@title\n", |
| 36 | + "! git clone https://github.com/proroklab/VectorizedMultiAgentSimulator.git\n", |
49 | 37 | "%cd /content/VectorizedMultiAgentSimulator\n", |
50 | | - "\n", |
51 | | - "!pip install -r requirements.txt\n", |
52 | | - "!apt-get update\n", |
53 | | - "!apt-get install -y x11-utils \n", |
54 | | - "!apt-get install -y xvfb\n", |
55 | | - "!apt-get install -y imagemagick\n", |
56 | 38 | "!pip install -e ." |
57 | 39 | ], |
58 | 40 | "metadata": { |
59 | | - "id": "zjnXLxaOMLuv", |
60 | | - "cellView": "form" |
| 41 | + "id": "zjnXLxaOMLuv" |
61 | 42 | }, |
62 | 43 | "execution_count": null, |
63 | 44 | "outputs": [] |
|
66 | 47 | "cell_type": "code", |
67 | 48 | "source": [ |
68 | 49 | "#@title\n", |
| 50 | + "!sudo apt-get update\n", |
| 51 | + "!sudo apt-get install python3-opengl xvfb\n", |
69 | 52 | "!pip install pyvirtualdisplay\n", |
70 | 53 | "import pyvirtualdisplay\n", |
71 | 54 | "display = pyvirtualdisplay.Display(visible=False, size=(1400, 900))\n", |
72 | 55 | "display.start()" |
73 | 56 | ], |
74 | 57 | "metadata": { |
75 | | - "id": "5wilTW60cNr4", |
76 | | - "cellView": "form" |
| 58 | + "id": "1ZpWFjvHOpZJ" |
77 | 59 | }, |
78 | 60 | "execution_count": null, |
79 | 61 | "outputs": [] |
|
90 | 72 | { |
91 | 73 | "cell_type": "code", |
92 | 74 | "source": [ |
93 | | - "# Copyright (c) 2022.\n", |
94 | | - "# ProrokLab (https://www.proroklab.org/)\n", |
95 | | - "# All rights reserved.\n", |
96 | | - "\n", |
| 75 | + "from vmas.simulator.scenario import BaseScenario\n", |
| 76 | + "from typing import Union\n", |
97 | 77 | "import time\n", |
98 | | - "import random\n", |
99 | 78 | "import torch\n", |
100 | | - "from PIL import Image\n", |
101 | | - "\n", |
102 | 79 | "from vmas import make_env\n", |
103 | 80 | "\n", |
104 | | - "scenario_name = \"waterfall\"\n", |
105 | | - "\n", |
106 | | - "# Scenario specific variables\n", |
107 | | - "n_agents = 4\n", |
| 81 | + "def use_vmas_env(\n", |
| 82 | + " render: bool = False,\n", |
| 83 | + " save_render: bool = False,\n", |
| 84 | + " num_envs: int = 32,\n", |
| 85 | + " n_steps: int = 100,\n", |
| 86 | + " device: str = \"cpu\",\n", |
| 87 | + " scenario: Union[str, BaseScenario]= \"waterfall\",\n", |
| 88 | + " n_agents: int = 4,\n", |
| 89 | + " continuous_actions: bool = True,\n", |
| 90 | + " **kwargs\n", |
| 91 | + "):\n", |
| 92 | + " \"\"\"Example function to use a vmas environment\n", |
| 93 | + "\n", |
| 94 | + " Args:\n", |
| 95 | + " continuous_actions (bool): Whether the agents have continuous or discrete actions\n", |
| 96 | + " n_agents (int): Number of agents\n", |
| 97 | + " scenario (str): Name of scenario\n", |
| 98 | + " device (str): Torch device to use\n", |
| 99 | + " render (bool): Whether to render the scenario\n", |
| 100 | + " save_render (bool): Whether to save render of the scenario\n", |
| 101 | + " num_envs (int): Number of vectorized environments\n", |
| 102 | + " n_steps (int): Number of steps before returning done\n", |
| 103 | + "\n", |
| 104 | + " Returns:\n", |
| 105 | + "\n", |
| 106 | + " \"\"\"\n", |
| 107 | + " assert not (save_render and not render), \"To save the video you have to render it\"\n", |
| 108 | + "\n", |
| 109 | + " simple_2d_action = (\n", |
| 110 | + " [0, -1.0] if continuous_actions else [3]\n", |
| 111 | + " ) # Simple action for an agent with 2d actions\n", |
| 112 | + "\n", |
| 113 | + " scenario_name = scenario if isinstance(scenario,str) else scenario.__class__.__name__\n", |
| 114 | + "\n", |
| 115 | + " env = make_env(\n", |
| 116 | + " scenario=scenario,\n", |
| 117 | + " num_envs=num_envs,\n", |
| 118 | + " device=device,\n", |
| 119 | + " continuous_actions=continuous_actions,\n", |
| 120 | + " wrapper=None,\n", |
| 121 | + " seed=None,\n", |
| 122 | + " # Environment specific variables\n", |
| 123 | + " n_agents=n_agents,\n", |
| 124 | + " **kwargs\n", |
| 125 | + " )\n", |
| 126 | + "\n", |
| 127 | + " frame_list = [] # For creating a gif\n", |
| 128 | + " init_time = time.time()\n", |
| 129 | + " step = 0\n", |
| 130 | + "\n", |
| 131 | + " for s in range(n_steps):\n", |
| 132 | + " step += 1\n", |
| 133 | + " print(f\"Step {step}\")\n", |
| 134 | + "\n", |
| 135 | + " actions = []\n", |
| 136 | + " for i, agent in enumerate(env.agents):\n", |
| 137 | + " action = torch.tensor(\n", |
| 138 | + " simple_2d_action,\n", |
| 139 | + " device=device,\n", |
| 140 | + " ).repeat(num_envs, 1)\n", |
108 | 141 | "\n", |
109 | | - "num_envs = 32 # Number of vectorized environments\n", |
110 | | - "continuous_actions = False\n", |
111 | | - "device = \"cpu\" # or cuda or any other torch device\n", |
112 | | - "n_steps = 100 # Number of steps before returning done\n", |
113 | | - "dict_spaces = True # Weather to return obs, rewards, and infos as dictionaries with agent names (by default they are lists of len # of agents)\n", |
114 | | - "\n", |
115 | | - "simple_2d_action = (\n", |
116 | | - " [0, 0.5] if continuous_actions else [3]\n", |
117 | | - ") # Simple action tell each agent to go down\n", |
118 | | - "\n", |
119 | | - "env = make_env(\n", |
120 | | - " scenario=scenario_name,\n", |
121 | | - " num_envs=num_envs,\n", |
122 | | - " device=device,\n", |
123 | | - " continuous_actions=continuous_actions,\n", |
124 | | - " dict_spaces=dict_spaces,\n", |
125 | | - " wrapper=None,\n", |
126 | | - " seed=None,\n", |
127 | | - " # Environment specific variables\n", |
128 | | - " n_agents=n_agents,\n", |
129 | | - ")\n", |
130 | | - "\n", |
131 | | - "frame_list = [] # For creating a gif\n", |
132 | | - "init_time = time.time()\n", |
133 | | - "step = 0\n", |
134 | | - "\n", |
135 | | - "for s in range(n_steps):\n", |
136 | | - " step += 1\n", |
137 | | - " print(f\"Step {step}\")\n", |
138 | | - "\n", |
139 | | - " # VMAS actions can be either a list of tensors (one per agent)\n", |
140 | | - " # or a dict of tensors (one entry per agent with its name as key)\n", |
141 | | - " # Both action inputs can be used independently of what type of space its chosen\n", |
142 | | - " dict_actions = random.choice([True, False])\n", |
143 | | - "\n", |
144 | | - " actions = {} if dict_actions else []\n", |
145 | | - " for i, agent in enumerate(env.agents):\n", |
146 | | - " action = torch.tensor(\n", |
147 | | - " simple_2d_action,\n", |
148 | | - " device=device,\n", |
149 | | - " ).repeat(num_envs, 1)\n", |
150 | | - " if dict_actions:\n", |
151 | | - " actions.update({agent.name: action})\n", |
152 | | - " else:\n", |
153 | 142 | " actions.append(action)\n", |
154 | 143 | "\n", |
155 | | - " obs, rews, dones, info = env.step(actions)\n", |
156 | | - "\n", |
157 | | - " frame_list.append(\n", |
158 | | - " Image.fromarray(env.render(mode=\"rgb_array\", agent_index_focus=None))\n", |
159 | | - " ) # Can give the camera an agent index to focus on\n", |
160 | | - "\n", |
161 | | - "gif_name = scenario_name + \".gif\"\n", |
162 | | - "\n", |
163 | | - "# Produce a gif\n", |
164 | | - "frame_list[0].save(\n", |
165 | | - " gif_name,\n", |
166 | | - " save_all=True,\n", |
167 | | - " append_images=frame_list[1:],\n", |
168 | | - " duration=3,\n", |
169 | | - " loop=0,\n", |
170 | | - ")\n", |
171 | | - "\n", |
172 | | - "\n", |
173 | | - "total_time = time.time() - init_time\n", |
174 | | - "print(\n", |
175 | | - " f\"It took: {total_time}s for {n_steps} steps of {num_envs} parallel environments on device {device} \"\n", |
176 | | - " f\"for {scenario_name} scenario.\"\n", |
177 | | - ")" |
| 144 | + " obs, rews, dones, info = env.step(actions)\n", |
| 145 | + "\n", |
| 146 | + " if render:\n", |
| 147 | + " frame = env.render(\n", |
| 148 | + " mode=\"rgb_array\" if save_render else \"human\",\n", |
| 149 | + " agent_index_focus=None, # Can give the camera an agent index to focus on\n", |
| 150 | + " visualize_when_rgb=True,\n", |
| 151 | + " )\n", |
| 152 | + " if save_render:\n", |
| 153 | + " frame_list.append(frame)\n", |
| 154 | + "\n", |
| 155 | + " total_time = time.time() - init_time\n", |
| 156 | + " print(\n", |
| 157 | + " f\"It took: {total_time}s for {n_steps} steps of {num_envs} parallel environments on device {device} \"\n", |
| 158 | + " f\"for {scenario_name} scenario.\"\n", |
| 159 | + " )\n", |
| 160 | + "\n", |
| 161 | + " if render and save_render:\n", |
| 162 | + " from moviepy.editor import ImageSequenceClip\n", |
| 163 | + " fps=30\n", |
| 164 | + " clip = ImageSequenceClip(frame_list, fps=fps)\n", |
| 165 | + " clip.write_gif(f'{scenario_name}.gif', fps=fps)" |
178 | 166 | ], |
179 | 167 | "metadata": { |
180 | 168 | "id": "2Ol4AFeRQ3Ma" |
|
185 | 173 | { |
186 | 174 | "cell_type": "code", |
187 | 175 | "source": [ |
188 | | - "from IPython.display import Image\n", |
189 | | - "Image(open(f'{scenario_name}.gif','rb').read())" |
| 176 | + "scenario_name=\"waterfall\"\n", |
| 177 | + "use_vmas_env(\n", |
| 178 | + " scenario=scenario_name,\n", |
| 179 | + " render=True,\n", |
| 180 | + " save_render=True,\n", |
| 181 | + " num_envs=32,\n", |
| 182 | + " n_steps=150,\n", |
| 183 | + " device=\"cuda\",\n", |
| 184 | + " continuous_actions=True,\n", |
| 185 | + " # Environment specific variables\n", |
| 186 | + " n_agents=4,\n", |
| 187 | + ")" |
190 | 188 | ], |
191 | 189 | "metadata": { |
192 | | - "id": "UPRa91hMPU1n" |
| 190 | + "id": "3cskWki-O8Ul" |
193 | 191 | }, |
194 | 192 | "execution_count": null, |
195 | 193 | "outputs": [] |
196 | 194 | }, |
197 | 195 | { |
198 | 196 | "cell_type": "code", |
199 | 197 | "source": [ |
200 | | - "import os\n", |
201 | | - "# Requires imagemagick to be installed to convert the gif in faster format\n", |
202 | | - "os.system(f\"convert -delay 1x30 -loop 0 {gif_name} {scenario_name}_fast.gif\")\n", |
203 | 198 | "from IPython.display import Image\n", |
204 | | - "Image(open(f'{scenario_name}_fast.gif','rb').read())" |
| 199 | + "Image(f'{scenario_name}.gif')" |
205 | 200 | ], |
206 | 201 | "metadata": { |
207 | | - "id": "BohliLebMOJB" |
| 202 | + "id": "UPRa91hMPU1n" |
208 | 203 | }, |
209 | 204 | "execution_count": null, |
210 | 205 | "outputs": [] |
|
0 commit comments