|
| 1 | +"""""" |
| 2 | + |
| 3 | +import os |
| 4 | +import sys |
| 5 | + |
| 6 | +import pytest |
| 7 | + |
| 8 | +from openrl.configs.config import create_config_parser |
| 9 | +from openrl.envs.common import make |
| 10 | +from openrl.envs.vec_env.wrappers.gen_data import GenDataWrapper |
| 11 | +from openrl.envs.wrappers.extra_wrappers import ZeroRewardWrapper |
| 12 | +from openrl.envs.wrappers.monitor import Monitor |
| 13 | +from openrl.modules.common import GAILNet as Net |
| 14 | +from openrl.modules.common import PPONet |
| 15 | +from openrl.runners.common import GAILAgent as Agent |
| 16 | +from openrl.runners.common import PPOAgent |
| 17 | + |
| 18 | + |
| 19 | +@pytest.fixture(scope="function") |
| 20 | +def gen_data(tmpdir): |
| 21 | + tmp_data_path = os.path.join(tmpdir, "data.pkl") |
| 22 | + env_wrappers = [ |
| 23 | + Monitor, |
| 24 | + ] |
| 25 | + print("generate data....") |
| 26 | + env = make( |
| 27 | + "CartPole-v1", |
| 28 | + env_num=2, |
| 29 | + asynchronous=True, |
| 30 | + env_wrappers=env_wrappers, |
| 31 | + ) |
| 32 | + agent = PPOAgent(PPONet(env)) |
| 33 | + env = GenDataWrapper(env, data_save_path=tmp_data_path, total_episode=5) |
| 34 | + obs, info = env.reset() |
| 35 | + done = False |
| 36 | + while not done: |
| 37 | + # Based on environmental observation input, predict next action. |
| 38 | + action, _ = agent.act(obs, deterministic=True) |
| 39 | + obs, r, done, info = env.step(action) |
| 40 | + env.close() |
| 41 | + print("generate data done!") |
| 42 | + return tmp_data_path |
| 43 | + |
| 44 | + |
| 45 | +@pytest.fixture( |
| 46 | + scope="function", params=[" --gail_use_action false", " --gail_use_action true"] |
| 47 | +) |
| 48 | +def config(request, gen_data): |
| 49 | + input_str = ( |
| 50 | + "--episode_length 5 --use_recurrent_policy true --use_joint_action_loss true" |
| 51 | + " --use_valuenorm true --use_adv_normalize true --reward_class.id GAILReward" |
| 52 | + ) |
| 53 | + input_str += request.param |
| 54 | + input_str += " --expert_data " + gen_data |
| 55 | + cfg_parser = create_config_parser() |
| 56 | + cfg = cfg_parser.parse_args(input_str.split()) |
| 57 | + return cfg |
| 58 | + |
| 59 | + |
| 60 | +@pytest.mark.unittest |
| 61 | +def test_train_gail(config): |
| 62 | + env = make("CartPole-v1", env_num=2, cfg=config, env_wrappers=[ZeroRewardWrapper]) |
| 63 | + |
| 64 | + net = Net( |
| 65 | + env, |
| 66 | + cfg=config, |
| 67 | + ) |
| 68 | + # initialize the trainer |
| 69 | + agent = Agent(net) |
| 70 | + agent.train(total_time_steps=200) |
| 71 | + env.close() |
| 72 | + |
| 73 | + |
| 74 | +if __name__ == "__main__": |
| 75 | + sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) |
0 commit comments