Skip to content

Commit e520d9a

Browse files
authored
Fix passing type intersections to overloaded functions. (#955)
Fixes queries such as `std.distinct(default.Foo.is_(default.Bar))`
1 parent 6a9a396 commit e520d9a

5 files changed

Lines changed: 119 additions & 6 deletions

File tree

gel/_internal/_qbmodel/_abstract/_methods.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from gel._internal._schemapath import (
2121
TypeNameIntersection,
2222
)
23+
from gel._internal import _type_expression
2324
from gel._internal._xmethod import classonlymethod
2425

2526
from ._base import AbstractGelModel
@@ -246,6 +247,7 @@ def __edgeql_qb_expr__(cls) -> _qb.Expr: # pyright: ignore [reportIncompatibleM
246247

247248
class BaseGelModelIntersection(
248249
BaseGelModel,
250+
_type_expression.Intersection,
249251
Generic[_T_Lhs, _T_Rhs],
250252
):
251253
__gel_type_class__: ClassVar[type]

gel/_internal/_type_expression.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# SPDX-PackageName: gel-python
2+
# SPDX-License-Identifier: Apache-2.0
3+
# SPDX-FileCopyrightText: Copyright Gel Data Inc. and the contributors.
4+
5+
import typing
6+
7+
8+
class Intersection:
9+
lhs: typing.ClassVar[type]
10+
rhs: typing.ClassVar[type]
11+
12+
13+
class Union:
14+
lhs: typing.ClassVar[type]
15+
rhs: typing.ClassVar[type]

gel/_internal/_typing_dispatch.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import typing
3535

3636
from gel._internal import _namespace
37+
from gel._internal import _type_expression
3738
from gel._internal import _typing_eval
3839
from gel._internal import _typing_inspect
3940
from gel._internal import _typing_parametric
@@ -66,6 +67,10 @@ def _issubclass(lhs: Any, tp: Any, fn: Any) -> bool:
6667
# subtypes of the variable bounds.
6768
# This lets us handle cases like:
6869
# std.array[Object] <: std.array[_T_anytype].
70+
71+
if issubclass(lhs, _type_expression.Intersection):
72+
return any(_issubclass(c, tp, fn) for c in (lhs.lhs, lhs.rhs))
73+
6974
if _typing_inspect.is_generic_alias(tp):
7075
origin = typing.get_origin(tp)
7176
args = typing.get_args(tp)

tests/dbsetup/orm_qb.gel

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,4 +618,8 @@ type Link_Inh_A {
618618
};
619619
};
620620

621+
function Read_Inh_A(x: Inh_A) -> int64 using (x.a ?? -1);
622+
function Read_Inh_A_Overload(x: Inh_A) -> int64 using (x.a ?? -1);
623+
function Read_Inh_A_Overload(x: Inh_AB) -> int64 using (x.ab ?? -1);
624+
621625
}

tests/test_qb.py

Lines changed: 93 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1695,9 +1695,7 @@ def test_qb_is_type_basic_07(self):
16951695
# Link TypeIntersection
16961696
from models.orm_qb import default
16971697

1698-
result = self.client.query(
1699-
default.Link_Inh_A.l.is_(default.Inh_B)
1700-
)
1698+
result = self.client.query(default.Link_Inh_A.l.is_(default.Inh_B))
17011699

17021700
self._assertObjectsWithFields(
17031701
result,
@@ -1900,9 +1898,9 @@ def test_qb_is_type_for_01(self):
19001898
from models.orm_qb import default, std
19011899

19021900
result = self.client.query(
1903-
std.for_(
1904-
default.Inh_A.is_(default.Inh_B), lambda x: x
1905-
).select(a=True)
1901+
std.for_(default.Inh_A.is_(default.Inh_B), lambda x: x).select(
1902+
a=True
1903+
)
19061904
)
19071905

19081906
self._assertObjectsWithFields(
@@ -2014,6 +2012,95 @@ def test_qb_is_type_for_03(self):
20142012
excluded_fields={'b', 'c', 'ab', 'ac', 'bc', 'abc', 'ab_ac'},
20152013
)
20162014

2015+
def test_qb_is_type_as_function_arg_01(self):
2016+
# Test that type exprs produced by is_ can be passed as function args
2017+
from models.orm_qb import default, std
2018+
2019+
result = self.client.query(
2020+
std.distinct(default.Inh_A.is_(default.Inh_B)).select('*')
2021+
)
2022+
2023+
self._assertObjectsWithFields(
2024+
result,
2025+
"a",
2026+
[
2027+
(
2028+
default.Inh_AB,
2029+
{
2030+
"a": 4,
2031+
"b": 5,
2032+
},
2033+
),
2034+
(
2035+
default.Inh_ABC,
2036+
{
2037+
"a": 13,
2038+
"b": 14,
2039+
},
2040+
),
2041+
(
2042+
default.Inh_AB_AC,
2043+
{
2044+
"a": 17,
2045+
"b": 18,
2046+
},
2047+
),
2048+
],
2049+
excluded_fields={'c', 'ab', 'ac', 'bc', 'abc', 'ab_ac'},
2050+
)
2051+
2052+
def test_qb_is_type_as_function_arg_02(self):
2053+
# Test that complex type exprs produced by is_ can be passed as
2054+
# function args
2055+
from models.orm_qb import default, std
2056+
2057+
result = self.client.query(
2058+
std.distinct(
2059+
default.Inh_A.is_(default.Inh_B).is_(default.Inh_C)
2060+
).select('*')
2061+
)
2062+
2063+
self._assertObjectsWithFields(
2064+
result,
2065+
"a",
2066+
[
2067+
(
2068+
default.Inh_ABC,
2069+
{
2070+
"a": 13,
2071+
"b": 14,
2072+
"c": 15,
2073+
},
2074+
),
2075+
(
2076+
default.Inh_AB_AC,
2077+
{
2078+
"a": 17,
2079+
"b": 18,
2080+
"c": 19,
2081+
},
2082+
),
2083+
],
2084+
excluded_fields={'ab', 'ac', 'bc', 'abc', 'ab_ac'},
2085+
)
2086+
2087+
def test_qb_is_type_as_function_arg_03(self):
2088+
# Test that exprs produced by is_ can be passed as function args to
2089+
# user defined function
2090+
from models.orm_qb import default
2091+
2092+
# Note, we do Inh_A[is Inh_B] since is_ currently pretends its return
2093+
# type is its argument type.
2094+
result = self.client.query(
2095+
default.Read_Inh_A(default.Inh_B.is_(default.Inh_A))
2096+
)
2097+
self.assertEqual(sorted(result), [4, 13, 17])
2098+
2099+
result = self.client.query(
2100+
default.Read_Inh_A_Overload(default.Inh_B.is_(default.Inh_A))
2101+
)
2102+
self.assertEqual(sorted(result), [6, 13, 20])
2103+
20172104

20182105
class TestQueryBuilderModify(tb.ModelTestCase):
20192106
"""This test suite is for data manipulation using QB."""

0 commit comments

Comments
 (0)