|
19 | 19 | from gel._internal import _qb |
20 | 20 | from gel._internal._schemapath import ( |
21 | 21 | TypeNameIntersection, |
| 22 | + TypeNameUnion, |
22 | 23 | ) |
23 | 24 | from gel._internal import _type_expression |
24 | 25 | from gel._internal._xmethod import classonlymethod |
@@ -256,6 +257,17 @@ class BaseGelModelIntersection( |
256 | 257 | rhs: ClassVar[type[AbstractGelModel]] |
257 | 258 |
|
258 | 259 |
|
| 260 | +class BaseGelModelUnion( |
| 261 | + BaseGelModel, |
| 262 | + _type_expression.Union, |
| 263 | + Generic[_T_Lhs, _T_Rhs], |
| 264 | +): |
| 265 | + __gel_type_class__: ClassVar[type] |
| 266 | + |
| 267 | + lhs: ClassVar[type[AbstractGelModel]] |
| 268 | + rhs: ClassVar[type[AbstractGelModel]] |
| 269 | + |
| 270 | + |
259 | 271 | T = TypeVar('T') |
260 | 272 | U = TypeVar('U') |
261 | 273 |
|
@@ -429,3 +441,93 @@ def process_path_alias( |
429 | 441 | _type_intersection_cache[lhs][rhs] = result |
430 | 442 |
|
431 | 443 | return result |
| 444 | + |
| 445 | + |
| 446 | +_type_union_cache: weakref.WeakKeyDictionary[ |
| 447 | + type[AbstractGelModel], |
| 448 | + weakref.WeakKeyDictionary[ |
| 449 | + type[AbstractGelModel], |
| 450 | + type[ |
| 451 | + BaseGelModelUnion[type[AbstractGelModel], type[AbstractGelModel]] |
| 452 | + ], |
| 453 | + ], |
| 454 | +] = weakref.WeakKeyDictionary() |
| 455 | + |
| 456 | + |
| 457 | +def create_union( |
| 458 | + lhs: _T_Lhs, |
| 459 | + rhs: _T_Rhs, |
| 460 | +) -> type[BaseGelModelUnion[_T_Lhs, _T_Rhs]]: |
| 461 | + """Create a runtime union type which acts like a GelModel.""" |
| 462 | + |
| 463 | + if (lhs_entry := _type_union_cache.get(lhs)) and ( |
| 464 | + rhs_entry := lhs_entry.get(rhs) |
| 465 | + ): |
| 466 | + return rhs_entry # type: ignore[return-value] |
| 467 | + |
| 468 | + # Combine pointer reflections from args |
| 469 | + ptr_reflections: dict[str, _qb.GelPointerReflection] = { |
| 470 | + p_name: p_refl |
| 471 | + for p_name, p_refl in lhs.__gel_reflection__.pointers.items() |
| 472 | + if p_name in rhs.__gel_reflection__.pointers |
| 473 | + } |
| 474 | + |
| 475 | + # Create type reflection for union type |
| 476 | + class __gel_reflection__(_qb.GelObjectTypeExprMetadata.__gel_reflection__): # noqa: N801 |
| 477 | + expr_object_types: set[type[AbstractGelModel]] = getattr( |
| 478 | + lhs.__gel_reflection__, 'expr_object_types', {lhs} |
| 479 | + ) | getattr(rhs.__gel_reflection__, 'expr_object_types', {rhs}) |
| 480 | + |
| 481 | + type_name = TypeNameUnion( |
| 482 | + args=( |
| 483 | + lhs.__gel_reflection__.type_name, |
| 484 | + rhs.__gel_reflection__.type_name, |
| 485 | + ) |
| 486 | + ) |
| 487 | + |
| 488 | + pointers = ptr_reflections |
| 489 | + |
| 490 | + @classmethod |
| 491 | + def object( |
| 492 | + cls, |
| 493 | + ) -> Any: |
| 494 | + raise NotImplementedError( |
| 495 | + "Type expressions schema objects are inaccessible" |
| 496 | + ) |
| 497 | + |
| 498 | + result = type( |
| 499 | + f"({lhs.__name__} | {rhs.__name__})", |
| 500 | + (BaseGelModelUnion,), |
| 501 | + { |
| 502 | + 'lhs': lhs, |
| 503 | + 'rhs': rhs, |
| 504 | + '__gel_reflection__': __gel_reflection__, |
| 505 | + }, |
| 506 | + ) |
| 507 | + |
| 508 | + # Generate path aliases for pointers. |
| 509 | + # |
| 510 | + # These are used to generate the appropriate path prefix when getting |
| 511 | + # pointers in shapes. |
| 512 | + path_aliases: dict[str, _qb.PathAlias] = { |
| 513 | + p_name: l_path_alias |
| 514 | + for p_name, p_refl in lhs.__gel_reflection__.pointers.items() |
| 515 | + if ( |
| 516 | + hasattr(lhs, p_name) |
| 517 | + and (l_path_alias := getattr(lhs, p_name, None)) is not None |
| 518 | + and isinstance(l_path_alias, _qb.PathAlias) |
| 519 | + ) |
| 520 | + if ( |
| 521 | + hasattr(rhs, p_name) |
| 522 | + and (r_path_alias := getattr(rhs, p_name, None)) is not None |
| 523 | + and isinstance(r_path_alias, _qb.PathAlias) |
| 524 | + ) |
| 525 | + } |
| 526 | + for p_name, path_alias in path_aliases.items(): |
| 527 | + setattr(result, p_name, path_alias) |
| 528 | + |
| 529 | + if lhs not in _type_union_cache: |
| 530 | + _type_union_cache[lhs] = weakref.WeakKeyDictionary() |
| 531 | + _type_union_cache[lhs][rhs] = result |
| 532 | + |
| 533 | + return result |
0 commit comments