Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion fastapi_jsonapi/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(
pagination_default_offset: Optional[int] = None,
pagination_default_limit: Optional[int] = None,
methods: Iterable[str] = (),
max_cache_size: int = 0,
) -> None:
"""
Initialize router items.
Expand Down Expand Up @@ -127,7 +128,7 @@ def __init__(
self.pagination_default_number: Optional[int] = pagination_default_number
self.pagination_default_offset: Optional[int] = pagination_default_offset
self.pagination_default_limit: Optional[int] = pagination_default_limit
self.schema_builder = SchemaBuilder(resource_type=resource_type)
self.schema_builder = SchemaBuilder(resource_type=resource_type, max_cache_size=max_cache_size)

dto = self.schema_builder.create_schemas(
schema=schema,
Expand Down
47 changes: 45 additions & 2 deletions fastapi_jsonapi/schema_builder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""JSON API schemas builder class."""
from dataclasses import dataclass
from functools import lru_cache
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -122,8 +123,16 @@ class SchemaBuilder:
def __init__(
self,
resource_type: str,
max_cache_size: int = 0,
):
self._resource_type = resource_type
self._init_cache(max_cache_size)

def _init_cache(self, max_cache_size: int):
# TODO: remove crutch
self._get_info_from_schema_for_building_cached = lru_cache(maxsize=max_cache_size)(
self._get_info_from_schema_for_building_cached,
)

def _create_schemas_objects_list(self, schema: Type[BaseModel]) -> Type[JSONAPIResultListSchema]:
object_jsonapi_list_schema, list_jsonapi_schema = self.build_list_schemas(schema)
Expand Down Expand Up @@ -187,7 +196,7 @@ def build_schema_in(
) -> Tuple[Type[BaseJSONAPIDataInSchema], Type[BaseJSONAPIItemInSchema]]:
base_schema_name = schema_in.__name__.removesuffix("Schema") + schema_name_suffix

dto = self._get_info_from_schema_for_building(
dto = self._get_info_from_schema_for_building_wrapper(
base_name=base_schema_name,
schema=schema_in,
non_optional_relationships=non_optional_relationships,
Expand Down Expand Up @@ -258,6 +267,40 @@ def build_list_schemas(
includes=includes,
)

def _get_info_from_schema_for_building_cached(
self,
base_name: str,
schema: Type[BaseModel],
includes: Iterable[str],
non_optional_relationships: bool,
):
return self._get_info_from_schema_for_building(
base_name=base_name,
schema=schema,
includes=includes,
non_optional_relationships=non_optional_relationships,
)

def _get_info_from_schema_for_building_wrapper(
self,
base_name: str,
schema: Type[BaseModel],
includes: Iterable[str] = not_passed,
non_optional_relationships: bool = False,
):
"""
Wrapper function for return cached schema result
"""
if includes is not not_passed:
includes = tuple(includes)

return self._get_info_from_schema_for_building_cached(
base_name=base_name,
schema=schema,
includes=includes,
non_optional_relationships=non_optional_relationships,
)

def _get_info_from_schema_for_building(
self,
base_name: str,
Expand Down Expand Up @@ -494,7 +537,7 @@ def create_jsonapi_object_schemas(
if includes is not not_passed:
includes = set(includes)

dto = self._get_info_from_schema_for_building(
dto = self._get_info_from_schema_for_building_wrapper(
base_name=base_name,
schema=schema,
includes=includes,
Expand Down
7 changes: 6 additions & 1 deletion tests/fixtures/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,11 @@ def build_app_custom(
resource_type: str = "misc",
class_list: Type[ListViewBase] = ListViewBaseGeneric,
class_detail: Type[DetailViewBase] = DetailViewBaseGeneric,
max_cache_size: int = 0,
) -> FastAPI:
router: APIRouter = APIRouter()

RoutersJSONAPI(
jsonapi_routers = RoutersJSONAPI(
router=router,
path=path,
tags=["Misc"],
Expand All @@ -246,6 +247,7 @@ def build_app_custom(
schema_in_patch=schema_in_patch,
schema_in_post=schema_in_post,
model=model,
max_cache_size=max_cache_size,
)

app = build_app_plain()
Expand All @@ -254,6 +256,9 @@ def build_app_custom(
atomic = AtomicOperations()
app.include_router(atomic.router, prefix="")
init(app)

app.jsonapi_routers = jsonapi_routers

return app


Expand Down
220 changes: 220 additions & 0 deletions tests/test_api/test_api_sqla_with_includes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from itertools import chain, zip_longest
from json import dumps, loads
from typing import Dict, List, Literal, Set, Tuple
from unittest.mock import call, patch
from uuid import UUID, uuid4

import pytest
Expand All @@ -20,6 +21,7 @@
from starlette.datastructures import QueryParams

from fastapi_jsonapi.api import RoutersJSONAPI
from fastapi_jsonapi.schema_builder import SchemaBuilder
from fastapi_jsonapi.views.view_base import ViewBase
from tests.common import is_postgres_tests
from tests.fixtures.app import build_alphabet_app, build_app_custom
Expand Down Expand Up @@ -52,6 +54,8 @@
CustomUUIDItemAttributesSchema,
PostAttributesBaseSchema,
PostCommentAttributesBaseSchema,
PostCommentSchema,
PostSchema,
SelfRelationshipAttributesSchema,
SelfRelationshipSchema,
UserAttributesBaseSchema,
Expand Down Expand Up @@ -360,6 +364,215 @@ async def test_select_custom_fields_for_includes_without_requesting_includes(
"meta": {"count": 1, "totalPages": 1},
}

def _get_clear_mock_calls(self, mock_obj) -> list[call]:
mock_calls = mock_obj.mock_calls
return [call_ for call_ in mock_calls if call_ not in [call.__len__(), call.__str__()]]

def _prepare_info_schema_calls_to_assert(self, mock_calls) -> list[call]:
calls_to_check = []
for wrapper_call in mock_calls:
kwargs = wrapper_call.kwargs
kwargs["includes"] = sorted(kwargs["includes"], key=lambda x: x)

calls_to_check.append(
call(
*wrapper_call.args,
**kwargs,
),
)

return sorted(
calls_to_check,
key=lambda x: (x.kwargs["base_name"], x.kwargs["includes"]),
)

async def test_check_get_info_schema_cache(
self,
user_1: User,
):
resource_type = "user_with_cache"
with suppress(KeyError):
RoutersJSONAPI.all_jsonapi_routers.pop(resource_type)

app_with_cache = build_app_custom(
model=User,
schema=UserSchema,
schema_in_post=UserInSchemaAllowIdOnPost,
schema_in_patch=UserPatchSchema,
resource_type=resource_type,
# set cache size to enable caching
max_cache_size=128,
)

target_func_name = "_get_info_from_schema_for_building"
url = app_with_cache.url_path_for(f"get_{resource_type}_list")
params = {
"include": "posts,posts.comments",
}

expected_len_with_cache = 6
expected_len_without_cache = 10

with patch.object(
SchemaBuilder,
target_func_name,
wraps=app_with_cache.jsonapi_routers.schema_builder._get_info_from_schema_for_building,
) as wrapped_func:
async with AsyncClient(app=app_with_cache, base_url="http://test") as client:
response = await client.get(url, params=params)
assert response.status_code == status.HTTP_200_OK, response.text

calls_to_check = self._prepare_info_schema_calls_to_assert(self._get_clear_mock_calls(wrapped_func))

# there are no duplicated calls
assert calls_to_check == sorted(
[
call(
base_name="UserSchema",
schema=UserSchema,
includes=["posts"],
non_optional_relationships=False,
),
call(
base_name="UserSchema",
schema=UserSchema,
includes=["posts", "posts.comments"],
non_optional_relationships=False,
),
call(
base_name="PostSchema",
schema=PostSchema,
includes=[],
non_optional_relationships=False,
),
call(
base_name="PostSchema",
schema=PostSchema,
includes=["comments"],
non_optional_relationships=False,
),
call(
base_name="PostCommentSchema",
schema=PostCommentSchema,
includes=[],
non_optional_relationships=False,
),
call(
base_name="PostCommentSchema",
schema=PostCommentSchema,
includes=["posts"],
non_optional_relationships=False,
),
],
key=lambda x: (x.kwargs["base_name"], x.kwargs["includes"]),
)
assert wrapped_func.call_count == expected_len_with_cache

response = await client.get(url, params=params)
assert response.status_code == status.HTTP_200_OK, response.text

# there are no new calls
assert wrapped_func.call_count == expected_len_with_cache

resource_type = "user_without_cache"
with suppress(KeyError):
RoutersJSONAPI.all_jsonapi_routers.pop(resource_type)

app_without_cache = build_app_custom(
model=User,
schema=UserSchema,
schema_in_post=UserInSchemaAllowIdOnPost,
schema_in_patch=UserPatchSchema,
resource_type=resource_type,
max_cache_size=0,
)

with patch.object(
SchemaBuilder,
target_func_name,
wraps=app_without_cache.jsonapi_routers.schema_builder._get_info_from_schema_for_building,
) as wrapped_func:
async with AsyncClient(app=app_without_cache, base_url="http://test") as client:
response = await client.get(url, params=params)
assert response.status_code == status.HTTP_200_OK, response.text

calls_to_check = self._prepare_info_schema_calls_to_assert(self._get_clear_mock_calls(wrapped_func))

# there are duplicated calls
assert calls_to_check == sorted(
[
call(
base_name="UserSchema",
schema=UserSchema,
includes=["posts"],
non_optional_relationships=False,
),
call(
base_name="UserSchema",
schema=UserSchema,
includes=["posts"],
non_optional_relationships=False,
), # duplicate
call(
base_name="UserSchema",
schema=UserSchema,
includes=["posts", "posts.comments"],
non_optional_relationships=False,
),
call(
base_name="PostSchema",
schema=PostSchema,
includes=[],
non_optional_relationships=False,
),
call(
base_name="PostSchema",
schema=PostSchema,
includes=[],
non_optional_relationships=False,
), # duplicate
call(
base_name="PostSchema",
schema=PostSchema,
includes=[],
non_optional_relationships=False,
), # duplicate
call(
base_name="PostSchema",
schema=PostSchema,
includes=["comments"],
non_optional_relationships=False,
),
call(
base_name="PostSchema",
schema=PostSchema,
includes=["comments"],
non_optional_relationships=False,
), # duplicate
call(
base_name="PostCommentSchema",
schema=PostCommentSchema,
includes=[],
non_optional_relationships=False,
),
call(
base_name="PostCommentSchema",
schema=PostCommentSchema,
includes=["posts"],
non_optional_relationships=False,
), # duplicate
],
key=lambda x: (x.kwargs["base_name"], x.kwargs["includes"]),
)

assert wrapped_func.call_count == expected_len_without_cache

response = await client.get(url, params=params)
assert response.status_code == status.HTTP_200_OK, response.text

# there are new calls
assert wrapped_func.call_count == expected_len_without_cache * 2


class TestCreatePostAndComments:
async def test_get_posts_with_users(
Expand All @@ -371,6 +584,13 @@ async def test_get_posts_with_users(
user_1_posts: List[Post],
user_2_posts: List[Post],
):
call(
base_name="UserSchema",
schema=UserSchema,
includes=["posts"],
non_optional_relationships=False,
on_optional_relationships=False,
)
url = app.url_path_for("get_post_list")
url = f"{url}?include=user"
response = await client.get(url)
Expand Down