Skip to content

Commit 8e03c38

Browse files
gmagogsfmfacebook-github-bot
authored andcommitted
Add prim::EnumName and prim::EnumValue ops (#41965)
Summary: [2/N] Implement Enum JIT support Add prim::EnumName and prim::EnumValue and their lowerings to support getting `name` and `value` attribute of Python enums. Supported: Enum-typed function targuments using Enum type and comparing them Support getting name/value attrs of enums TODO: Add PyThon sugared value for Enum Support Enum-typed return values Support enum values of different types in same Enum class Support serialization and deserialization Pull Request resolved: #41965 Reviewed By: eellison Differential Revision: D22714446 Pulled By: gmagogsfm fbshipit-source-id: db8c4e26b657e7782dbfc2b58a141add1263f76e
1 parent 6287f9e commit 8e03c38

File tree

7 files changed

+137
-3
lines changed

7 files changed

+137
-3
lines changed

aten/src/ATen/core/interned_strings.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ namespace c10 {
6666
_(prim, ListConstruct) \
6767
_(prim, ListUnpack) \
6868
_(prim, DictConstruct) \
69+
_(prim, EnumName) \
70+
_(prim, EnumValue) \
6971
_(prim, StringIndex) \
7072
_(prim, NumToTensor) \
7173
_(prim, Uninitialized) \

aten/src/ATen/core/jit_type.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1145,7 +1145,7 @@ struct CAFFE2_API EnumType : public NamedType {
11451145
AT_ERROR(
11461146
"Cannot create Enum with value type '",
11471147
value->str(),
1148-
"', only int, float, Tensor and string keys are supported");
1148+
"', only int, float and string are supported");
11491149
}
11501150
}
11511151

test/jit/test_enum.py

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,44 @@ def tearDown(self):
2323
if self.saved_enum_env_var:
2424
os.environ["EXPERIMENTAL_ENUM_SUPPORT"] = self.saved_enum_env_var
2525

26+
def test_enum_value_types(self):
27+
global IntEnum
28+
29+
class IntEnum(Enum):
30+
FOO = 1
31+
BAR = 2
32+
33+
global FloatEnum
34+
35+
class FloatEnum(Enum):
36+
FOO = 1.2
37+
BAR = 2.3
38+
39+
global StringEnum
40+
41+
class StringEnum(Enum):
42+
FOO = "foo as in foo bar"
43+
BAR = "bar as in foo bar"
44+
45+
def supported_enum_types(a: IntEnum, b: FloatEnum, c: StringEnum):
46+
return (a.name, b.name, c.name)
47+
# TODO(gmagogsfm): Re-enable hooks when serialization/deserialization
48+
# is supported.
49+
with torch._jit_internal._disable_emit_hooks():
50+
torch.jit.script(supported_enum_types)
51+
52+
global TensorEnum
53+
54+
class TensorEnum(Enum):
55+
FOO = torch.tensor(0)
56+
BAR = torch.tensor(1)
57+
58+
def unsupported_enum_types(a: TensorEnum):
59+
return a.name
60+
61+
with self.assertRaisesRegex(RuntimeError, "Cannot create Enum with value type 'Tensor'"):
62+
torch.jit.script(unsupported_enum_types)
63+
2664
def test_enum_comp(self):
2765
global Color
2866

@@ -33,7 +71,7 @@ class Color(Enum):
3371
def enum_comp(x: Color, y: Color) -> bool:
3472
return x == y
3573

36-
# TODO(gmagogsfm): Re-anble hooks when serialization/deserialization
74+
# TODO(gmagogsfm): Re-enable hooks when serialization/deserialization
3775
# is supported.
3876
with torch._jit_internal._disable_emit_hooks():
3977
scripted_enum_comp = torch.jit.script(enum_comp)
@@ -56,11 +94,47 @@ class Color(Enum):
5694
def enum_comp(x: Color, y: Color) -> bool:
5795
return x == y
5896

59-
# TODO(gmagogsfm): Re-anble hooks when serialization/deserialization
97+
# TODO(gmagogsfm): Re-enable hooks when serialization/deserialization
6098
# is supported.
6199
with self.assertRaisesRegex(RuntimeError, "Could not unify type list"):
62100
scripted_enum_comp = torch.jit.script(enum_comp)
63101

102+
def test_enum_name(self):
103+
global Color
104+
105+
class Color(Enum):
106+
RED = 1
107+
GREEN = 2
108+
109+
def enum_name(x: Color) -> str:
110+
return x.name
111+
112+
# TODO(gmagogsfm): Re-enable hooks when serialization/deserialization
113+
# is supported.
114+
with torch._jit_internal._disable_emit_hooks():
115+
scripted_enum_name = torch.jit.script(enum_name)
116+
117+
self.assertEqual(scripted_enum_name(Color.RED), Color.RED.name)
118+
self.assertEqual(scripted_enum_name(Color.GREEN), Color.GREEN.name)
119+
120+
def test_enum_value(self):
121+
global Color
122+
123+
class Color(Enum):
124+
RED = 1
125+
GREEN = 2
126+
127+
def enum_value(x: Color) -> int:
128+
return x.value
129+
130+
# TODO(gmagogsfm): Re-enable hooks when serialization/deserialization
131+
# is supported.
132+
with torch._jit_internal._disable_emit_hooks():
133+
scripted_enum_value = torch.jit.script(enum_value)
134+
135+
self.assertEqual(scripted_enum_value(Color.RED), Color.RED.value)
136+
self.assertEqual(scripted_enum_value(Color.GREEN), Color.GREEN.value)
137+
64138

65139
# Tests that Enum support features are properly guarded before they are mature.
66140
class TestEnumFeatureGuard(JitTestCase):

torch/csrc/jit/frontend/sugared_value.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,19 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
162162
if (auto schema = iface->getMethod(field)) {
163163
return std::make_shared<MethodValue>(getValue(), field);
164164
}
165+
} else if (auto enum_type = value_->type()->cast<EnumType>()) {
166+
// Handle access to Enum's `name` and `value` attribute.
167+
auto& g = *m.graph();
168+
169+
if (field == "name") {
170+
auto n = g.insertNode(g.createEnumName(value_));
171+
return std::make_shared<SimpleValue>(n->output());
172+
}
173+
174+
if (field == "value") {
175+
auto n = g.insertNode(g.createEnumValue(value_));
176+
return std::make_shared<SimpleValue>(n->output());
177+
}
165178
}
166179

167180
// none of the more-specific cases worked, so see if this is a builtin method

torch/csrc/jit/ir/ir.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1595,6 +1595,21 @@ Node* Graph::createTupleSlice(Value* tup, int64_t beg, int64_t end) {
15951595
return n;
15961596
}
15971597

1598+
Node* Graph::createEnumName(Value* e) {
1599+
e->type()->expect<EnumType>();
1600+
assert(e->type()->cast<EnumType>());
1601+
auto n = create(prim::EnumName, {e});
1602+
n->output()->setType(StringType::get());
1603+
return n;
1604+
}
1605+
1606+
Node* Graph::createEnumValue(Value* e) {
1607+
auto enum_type = e->type()->expect<EnumType>();
1608+
auto n = create(prim::EnumValue, {e});
1609+
n->output()->setType(enum_type->getValueType());
1610+
return n;
1611+
}
1612+
15981613
Node* Graph::createList(const TypePtr& elem_type, at::ArrayRef<Value*> values) {
15991614
auto n = create(prim::ListConstruct, values);
16001615
for (const auto& v : values) {

torch/csrc/jit/ir/ir.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,6 +1115,8 @@ struct Graph {
11151115
Value* idx,
11161116
const TypePtr& output_type);
11171117
TORCH_API Node* createTupleSlice(Value* tup, int64_t beg, int64_t end);
1118+
TORCH_API Node* createEnumName(Value* e);
1119+
TORCH_API Node* createEnumValue(Value* e);
11181120
TORCH_API Node* createList(
11191121
const TypePtr& elem_type,
11201122
at::ArrayRef<Value*> values);

torch/csrc/jit/runtime/register_prim_ops.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,34 @@ RegisterOperators reg(
311311
pack(stack, t.sizes().vec());
312312
},
313313
aliasAnalysisFromSchema()),
314+
Operator(
315+
"prim::EnumName(AnyEnumType enum) -> str",
316+
[](Stack* stack) {
317+
IValue e = pop(stack);
318+
push(stack, e.toEnumHolder()->name());
319+
},
320+
aliasAnalysisFromSchema()),
321+
Operator(
322+
"prim::EnumValue.int(AnyEnumType enum) -> int",
323+
[](Stack* stack) {
324+
IValue e = pop(stack);
325+
push(stack, e.toEnumHolder()->value());
326+
},
327+
aliasAnalysisFromSchema()),
328+
Operator(
329+
"prim::EnumValue.float(AnyEnumType enum) -> float",
330+
[](Stack* stack) {
331+
IValue e = pop(stack);
332+
push(stack, e.toEnumHolder()->value());
333+
},
334+
aliasAnalysisFromSchema()),
335+
Operator(
336+
"prim::EnumValue.str(AnyEnumType enum) -> str",
337+
[](Stack* stack) {
338+
IValue e = pop(stack);
339+
push(stack, e.toEnumHolder()->value());
340+
},
341+
aliasAnalysisFromSchema()),
314342
Operator(
315343
// note the compiler knows to type TupleIndex more accurately than it
316344
// is listed here.

0 commit comments

Comments
 (0)