Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ namespace c10 {
_(prim, ListConstruct) \
_(prim, ListUnpack) \
_(prim, DictConstruct) \
_(prim, EnumName) \
_(prim, EnumValue) \
_(prim, StringIndex) \
_(prim, NumToTensor) \
_(prim, Uninitialized) \
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/core/jit_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -1145,7 +1145,7 @@ struct CAFFE2_API EnumType : public NamedType {
AT_ERROR(
"Cannot create Enum with value type '",
value->str(),
"', only int, float, Tensor and string keys are supported");
"', only int, float and string are supported");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're not including bool as a type here, but we are registering a kernel for it

}
}

Expand Down
78 changes: 76 additions & 2 deletions test/jit/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,44 @@ def tearDown(self):
if self.saved_enum_env_var:
os.environ["EXPERIMENTAL_ENUM_SUPPORT"] = self.saved_enum_env_var

def test_enum_value_types(self):
global IntEnum

class IntEnum(Enum):
FOO = 1
BAR = 2

global FloatEnum

class FloatEnum(Enum):
FOO = 1.2
BAR = 2.3

global StringEnum

class StringEnum(Enum):
FOO = "foo as in foo bar"
BAR = "bar as in foo bar"

def supported_enum_types(a: IntEnum, b: FloatEnum, c: StringEnum):
return (a.name, b.name, c.name)
# TODO(gmagogsfm): Re-enable hooks when serialization/deserialization
# is supported.
with torch._jit_internal._disable_emit_hooks():
torch.jit.script(supported_enum_types)

global TensorEnum

class TensorEnum(Enum):
FOO = torch.tensor(0)
BAR = torch.tensor(1)

def unsupported_enum_types(a: TensorEnum):
return a.name

with self.assertRaisesRegex(RuntimeError, "Cannot create Enum with value type 'Tensor'"):
torch.jit.script(unsupported_enum_types)

def test_enum_comp(self):
global Color

Expand All @@ -33,7 +71,7 @@ class Color(Enum):
def enum_comp(x: Color, y: Color) -> bool:
return x == y

# TODO(gmagogsfm): Re-anble hooks when serialization/deserialization
# TODO(gmagogsfm): Re-enable hooks when serialization/deserialization
# is supported.
with torch._jit_internal._disable_emit_hooks():
scripted_enum_comp = torch.jit.script(enum_comp)
Expand All @@ -56,11 +94,47 @@ class Color(Enum):
def enum_comp(x: Color, y: Color) -> bool:
return x == y

# TODO(gmagogsfm): Re-anble hooks when serialization/deserialization
# TODO(gmagogsfm): Re-enable hooks when serialization/deserialization
# is supported.
with self.assertRaisesRegex(RuntimeError, "Could not unify type list"):
scripted_enum_comp = torch.jit.script(enum_comp)

def test_enum_name(self):
global Color

class Color(Enum):
RED = 1
GREEN = 2

def enum_name(x: Color) -> str:
return x.name

# TODO(gmagogsfm): Re-enable hooks when serialization/deserialization
# is supported.
with torch._jit_internal._disable_emit_hooks():
scripted_enum_name = torch.jit.script(enum_name)

self.assertEqual(scripted_enum_name(Color.RED), Color.RED.name)
self.assertEqual(scripted_enum_name(Color.GREEN), Color.GREEN.name)

def test_enum_value(self):
global Color

class Color(Enum):
RED = 1
GREEN = 2

def enum_value(x: Color) -> int:
return x.value

# TODO(gmagogsfm): Re-enable hooks when serialization/deserialization
# is supported.
with torch._jit_internal._disable_emit_hooks():
scripted_enum_value = torch.jit.script(enum_value)

self.assertEqual(scripted_enum_value(Color.RED), Color.RED.value)
self.assertEqual(scripted_enum_value(Color.GREEN), Color.GREEN.value)


# Tests that Enum support features are properly guarded before they are mature.
class TestEnumFeatureGuard(JitTestCase):
Expand Down
13 changes: 13 additions & 0 deletions torch/csrc/jit/frontend/sugared_value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,19 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
if (auto schema = iface->getMethod(field)) {
return std::make_shared<MethodValue>(getValue(), field);
}
} else if (auto enum_type = value_->type()->cast<EnumType>()) {
// Handle access to Enum's `name` and `value` attribute.
auto& g = *m.graph();

if (field == "name") {
auto n = g.insertNode(g.createEnumName(value_));
return std::make_shared<SimpleValue>(n->output());
}

if (field == "value") {
auto n = g.insertNode(g.createEnumValue(value_));
return std::make_shared<SimpleValue>(n->output());
}
}

// none of the more-specific cases worked, so see if this is a builtin method
Expand Down
15 changes: 15 additions & 0 deletions torch/csrc/jit/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1595,6 +1595,21 @@ Node* Graph::createTupleSlice(Value* tup, int64_t beg, int64_t end) {
return n;
}

Node* Graph::createEnumName(Value* e) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't actually need this custom create logic, bc there is only one possible schema for prim::EnumName, so you can use the normal schema driven insert logic:

https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/ir/ir.h#L1181

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, but it is actually nice to have a method to create it, more discoverable and we can enforce type checking here.

e->type()->expect<EnumType>();
assert(e->type()->cast<EnumType>());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove assert and use -type()->expect

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, thanks. Didn't know that. Done.

auto n = create(prim::EnumName, {e});
n->output()->setType(StringType::get());
return n;
}

Node* Graph::createEnumValue(Value* e) {
auto enum_type = e->type()->expect<EnumType>();
auto n = create(prim::EnumValue, {e});
n->output()->setType(enum_type->getValueType());
return n;
}

Node* Graph::createList(const TypePtr& elem_type, at::ArrayRef<Value*> values) {
auto n = create(prim::ListConstruct, values);
for (const auto& v : values) {
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -1115,6 +1115,8 @@ struct Graph {
Value* idx,
const TypePtr& output_type);
TORCH_API Node* createTupleSlice(Value* tup, int64_t beg, int64_t end);
TORCH_API Node* createEnumName(Value* e);
TORCH_API Node* createEnumValue(Value* e);
TORCH_API Node* createList(
const TypePtr& elem_type,
at::ArrayRef<Value*> values);
Expand Down
28 changes: 28 additions & 0 deletions torch/csrc/jit/runtime/register_prim_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,34 @@ RegisterOperators reg(
pack(stack, t.sizes().vec());
},
aliasAnalysisFromSchema()),
Operator(
"prim::EnumName(AnyEnumType enum) -> str",
[](Stack* stack) {
IValue e = pop(stack);
push(stack, e.toEnumHolder()->name());
},
aliasAnalysisFromSchema()),
Operator(
"prim::EnumValue.int(AnyEnumType enum) -> int",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is at least worth a discussion, because the duplication of the kernel has some binary implications. However, I think not-standard ops (unschematized ops) have some other complications around mobile and stuff so I think this is probably worth it. @ljk53 what do you think ?

@SplitInfinity

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the "binary implication" you are referring to with dup kernels?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each operator increases the size of the mobile build (if they're included in a mobile model). In this case, we could unify all of the prim::EnumValue operators but at the cost of having to make it a non-standard op where the schema doesnt express its true types. I think our current decision is good.

[](Stack* stack) {
IValue e = pop(stack);
push(stack, e.toEnumHolder()->value());
},
aliasAnalysisFromSchema()),
Operator(
"prim::EnumValue.float(AnyEnumType enum) -> float",
[](Stack* stack) {
IValue e = pop(stack);
push(stack, e.toEnumHolder()->value());
},
aliasAnalysisFromSchema()),
Operator(
"prim::EnumValue.str(AnyEnumType enum) -> str",
[](Stack* stack) {
IValue e = pop(stack);
push(stack, e.toEnumHolder()->value());
},
aliasAnalysisFromSchema()),
Operator(
// note the compiler knows to type TupleIndex more accurately than it
// is listed here.
Expand Down