Skip to content

Commit 7468cd6

Browse files
authored
Merge pull request #273 from huangshiyu13/main
add MAT network test
2 parents 6cd773a + f1ecdef commit 7468cd6

1 file changed

Lines changed: 59 additions & 0 deletions

File tree

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Copyright 2023 The OpenRL Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
""""""
18+
19+
import os
20+
import sys
21+
22+
import numpy as np
23+
import pytest
24+
from gymnasium import spaces
25+
26+
from openrl.configs.config import create_config_parser
27+
from openrl.modules.networks.MAT_network import MultiAgentTransformer
28+
29+
30+
@pytest.fixture(scope="module", params=[""])
31+
def config(request):
32+
cfg_parser = create_config_parser()
33+
cfg = cfg_parser.parse_args(request.param.split())
34+
return cfg
35+
36+
37+
@pytest.mark.unittest
38+
def test_MAT_network(config):
39+
net = MultiAgentTransformer(
40+
config,
41+
input_space=spaces.Discrete(2),
42+
action_space=spaces.Discrete(2),
43+
)
44+
net.get_actor_para()
45+
net.get_critic_para()
46+
47+
obs = np.zeros([1, 2])
48+
rnn_states = np.zeros(2)
49+
masks = np.zeros(2)
50+
action = np.zeros(1)
51+
net.get_actions(obs=obs, masks=masks)
52+
net.eval_actions(
53+
obs=obs, rnn_states=rnn_states, action=action, masks=masks, action_masks=None
54+
)
55+
net.get_values(critic_obs=obs, masks=masks)
56+
57+
58+
if __name__ == "__main__":
59+
sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))

0 commit comments

Comments
 (0)