@@ -2575,17 +2575,38 @@ def _write_enum_scalar_cast(
25752575 aexpr = self .import_name (BASE_IMPL , "AnnotatedExpr" )
25762576 cast_op = self .import_name (BASE_IMPL , "CastOp" )
25772577
2578+ py_to_gel_casts = self ._get_scalar_py_to_gel_casts (stype )
2579+ if not py_to_gel_casts :
2580+ return
2581+
2582+ arg_name : str = "expr"
2583+ arg_types : list [str ] = [expr_compat , * py_to_gel_casts .keys ()]
2584+
25782585 with self ._classmethod_def (
25792586 "cast" ,
2580- [f"expr : { expr_compat } " ],
2587+ [f"{ arg_name } : { ' | ' . join ( arg_types ) } " ],
25812588 type_self ,
25822589 ):
2590+ self .write ()
2591+ self .write (f"match { arg_name } :" )
2592+ with self .indented ():
2593+ for py_type , gel_cast in py_to_gel_casts .items ():
2594+ self .write (f"case { py_type } ():" )
2595+ with self .indented ():
2596+ cast_text : str
2597+ if isinstance (gel_cast , str ):
2598+ cast_text = f"{ gel_cast } ({ arg_name } )"
2599+ else :
2600+ cast_text = gel_cast (arg_name )
2601+ self .write (f"{ arg_name } = { cast_text } " )
2602+
2603+ self .write ()
25832604 self .write (f"return { aexpr } ( # type: ignore [return-value]" )
25842605 with self .indented ():
25852606 self .write ("cls," )
25862607 self .write (f"{ cast_op } (" )
25872608 with self .indented ():
2588- self .write ("expr=expr ," )
2609+ self .write (f "expr={ arg_name } ," )
25892610 self .write ("type_=cls.__gel_reflection__.type_name," )
25902611 self .write (")" )
25912612 self .write (")" )
@@ -2818,25 +2839,38 @@ def _write_regular_scalar_cast(
28182839 self_ = self .import_name ("typing_extensions" , "Self" )
28192840 type_self = f"{ type_ } [{ self_ } ]"
28202841
2821- if signature_only :
2822- self .write ()
2823- with self ._classmethod_def (
2824- "cast" ,
2825- [f"expr: { expr_compat } " ],
2826- type_self ,
2827- ):
2842+ py_to_gel_casts = self ._get_scalar_py_to_gel_casts (stype )
2843+ if not py_to_gel_casts :
2844+ return
2845+
2846+ arg_name : str = "expr"
2847+ arg_types : list [str ] = [expr_compat , * py_to_gel_casts .keys ()]
2848+
2849+ with self ._classmethod_def (
2850+ "cast" ,
2851+ [f"{ arg_name } : { ' | ' .join (arg_types )} " ],
2852+ type_self ,
2853+ ):
2854+ if signature_only :
28282855 self .write ("..." )
28292856
2830- else :
2831- aexpr = self .import_name (BASE_IMPL , "AnnotatedExpr" )
2832- cast_op = self .import_name (BASE_IMPL , "CastOp" )
2857+ else :
2858+ aexpr = self .import_name (BASE_IMPL , "AnnotatedExpr" )
2859+ cast_op = self .import_name (BASE_IMPL , "CastOp" )
28332860
2834- self .write ()
2835- with self ._classmethod_def (
2836- "cast" ,
2837- [f"expr: { expr_compat } " ],
2838- type_self ,
2839- ):
2861+ self .write (f"match { arg_name } :" )
2862+ with self .indented ():
2863+ for py_type , gel_cast in py_to_gel_casts .items ():
2864+ self .write (f"case { py_type } ():" )
2865+ with self .indented ():
2866+ cast_text : str
2867+ if isinstance (gel_cast , str ):
2868+ cast_text = f"{ gel_cast } ({ arg_name } )"
2869+ else :
2870+ cast_text = gel_cast (arg_name )
2871+ self .write (f"{ arg_name } = { cast_text } " )
2872+
2873+ self .write ()
28402874 self .write (f"return { aexpr } ( # type: ignore [return-value]" )
28412875 with self .indented ():
28422876 self .write ("cls," )
@@ -2847,6 +2881,98 @@ def _write_regular_scalar_cast(
28472881 self .write (")" )
28482882 self .write (")" )
28492883
2884+ def _get_scalar_py_to_gel_casts (
2885+ self ,
2886+ stype : reflection .ScalarType ,
2887+ ) -> dict [str , str | Callable [[str ], str ]] | None :
2888+ if not (explicit_casts := self ._casts .explicit_casts_to .get (stype .id )):
2889+ return None
2890+
2891+ py_to_gel_casts : dict [str , str | Callable [[str ], str ]] = {}
2892+
2893+ # Determine if the result type can be directly cast from a literal
2894+ direct_py_type_name : tuple [str , str ] | None = None
2895+
2896+ if py_type_names := _qbmodel .get_py_type_for_scalar (
2897+ stype .name ,
2898+ consider_generic = False ,
2899+ ):
2900+ # with consider_generic=False, there should be 1 value
2901+ direct_py_type_name = py_type_names [0 ]
2902+ if literal_name := _qbmodel .get_literal_name_for_py_type (
2903+ direct_py_type_name
2904+ ):
2905+ py_type = self .import_name (* direct_py_type_name )
2906+ literal = self .import_name (BASE_IMPL , literal_name )
2907+
2908+ py_to_gel_casts [py_type ] = lambda x : (
2909+ f"{ literal } ("
2910+ f"val={ x } ,"
2911+ f"type_=cls.__gel_reflection__.type_name,"
2912+ f")"
2913+ )
2914+
2915+ # Determine what python types can converted to a gel type before cast
2916+
2917+ # Get the gel types that can be cast to result type
2918+ scalar_arg_types = [
2919+ arg_type
2920+ for arg_type_id in explicit_casts
2921+ if (arg_type := self ._types .get (arg_type_id ))
2922+ if reflection .is_scalar_type (arg_type )
2923+ if arg_type .schemapath not in GENERIC_TYPES
2924+ ]
2925+
2926+ # Find the python types associated with the gel types
2927+ py_to_scalar_types : dict [
2928+ tuple [str , str ], list [reflection .ScalarType ]
2929+ ] = {}
2930+ for scalar_arg_type in scalar_arg_types :
2931+ if py_type_names := _qbmodel .get_py_type_for_scalar (
2932+ scalar_arg_type .name ,
2933+ consider_generic = False ,
2934+ ):
2935+ # with consider_generic=False, there should be 1 value
2936+ py_type_name = py_type_names [0 ]
2937+
2938+ if py_type_name == direct_py_type_name :
2939+ # Skip the directly converted type
2940+ continue
2941+
2942+ if py_type_name not in py_to_scalar_types :
2943+ py_to_scalar_types [py_type_name ] = []
2944+
2945+ py_to_scalar_types [py_type_name ].append (scalar_arg_type )
2946+
2947+ # Pick the best gel type to convert to
2948+ for py_type_name , scalar_types in py_to_scalar_types .items ():
2949+ py_type = self .import_name (* py_type_name )
2950+
2951+ scalars_with_rank : list [tuple [reflection .ScalarType , int ]] = []
2952+ for scalar_type in scalar_types :
2953+ rank = _qbmodel .get_py_type_scalar_match_rank (
2954+ py_type_name , scalar_type .name
2955+ )
2956+ if rank is None :
2957+ continue
2958+ scalars_with_rank .append ((scalar_type , rank ))
2959+
2960+ if not scalars_with_rank :
2961+ # This can happen for scalars which don't convert to simple
2962+ # python primitives. eg. ext::pgvector::halfvec
2963+ continue
2964+
2965+ best_scalar_type = min (
2966+ scalars_with_rank , key = operator .itemgetter (1 )
2967+ )[0 ]
2968+ gel_type = self .get_type (
2969+ best_scalar_type , import_time = ImportTime .typecheck_runtime
2970+ )
2971+
2972+ py_to_gel_casts [py_type ] = gel_type
2973+
2974+ return py_to_gel_casts
2975+
28502976 def render_callable_return_type (
28512977 self ,
28522978 tp : reflection .Type ,
0 commit comments