Skip to content

Commit 884e75b

Browse files
committed
[BugFix] dtype in simple_crypto observations
1 parent 8dba87b commit 884e75b

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

vmas/scenarios/mpe/simple_crypto.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,10 @@ def observation(self, agent: Agent):
181181
key,
182182
],
183183
dim=-1,
184-
)
184+
).to(torch.float)
185185
# listener
186186
if not agent.speaker and not agent.adversary:
187-
return torch.cat([key, *comm], dim=-1)
187+
return torch.cat([key, *comm], dim=-1).to(torch.float)
188188
# adv
189189
if not agent.speaker and agent.adversary:
190-
return torch.cat([*comm], dim=-1)
190+
return torch.cat([*comm], dim=-1).to(torch.float)

0 commit comments

Comments
 (0)