Skip to content

Commit b261518

Browse files
committed
auto convert id to column type
1 parent 01afd7d commit b261518

1 file changed

Lines changed: 35 additions & 12 deletions

File tree

fastapi_jsonapi/data_layers/sqla_orm.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Type
44

55
from sqlalchemy import delete, func, select
6-
from sqlalchemy.exc import DatabaseError, DBAPIError, NoResultFound
6+
from sqlalchemy.exc import DBAPIError, NoResultFound
77
from sqlalchemy.ext.asyncio import AsyncSession
88
from sqlalchemy.inspection import inspect
99
from sqlalchemy.orm import joinedload, selectinload
@@ -38,7 +38,6 @@
3838
from pydantic import BaseModel as PydanticBaseModel
3939
from sqlalchemy.sql import Select
4040

41-
4241
log = logging.getLogger(__name__)
4342

4443

@@ -56,6 +55,7 @@ def __init__(
5655
url_id_field: str = "id",
5756
eagerload_includes: bool = True,
5857
query: Optional["Select"] = None,
58+
auto_convert_id_to_column_type: bool = True,
5959
**kwargs: Any,
6060
):
6161
"""
@@ -85,6 +85,25 @@ def __init__(
8585
self.session = session
8686
self.eagerload_includes_ = eagerload_includes
8787
self._query = query
88+
self.auto_convert_id_to_column_type = auto_convert_id_to_column_type
89+
90+
def prepare_id_value(self, col: InstrumentedAttribute, value: Any) -> Any:
91+
"""
92+
Convert value to the required python type.
93+
Type is declared on the SQLA column.
94+
95+
:param col:
96+
:param value:
97+
:return:
98+
"""
99+
if not self.auto_convert_id_to_column_type:
100+
return value
101+
102+
py_type = col.type.python_type
103+
if not isinstance(value, py_type):
104+
value = py_type(value)
105+
106+
return value
88107

89108
async def apply_relationships(self, obj: TypeModel, data_create: BaseJSONAPIItemInSchema) -> None:
90109
"""
@@ -180,11 +199,12 @@ async def create_object(self, data_create: BaseJSONAPIItemInSchema, view_kwargs:
180199
self.session.add(obj)
181200
try:
182201
await self.session.commit()
183-
except DatabaseError:
202+
except DBAPIError:
184203
log.exception("Could not create object with data create %s", data_create)
185204
msg = "Object creation error"
186205
raise HTTPException(msg, pointer="/data")
187206
except Exception as e:
207+
log.exception("Error creating object with data create %s", data_create)
188208
await self.session.rollback()
189209
msg = f"Object creation error: {e}"
190210
raise HTTPException(msg, pointer="/data")
@@ -475,7 +495,9 @@ async def get_related_object(
475495
:param id_value: related object id value
476496
:return: a related SQLA ORM object
477497
"""
478-
stmt = select(related_model).where(getattr(related_model, related_id_field) == id_value)
498+
id_field = getattr(related_model, related_id_field)
499+
id_value = self.prepare_id_value(id_field, id_value)
500+
stmt = select(related_model).where(id_field == id_value)
479501
try:
480502
related_object = (await self.session.execute(stmt)).scalar_one()
481503
except NoResultFound:
@@ -497,16 +519,16 @@ async def get_related_objects_list(
497519
:param ids:
498520
:return:
499521
"""
500-
stmt = select(related_model).where(getattr(related_model, related_id_field).in_(ids))
522+
id_field = getattr(related_model, related_id_field)
523+
ids = [self.prepare_id_value(id_field, _id) for _id in ids]
524+
stmt = select(related_model).where(id_field.in_(ids))
501525

502526
related_objects = (await self.session.execute(stmt)).scalars().all()
503527
object_ids = [getattr(obj, related_id_field) for obj in related_objects]
504528

505529
not_found_ids = ids
506530
if object_ids:
507-
obj_type = type(object_ids[0])
508-
ids = {obj_type(_id) for _id in ids}
509-
not_found_ids = ids.difference(object_ids)
531+
not_found_ids = set(ids).difference(object_ids)
510532

511533
if not_found_ids:
512534
msg = f"Objects for {related_model.__name__} with ids: {not_found_ids} not found"
@@ -615,9 +637,8 @@ def retrieve_object_query(
615637
:param filter_value: the value to filter with
616638
:return sqlalchemy query: a query from sqlalchemy
617639
"""
618-
query: "Select" = self.query(view_kwargs)
619-
# noinspection PyNoneFunctionAssignment,PyTypeChecker
620-
query: "Select" = query.where(filter_field == filter_value)
640+
value = self.prepare_id_value(filter_field, filter_value)
641+
query: "Select" = self.query(view_kwargs).where(filter_field == value)
621642
return query
622643

623644
def query(self, view_kwargs: dict) -> "Select":
@@ -637,7 +658,9 @@ async def before_create_object(self, model_kwargs: dict, view_kwargs: dict):
637658
:param model_kwargs: the data validated by pydantic.
638659
:param view_kwargs: kwargs from the resource view.
639660
"""
640-
pass
661+
if (id_value := model_kwargs.get("id")) and self.auto_convert_id_to_column_type:
662+
model_field = self.get_object_id_field()
663+
model_kwargs.update(id=self.prepare_id_value(model_field, id_value))
641664

642665
async def after_create_object(self, obj: TypeModel, model_kwargs: dict, view_kwargs: dict):
643666
"""

0 commit comments

Comments
 (0)