Skip to content

Commit b41c141

Browse files
committed
[Feature] Optionally clamp input actions
1 parent b434185 commit b41c141

2 files changed

Lines changed: 9 additions & 1 deletion

File tree

vmas/make_env.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def make_env(
2323
seed: Optional[int] = None,
2424
dict_spaces: bool = False,
2525
multidiscrete_actions: bool = False,
26+
clamp_actions: bool = False,
2627
**kwargs,
2728
):
2829
"""
@@ -40,6 +41,8 @@ def make_env(
4041
multidiscrete_actions (bool): Whether to use multidiscrete_actions action spaces when continuous_actions=False.
4142
Otherwise, (default) the action space will be Discrete, and it will be the cartesian product of the
4243
action spaces of an agent.
44+
clamp_actions: Weather to clamp input actions to the range instead of throwing
45+
an error when continuous_actions is True and actions are out of bounds
4346
**kwargs ():
4447
4548
Returns:
@@ -60,6 +63,7 @@ def make_env(
6063
seed=seed,
6164
dict_spaces=dict_spaces,
6265
multidiscrete_actions=multidiscrete_actions,
66+
clamp_actions=clamp_actions,
6367
**kwargs,
6468
)
6569

vmas/simulator/environment/environment.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import torch
1010
from gym import spaces
1111
from torch import Tensor
12-
1312
from vmas.simulator.core import Agent, TorchVectorizedObject
1413
from vmas.simulator.scenario import BaseScenario
1514
import vmas.simulator.utils
@@ -43,6 +42,7 @@ def __init__(
4342
seed: Optional[int] = None,
4443
dict_spaces: bool = False,
4544
multidiscrete_actions: bool = False,
45+
clamp_actions: bool = False,
4646
**kwargs,
4747
):
4848
if multidiscrete_actions:
@@ -60,6 +60,7 @@ def __init__(
6060
self.max_steps = max_steps
6161
self.continuous_actions = continuous_actions
6262
self.dict_spaces = dict_spaces
63+
self.clamp_action = clamp_actions
6364

6465
self.reset(seed=seed)
6566

@@ -379,6 +380,9 @@ def _set_action(self, action, agent):
379380
f"Agent {agent.name} has wrong action size, got {action.shape[1]}, "
380381
f"expected {self.get_agent_action_size(agent)}"
381382
)
383+
if self.clamp_action and self.continuous_actions:
384+
action = action.clamp(-agent.action.u_range, agent.action.u_range)
385+
382386
action_index = 0
383387

384388
if self.continuous_actions:

0 commit comments

Comments
 (0)