|
1 | | -# Copyright (c) 2022-2023. |
| 1 | +# Copyright (c) 2022-2024. |
2 | 2 | # ProrokLab (https://www.proroklab.org/) |
3 | 3 | # All rights reserved. |
4 | | -import unittest |
| 4 | + |
| 5 | +import pytest |
| 6 | +import torch |
5 | 7 |
|
6 | 8 | from vmas import make_env |
7 | 9 | from vmas.scenarios import balance |
8 | 10 |
|
9 | 11 |
|
10 | | -class TestBalance(unittest.TestCase): |
| 12 | +class TestBalance: |
11 | 13 | def setup_env( |
12 | 14 | self, |
| 15 | + n_envs, |
13 | 16 | **kwargs, |
14 | 17 | ) -> None: |
15 | | - super().setUp() |
16 | 18 | self.n_agents = kwargs.get("n_agents", 4) |
17 | 19 |
|
18 | 20 | self.continuous_actions = True |
19 | | - self.n_envs = 15 |
20 | 21 | self.env = make_env( |
21 | 22 | scenario="balance", |
22 | | - num_envs=self.n_envs, |
| 23 | + num_envs=n_envs, |
23 | 24 | device="cpu", |
24 | 25 | continuous_actions=self.continuous_actions, |
25 | 26 | # Environment specific variables |
26 | 27 | **kwargs, |
27 | 28 | ) |
28 | 29 | self.env.seed(0) |
29 | 30 |
|
30 | | - def test_heuristic(self): |
| 31 | + @pytest.mark.parametrize("n_agents", [2, 5]) |
| 32 | + def test_heuristic(self, n_agents, n_steps=50, n_envs=4): |
| 33 | + self.setup_env( |
| 34 | + n_agents=n_agents, random_package_pos_on_line=False, n_envs=n_envs |
| 35 | + ) |
| 36 | + policy = balance.HeuristicPolicy(self.continuous_actions) |
31 | 37 |
|
32 | | - for n_agents in [2, 5, 6, 10]: |
33 | | - self.setup_env(n_agents=n_agents, random_package_pos_on_line=False) |
34 | | - policy = balance.HeuristicPolicy(self.continuous_actions) |
| 38 | + obs = self.env.reset() |
35 | 39 |
|
36 | | - obs = self.env.reset() |
37 | | - rews = None |
| 40 | + prev_package_dist_to_goal = obs[0][:, 8:10] |
38 | 41 |
|
39 | | - for _ in range(100): |
40 | | - actions = [] |
41 | | - for i in range(n_agents): |
42 | | - obs_agent = obs[i] |
| 42 | + for _ in range(n_steps): |
| 43 | + actions = [] |
| 44 | + for i in range(n_agents): |
| 45 | + obs_agent = obs[i] |
| 46 | + package_dist_to_goal = obs_agent[:, 8:10] |
43 | 47 |
|
44 | | - action_agent = policy.compute_action( |
45 | | - obs_agent, self.env.agents[i].u_range |
46 | | - ) |
| 48 | + action_agent = policy.compute_action( |
| 49 | + obs_agent, self.env.agents[i].u_range |
| 50 | + ) |
47 | 51 |
|
48 | | - actions.append(action_agent) |
| 52 | + actions.append(action_agent) |
49 | 53 |
|
50 | | - obs, new_rews, dones, _ = self.env.step(actions) |
| 54 | + obs, new_rews, dones, _ = self.env.step(actions) |
51 | 55 |
|
52 | | - if rews is not None: |
53 | | - for i in range(self.n_agents): |
54 | | - self.assertTrue((new_rews[i] >= rews[i]).all()) |
55 | | - rews = new_rews |
| 56 | + assert ( |
| 57 | + torch.linalg.vector_norm(package_dist_to_goal, dim=-1) |
| 58 | + <= torch.linalg.vector_norm(prev_package_dist_to_goal, dim=-1) |
| 59 | + ).all() |
| 60 | + prev_package_dist_to_goal = package_dist_to_goal |
0 commit comments