From 3f4a0b9e90c975572cbf2e5388fde1852cc476ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20=C5=81opaciuk?= Date: Tue, 17 Mar 2026 17:12:09 +0100 Subject: [PATCH 1/9] Properly serialise multidimensional arrays --- .gitignore | 1 + tests/test_to_dict.py | 115 ++++++++++++++++++++++++++++++++++++++++++ xobjects/array.py | 22 ++++++-- 3 files changed, 133 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 5284f58f..ef478923 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ cov_html .coverage .idea *.c +**/.DS_Store diff --git a/tests/test_to_dict.py b/tests/test_to_dict.py index 040dc87e..b363626b 100644 --- a/tests/test_to_dict.py +++ b/tests/test_to_dict.py @@ -39,3 +39,118 @@ class Uref(xo.UnionRef): assert b[1].a[0] == 3 assert b[5].d == 1 + + +def test_to_dict_array_multidimensional_static_shape(): + array_type = xo.Float64[2, 3] + array = array_type([[1, 2, 3], [4, 5, 6]]) + + assert array._to_dict() == [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] + rebuilt = array_type(array._to_dict()) + assert rebuilt._to_dict() == array._to_dict() + + +def test_to_dict_array_multidimensional_dynamic_shape(): + array_type = xo.Float64[:, :] + array = array_type(2, 3) + array[0, 0] = 1 + array[0, 1] = 2 + array[0, 2] = 3 + array[1, 0] = 4 + array[1, 1] = 5 + array[1, 2] = 6 + + assert array._to_dict() == [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] + rebuilt = array_type(array._to_dict()) + assert rebuilt._to_dict() == array._to_dict() + + +def test_to_dict_array_of_structs(): + class Item(xo.Struct): + value = xo.Int64 + coords = xo.Float64[2] + + array_type = Item[:] + items = array_type( + [ + {"value": 3, "coords": [1, 2]}, + {"value": 7, "coords": [4, 5]}, + ] + ) + + assert items._to_dict() == [ + {"value": 3, "coords": [1.0, 2.0]}, + {"value": 7, "coords": [4.0, 5.0]}, + ] + rebuilt = array_type(items._to_dict()) + assert rebuilt._to_dict() == items._to_dict() + + +def test_to_dict_multidimensional_array_of_structs(): + class Item(xo.Struct): + value = xo.Int64 + coords = xo.Float64[2] + + array_type = Item[:, :] + items = array_type( + [ + [ + {"value": 1, "coords": [1, 2]}, + {"value": 2, "coords": [3, 4]}, + ], + [ + {"value": 3, "coords": [5, 6]}, + {"value": 4, "coords": [7, 8]}, + ], + ] + ) + + expected = [ + [ + {"value": 1, "coords": [1.0, 2.0]}, + {"value": 2, "coords": [3.0, 4.0]}, + ], + [ + {"value": 3, "coords": [5.0, 6.0]}, + {"value": 4, "coords": [7.0, 8.0]}, + ], + ] + + assert items._to_dict() == expected + rebuilt = array_type(items._to_dict()) + assert rebuilt._to_dict() == expected + + +def test_to_dict_multidimensional_array_of_structs_with_refs(): + class Item(xo.Struct): + values = xo.Ref[xo.Float64[:]] + weight = xo.Int64 + + array_type = Item[:, :] + items = array_type( + [ + [ + {"values": [1, 2], "weight": 3}, + {"values": [4, 5, 6], "weight": 7}, + ], + [ + {"values": [8], "weight": 9}, + {"values": [10, 11], "weight": 12}, + ], + ] + ) + + expected = [ + [ + {"values": [1.0, 2.0], "weight": 3}, + {"values": [4.0, 5.0, 6.0], "weight": 7}, + ], + [ + {"values": [8.0], "weight": 9}, + {"values": [10.0, 11.0], "weight": 12}, + ], + ] + + assert items._to_dict() == expected + rebuilt = array_type(items._to_dict()) + assert rebuilt._to_dict() == expected diff --git a/xobjects/array.py b/xobjects/array.py index be1191e0..6daab0ae 100644 --- a/xobjects/array.py +++ b/xobjects/array.py @@ -693,16 +693,28 @@ def _to_json(self): raise NameError("`_to_json` has been removed. Use `_to_dict` instead.") def _to_dict(self): - out = [] - for v in self: # TODO does not support multidimensional arrays + from .ref import is_ref, is_unionref + + if hasattr(self._itemtype, "_dtype"): + return self.to_nparray().tolist() + + out = np.empty(dtype=object, shape=self._shape) + + for idx in self._iter_index(): + idx = idx if isinstance(idx, tuple) else (idx,) + v = self[*idx] + if hasattr(v, "_to_dict"): vdata = v._to_dict() else: vdata = v - if self._has_refs and v is not None: + + if v is not None and (is_ref(self._itemtype) or is_unionref(self._itemtype)): vdata = (v.__class__.__name__, vdata) - out.append(vdata) - return out + + out[*idx] = vdata + + return out.tolist() def is_index(atype): From e7d75751c58a7763d0df8fa1741c73f9c6e5c558 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20=C5=81opaciuk?= Date: Tue, 17 Mar 2026 17:17:43 +0100 Subject: [PATCH 2/9] black --- xobjects/array.py | 4 +++- xobjects/context_cpu.py | 4 ++-- xobjects/struct.py | 3 ++- xobjects/test_helpers.py | 3 ++- 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/xobjects/array.py b/xobjects/array.py index 6daab0ae..efed639b 100644 --- a/xobjects/array.py +++ b/xobjects/array.py @@ -709,7 +709,9 @@ def _to_dict(self): else: vdata = v - if v is not None and (is_ref(self._itemtype) or is_unionref(self._itemtype)): + if v is not None and ( + is_ref(self._itemtype) or is_unionref(self._itemtype) + ): vdata = (v.__class__.__name__, vdata) out[*idx] = vdata diff --git a/xobjects/context_cpu.py b/xobjects/context_cpu.py index 8050a769..125a763d 100644 --- a/xobjects/context_cpu.py +++ b/xobjects/context_cpu.py @@ -316,7 +316,7 @@ def build_kernels( if _forbid_compile: raise RuntimeError("Compilation is forbidden") - if os.environ.get('XOBJECTS_FORBID_COMPILE'): + if os.environ.get("XOBJECTS_FORBID_COMPILE"): raise RuntimeError( "Compilation is forbidden by the environment variable " "XOBJECTS_FORBID_COMPILE" @@ -466,7 +466,7 @@ def compile_kernel( return Path(output_file) finally: # Clean temp files - if 'XOBJECTS_KEEP_BUILD_FILES' not in os.environ: + if "XOBJECTS_KEEP_BUILD_FILES" not in os.environ: files_to_remove = [ module_name + ".c", module_name + ".o", diff --git a/xobjects/struct.py b/xobjects/struct.py index 0fbbe937..d0444ba7 100644 --- a/xobjects/struct.py +++ b/xobjects/struct.py @@ -485,6 +485,7 @@ def compile_class_kernels( get_suitable_kernel, XSK_PREBUILT_KERNELS_LOCATION, ) + kernel_info = get_suitable_kernel( config={}, tracker_element_classes=[], @@ -496,7 +497,7 @@ def compile_class_kernels( Print.suppress = _print_state if kernel_info: kernels = context.kernels_from_file( - module_name=kernel_info['module_name'], + module_name=kernel_info["module_name"], containing_dir=XSK_PREBUILT_KERNELS_LOCATION, kernel_descriptions=cls._kernels, ) diff --git a/xobjects/test_helpers.py b/xobjects/test_helpers.py index 36806368..2d7c7382 100644 --- a/xobjects/test_helpers.py +++ b/xobjects/test_helpers.py @@ -107,8 +107,9 @@ def wrapper(*args, **kwargs): return decorator + def skip_if_forbid_compile(): - if os.environ.get('XOBJECTS_FORBID_COMPILE'): + if os.environ.get("XOBJECTS_FORBID_COMPILE"): pytest.skip( "Compilation is forbidden by the environment variable " "XOBJECTS_FORBID_COMPILE" From cba2c5174514429a980c2f62503dce20af02361c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20=C5=81opaciuk?= Date: Tue, 17 Mar 2026 17:31:15 +0100 Subject: [PATCH 3/9] Fix circular import ordering --- xobjects/array.py | 3 +-- xobjects/ref.py | 5 ++++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/xobjects/array.py b/xobjects/array.py index efed639b..f786b882 100644 --- a/xobjects/array.py +++ b/xobjects/array.py @@ -15,6 +15,7 @@ _to_slot_size, default_conf, ) +from .ref import is_ref, is_unionref from .scalar import Int64, is_scalar log = logging.getLogger(__name__) @@ -693,8 +694,6 @@ def _to_json(self): raise NameError("`_to_json` has been removed. Use `_to_dict` instead.") def _to_dict(self): - from .ref import is_ref, is_unionref - if hasattr(self._itemtype, "_dtype"): return self.to_nparray().tolist() diff --git a/xobjects/ref.py b/xobjects/ref.py index eaa26716..f94156d0 100644 --- a/xobjects/ref.py +++ b/xobjects/ref.py @@ -9,7 +9,6 @@ from .typeutils import Info, dispatch_arg, allocate_on_buffer, default_conf from .scalar import Int64 -from .array import Array log = logging.getLogger(__name__) @@ -72,6 +71,8 @@ def _inspect_args(self, arg): return Info(size=self._size) def __getitem__(self, shape): + from .array import Array + return Array.mk_arrayclass(self, shape) def _gen_data_paths(self, base=None): @@ -213,6 +214,8 @@ def _to_buffer(cls, buffer, offset, value, info=None): Int64._array_to_buffer(buffer, offset, ref) def __getitem__(cls, shape): + from .array import Array + return Array.mk_arrayclass(cls, shape) def _pre_init(cls, *arg, **kwargs): From bde88f2f06f69584d83f939eef052ce92ca93616 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20=C5=81opaciuk?= Date: Wed, 18 Mar 2026 09:41:21 +0100 Subject: [PATCH 4/9] Correctly serialise UnionRef fields in Structs --- tests/test_to_dict.py | 119 ++++++++++++++++++++++++++++++++++++++++++ xobjects/struct.py | 10 ++++ 2 files changed, 129 insertions(+) diff --git a/tests/test_to_dict.py b/tests/test_to_dict.py index b363626b..237ce317 100644 --- a/tests/test_to_dict.py +++ b/tests/test_to_dict.py @@ -154,3 +154,122 @@ class Item(xo.Struct): assert items._to_dict() == expected rebuilt = array_type(items._to_dict()) assert rebuilt._to_dict() == expected + + +def test_to_dict_struct_with_unionref_field(): + class A(xo.Struct): + a = xo.Float64[:] + b = xo.Int64 + + class B(xo.Struct): + c = xo.Float64[:] + d = xo.Int64 + + class Uref(xo.UnionRef): + _reftypes = (A, B) + + class Item(xo.Struct): + ref = Uref + weight = xo.Int64 + + item = Item(ref=("A", {"a": [1, 2], "b": 3}), weight=11) + + expected = { + "ref": ("A", {"a": [1.0, 2.0], "b": 3}), + "weight": 11, + } + + assert item._to_dict() == expected + rebuilt = Item(item._to_dict()) + assert rebuilt._to_dict() == expected + + +def test_to_dict_struct_containing_array_of_structs_with_unionref_fields(): + class A(xo.Struct): + a = xo.Float64[:] + b = xo.Int64 + + class B(xo.Struct): + c = xo.Float64[:] + d = xo.Int64 + + class Uref(xo.UnionRef): + _reftypes = (A, B) + + class Item(xo.Struct): + ref = Uref + weight = xo.Int64 + + class Container(xo.Struct): + items = Item[:] + tag = xo.Int64 + + container = Container( + items=[ + {"ref": ("A", {"a": [1, 2], "b": 3}), "weight": 11}, + {"ref": ("B", {"c": [4, 5], "d": 6}), "weight": 12}, + ], + tag=99, + ) + + expected = { + "items": [ + {"ref": ("A", {"a": [1.0, 2.0], "b": 3}), "weight": 11}, + {"ref": ("B", {"c": [4.0, 5.0], "d": 6}), "weight": 12}, + ], + "tag": 99, + } + + assert container._to_dict() == expected + rebuilt = Container(container._to_dict()) + assert rebuilt._to_dict() == expected + + +def test_to_dict_struct_containing_multidimensional_array_of_structs_with_unionref_fields(): + class A(xo.Struct): + a = xo.Float64[:] + b = xo.Int64 + + class B(xo.Struct): + c = xo.Float64[:] + d = xo.Int64 + + class Uref(xo.UnionRef): + _reftypes = (A, B) + + class Item(xo.Struct): + ref = Uref + weight = xo.Int64 + + class Container(xo.Struct): + items = Item[:, :] + + container = Container( + items=[ + [ + {"ref": ("A", {"a": [1, 2], "b": 3}), "weight": 11}, + {"ref": ("B", {"c": [4, 5], "d": 6}), "weight": 12}, + ], + [ + {"ref": ("B", {"c": [7], "d": 8}), "weight": 13}, + {"ref": ("A", {"a": [9, 10], "b": 14}), "weight": 15}, + ], + ] + ) + + expected = { + "items": [ + [ + {"ref": ("A", {"a": [1.0, 2.0], "b": 3}), "weight": 11}, + {"ref": ("B", {"c": [4.0, 5.0], "d": 6}), "weight": 12}, + ], + [ + {"ref": ("B", {"c": [7.0], "d": 8}), "weight": 13}, + {"ref": ("A", {"a": [9.0, 10.0], "b": 14}), "weight": 15}, + ], + ] + } + + assert container._to_dict() == expected + rebuilt = Container(container._to_dict()) + assert rebuilt._to_dict() == expected diff --git a/xobjects/struct.py b/xobjects/struct.py index d0444ba7..635e5eb4 100644 --- a/xobjects/struct.py +++ b/xobjects/struct.py @@ -63,6 +63,7 @@ from .array import Array from .context import Source, Arg, Kernel from .context_cpu import ContextCpu +from .ref import is_unionref log = logging.getLogger(__name__) @@ -374,6 +375,15 @@ def _to_dict(self): out = {} for field in self._fields: v = field.__get__(self) + if is_unionref(field.ftype): + if v is None: + out[field.name] = None + else: + classname = v.__class__.__name__ + if hasattr(v, "_to_dict"): + v = v._to_dict() + out[field.name] = (classname, v) + continue if hasattr(v, "_to_dict"): v = v._to_dict() out[field.name] = v From 632740832958aa009ff9de070b02637eea8e8a6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20=C5=81opaciuk?= Date: Wed, 18 Mar 2026 09:49:17 +0100 Subject: [PATCH 5/9] Correct serialisation of None fields of UnionRef --- tests/test_to_dict.py | 44 +++++++++++++++++++++++++++++++++++++++++++ xobjects/ref.py | 2 ++ 2 files changed, 46 insertions(+) diff --git a/tests/test_to_dict.py b/tests/test_to_dict.py index 237ce317..fb20f73c 100644 --- a/tests/test_to_dict.py +++ b/tests/test_to_dict.py @@ -41,6 +41,50 @@ class Uref(xo.UnionRef): assert b[5].d == 1 +def test_to_dict_unionref_none_roundtrip(): + class A(xo.Struct): + x = xo.Int64 + + class B(xo.Struct): + y = xo.Int64 + + class Uref(xo.UnionRef): + _reftypes = (A, B) + + uref = Uref(None) + + assert uref._to_dict() is None + rebuilt = Uref(uref._to_dict()) + assert rebuilt.get() is None + + +def test_to_dict_array_of_unionrefs_with_mixed_none_and_values(): + class A(xo.Struct): + x = xo.Int64 + + class B(xo.Struct): + y = xo.Int64 + + class Uref(xo.UnionRef): + _reftypes = (A, B) + + array_type = Uref[:] + items = array_type(4) + items[1] = A(x=5) + items[3] = B(y=9) + + expected = [ + None, + ("A", {"x": 5}), + None, + ("B", {"y": 9}), + ] + + assert items._to_dict() == expected + rebuilt = array_type(items._to_dict()) + assert rebuilt._to_dict() == expected + + def test_to_dict_array_multidimensional_static_shape(): array_type = xo.Float64[2, 3] array = array_type([[1, 2, 3], [4, 5, 6]]) diff --git a/xobjects/ref.py b/xobjects/ref.py index f94156d0..5df98ec9 100644 --- a/xobjects/ref.py +++ b/xobjects/ref.py @@ -301,6 +301,8 @@ def _to_json(self): def _to_dict(self): v = self.get() + if v is None: + return None classname = v.__class__.__name__ if hasattr(v, "_to_dict"): v = v._to_dict() From c90db2e7f756230ece0238c26304ed3ce640e79e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20=C5=81opaciuk?= Date: Tue, 24 Mar 2026 13:42:17 +0100 Subject: [PATCH 6/9] Allow passing None for NULL pointer valued arguments to kernels --- tests/test_kernel.py | 39 ++++++++++++++++++++++++++++++++++++ xobjects/context_cpu.py | 2 ++ xobjects/context_cupy.py | 2 ++ xobjects/context_pyopencl.py | 2 ++ 4 files changed, 45 insertions(+) diff --git a/tests/test_kernel.py b/tests/test_kernel.py index e8547bc5..e381c11b 100644 --- a/tests/test_kernel.py +++ b/tests/test_kernel.py @@ -44,6 +44,45 @@ def test_kernel_cpu(): assert y == 285.0 +@for_all_test_contexts +def test_kernel_pointer_none_maps_to_null(test_context): + src_code = """ + #include "xobjects/headers/common.h" + + GPUKERN void ptr_is_null( + GPUGLMEM const double* x, + GPUGLMEM double* out + ) { + VECTORIZE_OVER(tid, 1) + out[tid] = x == NULL ? 1.0 : 0.0; + END_VECTORIZE + } + """ + + kernel_descriptions = { + "ptr_is_null": xo.Kernel( + args=[ + xo.Arg(xo.Float64, pointer=True, const=True, name="x"), + xo.Arg(xo.Float64, pointer=True, name="out"), + ], + n_threads=1, + ) + } + + test_context.add_kernels( + sources=[src_code], + kernels=kernel_descriptions, + save_source_as=None, + compile=True, + ) + + out_dev = test_context.zeros(1, dtype=np.float64) + test_context.kernels.ptr_is_null(x=None, out=out_dev) + out_host = test_context.nparray_from_context_array(out_dev) + + assert out_host[0] == 1.0 + + @for_all_test_contexts def test_kernels(test_context): src_code = """ diff --git a/xobjects/context_cpu.py b/xobjects/context_cpu.py index 125a763d..90952bf3 100644 --- a/xobjects/context_cpu.py +++ b/xobjects/context_cpu.py @@ -799,6 +799,8 @@ def __init__( def to_function_arg(self, arg, value): if arg.pointer: + if value is None: + return self.ffi_interface.NULL if hasattr(arg.atype, "_dtype"): # it is numerical scalar if hasattr(value, "dtype"): # nparray slice_first_elem = value[tuple(value.ndim * [slice(0, 1)])] diff --git a/xobjects/context_cupy.py b/xobjects/context_cupy.py index acf7686c..bd65115b 100644 --- a/xobjects/context_cupy.py +++ b/xobjects/context_cupy.py @@ -661,6 +661,8 @@ def __init__( def to_function_arg(self, arg, value): if arg.pointer: + if value is None: + return 0 if hasattr(arg.atype, "_dtype"): # it is numerical scalar if hasattr(value, "dtype"): # nparray assert isinstance(value, cupy.ndarray) diff --git a/xobjects/context_pyopencl.py b/xobjects/context_pyopencl.py index 13bac43e..9538a3df 100644 --- a/xobjects/context_pyopencl.py +++ b/xobjects/context_pyopencl.py @@ -491,6 +491,8 @@ def __init__( def to_function_arg(self, arg, value): if arg.pointer: + if value is None: + return None if hasattr(arg.atype, "_dtype"): # it is numerical scalar if isinstance(value, cl.Buffer): return value From 7b40abc3e85013fc4f0ff4315de548658105c6aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20=C5=81opaciuk?= Date: Thu, 26 Mar 2026 16:50:17 +0100 Subject: [PATCH 7/9] Accept a list when deserialising a UnionRef --- xobjects/ref.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xobjects/ref.py b/xobjects/ref.py index 5df98ec9..ac10a259 100644 --- a/xobjects/ref.py +++ b/xobjects/ref.py @@ -183,7 +183,9 @@ def _to_buffer(cls, buffer, offset, value, info=None): else: if value is None: xobj = None - elif isinstance(value, tuple): + elif isinstance( + value, (tuple, list) + ): # accept list as it might be coming from JSON if len(value) == 0: xobj = None typeid = None From 81e9f522b180a02a79e4453609a125d8fee61054 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20=C5=81opaciuk?= Date: Fri, 27 Mar 2026 10:47:19 +0100 Subject: [PATCH 8/9] Add a ThisClass option also to Struct --- tests/test_struct.py | 32 ++++++++++++++++++++++++++++++++ xobjects/__init__.py | 4 ++-- xobjects/hybrid_class.py | 18 +++++------------- xobjects/struct.py | 34 +++++++++++++++++++++++++++++++++- 4 files changed, 72 insertions(+), 16 deletions(-) diff --git a/tests/test_struct.py b/tests/test_struct.py index 19f763a2..76508a2b 100644 --- a/tests/test_struct.py +++ b/tests/test_struct.py @@ -333,6 +333,8 @@ class TestClass(xo.HybridClass): def myfun(self): return self._context.kernels.myfun(tc=self) + assert TestClass._kernels["myfun"].args[0].atype is TestClass._XoStruct + tc = TestClass(x=3, y=4, _context=test_context) tc.compile_kernels(only_if_needed=True) assert tc.myfun() == 12 @@ -343,3 +345,33 @@ def myfun(self): tc.compile_kernels(only_if_needed=True) assert tc.myfun() == 35 cffi_compile.assert_not_called() + + +@requires_context("ContextCpu") +def test_thisclass_placeholder_on_struct(): + test_context = xo.ContextCpu() + + class TestStruct(xo.Struct): + x = xo.Float64 + y = xo.Float64 + + _extra_c_sources = [""" + /*gpufun*/ double myfun(TestStruct tc){ + double x = TestStruct_get_x(tc); + double y = TestStruct_get_y(tc); + return x * y; + } + """] + _kernels = { + "myfun": xo.Kernel( + args=[xo.Arg(xo.ThisClass, name="tc")], + ret=xo.Arg(xo.Float64), + ), + } + + assert TestStruct._kernels["myfun"].args[0].atype is TestStruct + + ts = TestStruct(x=3, y=4, _context=test_context) + ts.compile_kernels() + + assert test_context.kernels.myfun(tc=ts) == 12 diff --git a/xobjects/__init__.py b/xobjects/__init__.py index 0d9f709b..087b1e75 100644 --- a/xobjects/__init__.py +++ b/xobjects/__init__.py @@ -17,7 +17,7 @@ ) from .array import Array from .string import String -from .struct import Struct, Field +from .struct import Struct, Field, ThisClass from .ref import Ref, UnionRef from .context_cpu import ContextCpu @@ -30,7 +30,7 @@ from .typeutils import context_default, get_a_buffer -from .hybrid_class import JEncoder, HybridClass, MetaHybridClass, ThisClass +from .hybrid_class import JEncoder, HybridClass, MetaHybridClass from .linkedarray import BypassLinked diff --git a/xobjects/hybrid_class.py b/xobjects/hybrid_class.py index 8859b2aa..f83086cd 100644 --- a/xobjects/hybrid_class.py +++ b/xobjects/hybrid_class.py @@ -127,8 +127,11 @@ def __new__(cls, name, bases, data): # Take xofields from data['_xofields'] or from bases xofields = _build_xofields_dict(bases, data) + xostruct_data = xofields.copy() + if "_kernels" in data.keys(): + xostruct_data["_kernels"] = data["_kernels"] - _XoStruct = type(_XoStruct_name, (Struct,), xofields) + _XoStruct = type(_XoStruct_name, (Struct,), xostruct_data) if "_rename" in data.keys(): rename = data["_rename"] @@ -184,14 +187,7 @@ def __new__(cls, name, bases, data): new_class._XoStruct._depends_on.extend(data["_depends_on"]) if "_kernels" in data.keys(): - kernels = data["_kernels"].copy() - for nn, kk in kernels.items(): - for aa in kk.args: - if aa.atype is ThisClass: - aa.atype = new_class._XoStruct - if isclass(aa.atype) and issubclass(aa.atype, HybridClass): - aa.atype = aa.atype._XoStruct - new_class._XoStruct._kernels.update(kernels) + new_class._kernels = new_class._XoStruct._kernels for ii, tt in enumerate(new_class._XoStruct._depends_on): if isclass(tt) and issubclass(tt, HybridClass): @@ -422,7 +418,3 @@ def __repr__(self): vvrepr = repr(vv) args.append(f"{fname}={vvrepr}") return f'{type(self).__name__}({", ".join(args)})' - - -class ThisClass: # Place holder - pass diff --git a/xobjects/struct.py b/xobjects/struct.py index 635e5eb4..f11c547a 100644 --- a/xobjects/struct.py +++ b/xobjects/struct.py @@ -46,6 +46,7 @@ """ +import copy import logging from typing import Callable, Optional @@ -68,6 +69,18 @@ log = logging.getLogger(__name__) +class ThisClass: # Place holder + pass + + +def _resolve_kernel_arg_type(owner, atype): + if atype is ThisClass: + return owner + if hasattr(atype, "_XoStruct"): + return atype._XoStruct + return atype + + class Field: def __init__( self, ftype, default=None, readonly=False, default_factory=None @@ -266,7 +279,26 @@ def _inspect_args(cls, *args, **kwargs): if "_kernels" not in data.keys(): data["_kernels"] = {} - return type.__new__(cls, name, bases, data) + new_class = type.__new__(cls, name, bases, data) + resolved_kernels = {} + for kernel_name, kernel in (new_class._kernels or {}).items(): + resolved_kernel = copy.copy(kernel) + resolved_kernel.args = [] + for arg in kernel.args: + resolved_arg = copy.copy(arg) + resolved_arg.atype = _resolve_kernel_arg_type( + new_class, resolved_arg.atype + ) + resolved_kernel.args.append(resolved_arg) + if isinstance(kernel.ret, Arg): + resolved_ret = copy.copy(kernel.ret) + resolved_ret.atype = _resolve_kernel_arg_type( + new_class, resolved_ret.atype + ) + resolved_kernel.ret = resolved_ret + resolved_kernels[kernel_name] = resolved_kernel + new_class._kernels = resolved_kernels + return new_class def __getitem__(cls, shape): return Array.mk_arrayclass(cls, shape) From 7b1e036264f708788c52b9b5c20789e9487d8e7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20=C5=81opaciuk?= Date: Mon, 30 Mar 2026 10:10:39 +0200 Subject: [PATCH 9/9] Fix pytest config in pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d64975b1..816611e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ include = '\.pyi?$' [project.entry-points.xobjects] include = "xobjects" -[pytest] +[tool.pytest] markers = [ "context_dependent: marks test as one that depends on the execution context", ]