Skip to content

Commit e8c6ee9

Browse files
authored
Merge pull request #276 from huangshiyu13/main
- add net, gail test
2 parents 91ac0df + 0c2ea2b commit e8c6ee9

7 files changed

Lines changed: 319 additions & 1 deletion

File tree

openrl/envs/common/build_envs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import inspect
33
from typing import Callable, Iterable, List, Optional, Union
44

5+
import gymnasium as gym
56
from gymnasium import Env
67

78
from openrl.envs.wrappers.base_wrapper import BaseWrapper
@@ -33,7 +34,7 @@ def _make_env() -> Env:
3334
if need_env_id:
3435
new_kwargs["env_id"] = env_id
3536
new_kwargs["env_num"] = env_num
36-
if id.startswith("ALE/"):
37+
if id.startswith("ALE/") or id in gym.envs.registry.keys():
3738
new_kwargs.pop("cfg", None)
3839

3940
env = make(

openrl/modules/vdn_module.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def __init__(
6868
device=device,
6969
)
7070
self.cfg = cfg
71+
self.obs_space = input_space
72+
self.act_space = act_space
7173

7274
def lr_decay(self, episode, episodes):
7375
update_linear_schedule(self.optimizers["q_net"], episode, episodes, self.lr)
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
""""""
2+
3+
import os
4+
import sys
5+
6+
import pytest
7+
8+
from openrl.configs.config import create_config_parser
9+
from openrl.envs.common import make
10+
from openrl.envs.vec_env.wrappers.gen_data import GenDataWrapper
11+
from openrl.envs.wrappers.extra_wrappers import ZeroRewardWrapper
12+
from openrl.envs.wrappers.monitor import Monitor
13+
from openrl.modules.common import GAILNet as Net
14+
from openrl.modules.common import PPONet
15+
from openrl.runners.common import GAILAgent as Agent
16+
from openrl.runners.common import PPOAgent
17+
18+
19+
@pytest.fixture(scope="function")
20+
def gen_data(tmpdir):
21+
tmp_data_path = os.path.join(tmpdir, "data.pkl")
22+
env_wrappers = [
23+
Monitor,
24+
]
25+
print("generate data....")
26+
env = make(
27+
"CartPole-v1",
28+
env_num=2,
29+
asynchronous=True,
30+
env_wrappers=env_wrappers,
31+
)
32+
agent = PPOAgent(PPONet(env))
33+
env = GenDataWrapper(env, data_save_path=tmp_data_path, total_episode=5)
34+
obs, info = env.reset()
35+
done = False
36+
while not done:
37+
# Based on environmental observation input, predict next action.
38+
action, _ = agent.act(obs, deterministic=True)
39+
obs, r, done, info = env.step(action)
40+
env.close()
41+
print("generate data done!")
42+
return tmp_data_path
43+
44+
45+
@pytest.fixture(
46+
scope="function", params=[" --gail_use_action false", " --gail_use_action true"]
47+
)
48+
def config(request, gen_data):
49+
input_str = (
50+
"--episode_length 5 --use_recurrent_policy true --use_joint_action_loss true"
51+
" --use_valuenorm true --use_adv_normalize true --reward_class.id GAILReward"
52+
)
53+
input_str += request.param
54+
input_str += " --expert_data " + gen_data
55+
cfg_parser = create_config_parser()
56+
cfg = cfg_parser.parse_args(input_str.split())
57+
return cfg
58+
59+
60+
@pytest.mark.unittest
61+
def test_train_gail(config):
62+
env = make("CartPole-v1", env_num=2, cfg=config, env_wrappers=[ZeroRewardWrapper])
63+
64+
net = Net(
65+
env,
66+
cfg=config,
67+
)
68+
# initialize the trainer
69+
agent = Agent(net)
70+
agent.train(total_time_steps=200)
71+
env.close()
72+
73+
74+
if __name__ == "__main__":
75+
sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Copyright 2023 The OpenRL Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
""""""
18+
19+
import os
20+
import sys
21+
22+
import pytest
23+
24+
from openrl.configs.config import create_config_parser
25+
from openrl.envs.common import make
26+
from openrl.envs.wrappers.extra_wrappers import AddStep
27+
from openrl.modules.common import DDPGNet as Net
28+
from openrl.runners.common import DDPGAgent as Agent
29+
30+
env_wrappers = [AddStep]
31+
32+
33+
@pytest.fixture(scope="module", params=[""])
34+
def config(request):
35+
cfg_parser = create_config_parser()
36+
cfg = cfg_parser.parse_args(request.param.split())
37+
return cfg
38+
39+
40+
def train(Agent, Net, env_name, env_num, total_time_steps, config):
41+
cfg = config
42+
env = make(env_name, env_num=env_num, cfg=cfg, env_wrappers=env_wrappers)
43+
44+
net = Net(
45+
env,
46+
cfg=cfg,
47+
)
48+
# initialize the trainer
49+
agent = Agent(net)
50+
# start training, set total number of training steps to 20000
51+
agent.train(total_time_steps=total_time_steps)
52+
env.close()
53+
54+
55+
@pytest.mark.unittest
56+
def test_ddpg_net(config):
57+
train(Agent, Net, "IdentityEnvcontinuous", 2, 100, config)
58+
59+
60+
if __name__ == "__main__":
61+
sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Copyright 2023 The OpenRL Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
""""""
18+
19+
import os
20+
import sys
21+
22+
import pytest
23+
24+
from openrl.configs.config import create_config_parser
25+
from openrl.envs.common import make
26+
from openrl.envs.wrappers.extra_wrappers import AddStep
27+
from openrl.modules.common import DQNNet as Net
28+
from openrl.runners.common import DQNAgent as Agent
29+
30+
env_wrappers = [AddStep]
31+
32+
33+
@pytest.fixture(scope="module", params=[""])
34+
def config(request):
35+
cfg_parser = create_config_parser()
36+
cfg = cfg_parser.parse_args(request.param.split())
37+
return cfg
38+
39+
40+
def train(Agent, Net, env_name, env_num, total_time_steps, config):
41+
cfg = config
42+
env = make(env_name, env_num=env_num, cfg=cfg, env_wrappers=env_wrappers)
43+
44+
net = Net(
45+
env,
46+
cfg=cfg,
47+
)
48+
# initialize the trainer
49+
agent = Agent(net)
50+
# start training, set total number of training steps to 20000
51+
agent.train(total_time_steps=total_time_steps)
52+
env.close()
53+
54+
55+
@pytest.mark.unittest
56+
def test_dqn_net(config):
57+
train(Agent, Net, "IdentityEnv", 2, 100, config)
58+
59+
60+
if __name__ == "__main__":
61+
sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Copyright 2023 The OpenRL Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
""""""
18+
19+
import os
20+
import sys
21+
22+
import pytest
23+
24+
from openrl.configs.config import create_config_parser
25+
from openrl.envs.common import make
26+
from openrl.envs.wrappers.extra_wrappers import AddStep
27+
from openrl.modules.common import SACNet as Net
28+
from openrl.runners.common import SACAgent as Agent
29+
30+
env_wrappers = [AddStep]
31+
32+
33+
@pytest.fixture(scope="module", params=[""])
34+
def config(request):
35+
cfg_parser = create_config_parser()
36+
cfg = cfg_parser.parse_args(request.param.split())
37+
return cfg
38+
39+
40+
def train(Agent, Net, env_name, env_num, total_time_steps, config):
41+
cfg = config
42+
env = make(env_name, env_num=env_num, cfg=cfg, env_wrappers=env_wrappers)
43+
44+
net = Net(
45+
env,
46+
cfg=cfg,
47+
)
48+
# initialize the trainer
49+
agent = Agent(net)
50+
# start training, set total number of training steps to 20000
51+
agent.train(total_time_steps=total_time_steps)
52+
env.close()
53+
54+
55+
@pytest.mark.unittest
56+
def test_sac_net(config):
57+
train(Agent, Net, "IdentityEnvcontinuous", 2, 100, config)
58+
59+
60+
if __name__ == "__main__":
61+
sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Copyright 2023 The OpenRL Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
""""""
18+
19+
import os
20+
import sys
21+
22+
import pytest
23+
24+
from openrl.configs.config import create_config_parser
25+
from openrl.envs.common import make
26+
from openrl.envs.wrappers.mat_wrapper import MATWrapper
27+
from openrl.modules.common import VDNNet
28+
from openrl.runners.common import VDNAgent as Agent
29+
30+
31+
@pytest.fixture(scope="module", params=[""])
32+
def config(request):
33+
cfg_parser = create_config_parser()
34+
cfg = cfg_parser.parse_args(request.param.split())
35+
return cfg
36+
37+
38+
@pytest.mark.unittest
39+
def test_vdn_net(config):
40+
env_num = 2
41+
env = make(
42+
"simple_spread",
43+
env_num=env_num,
44+
asynchronous=True,
45+
)
46+
env = MATWrapper(env)
47+
48+
net = VDNNet(env, cfg=config)
49+
# initialize the trainer
50+
agent = Agent(net)
51+
# start training
52+
agent.train(total_time_steps=100)
53+
env.close()
54+
55+
56+
if __name__ == "__main__":
57+
sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))

0 commit comments

Comments
 (0)