Skip to content

Commit 753fba7

Browse files
committed
Move envpool to examples
1 parent 448cc6f commit 753fba7

7 files changed

Lines changed: 160 additions & 67 deletions

File tree

examples/envpool/README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
## Installation
2+
3+
4+
Install envpool with:
5+
6+
``` shell
7+
pip install envpool
8+
```
9+
10+
Note 1: envpool only supports Linux operating system.
11+
12+
## Usage
13+
14+
You can use `OpenRL` to train Cartpole (envpool) via:
15+
16+
``` shell
17+
PYTHON_PATH train_ppo.py
18+
```
19+
20+
You can also add custom wrappers in `envpool_wrapper.py`. Currently we have `VecAdapter` and `VecMonitor` wrappers.

examples/envpool/make_env.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import copy
2+
import inspect
3+
from typing import Callable, Iterable, List, Optional, Union
4+
5+
import envpool
6+
from gymnasium import Env
7+
8+
9+
from openrl.envs.vec_env import (AsyncVectorEnv, RewardWrapper,
10+
SyncVectorEnv, VecMonitorWrapper)
11+
from openrl.envs.vec_env.vec_info import VecInfoFactory
12+
from openrl.envs.wrappers.base_wrapper import BaseWrapper
13+
from openrl.rewards import RewardFactory
14+
15+
16+
def build_envs(
17+
make,
18+
id: str,
19+
env_num: int = 1,
20+
wrappers: Optional[Union[Callable[[Env], Env], List[Callable[[Env], Env]]]] = None,
21+
need_env_id: bool = False,
22+
**kwargs,
23+
) -> List[Callable[[], Env]]:
24+
cfg = kwargs.get("cfg", None)
25+
26+
def create_env(env_id: int, env_num: int, need_env_id: bool) -> Callable[[], Env]:
27+
def _make_env() -> Env:
28+
new_kwargs = copy.deepcopy(kwargs)
29+
if need_env_id:
30+
new_kwargs["env_id"] = env_id
31+
new_kwargs["env_num"] = env_num
32+
if "envpool" in new_kwargs:
33+
# for now envpool doesnt support any render mode
34+
# envpool also doesnt stores the id anywhere
35+
new_kwargs.pop("envpool")
36+
env = make(
37+
id,
38+
**new_kwargs,
39+
)
40+
env.unwrapped.spec.id = id
41+
42+
if wrappers is not None:
43+
if callable(wrappers):
44+
if issubclass(wrappers, BaseWrapper):
45+
env = wrappers(env, cfg=cfg)
46+
else:
47+
env = wrappers(env)
48+
elif isinstance(wrappers, Iterable) and all(
49+
[callable(w) for w in wrappers]
50+
):
51+
for wrapper in wrappers:
52+
if (
53+
issubclass(wrapper, BaseWrapper)
54+
and "cfg" in inspect.signature(wrapper.__init__).parameters
55+
):
56+
env = wrapper(env, cfg=cfg)
57+
else:
58+
env = wrapper(env)
59+
else:
60+
raise NotImplementedError
61+
62+
return env
63+
64+
return _make_env
65+
66+
env_fns = [create_env(env_id, env_num, need_env_id) for env_id in range(env_num)]
67+
return env_fns
68+
69+
70+
def make_envpool_envs(
71+
id: str,
72+
env_num: int = 1,
73+
**kwargs,
74+
):
75+
assert "env_type" in kwargs
76+
assert kwargs.get("env_type") in ["gym", "dm", "gymnasium"]
77+
kwargs["envpool"] = True
78+
79+
if 'env_wrappers' in kwargs:
80+
env_wrappers = kwargs.pop("env_wrappers")
81+
else:
82+
env_wrappers = []
83+
env_fns = build_envs(
84+
make=envpool.make,
85+
id=id,
86+
env_num=env_num,
87+
wrappers=env_wrappers,
88+
**kwargs,
89+
)
90+
return env_fns
91+
92+
93+
def make(
94+
id: str,
95+
env_num: int = 1,
96+
asynchronous: bool = False,
97+
add_monitor: bool = True,
98+
render_mode: Optional[str] = None,
99+
auto_reset: bool = True,
100+
**kwargs,
101+
):
102+
cfg = kwargs.get("cfg", None)
103+
if id in envpool.registration.list_all_envs():
104+
env_fns = make_envpool_envs(
105+
id=id.split(":")[-1],
106+
env_num=env_num,
107+
**kwargs,
108+
)
109+
if asynchronous:
110+
env = AsyncVectorEnv(
111+
env_fns, render_mode=render_mode, auto_reset=auto_reset
112+
)
113+
else:
114+
env = SyncVectorEnv(env_fns, render_mode=render_mode, auto_reset=auto_reset)
115+
116+
reward_class = cfg.reward_class if cfg else None
117+
reward_class = RewardFactory.get_reward_class(reward_class, env)
118+
119+
env = RewardWrapper(env, reward_class)
120+
121+
if add_monitor:
122+
vec_info_class = cfg.vec_info_class if cfg else None
123+
vec_info_class = VecInfoFactory.get_vec_info_class(vec_info_class, env)
124+
env = VecMonitorWrapper(vec_info_class, env)
125+
126+
return env
127+
else:
128+
raise NotImplementedError(f"env {id} is not supported")

examples/envpool/train_ppo.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
import numpy as np
1919

2020
from openrl.configs.config import create_config_parser
21-
from openrl.envs.common import make
22-
from openrl.envs.wrappers.envpool_wrappers import VecAdapter, VecMonitor
21+
from make_env import make
22+
from examples.envpool.envpool_wrappers import VecAdapter, VecMonitor
2323
from openrl.modules.common import PPONet as Net
2424
from openrl.modules.common.ppo_net import PPONet as Net
2525
from openrl.runners.common import PPOAgent as Agent
@@ -32,7 +32,7 @@ def train():
3232

3333
# create environment, set environment parallelism to 9
3434
env = make(
35-
"envpool:CartPole-v1",
35+
"CartPole-v1",
3636
render_mode=None,
3737
env_num=9,
3838
asynchronous=False,
@@ -45,7 +45,7 @@ def train():
4545
cfg=cfg,
4646
)
4747
# initialize the trainer
48-
agent = Agent(net, use_wandb=False, project_name="envpool:CartPole-v1")
48+
agent = Agent(net, use_wandb=False, project_name="CartPole-v1")
4949
# start training, set total number of training steps to 20000
5050
agent.train(total_time_steps=20000)
5151

@@ -58,7 +58,14 @@ def evaluation(agent):
5858
# Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human.
5959
render_mode = "group_human"
6060
render_mode = None
61-
env = make("CartPole-v1", render_mode=render_mode, env_num=9, asynchronous=True)
61+
env = make(
62+
"CartPole-v1",
63+
env_wrappers=[VecAdapter, VecMonitor],
64+
render_mode=render_mode,
65+
env_num=9,
66+
asynchronous=True,
67+
env_type="gym",
68+
)
6269
# The trained agent sets up the interactive environment it needs.
6370
agent.set_env(env)
6471
# Initialize the environment and get initial observations and environmental information.

openrl/envs/common/registration.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
""""""
1818
from typing import Callable, Optional
1919

20-
import envpool
2120
import gymnasium as gym
2221

2322
import openrl
@@ -155,18 +154,6 @@ def make(
155154
env_fns = make_PettingZoo_envs(
156155
id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs
157156
)
158-
elif (
159-
"envpool:" in id
160-
and id.split(":")[-1] in envpool.registration.list_all_envs()
161-
):
162-
from openrl.envs.envpool import make_envpool_envs
163-
164-
env_fns = make_envpool_envs(
165-
id=id.split(":")[-1],
166-
env_num=env_num,
167-
render_mode=convert_render_mode,
168-
**kwargs,
169-
)
170157
else:
171158
raise NotImplementedError(f"env {id} is not supported.")
172159

openrl/envs/envpool/__init__.py

Lines changed: 0 additions & 47 deletions
This file was deleted.

setup.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ def get_extra_requires() -> dict:
7676
"async_timeout",
7777
"pettingzoo[classic]",
7878
"trueskill",
79-
"envpool",
8079
],
8180
"selfplay_test": [
8281
"ray[default]>=2.7",
@@ -85,7 +84,6 @@ def get_extra_requires() -> dict:
8584
"fastapi",
8685
"pettingzoo[mpe]",
8786
"pettingzoo[butterfly]",
88-
"envpool",
8987
],
9088
"retro": ["gym-retro"],
9189
"super_mario": ["gym-super-mario-bros"],

0 commit comments

Comments
 (0)