|
3 | 3 | # All rights reserved. |
4 | 4 | import random |
5 | 5 | from ctypes import byref |
6 | | -from typing import Callable, Dict, List, Optional, Tuple, Union |
| 6 | +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union |
7 | 7 |
|
8 | 8 | import numpy as np |
9 | 9 | import torch |
@@ -197,6 +197,21 @@ def step(self, actions: Union[List, Dict]): |
197 | 197 | dones: Tensor of len 'self.num_envs' of which each element is a bool |
198 | 198 | infos : List on len 'self.n_agents' of which each element is a dictionary for which each key is a metric |
199 | 199 | and the value is a tensor of shape '(self.num_envs, metric_size_per_agent)' |
| 200 | +
|
| 201 | + Examples: |
| 202 | + >>> import vmas |
| 203 | + >>> env = vmas.make_env( |
| 204 | + ... scenario="waterfall", # can be scenario name or BaseScenario class |
| 205 | + ... num_envs=32, |
| 206 | + ... device="cpu", # Or "cuda" for GPU |
| 207 | + ... continuous_actions=True, |
| 208 | + ... max_steps=None, # Defines the horizon. None is infinite horizon. |
| 209 | + ... seed=None, # Seed of the environment |
| 210 | + ... n_agents=3, # Additional arguments you want to pass to the scenario |
| 211 | + ... ) |
| 212 | + >>> obs = env.reset() |
| 213 | + >>> for _ in range(10): |
| 214 | + ... obs, rews, dones, info = env.step(env.get_random_actions()) |
200 | 215 | """ |
201 | 216 | if isinstance(actions, Dict): |
202 | 217 | actions_dict = actions |
@@ -352,6 +367,76 @@ def get_agent_observation_space(self, agent: Agent, obs: AGENT_OBS_TYPE): |
352 | 367 | f"Invalid type of observation {obs} for agent {agent.name}" |
353 | 368 | ) |
354 | 369 |
|
| 370 | + def get_random_action(self, agent: Agent) -> torch.Tensor: |
| 371 | + """Returns a random action for the given agent. |
| 372 | +
|
| 373 | + Args: |
| 374 | + agent (Agent): The agent to get the action for |
| 375 | +
|
| 376 | + Returns: |
| 377 | + torch.tensor: the random actions tensor with shape ``(agent.batch_dim, agent.action_size)`` |
| 378 | +
|
| 379 | + """ |
| 380 | + if self.continuous_actions: |
| 381 | + actions = [] |
| 382 | + for action_index in range(agent.action_size): |
| 383 | + actions.append( |
| 384 | + torch.zeros( |
| 385 | + agent.batch_dim, |
| 386 | + device=agent.device, |
| 387 | + dtype=torch.float32, |
| 388 | + ).uniform_( |
| 389 | + -agent.action.u_range_tensor[action_index], |
| 390 | + agent.action.u_range_tensor[action_index], |
| 391 | + ) |
| 392 | + ) |
| 393 | + if self.world.dim_c != 0 and not agent.silent: |
| 394 | + # If the agent needs to communicate |
| 395 | + for _ in range(self.world.dim_c): |
| 396 | + actions.append( |
| 397 | + torch.zeros( |
| 398 | + agent.batch_dim, |
| 399 | + device=agent.device, |
| 400 | + dtype=torch.float32, |
| 401 | + ).uniform_( |
| 402 | + 0, |
| 403 | + 1, |
| 404 | + ) |
| 405 | + ) |
| 406 | + action = torch.stack(actions, dim=-1) |
| 407 | + else: |
| 408 | + action = torch.randint( |
| 409 | + low=0, |
| 410 | + high=self.get_agent_action_space(agent).n, |
| 411 | + size=(agent.batch_dim,), |
| 412 | + device=agent.device, |
| 413 | + ) |
| 414 | + return action |
| 415 | + |
| 416 | + def get_random_actions(self) -> Sequence[torch.Tensor]: |
| 417 | + """Returns random actions for all agents that you can feed to :class:`step` |
| 418 | +
|
| 419 | + Returns: |
| 420 | + Sequence[torch.tensor]: the random actions for the agents |
| 421 | +
|
| 422 | + Examples: |
| 423 | + >>> import vmas |
| 424 | + >>> env = vmas.make_env( |
| 425 | + ... scenario="waterfall", # can be scenario name or BaseScenario class |
| 426 | + ... num_envs=32, |
| 427 | + ... device="cpu", # Or "cuda" for GPU |
| 428 | + ... continuous_actions=True, |
| 429 | + ... max_steps=None, # Defines the horizon. None is infinite horizon. |
| 430 | + ... seed=None, # Seed of the environment |
| 431 | + ... n_agents=3, # Additional arguments you want to pass to the scenario |
| 432 | + ... ) |
| 433 | + >>> obs = env.reset() |
| 434 | + >>> for _ in range(10): |
| 435 | + ... obs, rews, dones, info = env.step(env.get_random_actions()) |
| 436 | +
|
| 437 | + """ |
| 438 | + return [self.get_random_action(agent) for agent in self.agents] |
| 439 | + |
355 | 440 | def _check_discrete_action(self, action: Tensor, low: int, high: int, type: str): |
356 | 441 | assert torch.all( |
357 | 442 | (action >= torch.tensor(low, device=self.device)) |
|
0 commit comments