|
77 | 77 | "import time\n", |
78 | 78 | "import torch\n", |
79 | 79 | "from vmas import make_env\n", |
| 80 | + "from vmas.simulator.core import Agent\n", |
| 81 | + "\n", |
| 82 | + "def _get_deterministic_action(agent: Agent, continuous: bool, env):\n", |
| 83 | + " if continuous:\n", |
| 84 | + " action = -agent.action.u_range_tensor.expand(env.batch_dim, agent.action_size)\n", |
| 85 | + " else:\n", |
| 86 | + " action = (\n", |
| 87 | + " torch.tensor([1], device=env.device, dtype=torch.long)\n", |
| 88 | + " .unsqueeze(-1)\n", |
| 89 | + " .expand(env.batch_dim, 1)\n", |
| 90 | + " )\n", |
| 91 | + " return action.clone()\n", |
80 | 92 | "\n", |
81 | 93 | "def use_vmas_env(\n", |
82 | | - " render: bool = False,\n", |
83 | | - " save_render: bool = False,\n", |
84 | | - " num_envs: int = 32,\n", |
85 | | - " n_steps: int = 100,\n", |
86 | | - " device: str = \"cpu\",\n", |
87 | | - " scenario: Union[str, BaseScenario]= \"waterfall\",\n", |
88 | | - " n_agents: int = 4,\n", |
89 | | - " continuous_actions: bool = True,\n", |
| 94 | + " render: bool,\n", |
| 95 | + " num_envs: int,\n", |
| 96 | + " n_steps: int,\n", |
| 97 | + " device: str,\n", |
| 98 | + " scenario: Union[str, BaseScenario],\n", |
| 99 | + " continuous_actions: bool,\n", |
| 100 | + " random_action: bool,\n", |
90 | 101 | " **kwargs\n", |
91 | 102 | "):\n", |
92 | | - " \"\"\"Example function to use a vmas environment\n", |
| 103 | + " \"\"\"Example function to use a vmas environment.\n", |
| 104 | + " \n", |
| 105 | + " This is a simplification of the function in `vmas.examples.use_vmas_env.py`.\n", |
93 | 106 | "\n", |
94 | 107 | " Args:\n", |
95 | 108 | " continuous_actions (bool): Whether the agents have continuous or discrete actions\n", |
96 | | - " n_agents (int): Number of agents\n", |
97 | 109 | " scenario (str): Name of scenario\n", |
98 | 110 | " device (str): Torch device to use\n", |
99 | 111 | " render (bool): Whether to render the scenario\n", |
100 | | - " save_render (bool): Whether to save render of the scenario\n", |
101 | 112 | " num_envs (int): Number of vectorized environments\n", |
102 | 113 | " n_steps (int): Number of steps before returning done\n", |
103 | | - "\n", |
104 | | - " Returns:\n", |
| 114 | + " random_action (bool): Use random actions or have all agents perform the down action\n", |
105 | 115 | "\n", |
106 | 116 | " \"\"\"\n", |
107 | | - " assert not (save_render and not render), \"To save the video you have to render it\"\n", |
108 | | - "\n", |
109 | | - " simple_2d_action = (\n", |
110 | | - " [0, -1.0] if continuous_actions else [3]\n", |
111 | | - " ) # Simple action for an agent with 2d actions\n", |
112 | 117 | "\n", |
113 | 118 | " scenario_name = scenario if isinstance(scenario,str) else scenario.__class__.__name__\n", |
114 | 119 | "\n", |
|
117 | 122 | " num_envs=num_envs,\n", |
118 | 123 | " device=device,\n", |
119 | 124 | " continuous_actions=continuous_actions,\n", |
120 | | - " wrapper=None,\n", |
121 | | - " seed=None,\n", |
| 125 | + " seed=0,\n", |
122 | 126 | " # Environment specific variables\n", |
123 | | - " n_agents=n_agents,\n", |
124 | 127 | " **kwargs\n", |
125 | 128 | " )\n", |
126 | 129 | "\n", |
|
134 | 137 | "\n", |
135 | 138 | " actions = []\n", |
136 | 139 | " for i, agent in enumerate(env.agents):\n", |
137 | | - " action = torch.tensor(\n", |
138 | | - " simple_2d_action,\n", |
139 | | - " device=device,\n", |
140 | | - " ).repeat(num_envs, 1)\n", |
| 140 | + " if not random_action:\n", |
| 141 | + " action = _get_deterministic_action(agent, continuous_actions, env)\n", |
| 142 | + " else:\n", |
| 143 | + " action = env.get_random_action(agent)\n", |
141 | 144 | "\n", |
142 | 145 | " actions.append(action)\n", |
143 | 146 | "\n", |
144 | 147 | " obs, rews, dones, info = env.step(actions)\n", |
145 | 148 | "\n", |
146 | 149 | " if render:\n", |
147 | 150 | " frame = env.render(\n", |
148 | | - " mode=\"rgb_array\" if save_render else \"human\",\n", |
| 151 | + " mode=\"rgb_array\",\n", |
149 | 152 | " agent_index_focus=None, # Can give the camera an agent index to focus on\n", |
150 | | - " visualize_when_rgb=True,\n", |
151 | 153 | " )\n", |
152 | | - " if save_render:\n", |
153 | | - " frame_list.append(frame)\n", |
| 154 | + " frame_list.append(frame)\n", |
154 | 155 | "\n", |
155 | 156 | " total_time = time.time() - init_time\n", |
156 | 157 | " print(\n", |
157 | 158 | " f\"It took: {total_time}s for {n_steps} steps of {num_envs} parallel environments on device {device} \"\n", |
158 | 159 | " f\"for {scenario_name} scenario.\"\n", |
159 | 160 | " )\n", |
160 | 161 | "\n", |
161 | | - " if render and save_render:\n", |
| 162 | + " if render:\n", |
162 | 163 | " from moviepy.editor import ImageSequenceClip\n", |
163 | 164 | " fps=30\n", |
164 | 165 | " clip = ImageSequenceClip(frame_list, fps=fps)\n", |
|
177 | 178 | "use_vmas_env(\n", |
178 | 179 | " scenario=scenario_name,\n", |
179 | 180 | " render=True,\n", |
180 | | - " save_render=True,\n", |
181 | 181 | " num_envs=32,\n", |
182 | | - " n_steps=150,\n", |
| 182 | + " n_steps=100,\n", |
183 | 183 | " device=\"cuda\",\n", |
184 | | - " continuous_actions=True,\n", |
| 184 | + " continuous_actions=False,\n", |
| 185 | + " random_action=False,\n", |
185 | 186 | " # Environment specific variables\n", |
186 | 187 | " n_agents=4,\n", |
187 | 188 | ")" |
|
0 commit comments