Skip to content

Commit 6cd773a

Browse files
authored
Merge pull request #272 from huangshiyu13/main
fix petting zoo bugs
2 parents b71b07b + 6e9ce0f commit 6cd773a

4 files changed

Lines changed: 83 additions & 0 deletions

File tree

examples/custom_env/rock_paper_scissors.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919

2020
import functools
21+
import time
2122

2223
import gymnasium
2324
import numpy as np

openrl/envs/vec_env/sync_venv.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616

1717
""""""
18+
import time
1819
from copy import deepcopy
1920
from typing import Any, Callable, Iterable, List, Optional, Sequence, Union
2021

@@ -202,6 +203,7 @@ def _step(self, actions: ActType):
202203
self._truncateds[i],
203204
info,
204205
) = returns
206+
205207
need_reset = _need_reset and (
206208
all(self._terminateds[i]) or all(self._truncateds[i])
207209
)

openrl/envs/wrappers/extra_wrappers.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
import gymnasium as gym
2222
import numpy as np
2323
from gymnasium import spaces
24+
from gymnasium.utils.step_api_compatibility import (
25+
convert_to_terminated_truncated_step_api,
26+
)
2427
from gymnasium.wrappers import AutoResetWrapper, StepAPICompatibility
2528

2629
from openrl.envs.wrappers import BaseObservationWrapper, BaseRewardWrapper, BaseWrapper
@@ -46,6 +49,76 @@ def step(self, action):
4649
return obs, total_reward, term, trunc, info
4750

4851

52+
def convert_to_done_step_api(
53+
step_returns,
54+
is_vector_env: bool = False,
55+
):
56+
if len(step_returns) == 4:
57+
return step_returns
58+
else:
59+
assert len(step_returns) == 5
60+
observations, rewards, terminated, truncated, infos = step_returns
61+
62+
# Cases to handle - info single env / info vector env (list) / info vector env (dict)
63+
# if truncated[0]:
64+
# import pdb;
65+
# pdb.set_trace()
66+
67+
if is_vector_env is False:
68+
if isinstance(terminated, list):
69+
infos["TimeLimit.truncated"] = truncated[0] and not terminated[0]
70+
done_return = np.logical_or(terminated, truncated)
71+
else:
72+
if truncated or terminated:
73+
infos["TimeLimit.truncated"] = truncated and not terminated
74+
done_return = terminated or truncated
75+
return (
76+
observations,
77+
rewards,
78+
done_return,
79+
infos,
80+
)
81+
elif isinstance(infos, list):
82+
for info, env_truncated, env_terminated in zip(
83+
infos, truncated, terminated
84+
):
85+
if env_truncated or env_terminated:
86+
info["TimeLimit.truncated"] = env_truncated and not env_terminated
87+
return (
88+
observations,
89+
rewards,
90+
np.logical_or(terminated, truncated),
91+
infos,
92+
)
93+
elif isinstance(infos, dict):
94+
if np.logical_or(np.any(truncated), np.any(terminated)):
95+
infos["TimeLimit.truncated"] = np.logical_and(
96+
truncated, np.logical_not(terminated)
97+
)
98+
return (
99+
observations,
100+
rewards,
101+
np.logical_or(terminated, truncated),
102+
infos,
103+
)
104+
else:
105+
raise TypeError(
106+
"Unexpected value of infos, as is_vector_envs=False, expects `info` to"
107+
f" be a list or dict, actual type: {type(infos)}"
108+
)
109+
110+
111+
def step_api_compatibility(
112+
step_returns,
113+
output_truncation_bool: bool = True,
114+
is_vector_env: bool = False,
115+
):
116+
if output_truncation_bool:
117+
return convert_to_terminated_truncated_step_api(step_returns, is_vector_env)
118+
else:
119+
return convert_to_done_step_api(step_returns, is_vector_env)
120+
121+
49122
class RemoveTruncated(StepAPICompatibility, BaseWrapper):
50123
def __init__(
51124
self,
@@ -54,6 +127,12 @@ def __init__(
54127
output_truncation_bool = False
55128
super().__init__(env, output_truncation_bool=output_truncation_bool)
56129

130+
def step(self, action):
131+
step_returns = self.env.step(action)
132+
return step_api_compatibility(
133+
step_returns, self.output_truncation_bool, self.is_vector_env
134+
)
135+
57136

58137
class FlattenObservation(BaseObservationWrapper):
59138
def __init__(self, env: gym.Env):

openrl/selfplay/wrappers/base_multiplayer_wrapper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def reset(self, *, seed: Optional[int] = None, **kwargs):
104104
action = self.get_opponent_action(
105105
player_name, observation, reward, termination, truncation, info
106106
)
107+
107108
self.env.step(action)
108109

109110
def on_episode_end(

0 commit comments

Comments
 (0)