33"""
44
55import ast
6- from ast import Call , Expr , Load , Name , Subscript , Tuple , keyword
6+ from ast import Call , Expr , Load , Name , Subscript , Tuple , expr , keyword
77from operator import attrgetter
88from typing import Optional , cast
99
@@ -82,13 +82,14 @@ def update_args_infer_typ_sqlalchemy(_param, args, name, nullable, x_typ_sql):
8282 return _param .get ("default" ) == cdd .shared .ast_utils .NoneStr , None
8383 elif _param ["typ" ].startswith ("Optional[" ):
8484 _param ["typ" ] = _param ["typ" ][len ("Optional[" ) : - 1 ]
85- nullable = True
85+ nullable : bool = True
8686 if "Literal[" in _param ["typ" ]:
8787 parsed_typ : Call = cast (
8888 Call , cdd .shared .ast_utils .get_value (ast .parse (_param ["typ" ]).body [0 ])
8989 )
90- if parsed_typ .value .id != "Literal" :
91- return nullable , parsed_typ .value
90+ assert parsed_typ .value .id == "Literal" , "Expected `Literal` got: {!r}" .format (
91+ parsed_typ .value .id
92+ )
9293 val = cdd .shared .ast_utils .get_value (parsed_typ .slice )
9394 (
9495 args .append (
@@ -112,7 +113,7 @@ def update_args_infer_typ_sqlalchemy(_param, args, name, nullable, x_typ_sql):
112113 else _update_args_infer_typ_sqlalchemy_for_scalar (_param , args , x_typ_sql )
113114 )
114115 elif _param ["typ" ].startswith ("List[" ):
115- after_generic = _param ["typ" ][len ("List[" ) :]
116+ after_generic : str = _param ["typ" ][len ("List[" ) :]
116117 if "struct" in after_generic : # "," in after_generic or
117118 name : Name = Name (id = "JSON" , ctx = Load (), lineno = None , col_offset = None )
118119 else :
@@ -175,42 +176,53 @@ def update_args_infer_typ_sqlalchemy(_param, args, name, nullable, x_typ_sql):
175176 )
176177 )
177178 elif _param .get ("typ" ).startswith ("Union[" ):
178- # Hack to remove the union type. Enum parse seems to be incorrect?
179- union_typ : Subscript = cast (Subscript , ast .parse (_param ["typ" ]).body [0 ])
180- assert isinstance (
181- union_typ .value , Subscript
182- ), "Expected `Subscript` got `{type_name}`" .format (
183- type_name = type (union_typ .value ).__name__
184- )
185- union_typ_tuple = (
186- union_typ .value .slice if PY_GTE_3_9 else union_typ .value .slice .value
187- )
188- assert isinstance (
189- union_typ_tuple , Tuple
190- ), "Expected `Tuple` got `{type_name}`" .format (
191- type_name = type (union_typ_tuple ).__name__
192- )
193- assert (
194- len (union_typ_tuple .elts ) == 2
195- ), "Expected length of 2 got `{tuple_len}`" .format (
196- tuple_len = len (union_typ_tuple .elts )
197- )
198- left , right = map (attrgetter ("id" ), union_typ_tuple .elts )
199- args .append (
200- Name (
201- (
202- cdd .sqlalchemy .utils .emit_utils .typ2column_type [right ]
203- if right in cdd .sqlalchemy .utils .emit_utils .typ2column_type
204- else cdd .sqlalchemy .utils .emit_utils .typ2column_type .get (left , left )
205- ),
206- Load (),
207- lineno = None ,
208- col_offset = None ,
209- )
210- )
179+ args .append (_handle_union_of_length_2 (_param ["typ" ]))
211180 else :
212181 _update_args_infer_typ_sqlalchemy_for_scalar (_param , args , x_typ_sql )
213182 return nullable , None
214183
215184
185+ def _handle_union_of_length_2 (typ ):
186+ """
187+ Internal function to turn `str` to `Name`
188+
189+ :param typ: `str` which evaluates to `ast.Subscript`
190+ :type typ: ```str```
191+
192+ :return: Parsed out name
193+ :rtype: ```Name```
194+ """
195+ # Hack to remove the union type. Enum parse seems to be incorrect?
196+ union_typ : Subscript = cast (Subscript , ast .parse (typ ).body [0 ])
197+ assert isinstance (
198+ union_typ .value , Subscript
199+ ), "Expected `Subscript` got `{type_name}`" .format (
200+ type_name = type (union_typ .value ).__name__
201+ )
202+ union_typ_tuple : expr = (
203+ union_typ .value .slice if PY_GTE_3_9 else union_typ .value .slice .value
204+ )
205+ assert isinstance (
206+ union_typ_tuple , Tuple
207+ ), "Expected `Tuple` got `{type_name}`" .format (
208+ type_name = type (union_typ_tuple ).__name__
209+ )
210+ assert (
211+ len (union_typ_tuple .elts ) == 2
212+ ), "Expected length of 2 got `{tuple_len}`" .format (
213+ tuple_len = len (union_typ_tuple .elts )
214+ )
215+ left , right = map (attrgetter ("id" ), union_typ_tuple .elts )
216+ return Name (
217+ (
218+ cdd .sqlalchemy .utils .emit_utils .typ2column_type [right ]
219+ if right in cdd .sqlalchemy .utils .emit_utils .typ2column_type
220+ else cdd .sqlalchemy .utils .emit_utils .typ2column_type .get (left , left )
221+ ),
222+ Load (),
223+ lineno = None ,
224+ col_offset = None ,
225+ )
226+
227+
216228__all__ = ["update_args_infer_typ_sqlalchemy" ]
0 commit comments