|
1 | | -# Copyright (c) 2022-2023. |
| 1 | +# Copyright (c) 2022-2024. |
2 | 2 | # ProrokLab (https://www.proroklab.org/) |
3 | 3 | # All rights reserved. |
4 | 4 | """ |
@@ -101,21 +101,14 @@ def _cycle(self): |
101 | 101 | self.reset = False |
102 | 102 | total_rew = [0] * self.n_agents |
103 | 103 |
|
104 | | - action_list = [ |
105 | | - [0.0] * self.env.unwrapped().get_agent_action_size(agent) |
106 | | - for agent in self.agents |
107 | | - ] |
| 104 | + action_list = [[0.0] * agent.action_size for agent in self.agents] |
108 | 105 | action_list[self.current_agent_index] = self.u[ |
109 | | - : self.env.unwrapped().get_agent_action_size( |
110 | | - self.agents[self.current_agent_index] |
111 | | - ) |
| 106 | + : self.agents[self.current_agent_index].action_size |
112 | 107 | ] |
113 | 108 |
|
114 | 109 | if self.n_agents > 1 and self.control_two_agents: |
115 | 110 | action_list[self.current_agent_index2] = self.u2[ |
116 | | - : self.env.unwrapped().get_agent_action_size( |
117 | | - self.agents[self.current_agent_index2] |
118 | | - ) |
| 111 | + : self.agents[self.current_agent_index2].action_size |
119 | 112 | ] |
120 | 113 | obs, rew, done, info = self.env.step(action_list) |
121 | 114 |
|
@@ -167,56 +160,60 @@ def _write_values(self, index: int, message: str): |
167 | 160 | def _key_press(self, k, mod): |
168 | 161 | from pyglet.window import key |
169 | 162 |
|
170 | | - agent_range = self.agents[self.current_agent_index].u_range |
171 | | - agent_rot_range = self.agents[self.current_agent_index].u_rot_range |
| 163 | + agent_range = self.agents[self.current_agent_index].action.u_range_tensor |
| 164 | + try: |
| 165 | + if k == key.LEFT: |
| 166 | + self.keys[0] = agent_range[0] |
| 167 | + elif k == key.RIGHT: |
| 168 | + self.keys[1] = agent_range[0] |
| 169 | + elif k == key.DOWN: |
| 170 | + self.keys[2] = agent_range[1] |
| 171 | + elif k == key.UP: |
| 172 | + self.keys[3] = agent_range[1] |
| 173 | + elif k == key.M: |
| 174 | + self.keys[4] = agent_range[2] |
| 175 | + elif k == key.N: |
| 176 | + self.keys[5] = agent_range[2] |
| 177 | + elif k == key.TAB: |
| 178 | + self.current_agent_index = self._increment_selected_agent_index( |
| 179 | + self.current_agent_index |
| 180 | + ) |
| 181 | + if self.control_two_agents: |
| 182 | + while self.current_agent_index == self.current_agent_index2: |
| 183 | + self.current_agent_index = self._increment_selected_agent_index( |
| 184 | + self.current_agent_index |
| 185 | + ) |
172 | 186 |
|
173 | | - if k == key.LEFT: |
174 | | - self.keys[0] = agent_range |
175 | | - elif k == key.RIGHT: |
176 | | - self.keys[1] = agent_range |
177 | | - elif k == key.DOWN: |
178 | | - self.keys[2] = agent_range |
179 | | - elif k == key.UP: |
180 | | - self.keys[3] = agent_range |
181 | | - elif k == key.M: |
182 | | - self.keys[4] = agent_rot_range |
183 | | - elif k == key.N: |
184 | | - self.keys[5] = agent_rot_range |
185 | | - elif k == key.TAB: |
186 | | - self.current_agent_index = self._increment_selected_agent_index( |
187 | | - self.current_agent_index |
188 | | - ) |
189 | 187 | if self.control_two_agents: |
190 | | - while self.current_agent_index == self.current_agent_index2: |
191 | | - self.current_agent_index = self._increment_selected_agent_index( |
192 | | - self.current_agent_index |
193 | | - ) |
194 | | - |
195 | | - if self.control_two_agents: |
196 | | - agent2_range = self.agents[self.current_agent_index2].u_range |
197 | | - agent2_rot_range = self.agents[self.current_agent_index2].u_rot_range |
198 | | - |
199 | | - if k == key.A: |
200 | | - self.keys2[0] = agent2_range |
201 | | - elif k == key.D: |
202 | | - self.keys2[1] = agent2_range |
203 | | - elif k == key.S: |
204 | | - self.keys2[2] = agent2_range |
205 | | - elif k == key.W: |
206 | | - self.keys2[3] = agent2_range |
207 | | - elif k == key.E: |
208 | | - self.keys2[4] = agent2_rot_range |
209 | | - elif k == key.Q: |
210 | | - self.keys2[5] = agent2_rot_range |
211 | | - |
212 | | - elif k == key.LSHIFT: |
213 | | - self.current_agent_index2 = self._increment_selected_agent_index( |
| 188 | + agent2_range = self.agents[ |
214 | 189 | self.current_agent_index2 |
215 | | - ) |
216 | | - while self.current_agent_index == self.current_agent_index2: |
| 190 | + ].action.u_range_tensor |
| 191 | + |
| 192 | + if k == key.A: |
| 193 | + self.keys2[0] = agent2_range[0] |
| 194 | + elif k == key.D: |
| 195 | + self.keys2[1] = agent2_range[0] |
| 196 | + elif k == key.S: |
| 197 | + self.keys2[2] = agent2_range[1] |
| 198 | + elif k == key.W: |
| 199 | + self.keys2[3] = agent2_range[1] |
| 200 | + elif k == key.E: |
| 201 | + self.keys2[4] = agent2_range[2] |
| 202 | + elif k == key.Q: |
| 203 | + self.keys2[5] = agent2_range[2] |
| 204 | + |
| 205 | + elif k == key.LSHIFT: |
217 | 206 | self.current_agent_index2 = self._increment_selected_agent_index( |
218 | 207 | self.current_agent_index2 |
219 | 208 | ) |
| 209 | + while self.current_agent_index == self.current_agent_index2: |
| 210 | + self.current_agent_index2 = ( |
| 211 | + self._increment_selected_agent_index( |
| 212 | + self.current_agent_index2 |
| 213 | + ) |
| 214 | + ) |
| 215 | + except IndexError: |
| 216 | + print("Action not available") |
220 | 217 |
|
221 | 218 | if k == key.R: |
222 | 219 | self.reset = True |
|
0 commit comments