Skip to content

Commit e0a14b6

Browse files
authored
Support passing python types to scalars in casts. (#953)
Instead of generating cast functions for everything and letting the server figure it out, generate cast functions for scalars based on the reflected explicit casts. Add support to convert python types into gel scalars if possible. For example: ```py std.str.cast(1) std.str.cast("abc") ``` This works for basic scalars (like `int64` and `str`) as well as enums.
1 parent e520d9a commit e0a14b6

7 files changed

Lines changed: 195 additions & 18 deletions

File tree

gel/_internal/_codegen/_models/_pydantic.py

Lines changed: 144 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2575,17 +2575,38 @@ def _write_enum_scalar_cast(
25752575
aexpr = self.import_name(BASE_IMPL, "AnnotatedExpr")
25762576
cast_op = self.import_name(BASE_IMPL, "CastOp")
25772577

2578+
py_to_gel_casts = self._get_scalar_py_to_gel_casts(stype)
2579+
if not py_to_gel_casts:
2580+
return
2581+
2582+
arg_name: str = "expr"
2583+
arg_types: list[str] = [expr_compat, *py_to_gel_casts.keys()]
2584+
25782585
with self._classmethod_def(
25792586
"cast",
2580-
[f"expr: {expr_compat}"],
2587+
[f"{arg_name}: {' | '.join(arg_types)}"],
25812588
type_self,
25822589
):
2590+
self.write()
2591+
self.write(f"match {arg_name}:")
2592+
with self.indented():
2593+
for py_type, gel_cast in py_to_gel_casts.items():
2594+
self.write(f"case {py_type}():")
2595+
with self.indented():
2596+
cast_text: str
2597+
if isinstance(gel_cast, str):
2598+
cast_text = f"{gel_cast}({arg_name})"
2599+
else:
2600+
cast_text = gel_cast(arg_name)
2601+
self.write(f"{arg_name} = {cast_text}")
2602+
2603+
self.write()
25832604
self.write(f"return {aexpr}( # type: ignore [return-value]")
25842605
with self.indented():
25852606
self.write("cls,")
25862607
self.write(f"{cast_op}(")
25872608
with self.indented():
2588-
self.write("expr=expr,")
2609+
self.write(f"expr={arg_name},")
25892610
self.write("type_=cls.__gel_reflection__.type_name,")
25902611
self.write(")")
25912612
self.write(")")
@@ -2818,25 +2839,38 @@ def _write_regular_scalar_cast(
28182839
self_ = self.import_name("typing_extensions", "Self")
28192840
type_self = f"{type_}[{self_}]"
28202841

2821-
if signature_only:
2822-
self.write()
2823-
with self._classmethod_def(
2824-
"cast",
2825-
[f"expr: {expr_compat}"],
2826-
type_self,
2827-
):
2842+
py_to_gel_casts = self._get_scalar_py_to_gel_casts(stype)
2843+
if not py_to_gel_casts:
2844+
return
2845+
2846+
arg_name: str = "expr"
2847+
arg_types: list[str] = [expr_compat, *py_to_gel_casts.keys()]
2848+
2849+
with self._classmethod_def(
2850+
"cast",
2851+
[f"{arg_name}: {' | '.join(arg_types)}"],
2852+
type_self,
2853+
):
2854+
if signature_only:
28282855
self.write("...")
28292856

2830-
else:
2831-
aexpr = self.import_name(BASE_IMPL, "AnnotatedExpr")
2832-
cast_op = self.import_name(BASE_IMPL, "CastOp")
2857+
else:
2858+
aexpr = self.import_name(BASE_IMPL, "AnnotatedExpr")
2859+
cast_op = self.import_name(BASE_IMPL, "CastOp")
28332860

2834-
self.write()
2835-
with self._classmethod_def(
2836-
"cast",
2837-
[f"expr: {expr_compat}"],
2838-
type_self,
2839-
):
2861+
self.write(f"match {arg_name}:")
2862+
with self.indented():
2863+
for py_type, gel_cast in py_to_gel_casts.items():
2864+
self.write(f"case {py_type}():")
2865+
with self.indented():
2866+
cast_text: str
2867+
if isinstance(gel_cast, str):
2868+
cast_text = f"{gel_cast}({arg_name})"
2869+
else:
2870+
cast_text = gel_cast(arg_name)
2871+
self.write(f"{arg_name} = {cast_text}")
2872+
2873+
self.write()
28402874
self.write(f"return {aexpr}( # type: ignore [return-value]")
28412875
with self.indented():
28422876
self.write("cls,")
@@ -2847,6 +2881,98 @@ def _write_regular_scalar_cast(
28472881
self.write(")")
28482882
self.write(")")
28492883

2884+
def _get_scalar_py_to_gel_casts(
2885+
self,
2886+
stype: reflection.ScalarType,
2887+
) -> dict[str, str | Callable[[str], str]] | None:
2888+
if not (explicit_casts := self._casts.explicit_casts_to.get(stype.id)):
2889+
return None
2890+
2891+
py_to_gel_casts: dict[str, str | Callable[[str], str]] = {}
2892+
2893+
# Determine if the result type can be directly cast from a literal
2894+
direct_py_type_name: tuple[str, str] | None = None
2895+
2896+
if py_type_names := _qbmodel.get_py_type_for_scalar(
2897+
stype.name,
2898+
consider_generic=False,
2899+
):
2900+
# with consider_generic=False, there should be 1 value
2901+
direct_py_type_name = py_type_names[0]
2902+
if literal_name := _qbmodel.get_literal_name_for_py_type(
2903+
direct_py_type_name
2904+
):
2905+
py_type = self.import_name(*direct_py_type_name)
2906+
literal = self.import_name(BASE_IMPL, literal_name)
2907+
2908+
py_to_gel_casts[py_type] = lambda x: (
2909+
f"{literal}("
2910+
f"val={x},"
2911+
f"type_=cls.__gel_reflection__.type_name,"
2912+
f")"
2913+
)
2914+
2915+
# Determine what python types can converted to a gel type before cast
2916+
2917+
# Get the gel types that can be cast to result type
2918+
scalar_arg_types = [
2919+
arg_type
2920+
for arg_type_id in explicit_casts
2921+
if (arg_type := self._types.get(arg_type_id))
2922+
if reflection.is_scalar_type(arg_type)
2923+
if arg_type.schemapath not in GENERIC_TYPES
2924+
]
2925+
2926+
# Find the python types associated with the gel types
2927+
py_to_scalar_types: dict[
2928+
tuple[str, str], list[reflection.ScalarType]
2929+
] = {}
2930+
for scalar_arg_type in scalar_arg_types:
2931+
if py_type_names := _qbmodel.get_py_type_for_scalar(
2932+
scalar_arg_type.name,
2933+
consider_generic=False,
2934+
):
2935+
# with consider_generic=False, there should be 1 value
2936+
py_type_name = py_type_names[0]
2937+
2938+
if py_type_name == direct_py_type_name:
2939+
# Skip the directly converted type
2940+
continue
2941+
2942+
if py_type_name not in py_to_scalar_types:
2943+
py_to_scalar_types[py_type_name] = []
2944+
2945+
py_to_scalar_types[py_type_name].append(scalar_arg_type)
2946+
2947+
# Pick the best gel type to convert to
2948+
for py_type_name, scalar_types in py_to_scalar_types.items():
2949+
py_type = self.import_name(*py_type_name)
2950+
2951+
scalars_with_rank: list[tuple[reflection.ScalarType, int]] = []
2952+
for scalar_type in scalar_types:
2953+
rank = _qbmodel.get_py_type_scalar_match_rank(
2954+
py_type_name, scalar_type.name
2955+
)
2956+
if rank is None:
2957+
continue
2958+
scalars_with_rank.append((scalar_type, rank))
2959+
2960+
if not scalars_with_rank:
2961+
# This can happen for scalars which don't convert to simple
2962+
# python primitives. eg. ext::pgvector::halfvec
2963+
continue
2964+
2965+
best_scalar_type = min(
2966+
scalars_with_rank, key=operator.itemgetter(1)
2967+
)[0]
2968+
gel_type = self.get_type(
2969+
best_scalar_type, import_time=ImportTime.typecheck_runtime
2970+
)
2971+
2972+
py_to_gel_casts[py_type] = gel_type
2973+
2974+
return py_to_gel_casts
2975+
28502976
def render_callable_return_type(
28512977
self,
28522978
tp: reflection.Type,

gel/_internal/_qb/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919

2020
from ._expressions import (
21+
BigIntLiteral,
2122
BinaryOp,
2223
BoolLiteral,
2324
BytesLiteral,
@@ -112,6 +113,7 @@
112113
"AnnotatedPath",
113114
"AnnotatedVar",
114115
"BaseAlias",
116+
"BigIntLiteral",
115117
"BinaryOp",
116118
"BoolLiteral",
117119
"BytesLiteral",

gel/_internal/_qb/_protocols.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,12 @@ def edgeql(
121121
return value
122122

123123

124+
ExprPrimitive = TypeAliasType(
125+
"ExprPrimitive",
126+
int,
127+
)
128+
129+
124130
def edgeql_qb_expr(
125131
x: ExprCompatible | ExprClosure,
126132
*,

gel/_internal/_qbmodel/_abstract/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
Tuple,
9696
TupleMeta,
9797
UUIDImpl,
98+
get_literal_name_for_py_type,
9899
get_py_type_from_gel_type,
99100
get_base_scalars_backed_by_py_type,
100101
get_overlapping_py_types,
@@ -177,6 +178,7 @@
177178
"empty_set_if_none",
178179
"field_descriptor",
179180
"get_base_scalars_backed_by_py_type",
181+
"get_literal_name_for_py_type",
180182
"get_overlapping_py_types",
181183
"get_proxy_linkprops",
182184
"get_py_base_for_scalar",

gel/_internal/_qbmodel/_abstract/_primitive.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,15 @@ def get_py_type_scalar_match_rank(
817817
decimal.Decimal: _qb.DecimalLiteral,
818818
}
819819

820+
_py_type_name_to_literal_name: dict[tuple[str, str], str] = {
821+
("builtins", "bool"): "BoolLiteral",
822+
("builtins", "int"): "IntLiteral",
823+
("builtins", "float"): "FloatLiteral",
824+
("builtins", "str"): "StringLiteral",
825+
("builtins", "bytes"): "BytesLiteral",
826+
("decimal", "Decimal"): "DecimalLiteral",
827+
}
828+
820829

821830
_PT_co = TypeVar("_PT_co", bound=PyConstType, covariant=True)
822831
_ST = TypeVar("_ST", bound=GelScalarType, default=GelScalarType)
@@ -850,6 +859,12 @@ def get_literal_for_scalar(
850859
)
851860

852861

862+
def get_literal_name_for_py_type(
863+
py_type_name: tuple[str, str],
864+
) -> str | None:
865+
return _py_type_name_to_literal_name.get(py_type_name)
866+
867+
853868
class PyTypeScalar(
854869
_typing_parametric.ParametricType,
855870
GelScalarType,

gel/models/pydantic.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,14 @@
2222
OPERAND_IS_ALIAS,
2323
AnnotatedExpr,
2424
BaseAlias,
25+
BigIntLiteral,
26+
BoolLiteral,
27+
BytesLiteral,
2528
CastOp,
29+
DecimalLiteral,
2630
EmptyDirection,
2731
Direction,
32+
FloatLiteral,
2833
GelLinkMetadata,
2934
GelObjectTypeMetadata,
3035
GelPointerReflection,
@@ -34,10 +39,12 @@
3439
ExprCompatible,
3540
IndexOp,
3641
InfixOp,
42+
IntLiteral,
3743
FuncCall,
3844
ObjectWhenType,
3945
PathAlias,
4046
SetLiteral,
47+
StringLiteral,
4148
UnaryOp,
4249
construct_infix_op_chain,
4350
)
@@ -117,6 +124,9 @@
117124
"Array",
118125
"ArrayMeta",
119126
"BaseAlias",
127+
"BigIntLiteral",
128+
"BoolLiteral",
129+
"BytesLiteral",
120130
"Cardinality",
121131
"CastOp",
122132
"ComputedLink",
@@ -128,12 +138,14 @@
128138
"DateImpl",
129139
"DateTimeImpl",
130140
"DateTimeLike",
141+
"DecimalLiteral",
131142
"DefaultValue",
132143
"DeferredImport",
133144
"Direction",
134145
"EmptyDirection",
135146
"ExprClosure",
136147
"ExprCompatible",
148+
"FloatLiteral",
137149
"FuncCall",
138150
"GelLinkMetadata",
139151
"GelLinkModel",
@@ -151,6 +163,7 @@
151163
"IdProperty",
152164
"IndexOp",
153165
"InfixOp",
166+
"IntLiteral",
154167
"JSONImpl",
155168
"LazyClassProperty",
156169
"LinkClassNamespace",
@@ -183,6 +196,7 @@
183196
"RequiredMultiLinkWithProps",
184197
"SchemaPath",
185198
"SetLiteral",
199+
"StringLiteral",
186200
"TimeDeltaImpl",
187201
"TimeImpl",
188202
"Tuple",

tests/test_qb.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1356,6 +1356,14 @@ def test_qb_cast_scalar_01(self):
13561356
result = self.client.get(std.str.cast(std.int64(1)))
13571357
self.assertEqual(result, "1")
13581358

1359+
# python scalar to scalar
1360+
from models.orm_qb import std
1361+
1362+
result = self.client.get(std.str.cast(1))
1363+
self.assertEqual(result, "1")
1364+
result = self.client.get(std.str.cast("1"))
1365+
self.assertEqual(result, "1")
1366+
13591367
def test_qb_cast_scalar_02(self):
13601368
# enum to scalar
13611369
from models.orm_qb import default, std
@@ -1370,6 +1378,10 @@ def test_qb_cast_scalar_03(self):
13701378
result = self.client.get(default.Color.cast(std.str("Red")))
13711379
self.assertEqual(result, default.Color.Red)
13721380

1381+
# python scalar to enum
1382+
result = self.client.get(default.Color.cast("Red"))
1383+
self.assertEqual(result, default.Color.Red)
1384+
13731385
def test_qb_cast_array_01(self):
13741386
# array[scalar] to array[scalar]
13751387
from models.orm_qb import std

0 commit comments

Comments
 (0)