2121import gymnasium as gym
2222import numpy as np
2323from gymnasium import spaces
24+ from gymnasium .utils .step_api_compatibility import (
25+ convert_to_terminated_truncated_step_api ,
26+ )
2427from gymnasium .wrappers import AutoResetWrapper , StepAPICompatibility
2528
2629from openrl .envs .wrappers import BaseObservationWrapper , BaseRewardWrapper , BaseWrapper
@@ -46,6 +49,76 @@ def step(self, action):
4649 return obs , total_reward , term , trunc , info
4750
4851
52+ def convert_to_done_step_api (
53+ step_returns ,
54+ is_vector_env : bool = False ,
55+ ):
56+ if len (step_returns ) == 4 :
57+ return step_returns
58+ else :
59+ assert len (step_returns ) == 5
60+ observations , rewards , terminated , truncated , infos = step_returns
61+
62+ # Cases to handle - info single env / info vector env (list) / info vector env (dict)
63+ # if truncated[0]:
64+ # import pdb;
65+ # pdb.set_trace()
66+
67+ if is_vector_env is False :
68+ if isinstance (terminated , list ):
69+ infos ["TimeLimit.truncated" ] = truncated [0 ] and not terminated [0 ]
70+ done_return = np .logical_or (terminated , truncated )
71+ else :
72+ if truncated or terminated :
73+ infos ["TimeLimit.truncated" ] = truncated and not terminated
74+ done_return = terminated or truncated
75+ return (
76+ observations ,
77+ rewards ,
78+ done_return ,
79+ infos ,
80+ )
81+ elif isinstance (infos , list ):
82+ for info , env_truncated , env_terminated in zip (
83+ infos , truncated , terminated
84+ ):
85+ if env_truncated or env_terminated :
86+ info ["TimeLimit.truncated" ] = env_truncated and not env_terminated
87+ return (
88+ observations ,
89+ rewards ,
90+ np .logical_or (terminated , truncated ),
91+ infos ,
92+ )
93+ elif isinstance (infos , dict ):
94+ if np .logical_or (np .any (truncated ), np .any (terminated )):
95+ infos ["TimeLimit.truncated" ] = np .logical_and (
96+ truncated , np .logical_not (terminated )
97+ )
98+ return (
99+ observations ,
100+ rewards ,
101+ np .logical_or (terminated , truncated ),
102+ infos ,
103+ )
104+ else :
105+ raise TypeError (
106+ "Unexpected value of infos, as is_vector_envs=False, expects `info` to"
107+ f" be a list or dict, actual type: { type (infos )} "
108+ )
109+
110+
111+ def step_api_compatibility (
112+ step_returns ,
113+ output_truncation_bool : bool = True ,
114+ is_vector_env : bool = False ,
115+ ):
116+ if output_truncation_bool :
117+ return convert_to_terminated_truncated_step_api (step_returns , is_vector_env )
118+ else :
119+ return convert_to_done_step_api (step_returns , is_vector_env )
120+
121+
49122class RemoveTruncated (StepAPICompatibility , BaseWrapper ):
50123 def __init__ (
51124 self ,
@@ -54,6 +127,12 @@ def __init__(
54127 output_truncation_bool = False
55128 super ().__init__ (env , output_truncation_bool = output_truncation_bool )
56129
130+ def step (self , action ):
131+ step_returns = self .env .step (action )
132+ return step_api_compatibility (
133+ step_returns , self .output_truncation_bool , self .is_vector_env
134+ )
135+
57136
58137class FlattenObservation (BaseObservationWrapper ):
59138 def __init__ (self , env : gym .Env ):
0 commit comments