Skip to content

Commit c642907

Browse files
committed
[1/N] Implement Enum JIT support
* Enum support tempoartily hidden behind environemnt variable EXPERIMENTAL_ENUM_SUPPORT to avoid misuse * Add EnumType and AnyEnumType as first-class jit type * Add Enum-typed IValue * Enhanced aten::eq to support Enum Supported: Enum-typed function arguments using Enum type and comparing them TODO: Add PyThon sugared value for Enum Support getting name/value attrs of enums Support Enum-typed return values Support enum values of different types in same Enum class Support serialization and deserialization
1 parent 349c405 commit c642907

File tree

18 files changed

+408
-26
lines changed

18 files changed

+408
-26
lines changed

aten/src/ATen/core/ivalue.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ TypePtr IValue::type() const {
9696
return toTuple()->type();
9797
case Tag::Generator:
9898
return GeneratorType::get();
99+
case Tag::Enum:
100+
// TODO(gmagogsfm): Implement this.
101+
TORCH_INTERNAL_ASSERT(false, "To be implemented");
99102
}
100103
// switch above is complete but this silences compiler warnings
101104
TORCH_INTERNAL_ASSERT(false, "unhandled case in IValue::type()");
@@ -264,6 +267,8 @@ IValue IValue::equals(const IValue& rhs) const {
264267
case Tag::Capsule:
265268
case Tag::Generator:
266269
return ptrEqual(lhs, rhs);
270+
case Tag::Enum:
271+
return lhs.toEnumHolder()->is(*rhs.toEnumHolder());
267272
case Tag::Uninitialized:
268273
// Unitialized ivalues show up in no-ops when the compiler can prove a
269274
// value will never be used. Just return false on any equality comparison.
@@ -501,6 +506,11 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
501506
// print this out the way python would do it
502507
return out << "<" << obj->name() << " object at " << obj.get() << ">";
503508
}
509+
case IValue::Tag::Enum:
510+
auto enum_holder = v.toEnumHolder();
511+
return out << "Enum<" << enum_holder->qualifiedClassName() << "." <<
512+
enum_holder->name() << ">";
513+
504514
}
505515
AT_ERROR("Tag not found: ", v.tagKind());
506516
}

aten/src/ATen/core/ivalue.h

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ struct ConstantString;
3737
struct GenericDict;
3838
struct Object;
3939
struct PyObjectHolder;
40+
struct EnumHolder;
4041
}
4142

4243
// This is an owning wrapper for a c10::optional<std::vector<T>>
@@ -77,24 +78,25 @@ struct OptionalArray {
7778
// retain/release calls.
7879

7980
#define TORCH_FORALL_TAGS(_) \
80-
_(None) \
81-
_(Tensor) \
82-
_(Double) \
83-
_(Int) \
84-
_(Bool) \
85-
_(Tuple) \
86-
_(String) \
87-
_(Blob) \
88-
_(GenericList) \
89-
_(GenericDict) \
90-
_(Future) \
91-
_(Device) \
92-
_(Object) \
93-
_(PyObject) \
94-
_(Uninitialized) \
95-
_(Capsule) \
96-
_(RRef) \
97-
_(Generator) \
81+
_(None) \
82+
_(Tensor) \
83+
_(Double) \
84+
_(Int) \
85+
_(Bool) \
86+
_(Tuple) \
87+
_(String) \
88+
_(Blob) \
89+
_(GenericList) \
90+
_(GenericDict) \
91+
_(Future) \
92+
_(Device) \
93+
_(Object) \
94+
_(PyObject) \
95+
_(Uninitialized) \
96+
_(Capsule) \
97+
_(RRef) \
98+
_(Generator) \
99+
_(Enum) \
98100

99101
// [doxygen private]
100102
// These methods are not actually private but we don't want to document them, so
@@ -407,13 +409,13 @@ struct CAFFE2_API IValue final {
407409
c10::List<bool> toBoolList() &&;
408410
c10::List<bool> toBoolList() const &;
409411

410-
//TensorList
412+
// TensorList
411413
bool isTensorList() const;
412414
c10::List<at::Tensor> toTensorList() &&;
413415
c10::List<at::Tensor> toTensorList() const &;
414416
std::vector<at::Tensor> toTensorVector() const;
415417

416-
//GenericList
418+
// GenericList
417419
IValue(c10::List<IValue> v);
418420
bool isList() const { return Tag::GenericList == tag; }
419421
c10::List<IValue> toList() &&;
@@ -479,6 +481,12 @@ struct CAFFE2_API IValue final {
479481
c10::intrusive_ptr<ivalue::PyObjectHolder> toPyObjectHolder() const &;
480482
PyObject* toPyObject() const;
481483

484+
// Enum
485+
explicit IValue(c10::intrusive_ptr<ivalue::EnumHolder> v);
486+
bool isEnum() const { return tag == Tag::Enum; }
487+
c10::intrusive_ptr<ivalue::EnumHolder> toEnumHolder() &&;
488+
c10::intrusive_ptr<ivalue::EnumHolder> toEnumHolder() const &;
489+
482490
// None
483491
IValue() : payload{0}, tag(Tag::None), is_intrusive_ptr(false) {}
484492
bool isNone() const {

aten/src/ATen/core/ivalue_inl.h

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
#include <c10/core/Scalar.h>
99
#include <c10/core/TensorImpl.h>
1010
#include <c10/core/UndefinedTensorImpl.h>
11+
#include <c10/util/intrusive_ptr.h>
1112
#include <ATen/core/Dict.h>
1213
#include <ATen/core/List.h>
14+
#include <ATen/core/qualified_name.h>
1315
#include <ATen/core/rref_interface.h>
1416

1517
namespace torch {
@@ -96,13 +98,21 @@ inline c10::intrusive_ptr<ivalue::Object> IValue::toObject() const & {
9698
return toIntrusivePtr<ivalue::Object>();
9799
}
98100
inline c10::intrusive_ptr<ivalue::PyObjectHolder> IValue::toPyObjectHolder() && {
99-
TORCH_INTERNAL_ASSERT(isPyObject(), "Expected PyObject but got", tagKind());
101+
TORCH_INTERNAL_ASSERT(isPyObject(), "Expected PyObject but got ", tagKind());
100102
return moveToIntrusivePtr<ivalue::PyObjectHolder>();
101103
}
102104
inline c10::intrusive_ptr<ivalue::PyObjectHolder> IValue::toPyObjectHolder() const & {
103-
TORCH_INTERNAL_ASSERT(isPyObject(), "Expected PyObject but got", tagKind());
105+
TORCH_INTERNAL_ASSERT(isPyObject(), "Expected PyObject but got ", tagKind());
104106
return toIntrusivePtr<ivalue::PyObjectHolder>();
105107
}
108+
inline c10::intrusive_ptr<ivalue::EnumHolder> IValue::toEnumHolder() && {
109+
TORCH_INTERNAL_ASSERT(isEnum(), "Expected Enum but got ", tagKind());
110+
return moveToIntrusivePtr<ivalue::EnumHolder>();
111+
}
112+
inline c10::intrusive_ptr<ivalue::EnumHolder> IValue::toEnumHolder() const & {
113+
TORCH_INTERNAL_ASSERT(isEnum(), "Expected Enum but got ", tagKind());
114+
return toIntrusivePtr<ivalue::EnumHolder>();
115+
}
106116
inline at::Tensor IValue::toTensor() && {
107117
AT_ASSERT(isTensor(), "Expected Tensor but got ", tagKind());
108118
return at::Tensor(moveToIntrusivePtr<at::TensorImpl, at::UndefinedTensorImpl>());
@@ -216,6 +226,7 @@ struct CAFFE2_API Tuple : c10::intrusive_ptr_target {
216226

217227
struct Object;
218228
struct PyObjectHolder;
229+
struct EnumHolder;
219230
}
220231

221232
// Future
@@ -524,6 +535,39 @@ struct ivalue::PyObjectHolder : c10::intrusive_ptr_target {
524535
virtual ~PyObjectHolder() {};
525536
};
526537

538+
struct ivalue::EnumHolder : c10::intrusive_ptr_target {
539+
public:
540+
EnumHolder(c10::QualifiedName qualified_class_name, std::string name, IValue value)
541+
: qualified_class_name_(std::move(qualified_class_name)),
542+
name_(std::move(name)), value_(std::move(value)) {}
543+
544+
bool is(const ivalue::EnumHolder& rhs) {
545+
return *this == rhs;
546+
}
547+
548+
bool operator==(const ivalue::EnumHolder& o) const {
549+
return qualified_class_name_ == o.qualifiedClassName() &&
550+
name_ == o.name() && value_ == o.value();
551+
}
552+
553+
const std::string& qualifiedClassName() const {
554+
return qualified_class_name_.qualifiedName();
555+
}
556+
557+
const std::string& name() const {
558+
return name_;
559+
}
560+
561+
const IValue& value() const {
562+
return value_;
563+
}
564+
565+
private:
566+
c10::QualifiedName qualified_class_name_;
567+
std::string name_;
568+
IValue value_;
569+
};
570+
527571
#undef TORCH_FORALL_TAGS
528572

529573
namespace detail {
@@ -930,10 +974,17 @@ inline IValue::IValue(c10::intrusive_ptr<ivalue::Object> v)
930974
: tag(Tag::Object), is_intrusive_ptr(true) {
931975
payload.as_intrusive_ptr = v.release();
932976
}
977+
933978
inline IValue::IValue(c10::intrusive_ptr<ivalue::PyObjectHolder> v)
934979
: tag(Tag::PyObject), is_intrusive_ptr(true) {
935980
payload.as_intrusive_ptr = v.release();
936981
}
982+
983+
inline IValue::IValue(c10::intrusive_ptr<ivalue::EnumHolder> v)
984+
: tag(Tag::Enum), is_intrusive_ptr(true) {
985+
payload.as_intrusive_ptr = v.release();
986+
}
987+
937988
inline IValue IValue::make_capsule(intrusive_ptr<torch::CustomClassHolder> blob) {
938989
IValue iv;
939990
iv.tag = Tag::Capsule;
@@ -974,6 +1025,7 @@ inline const std::string& IValue::toStringRef() const {
9741025
inline PyObject* IValue::toPyObject() const {
9751026
return toPyObjectHolder()->getPyObject();
9761027
}
1028+
9771029
template<typename T>
9781030
inline optional<T> IValue::toOptional() {
9791031
if (this->isNone()) {

aten/src/ATen/core/jit_type.h

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ using OptNameList = c10::optional<std::vector<std::string>>;
2828

2929
#define C10_FORALL_TYPES(_) \
3030
_(AnyType) \
31+
_(EnumType) \
32+
_(AnyEnumType) \
3133
_(TensorType) \
3234
_(TupleType) \
3335
_(ListType) \
@@ -1025,7 +1027,8 @@ struct CAFFE2_API NamedType : public Type {
10251027
: Type(tk), name_(std::move(name)) {
10261028
TORCH_INTERNAL_ASSERT(
10271029
tk == TypeKind::TupleType || tk == TypeKind::FunctionType ||
1028-
tk == TypeKind::ClassType || tk == TypeKind::InterfaceType,
1030+
tk == TypeKind::ClassType || tk == TypeKind::InterfaceType ||
1031+
tk == TypeKind::EnumType,
10291032
"If you add a new kind of NamedType, ",
10301033
"please update the cast<NamedType> specialization and this assert");
10311034
}
@@ -1124,6 +1127,95 @@ struct CAFFE2_API TupleType : public NamedType {
11241127
std::shared_ptr<FunctionSchema> schema_;
11251128
};
11261129

1130+
struct EnumType;
1131+
using EnumTypePtr = std::shared_ptr<EnumType>;
1132+
struct CAFFE2_API EnumType : public NamedType {
1133+
friend struct Type;
1134+
static const TypeKind Kind = TypeKind::EnumType;
1135+
1136+
static EnumTypePtr create(
1137+
const c10::QualifiedName& qualified_name,
1138+
TypePtr value, std::weak_ptr<::torch::jit::CompilationUnit> cu) {
1139+
switch (value->kind()) {
1140+
case TypeKind::IntType:
1141+
case TypeKind::FloatType:
1142+
case TypeKind::StringType:
1143+
return EnumTypePtr(new EnumType(qualified_name, value, cu));
1144+
default:
1145+
AT_ERROR(
1146+
"Cannot create Enum with value type '",
1147+
value->str(),
1148+
"', only int, float, Tensor and string keys are supported");
1149+
}
1150+
}
1151+
1152+
std::string str() const override {
1153+
return "Enum<" + annotation_str() + ">";
1154+
}
1155+
1156+
std::string repr_str() const override {
1157+
return str();
1158+
}
1159+
1160+
TypePtr getValueType() const {
1161+
return value_type_;
1162+
}
1163+
1164+
bool operator==(const Type& rhs) const override {
1165+
if (auto enum_rhs = rhs.cast<EnumType>()) {
1166+
return name().value() == enum_rhs->name().value() &&
1167+
*getValueType() == *(enum_rhs->getValueType()) &&
1168+
this->compilation_unit() == enum_rhs->compilation_unit();
1169+
}
1170+
return false;
1171+
}
1172+
1173+
bool isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const override;
1174+
1175+
std::shared_ptr<const ::torch::jit::CompilationUnit> compilation_unit() const {
1176+
auto cu = cu_.lock();
1177+
return cu;
1178+
}
1179+
1180+
private:
1181+
EnumType(c10::QualifiedName name, TypePtr value_type, std::weak_ptr<torch::jit::CompilationUnit> cu)
1182+
: NamedType(TypeKind::EnumType, std::move(name)),
1183+
value_type_(std::move(value_type)), cu_(cu) {}
1184+
1185+
std::string annotation_str_impl(TypePrinter printer = nullptr) const override {
1186+
const auto& n = name().value();
1187+
return n.qualifiedName();
1188+
}
1189+
1190+
TypePtr value_type_;
1191+
std::weak_ptr<::torch::jit::CompilationUnit> cu_;
1192+
};
1193+
1194+
1195+
// the common supertype of all Enums, only used in operator registraion.
1196+
// EnumType <: AnyEnumType for all Enums
1197+
struct AnyEnumType;
1198+
using AnyEnumTypePtr = std::shared_ptr<AnyEnumType>;
1199+
struct CAFFE2_API AnyEnumType : public Type {
1200+
static AnyEnumTypePtr create() {
1201+
return AnyEnumTypePtr(
1202+
new AnyEnumType()); // NOLINT(modernize-make-shared)
1203+
}
1204+
bool operator==(const Type& rhs) const override {
1205+
return rhs.kind() == kind();
1206+
}
1207+
std::string str() const override {
1208+
return "AnyEnumType";
1209+
}
1210+
static const TypeKind Kind = TypeKind::AnyEnumType;
1211+
// global singleton
1212+
static AnyEnumTypePtr get();
1213+
private:
1214+
AnyEnumType()
1215+
: Type(TypeKind::AnyEnumType) {}
1216+
};
1217+
1218+
11271219
struct NumberType;
11281220
using NumberTypePtr = std::shared_ptr<NumberType>;
11291221
// This type represents a Python number

aten/src/ATen/core/type.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,11 @@ AnyClassTypePtr AnyClassType::get() {
193193
return value;
194194
}
195195

196+
AnyEnumTypePtr AnyEnumType::get() {
197+
static auto value = AnyEnumType::create();
198+
return value;
199+
}
200+
196201
c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2) {
197202
// check direct subtyping relation
198203
if (t1->isSubtypeOf(t2)) {
@@ -1366,4 +1371,9 @@ SymbolicShape SymbolicShape::merge(const SymbolicShape& other) const {
13661371
return SymbolicShape(std::move(dims));
13671372
}
13681373

1374+
bool EnumType::isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const {
1375+
return rhs->kind() == TypeKind::AnyType ||
1376+
rhs->kind() == TypeKind::AnyEnumType || *this == *rhs;
1377+
}
1378+
13691379
} // namespace c10

0 commit comments

Comments
 (0)