Skip to content

Commit 03505ab

Browse files
committed
fix: implement caching for dereferenced pointer objects
1 parent bda7078 commit 03505ab

2 files changed

Lines changed: 105 additions & 5 deletions

File tree

libdestruct/common/ptr/ptr.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def __init__(self: ptr, resolver: Resolver, wrapper: type | None = None) -> None
5757
"""
5858
super().__init__(resolver)
5959
self.wrapper = wrapper
60+
self._cached_unwrap: obj | None = None
61+
self._cache_valid: bool = False
6062

6163
def get(self: ptr) -> int:
6264
"""Return the value of the pointer."""
@@ -73,32 +75,47 @@ def to_bytes(self: obj) -> bytes:
7375
def _set(self: ptr, value: int) -> None:
7476
"""Set the value of the pointer to the given value."""
7577
self.resolver.modify(self.size, 0, value.to_bytes(self.size, self.endianness))
78+
self.invalidate()
79+
80+
def invalidate(self: ptr) -> None:
81+
"""Clear the cached unwrap result."""
82+
self._cached_unwrap = None
83+
self._cache_valid = False
7684

7785
def unwrap(self: ptr, length: int | None = None) -> obj:
7886
"""Return the object pointed to by the pointer.
7987
8088
Args:
8189
length: The length of the object in memory this points to.
8290
"""
91+
if self._cache_valid:
92+
return self._cached_unwrap
93+
8394
address = self.get()
8495

8596
if self.wrapper:
8697
if length:
8798
raise ValueError("Length is not supported when unwrapping a pointer to a wrapper object.")
8899

89-
return self.wrapper(self.resolver.absolute_from_own(address))
90-
91-
if not length:
92-
length = 1
100+
result = self.wrapper(self.resolver.absolute_from_own(address))
101+
elif not length:
102+
result = self.resolver.resolve(1, 0)
103+
else:
104+
result = self.resolver.resolve(length, 0)
93105

94-
return self.resolver.resolve(length, 0)
106+
self._cached_unwrap = result
107+
self._cache_valid = True
108+
return result
95109

96110
def try_unwrap(self: ptr, length: int | None = None) -> obj | None:
97111
"""Return the object pointed to by the pointer, if it is valid.
98112
99113
Args:
100114
length: The length of the object in memory this points to.
101115
"""
116+
if self._cache_valid:
117+
return self._cached_unwrap
118+
102119
address = self.get()
103120

104121
try:

test/scripts/types_unit_test.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,89 @@ def test_ptr_arithmetic_chain(self):
174174

175175
self.assertEqual((p + 2)[0].value, 3)
176176

177+
def test_unwrap_cached(self):
178+
"""Two unwrap() calls return the same object."""
179+
class test_t(struct):
180+
a: c_int
181+
p: ptr = ptr_to_self()
182+
183+
memory = bytearray(12)
184+
memory[0:4] = (42).to_bytes(4, "little")
185+
memory[4:12] = (0).to_bytes(8, "little")
186+
187+
lib = inflater(memory)
188+
test = lib.inflate(test_t, 0)
189+
190+
r1 = test.p.unwrap()
191+
r2 = test.p.unwrap()
192+
self.assertIs(r1, r2)
193+
194+
def test_invalidate_clears_cache(self):
195+
"""invalidate() causes next unwrap() to return a new object."""
196+
class test_t(struct):
197+
a: c_int
198+
p: ptr = ptr_to_self()
199+
200+
memory = bytearray(12)
201+
memory[0:4] = (42).to_bytes(4, "little")
202+
memory[4:12] = (0).to_bytes(8, "little")
203+
204+
lib = inflater(memory)
205+
test = lib.inflate(test_t, 0)
206+
207+
r1 = test.p.unwrap()
208+
test.p.invalidate()
209+
r2 = test.p.unwrap()
210+
self.assertIsNot(r1, r2)
211+
212+
def test_cache_reflects_memory_change(self):
213+
"""After memory change + invalidate, unwrap gets new value."""
214+
class test_t(struct):
215+
a: c_int
216+
p: ptr = ptr_to_self()
217+
218+
memory = bytearray(12)
219+
memory[0:4] = (42).to_bytes(4, "little")
220+
memory[4:12] = (0).to_bytes(8, "little")
221+
222+
lib = inflater(memory)
223+
test = lib.inflate(test_t, 0)
224+
225+
self.assertEqual(test.p.unwrap().a.value, 42)
226+
memory[0:4] = (99).to_bytes(4, "little")
227+
test.p.invalidate()
228+
self.assertEqual(test.p.unwrap().a.value, 99)
229+
230+
def test_try_unwrap_cached(self):
231+
"""try_unwrap() also uses cache."""
232+
class test_t(struct):
233+
a: c_int
234+
p: ptr = ptr_to_self()
235+
236+
memory = bytearray(12)
237+
memory[0:4] = (42).to_bytes(4, "little")
238+
memory[4:12] = (0).to_bytes(8, "little")
239+
240+
lib = inflater(memory)
241+
test = lib.inflate(test_t, 0)
242+
243+
r1 = test.p.try_unwrap()
244+
r2 = test.p.try_unwrap()
245+
self.assertIs(r1, r2)
246+
247+
def test_cache_invalidated_on_set(self):
248+
"""ptr.value = new_addr auto-invalidates the cache."""
249+
memory = bytearray(8 + 8) # ptr + two c_int slots
250+
memory[0:8] = (8).to_bytes(8, "little") # points to offset 8
251+
memory[8:12] = (10).to_bytes(4, "little")
252+
memory[12:16] = (20).to_bytes(4, "little")
253+
254+
p = ptr(MemoryResolver(memory, 0), c_int)
255+
256+
self.assertEqual(p.unwrap().value, 10)
257+
p.value = 12 # now points to offset 12
258+
self.assertEqual(p.unwrap().value, 20)
259+
177260

178261
class FloatTest(unittest.TestCase):
179262
"""c_float and c_double types."""

0 commit comments

Comments
 (0)