11from typing import List , Union , Tuple
22from tools .codegen .model import (Type , BaseTy , BaseType , OptionalType ,
33 ListType , OperatorName , FunctionSchema ,
4- Return , TensorOptionsArguments )
4+ Return )
55from tools .codegen .api .types import (BaseCppType , BaseCType , OptionalCType ,
66 ConstRefCType , NamedCType ,
7- MutRefCType , deviceT , layoutT ,
7+ MutRefCType ,
88 VectorCType , boolT , longT , doubleT , ListCType , stringT ,
99 scalarT , scalarTypeT , ArrayRefCType , ArrayCType , TupleCType )
1010
@@ -33,9 +33,7 @@ def process_ir_type(typ: Type) -> Union[BaseCType, VectorCType, OptionalCType, L
3333 if typ .name == BaseTy .Tensor :
3434 return BaseCType (valueT )
3535 elif typ .name == BaseTy .Scalar :
36- # at::scalar has special handling,
37- # and is wrapped in an IR value just like at::tensor
38- return BaseCType (valueT )
36+ return BaseCType (scalarT )
3937 elif typ .name == BaseTy .ScalarType :
4038 return BaseCType (scalarTypeT )
4139 elif typ .name == BaseTy .int :
@@ -46,10 +44,6 @@ def process_ir_type(typ: Type) -> Union[BaseCType, VectorCType, OptionalCType, L
4644 return BaseCType (doubleT )
4745 elif typ .name == BaseTy .str :
4846 return BaseCType (stringT )
49- elif typ .name == BaseTy .Device :
50- return BaseCType (deviceT )
51- elif typ .name == BaseTy .Layout :
52- return BaseCType (layoutT )
5347 else :
5448 raise AssertionError (f"TODO add support for type { repr (typ )} " )
5549 elif isinstance (typ , OptionalType ):
@@ -71,30 +65,12 @@ def isValueType(typ: Union[Type, BaseCType, OptionalCType, ConstRefCType, MutRef
7165 being Tensor-like, but assumes the type has already been transformed.
7266 """
7367 if isinstance (typ , BaseCType ):
74- # I am regretting my naming conventions, but now we are wrapping at::scalar in
75- # lazy value, while preserving other 'scalar' types as scalars in the IR
76- return typ .type == valueT or typ .type == scalarT
68+ return typ .type == valueT
7769 elif isinstance (typ , (OptionalCType , ListCType , VectorCType )):
7870 return isValueType (typ .elem )
7971 else :
8072 return False
8173
82- def isWrappedScalarType (typ : Type ) -> bool :
83- """
84- Given a type, determine if it is a c10::scalar which we will wrap in a lazy Value.
85- Since we literally change the type from scalarT to valueT, information is lost.
86- This function helps build a list of wrapped scalars to save that information
87- """
88- if isinstance (typ , BaseType ):
89- # I am regretting my naming conventions, but now we are wrapping at::scalar in
90- # lazy value, while preserving other 'scalar' types as scalars in the IR
91- return typ .name == BaseTy .Scalar
92- elif isinstance (typ , (OptionalType , ListType )):
93- return isWrappedScalarType (typ .elem )
94- else :
95- return False
96-
97-
9874# Inspired by a FunctionSchema object, a LazyIrSchema holds the schema of a Lazy IR node.
9975# Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML),
10076# but carries type information from a native FunctionSchema modified for use with IR nodes,
@@ -111,8 +87,6 @@ class LazyIrSchema:
11187 # TODO: Need to handle collisions with argument names at some point
11288 returns : Tuple ['Return' , ...]
11389
114- wrapped_scalar_names : List [str ]
115-
11690 def __init__ (self , func : FunctionSchema ):
11791
11892 positional_arg_types = []
@@ -134,15 +108,14 @@ def __init__(self, func: FunctionSchema):
134108 "tensor_options" ,
135109 "post_tensor_options_kwarg_only" ,
136110 "out" ]:
137- curr_args = getattr (func .arguments , arg_field )
138- if curr_args is not None :
139- if isinstance ( curr_args , TensorOptionsArguments ):
140- curr_args = curr_args . all ()
141- keyword_arg_types . extend ([ NamedCType ( arg . name , process_ir_type (arg .type )) for arg in curr_args ])
111+ if getattr (func .arguments , arg_field ) is not None :
112+ keyword_arg_types . extend ([
113+ NamedCType (
114+ arg . name ,
115+ process_ir_type (arg .type )) for arg in getattr ( func . arguments , arg_field ) ])
142116 self .keyword_arg_types = tuple (keyword_arg_types )
143117 self .name = func .name
144118 self .returns = func .returns
145- self .wrapped_scalar_names = [arg .name for arg in func .schema_order_arguments () if isWrappedScalarType (arg .type )]
146119
147120 @property
148121 def node_name (self ) -> str :
0 commit comments