Skip to content

Commit 6137223

Browse files
DeepMindcopybara-github
authored andcommitted
Allow custom spec types to be passed through in composer obs updater.
PiperOrigin-RevId: 530928000 Change-Id: If155fb61320163c4f256300771dcc836bc7a6aae
1 parent 0cb8fcb commit 6137223

2 files changed

Lines changed: 38 additions & 0 deletions

File tree

dm_control/composer/observation/updater.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,16 @@ def make_observation_spec_dict(enabled_dict):
195195
out_dict = type(enabled_dict)()
196196
for name, enabled in enabled_dict.items():
197197

198+
if (enabled.observable.aggregator is None
199+
and enabled.observable.array_spec is not None):
200+
# If possible, keep the original array spec, just updating the name
201+
# and modifying the dimension for buffering. Doing this allows for
202+
# custom spec types to be exposed by the environment where possible.
203+
out_dict[name] = enabled.observable.array_spec.replace(
204+
name=name, shape=enabled.buffer.shape
205+
)
206+
continue
207+
198208
if isinstance(enabled.observable.array_spec, specs.BoundedArray):
199209
bounds = (enabled.observable.array_spec.minimum,
200210
enabled.observable.array_spec.maximum)

dm_control/composer/observation/updater_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,17 @@ def array_spec(self):
5454
maximum=self._bounds[1])
5555

5656

57+
class MyArraySpec(specs.Array):
58+
pass
59+
60+
61+
class GenericObservableWithMyArraySpec(observable.Generic):
62+
@property
63+
def array_spec(self):
64+
datum = np.array(self(None, None))
65+
return MyArraySpec(shape=datum.shape, dtype=datum.dtype)
66+
67+
5768
class UpdaterTest(parameterized.TestCase):
5869

5970
@parameterized.parameters(list, tuple)
@@ -163,6 +174,23 @@ def testObservationSpecInference(self):
163174
self.assertCorrectSpec(spec['matrix'], (4, 2, 3), int, 'matrix')
164175
self.assertCorrectSpec(spec['sqrt'], (3,), float, 'sqrt')
165176

177+
def testCustomSpecTypePassedThrough(self):
178+
physics = fake_physics.FakePhysics()
179+
physics.observables['two_twos'] = GenericObservableWithMyArraySpec(
180+
lambda _: [2.0, 2.0], buffer_size=3
181+
)
182+
183+
physics.observables['two_twos'].enabled = True
184+
185+
observation_updater = updater.Updater(physics.observables)
186+
observation_updater.reset(physics=physics, random_state=None)
187+
188+
spec = observation_updater.observation_spec()
189+
self.assertIsInstance(spec['two_twos'], MyArraySpec)
190+
self.assertEqual(spec['two_twos'].shape, (3, 2))
191+
self.assertEqual(spec['two_twos'].dtype, float)
192+
self.assertEqual(spec['two_twos'].name, 'two_twos')
193+
166194
@parameterized.parameters(True, False)
167195
def testObservation(self, pad_with_initial_value):
168196
physics = fake_physics.FakePhysics()

0 commit comments

Comments
 (0)