-
Notifications
You must be signed in to change notification settings - Fork 143
Expand file tree
/
Copy pathunchecked_base_model.py
More file actions
396 lines (328 loc) · 14.7 KB
/
unchecked_base_model.py
File metadata and controls
396 lines (328 loc) · 14.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
# This file was auto-generated by Fern from our API Definition.
import datetime as dt
import inspect
import typing
import uuid
import pydantic
import typing_extensions
from .pydantic_utilities import (
IS_PYDANTIC_V2,
ModelField,
UniversalBaseModel,
get_args,
get_origin,
is_literal_type,
is_union,
parse_date,
parse_datetime,
parse_obj_as,
)
from .serialization import convert_and_respect_annotation_metadata, get_field_to_alias_mapping
from pydantic_core import PydanticUndefined
class UnionMetadata:
discriminant: str
def __init__(self, *, discriminant: str) -> None:
self.discriminant = discriminant
Model = typing.TypeVar("Model", bound=pydantic.BaseModel)
class UncheckedBaseModel(UniversalBaseModel):
if IS_PYDANTIC_V2:
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow") # type: ignore # Pydantic v2
else:
class Config:
extra = pydantic.Extra.allow
if IS_PYDANTIC_V2:
@classmethod
def model_validate(
cls: typing.Type["Model"],
obj: typing.Any,
*args: typing.Any,
**kwargs: typing.Any,
) -> "Model":
"""
Ensure that when using Pydantic v2's `model_validate` entrypoint we still
respect our FieldMetadata-based aliasing.
"""
dealiased_obj = convert_and_respect_annotation_metadata(
object_=obj,
annotation=cls,
direction="read",
)
return super().model_validate(dealiased_obj, *args, **kwargs) # type: ignore[misc]
@classmethod
def model_construct(
cls: typing.Type["Model"],
_fields_set: typing.Optional[typing.Set[str]] = None,
**values: typing.Any,
) -> "Model":
# Fallback construct function to the specified override below.
return cls.construct(_fields_set=_fields_set, **values)
# Allow construct to not validate model
# Implementation taken from: https://github.com/pydantic/pydantic/issues/1168#issuecomment-817742836
@classmethod
def construct(
cls: typing.Type["Model"],
_fields_set: typing.Optional[typing.Set[str]] = None,
**values: typing.Any,
) -> "Model":
m = cls.__new__(cls)
fields_values = {}
if _fields_set is None:
_fields_set = set(values.keys())
fields = _get_model_fields(cls)
populate_by_name = _get_is_populate_by_name(cls)
field_aliases = get_field_to_alias_mapping(cls)
for name, field in fields.items():
# Key here is only used to pull data from the values dict
# you should always use the NAME of the field to for field_values, etc.
# because that's how the object is constructed from a pydantic perspective
key = field.alias
if (key is None or field.alias == name) and name in field_aliases:
key = field_aliases[name]
if key is None or (key not in values and populate_by_name): # Added this to allow population by field name
key = name
if key in values:
if IS_PYDANTIC_V2:
type_ = field.annotation # type: ignore # Pydantic v2
else:
type_ = typing.cast(typing.Type, field.outer_type_) # type: ignore # Pydantic < v1.10.15
fields_values[name] = (
construct_type(object_=values[key], type_=type_) if type_ is not None else values[key]
)
_fields_set.add(name)
else:
default = _get_field_default(field)
fields_values[name] = default
# If the default values are non-null act like they've been set
# This effectively allows exclude_unset to work like exclude_none where
# the latter passes through intentionally set none values.
if default != None and default != PydanticUndefined:
_fields_set.add(name)
# Add extras back in
extras = {}
pydantic_alias_fields = [field.alias for field in fields.values()]
internal_alias_fields = list(field_aliases.values())
for key, value in values.items():
# If the key is not a field by name, nor an alias to a field, then it's extra
if (key not in pydantic_alias_fields and key not in internal_alias_fields) and key not in fields:
if IS_PYDANTIC_V2:
extras[key] = value
else:
_fields_set.add(key)
fields_values[key] = value
object.__setattr__(m, "__dict__", fields_values)
if IS_PYDANTIC_V2:
object.__setattr__(m, "__pydantic_private__", None)
object.__setattr__(m, "__pydantic_extra__", extras)
object.__setattr__(m, "__pydantic_fields_set__", _fields_set)
else:
object.__setattr__(m, "__fields_set__", _fields_set)
m._init_private_attributes() # type: ignore # Pydantic v1
return m
def _validate_collection_items_compatible(collection: typing.Any, target_type: typing.Type[typing.Any]) -> bool:
"""
Validate that all items in a collection are compatible with the target type.
Args:
collection: The collection to validate (list, set, or dict values)
target_type: The target type to validate against
Returns:
True if all items are compatible, False otherwise
"""
if inspect.isclass(target_type) and issubclass(target_type, pydantic.BaseModel):
for item in collection:
try:
# Try to validate the item against the target type
if isinstance(item, dict):
parse_obj_as(target_type, item)
else:
# If it's not a dict, it might already be the right type
if not isinstance(item, target_type):
return False
except Exception:
return False
return True
def _convert_undiscriminated_union_type(union_type: typing.Type[typing.Any], object_: typing.Any) -> typing.Any:
inner_types = get_args(union_type)
if typing.Any in inner_types:
return object_
for inner_type in inner_types:
# Handle lists of objects that need parsing
if get_origin(inner_type) is list and isinstance(object_, list):
list_inner_type = get_args(inner_type)[0]
try:
if inspect.isclass(list_inner_type) and issubclass(list_inner_type, pydantic.BaseModel):
# Validate that all items in the list are compatible with the target type
if _validate_collection_items_compatible(object_, list_inner_type):
parsed_list = [parse_obj_as(object_=item, type_=list_inner_type) for item in object_]
return parsed_list
except Exception:
pass
try:
if inspect.isclass(inner_type) and issubclass(inner_type, pydantic.BaseModel):
# Attempt a validated parse until one works
return parse_obj_as(inner_type, object_)
except Exception:
continue
# If none of the types work, try matching literal fields first, then fall back
# First pass: try types where all literal fields match the object's values
for inner_type in inner_types:
if inspect.isclass(inner_type) and issubclass(inner_type, pydantic.BaseModel):
fields = _get_model_fields(inner_type)
literal_fields_match = True
for field_name, field in fields.items():
# Check if this field has a Literal type
if IS_PYDANTIC_V2:
field_type = field.annotation # type: ignore # Pydantic v2
else:
field_type = field.outer_type_ # type: ignore # Pydantic v1
if is_literal_type(field_type): # type: ignore[arg-type]
field_default = _get_field_default(field)
name_or_alias = get_field_to_alias_mapping(inner_type).get(field_name, field_name)
# Get the value from the object
if isinstance(object_, dict):
object_value = object_.get(name_or_alias)
else:
object_value = getattr(object_, name_or_alias, None)
# If the literal field value doesn't match, this type is not a match
if object_value is not None and field_default != object_value:
literal_fields_match = False
break
# If all literal fields match, try to construct this type
if literal_fields_match:
try:
return construct_type(object_=object_, type_=inner_type)
except Exception:
continue
# Second pass: if no literal matches, just return the first successful cast
for inner_type in inner_types:
try:
return construct_type(object_=object_, type_=inner_type)
except Exception:
continue
def _convert_union_type(type_: typing.Type[typing.Any], object_: typing.Any) -> typing.Any:
base_type = get_origin(type_) or type_
union_type = type_
if base_type == typing_extensions.Annotated: # type: ignore[comparison-overlap]
union_type = get_args(type_)[0]
annotated_metadata = get_args(type_)[1:]
for metadata in annotated_metadata:
if isinstance(metadata, UnionMetadata):
try:
# Cast to the correct type, based on the discriminant
for inner_type in get_args(union_type):
try:
objects_discriminant = getattr(object_, metadata.discriminant)
except:
objects_discriminant = object_[metadata.discriminant]
if inner_type.__fields__[metadata.discriminant].default == objects_discriminant:
return construct_type(object_=object_, type_=inner_type)
except Exception:
# Allow to fall through to our regular union handling
pass
return _convert_undiscriminated_union_type(union_type, object_)
def construct_type(*, type_: typing.Type[typing.Any], object_: typing.Any) -> typing.Any:
"""
Here we are essentially creating the same `construct` method in spirit as the above, but for all types, not just
Pydantic models.
The idea is to essentially attempt to coerce object_ to type_ (recursively)
"""
# Short circuit when dealing with optionals, don't try to coerces None to a type
if object_ is None:
return None
base_type = get_origin(type_) or type_
is_annotated = base_type == typing_extensions.Annotated # type: ignore[comparison-overlap]
maybe_annotation_members = get_args(type_)
is_annotated_union = is_annotated and is_union(get_origin(maybe_annotation_members[0]))
if base_type == typing.Any: # type: ignore[comparison-overlap]
return object_
if base_type == dict:
if not isinstance(object_, typing.Mapping):
return object_
key_type, items_type = get_args(type_)
d = {
construct_type(object_=key, type_=key_type): construct_type(object_=item, type_=items_type)
for key, item in object_.items()
}
return d
if base_type == list:
if not isinstance(object_, list):
return object_
inner_type = get_args(type_)[0]
return [construct_type(object_=entry, type_=inner_type) for entry in object_]
if base_type == set:
if not isinstance(object_, set) and not isinstance(object_, list):
return object_
inner_type = get_args(type_)[0]
return {construct_type(object_=entry, type_=inner_type) for entry in object_}
if is_union(base_type) or is_annotated_union:
return _convert_union_type(type_, object_)
# Cannot do an `issubclass` with a literal type, let's also just confirm we have a class before this call
if (
object_ is not None
and not is_literal_type(type_)
and (
(inspect.isclass(base_type) and issubclass(base_type, pydantic.BaseModel))
or (
is_annotated
and inspect.isclass(maybe_annotation_members[0])
and issubclass(maybe_annotation_members[0], pydantic.BaseModel)
)
)
):
if IS_PYDANTIC_V2:
return type_.model_construct(**object_)
else:
return type_.construct(**object_)
if base_type == dt.datetime:
try:
return parse_datetime(object_)
except Exception:
return object_
if base_type == dt.date:
try:
return parse_date(object_)
except Exception:
return object_
if base_type == uuid.UUID:
try:
return uuid.UUID(object_)
except Exception:
return object_
if base_type == int:
try:
return int(object_)
except Exception:
return object_
if base_type == bool:
try:
if isinstance(object_, str):
stringified_object = object_.lower()
return stringified_object == "true" or stringified_object == "1"
return bool(object_)
except Exception:
return object_
return object_
def _get_is_populate_by_name(model: typing.Type["Model"]) -> bool:
if IS_PYDANTIC_V2:
return model.model_config.get("populate_by_name", False) # type: ignore # Pydantic v2
return model.__config__.allow_population_by_field_name # type: ignore # Pydantic v1
PydanticField = typing.Union[ModelField, pydantic.fields.FieldInfo]
# Pydantic V1 swapped the typing of __fields__'s values from ModelField to FieldInfo
# And so we try to handle both V1 cases, as well as V2 (FieldInfo from model.model_fields)
def _get_model_fields(
model: typing.Type["Model"],
) -> typing.Mapping[str, PydanticField]:
if IS_PYDANTIC_V2:
return model.model_fields # type: ignore # Pydantic v2
else:
return model.__fields__ # type: ignore # Pydantic v1
def _get_field_default(field: PydanticField) -> typing.Any:
try:
value = field.get_default() # type: ignore # Pydantic < v1.10.15
except:
value = field.default
if IS_PYDANTIC_V2:
from pydantic_core import PydanticUndefined
if value == PydanticUndefined:
return None
return value
return value