@@ -36,7 +36,9 @@ def test_all_scenarios_included():
3636
3737@pytest .mark .parametrize ("scenario" , scenario_names ())
3838@pytest .mark .parametrize ("continuous_actions" , [True , False ])
39- def test_use_vmas_env (scenario , continuous_actions , num_envs = 10 , n_steps = 10 ):
39+ def test_use_vmas_env (
40+ scenario , continuous_actions , dict_spaces = True , num_envs = 10 , n_steps = 10
41+ ):
4042 render = True
4143 if sys .platform .startswith ("win32" ):
4244 # Windows on github servers has issues with pyglet
@@ -51,9 +53,36 @@ def test_use_vmas_env(scenario, continuous_actions, num_envs=10, n_steps=10):
5153 continuous_actions = continuous_actions ,
5254 num_envs = num_envs ,
5355 n_steps = n_steps ,
56+ dict_spaces = dict_spaces ,
5457 )
5558
5659
60+ @pytest .mark .parametrize ("scenario" , scenario_names ())
61+ def test_multi_discrete_actions (scenario , num_envs = 10 , n_steps = 10 ):
62+ env = make_env (
63+ scenario = scenario ,
64+ num_envs = num_envs ,
65+ seed = 0 ,
66+ multidiscrete_actions = True ,
67+ continuous_actions = False ,
68+ )
69+ for _ in range (n_steps ):
70+ env .step (env .get_random_actions ())
71+
72+
73+ @pytest .mark .parametrize ("scenario" , scenario_names ())
74+ def test_non_dict_spaces_actions (scenario , num_envs = 10 , n_steps = 10 ):
75+ env = make_env (
76+ scenario = scenario ,
77+ num_envs = num_envs ,
78+ seed = 0 ,
79+ continuous_actions = True ,
80+ dict_spaces = False ,
81+ )
82+ for _ in range (n_steps ):
83+ env .step (env .get_random_actions ())
84+
85+
5786@pytest .mark .parametrize ("scenario" , scenario_names ())
5887def test_partial_reset (scenario , num_envs = 10 , n_steps = 10 ):
5988 env = make_env (
@@ -70,6 +99,19 @@ def test_partial_reset(scenario, num_envs=10, n_steps=10):
7099 env_index = 0
71100
72101
102+ @pytest .mark .parametrize ("scenario" , scenario_names ())
103+ def test_global_reset (scenario , num_envs = 10 , n_steps = 10 ):
104+ env = make_env (
105+ scenario = scenario ,
106+ num_envs = num_envs ,
107+ seed = 0 ,
108+ )
109+ for step in range (n_steps ):
110+ env .step (env .get_random_actions ())
111+ if step == n_steps // 2 :
112+ env .reset ()
113+
114+
73115@pytest .mark .parametrize ("scenario" , vmas .scenarios + vmas .mpe_scenarios )
74116def test_vmas_differentiable (scenario , n_steps = 10 , n_envs = 10 ):
75117 if scenario == "football" or scenario == "simple_crypto" :
0 commit comments