Skip to content

Commit 280486c

Browse files
committed
add typing_extensions.Literals
1 parent 4a14b5d commit 280486c

File tree

5 files changed

+86
-0
lines changed

5 files changed

+86
-0
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from typing_extensions import Literal
2+
3+
4+
def func_with_literals(
5+
input1: Literal[
6+
'a',
7+
1,
8+
Literal[2,3],
9+
],
10+
) -> Literal[True, 0.5, Literal[3]]:
11+
return input1
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from py_codegen.type_extractor.__tests__.utils import cleanup, traverse
2+
from py_codegen.type_extractor.nodes.FunctionFound import FunctionFound
3+
from py_codegen.type_extractor.nodes.LiteralFound import LiteralFound
4+
from py_codegen.type_extractor.nodes.TypeOR import TypeOR
5+
from py_codegen.type_extractor.type_extractor import TypeExtractor
6+
from py_codegen.test_fixtures.func_with_literals import func_with_literals
7+
8+
9+
def test_func_with_list():
10+
type_collector = TypeExtractor()
11+
12+
type_collector.add(None)(func_with_literals)
13+
14+
func_found_cleaned = cleanup(
15+
type_collector.collected_types[func_with_literals.__qualname__],
16+
)
17+
assert func_found_cleaned == traverse(
18+
FunctionFound(
19+
name=func_with_literals.__qualname__,
20+
params={
21+
'input1': TypeOR(
22+
a=LiteralFound('a'),
23+
b=TypeOR(
24+
a=LiteralFound(1),
25+
b=TypeOR(
26+
a=LiteralFound(2),
27+
b=LiteralFound(3),
28+
),
29+
),
30+
),
31+
},
32+
return_type=TypeOR(
33+
a=LiteralFound(True),
34+
b=TypeOR(
35+
a=LiteralFound(0.5),
36+
b=LiteralFound(3),
37+
),
38+
)
39+
),
40+
cleanup,
41+
)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from typing import Tuple
2+
from typing_extensions import Literal
3+
4+
from py_codegen.type_extractor.__base__ import BaseTypeExtractor
5+
from py_codegen.type_extractor.nodes.LiteralFound import LiteralFound
6+
from py_codegen.type_extractor.nodes.TypeOR import TypeOR
7+
8+
9+
def literal_found_middleware(typ, type_extractor: BaseTypeExtractor):
10+
typ_origin = typ.__origin__
11+
if typ_origin is not Literal:
12+
return
13+
return __process_literal_args(typ.__args__)
14+
15+
16+
def __process_literal_args(args: Tuple):
17+
current = LiteralFound(args[0])
18+
try:
19+
if args[0].__origin__ is Literal:
20+
current = __process_literal_args(args[0].__args__)
21+
except:
22+
pass
23+
if len(args) == 1:
24+
return current
25+
return TypeOR(current, __process_literal_args(args[1:]))
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from typing import NamedTuple, Any
2+
3+
from py_codegen.type_extractor.nodes.BaseNodeType import BaseNodeType
4+
5+
6+
class LiteralFound(NamedTuple, BaseNodeType): # type: ignore
7+
value: Any

py_codegen/type_extractor/type_extractor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from py_codegen.type_extractor.middlewares.dict_found import dict_found_middleware
1717
from py_codegen.type_extractor.middlewares.function_found import func_found_middleware
1818
from py_codegen.type_extractor.middlewares.list_found import list_found_middleware
19+
from py_codegen.type_extractor.middlewares.literal_found import literal_found_middleware
1920
from py_codegen.type_extractor.middlewares.mapping_found import mapping_found_middleware
2021
from py_codegen.type_extractor.middlewares.tuple_found import tuple_found_middleware
2122
from py_codegen.type_extractor.middlewares.type_or import typeor_middleware
@@ -44,6 +45,7 @@ class TypeExtractor(BaseTypeExtractor):
4445
list_found_middleware,
4546
typeor_middleware,
4647
typeddict_found_middleware,
48+
literal_found_middleware,
4749
dict_found_middleware,
4850
tuple_found_middleware,
4951
class_found_middleware,

0 commit comments

Comments
 (0)