Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 25 additions & 27 deletions openapi_core/casting/schemas/casters.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any
from typing import Generic
from typing import Iterable
from typing import List
from typing import Mapping
from typing import Optional
from typing import Type
Expand All @@ -28,6 +27,14 @@ def __init__(
self.schema_caster = schema_caster

def __call__(self, value: Any) -> Any:
self.validate(value)

return self.cast(value)

def validate(self, value: Any) -> None:
pass

def cast(self, value: Any) -> Any:
return value


Expand All @@ -37,18 +44,9 @@ def __call__(self, value: Any) -> Any:
class PrimitiveTypeCaster(Generic[PrimitiveType], PrimitiveCaster):
primitive_type: Type[PrimitiveType] = NotImplemented

def __call__(self, value: Union[str, bytes]) -> Any:
self.validate(value)

def cast(self, value: Union[str, bytes]) -> PrimitiveType:
return self.primitive_type(value) # type: ignore [call-arg]

def validate(self, value: Any) -> None:
# FIXME: don't cast data from media type deserializer
# See https://github.com/python-openapi/openapi-core/issues/706
# if not isinstance(value, (str, bytes)):
# raise ValueError("should cast only from string or bytes")
pass


class IntegerCaster(PrimitiveTypeCaster[int]):
primitive_type = int
Expand All @@ -61,22 +59,18 @@ class NumberCaster(PrimitiveTypeCaster[float]):
class BooleanCaster(PrimitiveTypeCaster[bool]):
primitive_type = bool

def __call__(self, value: Union[str, bytes]) -> Any:
self.validate(value)

return self.primitive_type(forcebool(value))

def validate(self, value: Any) -> None:
super().validate(value)

# FIXME: don't cast data from media type deserializer
# See https://github.com/python-openapi/openapi-core/issues/706
if isinstance(value, bool):
return

if value.lower() not in ["false", "true"]:
raise ValueError("not a boolean format")

def cast(self, value: Union[str, bytes]) -> bool:
return self.primitive_type(forcebool(value))


class ArrayCaster(PrimitiveCaster):
@property
Expand All @@ -85,19 +79,21 @@ def items_caster(self) -> "SchemaCaster":
items_schema = self.schema.get("items", SchemaPath.from_dict({}))
return self.schema_caster.evolve(items_schema)

def __call__(self, value: Any) -> List[Any]:
def validate(self, value: Any) -> None:
# str and bytes are not arrays according to the OpenAPI spec
if isinstance(value, (str, bytes)) or not isinstance(value, Iterable):
raise CastError(value, self.schema["type"])
raise ValueError("not an array format")

try:
return list(map(self.items_caster.cast, value))
except (ValueError, TypeError):
raise CastError(value, self.schema["type"])
def cast(self, value: list[Any]) -> list[Any]:
return list(map(self.items_caster.cast, value))


class ObjectCaster(PrimitiveCaster):
def __call__(self, value: Any) -> Any:
def validate(self, value: Any) -> None:
if not isinstance(value, dict):
raise ValueError("not an object format")

def cast(self, value: dict[str, Any]) -> dict[str, Any]:
return self._cast_proparties(value)

def evolve(self, schema: SchemaPath) -> "ObjectCaster":
Expand All @@ -109,9 +105,11 @@ def evolve(self, schema: SchemaPath) -> "ObjectCaster":
self.schema_caster.evolve(schema),
)

def _cast_proparties(self, value: Any, schema_only: bool = False) -> Any:
def _cast_proparties(
self, value: dict[str, Any], schema_only: bool = False
) -> dict[str, Any]:
if not isinstance(value, dict):
raise CastError(value, self.schema["type"])
raise ValueError("not an object format")

all_of_schemas = self.schema_validator.iter_all_of_schemas(value)
for all_of_schema in all_of_schemas:
Expand Down
4 changes: 0 additions & 4 deletions openapi_core/casting/schemas/datatypes.py

This file was deleted.

4 changes: 2 additions & 2 deletions openapi_core/casting/schemas/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from dataclasses import dataclass
from typing import Any

from openapi_core.exceptions import OpenAPIError
from openapi_core.deserializing.exceptions import DeserializeError


@dataclass
class CastError(OpenAPIError):
class CastError(DeserializeError):
"""Schema cast operation error"""

value: Any
Expand Down
8 changes: 1 addition & 7 deletions openapi_core/deserializing/media_types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
from openapi_core.deserializing.media_types.util import plain_loads
from openapi_core.deserializing.media_types.util import urlencoded_form_loads
from openapi_core.deserializing.media_types.util import xml_loads
from openapi_core.deserializing.styles import style_deserializers_factory

__all__ = ["media_type_deserializers_factory"]
__all__ = ["media_type_deserializers", "MediaTypeDeserializersFactory"]

media_type_deserializers: MediaTypeDeserializersDict = defaultdict(
lambda: binary_loads,
Expand All @@ -30,8 +29,3 @@
"multipart/form-data": data_form_loads,
}
)

media_type_deserializers_factory = MediaTypeDeserializersFactory(
style_deserializers_factory,
media_type_deserializers=media_type_deserializers,
)
27 changes: 27 additions & 0 deletions openapi_core/deserializing/media_types/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from jsonschema_path import SchemaPath

from openapi_core.casting.schemas.factories import SchemaCastersFactory
from openapi_core.deserializing.media_types.datatypes import (
MediaTypeDeserializersDict,
)
Expand All @@ -12,6 +13,7 @@
from openapi_core.deserializing.media_types.deserializers import (
MediaTypesDeserializer,
)
from openapi_core.deserializing.styles.datatypes import StyleDeserializersDict
from openapi_core.deserializing.styles.factories import (
StyleDeserializersFactory,
)
Expand All @@ -28,6 +30,31 @@ def __init__(
media_type_deserializers = {}
self.media_type_deserializers = media_type_deserializers

@classmethod
def from_schema_casters_factory(
cls,
schema_casters_factory: SchemaCastersFactory,
style_deserializers: Optional[StyleDeserializersDict] = None,
media_type_deserializers: Optional[MediaTypeDeserializersDict] = None,
) -> "MediaTypeDeserializersFactory":
from openapi_core.deserializing.media_types import (
media_type_deserializers as default_media_type_deserializers,
)
from openapi_core.deserializing.styles import (
style_deserializers as default_style_deserializers,
)

style_deserializers_factory = StyleDeserializersFactory(
schema_casters_factory,
style_deserializers=style_deserializers
or default_style_deserializers,
)
return cls(
style_deserializers_factory,
media_type_deserializers=media_type_deserializers
or default_media_type_deserializers,
)

def create(
self,
mimetype: str,
Expand Down
6 changes: 1 addition & 5 deletions openapi_core/deserializing/styles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from openapi_core.deserializing.styles.util import simple_loads
from openapi_core.deserializing.styles.util import space_delimited_loads

__all__ = ["style_deserializers_factory"]
__all__ = ["style_deserializers", "StyleDeserializersFactory"]

style_deserializers: StyleDeserializersDict = {
"matrix": matrix_loads,
Expand All @@ -21,7 +21,3 @@
"pipeDelimited": pipe_delimited_loads,
"deepObject": deep_object_loads,
}

style_deserializers_factory = StyleDeserializersFactory(
style_deserializers=style_deserializers,
)
22 changes: 17 additions & 5 deletions openapi_core/deserializing/styles/deserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
from typing import Mapping
from typing import Optional

from jsonschema_path import SchemaPath

from openapi_core.casting.schemas.casters import SchemaCaster
from openapi_core.casting.schemas.exceptions import CastError
from openapi_core.deserializing.exceptions import DeserializeError
from openapi_core.deserializing.styles.datatypes import DeserializerCallable

Expand All @@ -13,13 +17,16 @@ def __init__(
style: str,
explode: bool,
name: str,
schema_type: str,
schema: SchemaPath,
caster: SchemaCaster,
deserializer_callable: Optional[DeserializerCallable] = None,
):
self.style = style
self.explode = explode
self.name = name
self.schema_type = schema_type
self.schema = schema
self.schema_type = schema.getkey("type", "")
self.caster = caster
self.deserializer_callable = deserializer_callable

def deserialize(self, location: Mapping[str, Any]) -> Any:
Expand All @@ -28,8 +35,13 @@ def deserialize(self, location: Mapping[str, Any]) -> Any:
return location[self.name]

try:
return self.deserializer_callable(
value = self.deserializer_callable(
self.explode, self.name, self.schema_type, location
)
except (ValueError, TypeError, AttributeError):
raise DeserializeError(self.style, self.name)
except (ValueError, TypeError, AttributeError) as exc:
raise DeserializeError(self.style, self.name) from exc

try:
return self.caster.cast(value)
except (ValueError, TypeError, AttributeError) as exc:
raise CastError(value, self.schema_type) from exc
8 changes: 5 additions & 3 deletions openapi_core/deserializing/styles/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@

from jsonschema_path import SchemaPath

from openapi_core.casting.schemas.factories import SchemaCastersFactory
from openapi_core.deserializing.styles.datatypes import StyleDeserializersDict
from openapi_core.deserializing.styles.deserializers import StyleDeserializer


class StyleDeserializersFactory:
def __init__(
self,
schema_casters_factory: SchemaCastersFactory,
style_deserializers: Optional[StyleDeserializersDict] = None,
):
self.schema_casters_factory = schema_casters_factory
if style_deserializers is None:
style_deserializers = {}
self.style_deserializers = style_deserializers
Expand All @@ -22,9 +25,8 @@ def create(
schema: SchemaPath,
name: str,
) -> StyleDeserializer:
schema_type = schema.getkey("type", "")

deserialize_callable = self.style_deserializers.get(style)
caster = self.schema_casters_factory.create(schema)
return StyleDeserializer(
style, explode, name, schema_type, deserialize_callable
style, explode, name, schema, caster, deserialize_callable
)
20 changes: 12 additions & 8 deletions openapi_core/unmarshalling/request/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,12 @@
from openapi_spec_validator.validation.types import SpecValidatorType

from openapi_core.casting.schemas.factories import SchemaCastersFactory
from openapi_core.deserializing.media_types import (
media_type_deserializers_factory,
)
from openapi_core.deserializing.media_types.datatypes import (
MediaTypeDeserializersDict,
)
from openapi_core.deserializing.media_types.factories import (
MediaTypeDeserializersFactory,
)
from openapi_core.deserializing.styles import style_deserializers_factory
from openapi_core.deserializing.styles.factories import (
StyleDeserializersFactory,
)
Expand All @@ -43,8 +39,12 @@ def __init__(
self,
spec: SchemaPath,
base_url: Optional[str] = None,
style_deserializers_factory: StyleDeserializersFactory = style_deserializers_factory,
media_type_deserializers_factory: MediaTypeDeserializersFactory = media_type_deserializers_factory,
style_deserializers_factory: Optional[
StyleDeserializersFactory
] = None,
media_type_deserializers_factory: Optional[
MediaTypeDeserializersFactory
] = None,
schema_casters_factory: Optional[SchemaCastersFactory] = None,
schema_validators_factory: Optional[SchemaValidatorsFactory] = None,
path_finder_cls: Optional[PathFinderType] = None,
Expand Down Expand Up @@ -74,8 +74,12 @@ def __init__(
self,
spec: SchemaPath,
base_url: Optional[str] = None,
style_deserializers_factory: StyleDeserializersFactory = style_deserializers_factory,
media_type_deserializers_factory: MediaTypeDeserializersFactory = media_type_deserializers_factory,
style_deserializers_factory: Optional[
StyleDeserializersFactory
] = None,
media_type_deserializers_factory: Optional[
MediaTypeDeserializersFactory
] = None,
schema_casters_factory: Optional[SchemaCastersFactory] = None,
schema_validators_factory: Optional[SchemaValidatorsFactory] = None,
path_finder_cls: Optional[PathFinderType] = None,
Expand Down
12 changes: 6 additions & 6 deletions openapi_core/unmarshalling/request/unmarshallers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,12 @@
from openapi_spec_validator.validation.types import SpecValidatorType

from openapi_core.casting.schemas.factories import SchemaCastersFactory
from openapi_core.deserializing.media_types import (
media_type_deserializers_factory,
)
from openapi_core.deserializing.media_types.datatypes import (
MediaTypeDeserializersDict,
)
from openapi_core.deserializing.media_types.factories import (
MediaTypeDeserializersFactory,
)
from openapi_core.deserializing.styles import style_deserializers_factory
from openapi_core.deserializing.styles.factories import (
StyleDeserializersFactory,
)
Expand Down Expand Up @@ -85,8 +81,12 @@ def __init__(
self,
spec: SchemaPath,
base_url: Optional[str] = None,
style_deserializers_factory: StyleDeserializersFactory = style_deserializers_factory,
media_type_deserializers_factory: MediaTypeDeserializersFactory = media_type_deserializers_factory,
style_deserializers_factory: Optional[
StyleDeserializersFactory
] = None,
media_type_deserializers_factory: Optional[
MediaTypeDeserializersFactory
] = None,
schema_casters_factory: Optional[SchemaCastersFactory] = None,
schema_validators_factory: Optional[SchemaValidatorsFactory] = None,
path_finder_cls: Optional[PathFinderType] = None,
Expand Down
Loading
Loading