Skip to content

Commit e7c0c52

Browse files
committed
refactoring:
- use context vars for inner calls (skip extra kwargs passing) - define a method to prepare db item id
1 parent 27c39e6 commit e7c0c52

2 files changed

Lines changed: 58 additions & 71 deletions

File tree

fastapi_jsonapi/views/view_base.py

Lines changed: 42 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import logging
2-
from functools import partial
2+
from contextvars import ContextVar
33
from typing import (
4-
Callable,
54
Dict,
65
Iterable,
76
List,
@@ -30,19 +29,37 @@
3029
logger = logging.getLogger(__name__)
3130

3231

32+
previous_resource_type_ctx_var: ContextVar[str] = ContextVar("previous_resource_type_ctx_var")
33+
related_field_name_ctx_var: ContextVar[str] = ContextVar("related_field_name_ctx_var")
34+
relationships_schema_ctx_var: ContextVar[Type[BaseModel]] = ContextVar("relationships_schema_ctx_var")
35+
object_schema_ctx_var: ContextVar[Type[JSONAPIObjectSchema]] = ContextVar("object_schema_ctx_var")
36+
included_object_schema_ctx_var: ContextVar[Type[TypeSchema]] = ContextVar("included_object_schema_ctx_var")
37+
relationship_info_ctx_var: ContextVar[RelationshipInfo] = ContextVar("relationship_info_ctx_var")
38+
39+
3340
class ViewBase:
3441
def __init__(self, jsonapi: RoutersJSONAPI, **options):
3542
self.jsonapi = jsonapi
3643
self.options = options
3744

45+
@classmethod
46+
def get_db_item_id(cls, item_from_db: TypeModel):
47+
"""
48+
TODO: check if id is None? raise?
49+
TODO: any another conversion for id to string?
50+
:param item_from_db:
51+
:return:
52+
"""
53+
return str(item_from_db.id)
54+
3855
@classmethod
3956
def prepare_related_object_data(
4057
cls,
4158
item_from_db: TypeModel,
42-
included_object_schema: Type[TypeSchema],
43-
relationship_info: RelationshipInfo,
4459
) -> Tuple[Dict[str, Union[str, int]], Optional[TypeSchema]]:
45-
item_id = str(item_from_db.id)
60+
included_object_schema: Type[TypeSchema] = included_object_schema_ctx_var.get()
61+
relationship_info: RelationshipInfo = relationship_info_ctx_var.get()
62+
item_id = cls.get_db_item_id(item_from_db)
4663
data_for_relationship = {"id": item_id}
4764
processed_object = included_object_schema(
4865
id=item_id,
@@ -56,20 +73,12 @@ def prepare_related_object_data(
5673
def prepare_data_for_relationship(
5774
cls,
5875
related_db_item: Union[List[TypeModel], TypeModel],
59-
relationship_info: RelationshipInfo,
60-
included_object_schema: Type[TypeSchema],
6176
) -> Tuple[Optional[Dict[str, Union[str, int]]], List[TypeSchema]]:
62-
prepare_related_db_item = partial(
63-
cls.prepare_related_object_data,
64-
included_object_schema=included_object_schema,
65-
relationship_info=relationship_info,
66-
)
67-
6877
included_objects = []
6978
if isinstance(related_db_item, Iterable):
7079
data_for_relationship = []
7180
for included_item in related_db_item:
72-
relation_data, processed_object = prepare_related_db_item(
81+
relation_data, processed_object = cls.prepare_related_object_data(
7382
item_from_db=included_item,
7483
)
7584
data_for_relationship.append(relation_data)
@@ -79,7 +88,7 @@ def prepare_data_for_relationship(
7988
if related_db_item is None:
8089
return None, included_objects
8190

82-
data_for_relationship, processed_object = prepare_related_db_item(
91+
data_for_relationship, processed_object = cls.prepare_related_object_data(
8392
item_from_db=related_db_item,
8493
)
8594
if processed_object:
@@ -90,12 +99,13 @@ def prepare_data_for_relationship(
9099
def update_related_object(
91100
cls,
92101
relationship_data: Union[Dict[str, str], List[Dict[str, str]]],
93-
relationships_schema: Type[BaseModel],
94-
object_schema: Type[JSONAPIObjectSchema],
95102
included_objects: Dict[Tuple[str, str], TypeSchema],
96103
cache_key: Tuple[str, str],
97104
related_field_name: str,
98105
):
106+
relationships_schema: Type[BaseModel] = relationships_schema_ctx_var.get()
107+
object_schema: Type[JSONAPIObjectSchema] = object_schema_ctx_var.get()
108+
99109
relationship_data_schema = get_related_schema(relationships_schema, related_field_name)
100110
parent_included_object = included_objects.get(cache_key)
101111
new_relationships = {}
@@ -132,15 +142,13 @@ def update_known_included(
132142
def process_single_db_item_and_prepare_includes(
133143
cls,
134144
parent_db_item: TypeModel,
135-
previous_resource_type: str,
136-
related_field_name: str,
137145
included_objects: Dict[Tuple[str, str], TypeSchema],
138-
relationships_schema: Type[BaseModel],
139-
object_schema: Type[JSONAPIObjectSchema],
140-
process_db_item: Callable,
141146
):
147+
previous_resource_type: str = previous_resource_type_ctx_var.get()
148+
related_field_name: str = related_field_name_ctx_var.get()
149+
142150
next_current_db_item = []
143-
cache_key = (str(parent_db_item.id), previous_resource_type)
151+
cache_key = (cls.get_db_item_id(parent_db_item), previous_resource_type)
144152
current_db_item = getattr(parent_db_item, related_field_name)
145153
current_is_single = False
146154
if not isinstance(current_db_item, Iterable):
@@ -151,7 +159,7 @@ def process_single_db_item_and_prepare_includes(
151159

152160
for db_item in current_db_item:
153161
next_current_db_item.append(db_item)
154-
data_for_relationship, new_included = process_db_item(
162+
data_for_relationship, new_included = cls.prepare_data_for_relationship(
155163
related_db_item=db_item,
156164
)
157165

@@ -168,8 +176,6 @@ def process_single_db_item_and_prepare_includes(
168176

169177
cls.update_related_object(
170178
relationship_data=relationship_data_items,
171-
relationships_schema=relationships_schema,
172-
object_schema=object_schema,
173179
included_objects=included_objects,
174180
cache_key=cache_key,
175181
related_field_name=related_field_name,
@@ -181,31 +187,14 @@ def process_single_db_item_and_prepare_includes(
181187
def process_db_items_and_prepare_includes(
182188
cls,
183189
parent_db_items: List[TypeModel],
184-
previous_resource_type: str,
185-
related_field_name: str,
186-
relationship_info: RelationshipInfo,
187-
included_object_schema: Type[JSONAPIObjectSchema],
188190
included_objects: Dict[Tuple[str, str], TypeSchema],
189-
relationships_schema: Type[BaseModel],
190-
object_schema: Type[JSONAPIObjectSchema],
191191
):
192-
process_db_item = partial(
193-
cls.prepare_data_for_relationship,
194-
relationship_info=relationship_info,
195-
included_object_schema=included_object_schema,
196-
)
197-
198192
next_current_db_item = []
199193

200194
for parent_db_item in parent_db_items:
201195
new_next_items = cls.process_single_db_item_and_prepare_includes(
202196
parent_db_item=parent_db_item,
203-
previous_resource_type=previous_resource_type,
204-
related_field_name=related_field_name,
205197
included_objects=included_objects,
206-
relationships_schema=relationships_schema,
207-
object_schema=object_schema,
208-
process_db_item=process_db_item,
209198
)
210199
next_current_db_item.extend(new_next_items)
211200
return next_current_db_item
@@ -222,13 +211,8 @@ def process_include_with_nested(
222211
root_item_key: item_as_schema,
223212
}
224213
previous_resource_type = item_as_schema.type
225-
for related_field_name in include.split(SPLIT_REL):
226-
# TODO: right now if you want to do:
227-
# `include=user.posts.comments.author`,
228-
# you actually have to do:
229-
# `include=user,user.posts,user.posts.comments,user.posts.comments.author`
230-
# (pass all levels to the nested)
231214

215+
for related_field_name in include.split(SPLIT_REL):
232216
object_schemas = self.jsonapi.create_jsonapi_object_schemas(
233217
schema=current_relation_schema,
234218
includes=[related_field_name],
@@ -247,15 +231,17 @@ def process_include_with_nested(
247231
# xxx: less if/else
248232
current_db_item = [current_db_item]
249233

234+
# ctx vars to skip multi-level args passing
235+
relationships_schema_ctx_var.set(relationships_schema)
236+
object_schema_ctx_var.set(object_schemas.object_jsonapi_schema)
237+
previous_resource_type_ctx_var.set(previous_resource_type)
238+
related_field_name_ctx_var.set(related_field_name)
239+
relationship_info_ctx_var.set(relationship_info)
240+
included_object_schema_ctx_var.set(included_object_schema)
241+
250242
current_db_item = self.process_db_items_and_prepare_includes(
251243
parent_db_items=current_db_item,
252-
previous_resource_type=previous_resource_type,
253-
related_field_name=related_field_name,
254-
relationship_info=relationship_info,
255-
included_object_schema=included_object_schema,
256244
included_objects=included_objects,
257-
relationships_schema=relationships_schema,
258-
object_schema=object_schemas.object_jsonapi_schema,
259245
)
260246

261247
previous_resource_type = relationship_info.resource_type
@@ -272,7 +258,7 @@ def process_db_object(
272258
included_objects = []
273259

274260
item_as_schema = object_schemas.object_jsonapi_schema(
275-
id=str(item.id), # TODO: error if None?
261+
id=self.get_db_item_id(item),
276262
attributes=object_schemas.attributes_schema.from_orm(item),
277263
)
278264

tests/test_api/test_api_sqla_with_includes.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from fastapi_jsonapi.schema_base import BaseModel, Field, RelationshipInfo
1919
from fastapi_jsonapi.views.detail_view import DetailViewBase
2020
from fastapi_jsonapi.views.list_view import ListViewBase
21+
from fastapi_jsonapi.views.view_base import ViewBase
2122

2223
pytestmark = pytest.mark.asyncio
2324

@@ -802,7 +803,7 @@ async def test_get_users(client: AsyncClient, user_1: User, user_2: User):
802803
users = [user_1, user_2]
803804
assert len(users_data) == len(users)
804805
for user_data, user in zip(users_data, users):
805-
assert user_data["id"] == str(user.id)
806+
assert user_data["id"] == ViewBase.get_db_item_id(user)
806807
assert user_data["type"] == "user"
807808

808809

@@ -816,11 +817,11 @@ async def test_get_user_with_bio_relation(
816817
assert response.status_code == status.HTTP_200_OK
817818
response_data = response.json()
818819
assert "data" in response_data, response_data
819-
assert response_data["data"]["id"] == str(user_1.id)
820+
assert response_data["data"]["id"] == ViewBase.get_db_item_id(user_1)
820821
assert response_data["data"]["type"] == "user"
821822
assert "included" in response_data, response_data
822823
included_bio = response_data["included"][0]
823-
assert included_bio["id"] == str(user_1_bio.id)
824+
assert included_bio["id"] == ViewBase.get_db_item_id(user_1_bio)
824825
assert included_bio["type"] == "user_bio"
825826

826827

@@ -839,12 +840,12 @@ async def test_get_users_with_bio_relation(
839840
users = [user_1, user_2]
840841
assert len(users_data) == len(users)
841842
for user_data, user in zip(users_data, users):
842-
assert user_data["id"] == str(user.id)
843+
assert user_data["id"] == ViewBase.get_db_item_id(user)
843844
assert user_data["type"] == "user"
844845

845846
assert "included" in response_data, response_data
846847
included_bio = response_data["included"][0]
847-
assert included_bio["id"] == str(user_1_bio.id)
848+
assert included_bio["id"] == ViewBase.get_db_item_id(user_1_bio)
848849
assert included_bio["type"] == "user_bio"
849850

850851

@@ -872,11 +873,11 @@ async def test_get_posts_with_users(
872873
included_users = response_data["included"]
873874
assert len(included_users) == len(users)
874875
for user_data, user in zip(included_users, users):
875-
assert user_data["id"] == str(user.id)
876+
assert user_data["id"] == ViewBase.get_db_item_id(user)
876877
assert user_data["type"] == "user"
877878

878879
for post_data, post in zip(posts_data, posts):
879-
assert post_data["id"] == str(post.id)
880+
assert post_data["id"] == ViewBase.get_db_item_id(post)
880881
assert post_data["type"] == "post"
881882

882883
all_posts_data = list(posts_data)
@@ -892,7 +893,7 @@ async def test_get_posts_with_users(
892893
idx_start = next_idx
893894

894895
u1_relation = {
895-
"id": str(user.id),
896+
"id": ViewBase.get_db_item_id(user),
896897
"type": "user",
897898
}
898899
for post_data in posts_data:
@@ -937,7 +938,7 @@ async def test_get_users_with_all_inner_relations(
937938
users_data,
938939
[(user_1, user_1_posts, user_1_bio), (user_2, user_2_posts, None)],
939940
):
940-
assert user_data["id"] == str(user.id)
941+
assert user_data["id"] == ViewBase.get_db_item_id(user)
941942
assert user_data["type"] == "user"
942943
user_relationships = user_data["relationships"]
943944
posts_relation = user_relationships["posts"]["data"]
@@ -952,7 +953,7 @@ async def test_get_users_with_all_inner_relations(
952953
continue
953954

954955
assert bio_relation == {
955-
"id": str(user_1_bio.id),
956+
"id": ViewBase.get_db_item_id(user_1_bio),
956957
"type": "user_bio",
957958
}
958959

@@ -962,21 +963,21 @@ async def test_get_users_with_all_inner_relations(
962963
(user_2_posts, user_1_comments_for_u2_posts, user_1),
963964
]:
964965
for post, post_comment in zip(posts, comments):
965-
post_data = included_data[("post", str(post.id))]
966+
post_data = included_data[("post", ViewBase.get_db_item_id(post))]
966967
post_relationships = post_data["relationships"]
967968
assert "comments" in post_relationships
968969
post_comments_relation = post_relationships["comments"]["data"]
969970
post_comments = [post_comment]
970971
assert len(post_comments_relation) == len(post_comments)
971972
for comment_relation_data, comment in zip(post_comments_relation, post_comments):
972973
assert comment_relation_data == {
973-
"id": str(comment.id),
974+
"id": ViewBase.get_db_item_id(comment),
974975
"type": "post_comment",
975976
}
976977

977-
comment_data = included_data[("post_comment", str(comment.id))]
978+
comment_data = included_data[("post_comment", ViewBase.get_db_item_id(comment))]
978979
assert comment_data["relationships"]["author"]["data"] == {
979-
"id": str(comment_author.id),
980+
"id": ViewBase.get_db_item_id(comment_author),
980981
"type": "user",
981982
}
982-
assert ("user", str(comment_author.id)) in included_data
983+
assert ("user", ViewBase.get_db_item_id(comment_author)) in included_data

0 commit comments

Comments
 (0)