Skip to content

Commit 911550e

Browse files
committed
feat: add comparison operators for libdestruct objects
1 parent 57def08 commit 911550e

2 files changed

Lines changed: 105 additions & 5 deletions

File tree

libdestruct/common/obj.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,56 @@ def __repr__(self: obj) -> str:
130130
"""Return a string representation of the object."""
131131
return f"{self.__class__.__name__}({self.get()})"
132132

133-
def __eq__(self: obj, value: object) -> bool:
133+
def _compare_value(self: obj, other: object) -> tuple[object, object] | None:
134+
"""Extract comparable values from self and other, or None if incompatible."""
135+
self_val = self.value
136+
if isinstance(other, obj):
137+
return self_val, other.value
138+
if isinstance(other, int | float):
139+
return self_val, other
140+
return None
141+
142+
def __eq__(self: obj, other: object) -> bool:
134143
"""Return whether the object is equal to the given value."""
135-
if not isinstance(value, obj):
136-
return False
137-
138-
return self.get() == value.get()
144+
pair = self._compare_value(other)
145+
if pair is None:
146+
return NotImplemented
147+
return pair[0] == pair[1]
148+
149+
def __ne__(self: obj, other: object) -> bool:
150+
"""Return whether the object is not equal to the given value."""
151+
pair = self._compare_value(other)
152+
if pair is None:
153+
return NotImplemented
154+
return pair[0] != pair[1]
155+
156+
def __lt__(self: obj, other: object) -> bool:
157+
"""Return whether this object is less than the given value."""
158+
pair = self._compare_value(other)
159+
if pair is None:
160+
return NotImplemented
161+
return pair[0] < pair[1]
162+
163+
def __le__(self: obj, other: object) -> bool:
164+
"""Return whether this object is less than or equal to the given value."""
165+
pair = self._compare_value(other)
166+
if pair is None:
167+
return NotImplemented
168+
return pair[0] <= pair[1]
169+
170+
def __gt__(self: obj, other: object) -> bool:
171+
"""Return whether this object is greater than the given value."""
172+
pair = self._compare_value(other)
173+
if pair is None:
174+
return NotImplemented
175+
return pair[0] > pair[1]
176+
177+
def __ge__(self: obj, other: object) -> bool:
178+
"""Return whether this object is greater than or equal to the given value."""
179+
pair = self._compare_value(other)
180+
if pair is None:
181+
return NotImplemented
182+
return pair[0] >= pair[1]
139183

140184
def hexdump(self: obj) -> str:
141185
"""Return a hex dump of this object's bytes."""

test/scripts/types_unit_test.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,5 +513,61 @@ class big_t(struct):
513513
self.assertGreater(len(lines), 1)
514514

515515

516+
class ComparisonTest(unittest.TestCase):
517+
"""Comparison operators on primitive types."""
518+
519+
def test_int_gt_python_int(self):
520+
x = c_int.from_bytes((10).to_bytes(4, "little"))
521+
self.assertTrue(x > 5)
522+
self.assertFalse(x > 10)
523+
524+
def test_int_lt_python_int(self):
525+
x = c_int.from_bytes((3).to_bytes(4, "little"))
526+
self.assertTrue(x < 5)
527+
self.assertFalse(x < 3)
528+
529+
def test_int_ge_le(self):
530+
x = c_int.from_bytes((7).to_bytes(4, "little"))
531+
self.assertTrue(x >= 7)
532+
self.assertTrue(x >= 6)
533+
self.assertFalse(x >= 8)
534+
self.assertTrue(x <= 7)
535+
self.assertTrue(x <= 8)
536+
self.assertFalse(x <= 6)
537+
538+
def test_int_eq_python_int(self):
539+
x = c_int.from_bytes((42).to_bytes(4, "little"))
540+
self.assertTrue(x == 42)
541+
self.assertFalse(x == 43)
542+
543+
def test_int_ne_python_int(self):
544+
x = c_int.from_bytes((42).to_bytes(4, "little"))
545+
self.assertTrue(x != 43)
546+
self.assertFalse(x != 42)
547+
548+
def test_float_gt_python_float(self):
549+
x = c_float.from_bytes(pystruct.pack("<f", 3.14))
550+
self.assertTrue(x > 3.0)
551+
self.assertFalse(x > 4.0)
552+
553+
def test_float_eq_python_float(self):
554+
x = c_double.from_bytes(pystruct.pack("<d", 2.5))
555+
self.assertTrue(x == 2.5)
556+
self.assertFalse(x == 2.6)
557+
558+
def test_obj_vs_obj(self):
559+
a = c_int.from_bytes((10).to_bytes(4, "little"))
560+
b = c_int.from_bytes((20).to_bytes(4, "little"))
561+
self.assertTrue(a < b)
562+
self.assertTrue(b > a)
563+
self.assertTrue(a != b)
564+
self.assertFalse(a == b)
565+
566+
def test_comparison_returns_not_implemented_for_incompatible(self):
567+
x = c_int.from_bytes((1).to_bytes(4, "little"))
568+
self.assertFalse(x == "hello")
569+
self.assertTrue(x != "hello")
570+
571+
516572
if __name__ == "__main__":
517573
unittest.main()

0 commit comments

Comments
 (0)