33from typing import TYPE_CHECKING , Any , Iterable , List , Optional , Tuple , Type
44
55from sqlalchemy import delete , func , select
6- from sqlalchemy .exc import DatabaseError , DBAPIError , NoResultFound
6+ from sqlalchemy .exc import DBAPIError , NoResultFound
77from sqlalchemy .ext .asyncio import AsyncSession
88from sqlalchemy .inspection import inspect
99from sqlalchemy .orm import joinedload , selectinload
3838 from pydantic import BaseModel as PydanticBaseModel
3939 from sqlalchemy .sql import Select
4040
41-
4241log = logging .getLogger (__name__ )
4342
4443
@@ -56,6 +55,7 @@ def __init__(
5655 url_id_field : str = "id" ,
5756 eagerload_includes : bool = True ,
5857 query : Optional ["Select" ] = None ,
58+ auto_convert_id_to_column_type : bool = True ,
5959 ** kwargs : Any ,
6060 ):
6161 """
@@ -85,6 +85,25 @@ def __init__(
8585 self .session = session
8686 self .eagerload_includes_ = eagerload_includes
8787 self ._query = query
88+ self .auto_convert_id_to_column_type = auto_convert_id_to_column_type
89+
90+ def prepare_id_value (self , col : InstrumentedAttribute , value : Any ) -> Any :
91+ """
92+ Convert value to the required python type.
93+ Type is declared on the SQLA column.
94+
95+ :param col:
96+ :param value:
97+ :return:
98+ """
99+ if not self .auto_convert_id_to_column_type :
100+ return value
101+
102+ py_type = col .type .python_type
103+ if not isinstance (value , py_type ):
104+ value = py_type (value )
105+
106+ return value
88107
89108 async def apply_relationships (self , obj : TypeModel , data_create : BaseJSONAPIItemInSchema ) -> None :
90109 """
@@ -180,11 +199,12 @@ async def create_object(self, data_create: BaseJSONAPIItemInSchema, view_kwargs:
180199 self .session .add (obj )
181200 try :
182201 await self .session .commit ()
183- except DatabaseError :
202+ except DBAPIError :
184203 log .exception ("Could not create object with data create %s" , data_create )
185204 msg = "Object creation error"
186205 raise HTTPException (msg , pointer = "/data" )
187206 except Exception as e :
207+ log .exception ("Error creating object with data create %s" , data_create )
188208 await self .session .rollback ()
189209 msg = f"Object creation error: { e } "
190210 raise HTTPException (msg , pointer = "/data" )
@@ -475,7 +495,9 @@ async def get_related_object(
475495 :param id_value: related object id value
476496 :return: a related SQLA ORM object
477497 """
478- stmt = select (related_model ).where (getattr (related_model , related_id_field ) == id_value )
498+ id_field = getattr (related_model , related_id_field )
499+ id_value = self .prepare_id_value (id_field , id_value )
500+ stmt = select (related_model ).where (id_field == id_value )
479501 try :
480502 related_object = (await self .session .execute (stmt )).scalar_one ()
481503 except NoResultFound :
@@ -497,16 +519,16 @@ async def get_related_objects_list(
497519 :param ids:
498520 :return:
499521 """
500- stmt = select (related_model ).where (getattr (related_model , related_id_field ).in_ (ids ))
522+ id_field = getattr (related_model , related_id_field )
523+ ids = [self .prepare_id_value (id_field , _id ) for _id in ids ]
524+ stmt = select (related_model ).where (id_field .in_ (ids ))
501525
502526 related_objects = (await self .session .execute (stmt )).scalars ().all ()
503527 object_ids = [getattr (obj , related_id_field ) for obj in related_objects ]
504528
505529 not_found_ids = ids
506530 if object_ids :
507- obj_type = type (object_ids [0 ])
508- ids = {obj_type (_id ) for _id in ids }
509- not_found_ids = ids .difference (object_ids )
531+ not_found_ids = set (ids ).difference (object_ids )
510532
511533 if not_found_ids :
512534 msg = f"Objects for { related_model .__name__ } with ids: { not_found_ids } not found"
@@ -615,9 +637,8 @@ def retrieve_object_query(
615637 :param filter_value: the value to filter with
616638 :return sqlalchemy query: a query from sqlalchemy
617639 """
618- query : "Select" = self .query (view_kwargs )
619- # noinspection PyNoneFunctionAssignment,PyTypeChecker
620- query : "Select" = query .where (filter_field == filter_value )
640+ value = self .prepare_id_value (filter_field , filter_value )
641+ query : "Select" = self .query (view_kwargs ).where (filter_field == value )
621642 return query
622643
623644 def query (self , view_kwargs : dict ) -> "Select" :
@@ -637,7 +658,9 @@ async def before_create_object(self, model_kwargs: dict, view_kwargs: dict):
637658 :param model_kwargs: the data validated by pydantic.
638659 :param view_kwargs: kwargs from the resource view.
639660 """
640- pass
661+ if (id_value := model_kwargs .get ("id" )) and self .auto_convert_id_to_column_type :
662+ model_field = self .get_object_id_field ()
663+ model_kwargs .update (id = self .prepare_id_value (model_field , id_value ))
641664
642665 async def after_create_object (self , obj : TypeModel , model_kwargs : dict , view_kwargs : dict ):
643666 """
0 commit comments