Skip to content

Commit e22194c

Browse files
committed
Fix union and coalesce expressions not decoding to the correct type.
1 parent e520d9a commit e22194c

5 files changed

Lines changed: 149 additions & 0 deletions

File tree

gel/_internal/_codegen/_models/_pydantic.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5795,6 +5795,45 @@ def resolve(
57955795
f"# type: ignore [assignment, misc, unused-ignore]"
57965796
)
57975797

5798+
if function.schemapath in {
5799+
SchemaPath('std', 'UNION'),
5800+
SchemaPath('std', 'IF'),
5801+
SchemaPath('std', '??'),
5802+
}:
5803+
# Special case for the UNION, IF and ?? operators
5804+
# Produce a union type instead of just taking the first
5805+
# valid type.
5806+
#
5807+
# See gel: compile_operator
5808+
create_union = self.import_name(
5809+
BASE_IMPL, "create_union"
5810+
)
5811+
5812+
tvars: list[str] = []
5813+
for param, path in sources:
5814+
if (
5815+
param.name in required_generic_params
5816+
or param.name in optional_generic_params
5817+
):
5818+
pn = param_vars[param.name]
5819+
tvar = f"__t_{pn}__"
5820+
5821+
resolve(pn, path, tvar)
5822+
tvars.append(tvar)
5823+
5824+
self.write(
5825+
f"{gtvar} = {tvars[0]} "
5826+
f"# type: ignore [assignment, misc, unused-ignore]"
5827+
)
5828+
for tvar in tvars[1:]:
5829+
self.write(
5830+
f"{gtvar} = {create_union}({gtvar}, {tvar}) "
5831+
f"# type: ignore ["
5832+
f"assignment, misc, unused-ignore]"
5833+
)
5834+
5835+
continue
5836+
57985837
# Try to infer generic type from required params first
57995838
for param, path in sources:
58005839
if param.name in required_generic_params:

gel/_internal/_qbmodel/_abstract/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@
6565
from ._methods import (
6666
BaseGelModel,
6767
BaseGelModelIntersection,
68+
BaseGelModelUnion,
69+
create_union,
6870
)
6971

7072

@@ -132,6 +134,7 @@
132134
"ArrayMeta",
133135
"BaseGelModel",
134136
"BaseGelModelIntersection",
137+
"BaseGelModelUnion",
135138
"ComputedLinkSet",
136139
"ComputedLinkWithPropsSet",
137140
"ComputedMultiLinkDescriptor",
@@ -174,6 +177,7 @@
174177
"TupleMeta",
175178
"UUIDImpl",
176179
"copy_or_ref_lprops",
180+
"create_union",
177181
"empty_set_if_none",
178182
"field_descriptor",
179183
"get_base_scalars_backed_by_py_type",

gel/_internal/_qbmodel/_abstract/_methods.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from gel._internal import _qb
2020
from gel._internal._schemapath import (
2121
TypeNameIntersection,
22+
TypeNameUnion,
2223
)
2324
from gel._internal import _type_expression
2425
from gel._internal._xmethod import classonlymethod
@@ -256,6 +257,17 @@ class BaseGelModelIntersection(
256257
rhs: ClassVar[type[AbstractGelModel]]
257258

258259

260+
class BaseGelModelUnion(
261+
BaseGelModel,
262+
_type_expression.Union,
263+
Generic[_T_Lhs, _T_Rhs],
264+
):
265+
__gel_type_class__: ClassVar[type]
266+
267+
lhs: ClassVar[type[AbstractGelModel]]
268+
rhs: ClassVar[type[AbstractGelModel]]
269+
270+
259271
T = TypeVar('T')
260272
U = TypeVar('U')
261273

@@ -429,3 +441,93 @@ def process_path_alias(
429441
_type_intersection_cache[lhs][rhs] = result
430442

431443
return result
444+
445+
446+
_type_union_cache: weakref.WeakKeyDictionary[
447+
type[AbstractGelModel],
448+
weakref.WeakKeyDictionary[
449+
type[AbstractGelModel],
450+
type[
451+
BaseGelModelUnion[type[AbstractGelModel], type[AbstractGelModel]]
452+
],
453+
],
454+
] = weakref.WeakKeyDictionary()
455+
456+
457+
def create_union(
458+
lhs: _T_Lhs,
459+
rhs: _T_Rhs,
460+
) -> type[BaseGelModelUnion[_T_Lhs, _T_Rhs]]:
461+
"""Create a runtime union type which acts like a GelModel."""
462+
463+
if (lhs_entry := _type_union_cache.get(lhs)) and (
464+
rhs_entry := lhs_entry.get(rhs)
465+
):
466+
return rhs_entry # type: ignore[return-value]
467+
468+
# Combine pointer reflections from args
469+
ptr_reflections: dict[str, _qb.GelPointerReflection] = {
470+
p_name: p_refl
471+
for p_name, p_refl in lhs.__gel_reflection__.pointers.items()
472+
if p_name in rhs.__gel_reflection__.pointers
473+
}
474+
475+
# Create type reflection for union type
476+
class __gel_reflection__(_qb.GelObjectTypeExprMetadata.__gel_reflection__): # noqa: N801
477+
expr_object_types: set[type[AbstractGelModel]] = getattr(
478+
lhs.__gel_reflection__, 'expr_object_types', {lhs}
479+
) | getattr(rhs.__gel_reflection__, 'expr_object_types', {rhs})
480+
481+
type_name = TypeNameUnion(
482+
args=(
483+
lhs.__gel_reflection__.type_name,
484+
rhs.__gel_reflection__.type_name,
485+
)
486+
)
487+
488+
pointers = ptr_reflections
489+
490+
@classmethod
491+
def object(
492+
cls,
493+
) -> Any:
494+
raise NotImplementedError(
495+
"Type expressions schema objects are inaccessible"
496+
)
497+
498+
result = type(
499+
f"({lhs.__name__} | {rhs.__name__})",
500+
(BaseGelModelUnion,),
501+
{
502+
'lhs': lhs,
503+
'rhs': rhs,
504+
'__gel_reflection__': __gel_reflection__,
505+
},
506+
)
507+
508+
# Generate path aliases for pointers.
509+
#
510+
# These are used to generate the appropriate path prefix when getting
511+
# pointers in shapes.
512+
path_aliases: dict[str, _qb.PathAlias] = {
513+
p_name: l_path_alias
514+
for p_name, p_refl in lhs.__gel_reflection__.pointers.items()
515+
if (
516+
hasattr(lhs, p_name)
517+
and (l_path_alias := getattr(lhs, p_name, None)) is not None
518+
and isinstance(l_path_alias, _qb.PathAlias)
519+
)
520+
if (
521+
hasattr(rhs, p_name)
522+
and (r_path_alias := getattr(rhs, p_name, None)) is not None
523+
and isinstance(r_path_alias, _qb.PathAlias)
524+
)
525+
}
526+
for p_name, path_alias in path_aliases.items():
527+
setattr(result, p_name, path_alias)
528+
529+
if lhs not in _type_union_cache:
530+
_type_union_cache[lhs] = weakref.WeakKeyDictionary()
531+
_type_union_cache[lhs][rhs] = result
532+
533+
return result

gel/_internal/_typing_dispatch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ def _issubclass(lhs: Any, tp: Any, fn: Any) -> bool:
7070

7171
if issubclass(lhs, _type_expression.Intersection):
7272
return any(_issubclass(c, tp, fn) for c in (lhs.lhs, lhs.rhs))
73+
elif issubclass(lhs, _type_expression.Union):
74+
return all(_issubclass(c, tp, fn) for c in (lhs.lhs, lhs.rhs))
7375

7476
if _typing_inspect.is_generic_alias(tp):
7577
origin = typing.get_origin(tp)

gel/models/pydantic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
PyTypeScalarConstraint,
6969
RangeMeta,
7070
UUIDImpl,
71+
create_union,
7172
empty_set_if_none,
7273
)
7374

@@ -195,6 +196,7 @@
195196
"classonlymethod",
196197
"computed_field",
197198
"construct_infix_op_chain",
199+
"create_union",
198200
"dispatch_overload",
199201
"empty_set_if_none",
200202
)

0 commit comments

Comments
 (0)