Skip to content

Commit 889f3f4

Browse files
wconstabpytorchmergebot
authored andcommitted
Revert D34178476: Update lazy_ir.py from lazy_tensor_staging
Test Plan: revert-hammer Differential Revision: D34178476 (3842140) Original commit changeset: 7190b2e0d82b Original Phabricator Diff: D34178476 (3842140) fbshipit-source-id: 4c969a355f01244c6f5acc52bc31679f2182aa55 (cherry picked from commit 1708207)
1 parent c192558 commit 889f3f4

File tree

16 files changed

+95
-243
lines changed

16 files changed

+95
-243
lines changed

test/cpp/lazy/test_cache.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ namespace lazy {
1111
class CacheNode : public Node {
1212
public:
1313
explicit CacheNode(const std::string& str)
14-
: Node(OpKind(), /* num_outputs */ 1, /* hash_func */ [&](bool bakeInSizes) -> hash_t { return Hash(str); }),
14+
: Node(OpKind(), /* num_outputs */ 1, /* hash_seed */ Hash(str)),
1515
str_(str) {}
1616
~CacheNode() override = default;
1717

test/cpp/lazy/test_ir.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace lazy {
1212
class TestLeafNode : public Node {
1313
public:
1414
explicit TestLeafNode(size_t param)
15-
: Node(OpKind(), /* num_outputs */ 1, /* hash_func */[&](bool bakeInSizes) -> hash_t { return Hash(param); }),
15+
: Node(OpKind(), /* num_outputs */ 1, /* hash_seed */ Hash(param)),
1616
param_(param) {}
1717
~TestLeafNode() override = default;
1818

test/cpp/lazy/test_ir_util.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace lazy {
1212
class IrUtilNode : public Node {
1313
public:
1414
explicit IrUtilNode()
15-
: Node(OpKind(), /* num_outputs */ 1, /* hash_func */ [&](bool bakeInSizes) -> hash_t { return Hash(0); }) {}
15+
: Node(OpKind(), /* num_outputs */ 1, /* hash_seed */ Hash(0)) {}
1616
~IrUtilNode() override = default;
1717

1818
void AddOperand(Value v) {

tools/codegen/api/lazy.py

Lines changed: 9 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from typing import List, Union, Tuple
22
from tools.codegen.model import (Type, BaseTy, BaseType, OptionalType,
33
ListType, OperatorName, FunctionSchema,
4-
Return, TensorOptionsArguments)
4+
Return)
55
from tools.codegen.api.types import (BaseCppType, BaseCType, OptionalCType,
66
ConstRefCType, NamedCType,
7-
MutRefCType, deviceT, layoutT,
7+
MutRefCType,
88
VectorCType, boolT, longT, doubleT, ListCType, stringT,
99
scalarT, scalarTypeT, ArrayRefCType, ArrayCType, TupleCType)
1010

@@ -33,9 +33,7 @@ def process_ir_type(typ: Type) -> Union[BaseCType, VectorCType, OptionalCType, L
3333
if typ.name == BaseTy.Tensor:
3434
return BaseCType(valueT)
3535
elif typ.name == BaseTy.Scalar:
36-
# at::scalar has special handling,
37-
# and is wrapped in an IR value just like at::tensor
38-
return BaseCType(valueT)
36+
return BaseCType(scalarT)
3937
elif typ.name == BaseTy.ScalarType:
4038
return BaseCType(scalarTypeT)
4139
elif typ.name == BaseTy.int:
@@ -46,10 +44,6 @@ def process_ir_type(typ: Type) -> Union[BaseCType, VectorCType, OptionalCType, L
4644
return BaseCType(doubleT)
4745
elif typ.name == BaseTy.str:
4846
return BaseCType(stringT)
49-
elif typ.name == BaseTy.Device:
50-
return BaseCType(deviceT)
51-
elif typ.name == BaseTy.Layout:
52-
return BaseCType(layoutT)
5347
else:
5448
raise AssertionError(f"TODO add support for type {repr(typ)}")
5549
elif isinstance(typ, OptionalType):
@@ -71,30 +65,12 @@ def isValueType(typ: Union[Type, BaseCType, OptionalCType, ConstRefCType, MutRef
7165
being Tensor-like, but assumes the type has already been transformed.
7266
"""
7367
if isinstance(typ, BaseCType):
74-
# I am regretting my naming conventions, but now we are wrapping at::scalar in
75-
# lazy value, while preserving other 'scalar' types as scalars in the IR
76-
return typ.type == valueT or typ.type == scalarT
68+
return typ.type == valueT
7769
elif isinstance(typ, (OptionalCType, ListCType, VectorCType)):
7870
return isValueType(typ.elem)
7971
else:
8072
return False
8173

82-
def isWrappedScalarType(typ: Type) -> bool:
83-
"""
84-
Given a type, determine if it is a c10::scalar which we will wrap in a lazy Value.
85-
Since we literally change the type from scalarT to valueT, information is lost.
86-
This function helps build a list of wrapped scalars to save that information
87-
"""
88-
if isinstance(typ, BaseType):
89-
# I am regretting my naming conventions, but now we are wrapping at::scalar in
90-
# lazy value, while preserving other 'scalar' types as scalars in the IR
91-
return typ.name == BaseTy.Scalar
92-
elif isinstance(typ, (OptionalType, ListType)):
93-
return isWrappedScalarType(typ.elem)
94-
else:
95-
return False
96-
97-
9874
# Inspired by a FunctionSchema object, a LazyIrSchema holds the schema of a Lazy IR node.
9975
# Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML),
10076
# but carries type information from a native FunctionSchema modified for use with IR nodes,
@@ -111,8 +87,6 @@ class LazyIrSchema:
11187
# TODO: Need to handle collisions with argument names at some point
11288
returns: Tuple['Return', ...]
11389

114-
wrapped_scalar_names: List[str]
115-
11690
def __init__(self, func: FunctionSchema):
11791

11892
positional_arg_types = []
@@ -134,15 +108,14 @@ def __init__(self, func: FunctionSchema):
134108
"tensor_options",
135109
"post_tensor_options_kwarg_only",
136110
"out"]:
137-
curr_args = getattr(func.arguments, arg_field)
138-
if curr_args is not None:
139-
if isinstance(curr_args, TensorOptionsArguments):
140-
curr_args = curr_args.all()
141-
keyword_arg_types.extend([NamedCType(arg.name, process_ir_type(arg.type)) for arg in curr_args])
111+
if getattr(func.arguments, arg_field) is not None:
112+
keyword_arg_types.extend([
113+
NamedCType(
114+
arg.name,
115+
process_ir_type(arg.type)) for arg in getattr(func.arguments, arg_field)])
142116
self.keyword_arg_types = tuple(keyword_arg_types)
143117
self.name = func.name
144118
self.returns = func.returns
145-
self.wrapped_scalar_names = [arg.name for arg in func.schema_order_arguments() if isWrappedScalarType(arg.type)]
146119

147120
@property
148121
def node_name(self) -> str:

0 commit comments

Comments
 (0)