Skip to content

Commit cc06d8d

Browse files
committed
[jit] register __getitem__ builtin
Summary: Follow up of #21990, I am switching index select operations to a standard __getitem__ builtin, rather than bunch of different builtins according to the type, such as prim::DictIndex, prim::ListIndex, etc. This will also aligned with the some other magic methods that we already use gh-metadata: pytorch pytorch 22276 gh/wanchaol/28/head
1 parent cb24b61 commit cc06d8d

File tree

11 files changed

+187
-211
lines changed

11 files changed

+187
-211
lines changed

aten/src/ATen/core/interned_strings.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ namespace c10 {
6060
_(prim, ListConstruct) \
6161
_(prim, ListUnpack) \
6262
_(prim, DictConstruct) \
63-
_(prim, DictIndex) \
6463
_(prim, StringIndex) \
6564
_(prim, NumToTensor) \
6665
_(prim, Uninitialized) \
@@ -132,6 +131,7 @@ namespace c10 {
132131
_(aten, ne_) \
133132
_(aten, transpose_) \
134133
_(aten, unsqueeze_) \
134+
_(aten, __getitem__) \
135135
_(aten, _set_item) \
136136
_(aten, manual_seed) \
137137
_(aten, set_) \
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
def empty_int_list_test(y: Tensor) -> int:
22
x = annotate(List[int], [])
3-
return torch.select(x, 0)
3+
return x[0]

test/test_jit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10389,7 +10389,7 @@ def test_no_type():
1038910389
return torch.jit._unwrap_optional(None)
1039010390

1039110391
def test_indexing_error(self):
10392-
with self.assertRaisesRegex(RuntimeError, "only supported on List, Dict, Tensor, Tuple, and str"):
10392+
with self.assertRaisesRegex(RuntimeError, "'int' object is not subscriptable"):
1039310393
@torch.jit.script
1039410394
def test_wrong_type():
1039510395
a = 8
@@ -11688,7 +11688,7 @@ def foo(a):
1168811688
i += 1
1168911689
return b
1169011690

11691-
FileCheck().check("prim::Loop").check_not("aten::rand").check("aten::select") \
11691+
FileCheck().check("prim::Loop").check_not("aten::rand").check("aten::__getitem__") \
1169211692
.check_count("aten::rand", 1, exactly=True).run(str(foo.graph))
1169311693

1169411694
def test_mutable_dce_wildcards(self):

torch/csrc/jit/ir.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,15 +1391,6 @@ Node* Graph::createDict(
13911391
return n;
13921392
}
13931393

1394-
Node* Graph::createDictIndex(Value* dict, Value* index) {
1395-
auto dict_type = dict->type()->expect<DictType>();
1396-
AT_ASSERT(index->type()->isSubtypeOf(dict_type->getKeyType()));
1397-
1398-
auto n = create(prim::DictIndex, {dict, index});
1399-
n->output()->setType(dict_type->getValueType());
1400-
return n;
1401-
}
1402-
14031394
Node* Graph::createNumToTensor(Value* value) {
14041395
auto typ = value->type();
14051396
Node* result = create(prim::NumToTensor, {value});

torch/csrc/jit/ir.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1080,7 +1080,6 @@ struct Graph {
10801080
const TypePtr& value_type,
10811081
at::ArrayRef<Value*> keys,
10821082
at::ArrayRef<Value*> values);
1083-
TORCH_API Node* createDictIndex(Value* dict, Value* index);
10841083
TORCH_API Node* createNumToTensor(Value* value);
10851084
TORCH_API Node* createImplicitTensorToNum(const TypePtr& type, Value* value);
10861085
TORCH_API Node* createObject(const ClassTypePtr& type);

torch/csrc/jit/passes/alias_analysis.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,6 @@ void AliasDb::analyzeImpl(Node* node) {
275275
return analyzeContainerConstruct(node);
276276
case prim::TupleUnpack:
277277
case prim::TupleIndex:
278-
case prim::DictIndex:
279278
case prim::TupleSlice:
280279
case prim::ListUnpack:
281280
case prim::PythonOp:
@@ -1118,7 +1117,6 @@ bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) {
11181117
prim::Function,
11191118
prim::TupleUnpack,
11201119
prim::TupleIndex,
1121-
prim::DictIndex,
11221120
prim::TupleSlice,
11231121
prim::ListUnpack,
11241122
prim::PythonOp,

torch/csrc/jit/passes/python_print.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,18 @@ struct PythonPrintPass {
582582
stmt << end;
583583
}
584584

585+
void printValueIndex(TaggedStringStream& stmt, at::ArrayRef<Value*> inputs) {
586+
const std::string val_name = useOf(inputs[0])->str();
587+
if (isValidIdentifier(val_name)) {
588+
stmt << val_name;
589+
} else {
590+
stmt << "(" << val_name << ")";
591+
}
592+
stmt << "[";
593+
stmt << useOf(inputs[1]);
594+
stmt << "]";
595+
}
596+
585597
void printDict(
586598
TaggedStringStream& stmt,
587599
at::ArrayRef<Value*> key_value_pairs,
@@ -1014,6 +1026,9 @@ struct PythonPrintPass {
10141026
case aten::str: {
10151027
printValueList(stmt, node->inputs(), "str(", ")");
10161028
} break;
1029+
case aten::__getitem__: {
1030+
printValueIndex(stmt, node->inputs());
1031+
} break;
10171032
case prim::Print: {
10181033
printValueList(stmt, node->inputs(), "print(", ")");
10191034
} break;
@@ -1059,10 +1074,6 @@ struct PythonPrintPass {
10591074
printDict(stmt, node->inputs());
10601075
}
10611076
} break;
1062-
case prim::DictIndex: {
1063-
stmt << "(" << useOf(node->inputs().at(0)) << ")["
1064-
<< useOf(node->inputs().at(1)) << "]";
1065-
} break;
10661077
case prim::CreateObject: {
10671078
const auto classType = node->output()->type()->expect<ClassType>();
10681079
stmt << classType->python_str() << ".__new__("
@@ -1337,7 +1348,6 @@ bool printerHasSpecialCaseFor(Symbol sym) {
13371348
prim::PythonOp,
13381349
prim::TupleConstruct,
13391350
prim::TupleIndex,
1340-
prim::DictIndex,
13411351
prim::TupleSlice,
13421352
prim::TupleUnpack,
13431353
prim::CreateObject,

torch/csrc/jit/register_prim_ops.cpp

Lines changed: 104 additions & 88 deletions
Large diffs are not rendered by default.

torch/csrc/jit/script/compiler.cpp

Lines changed: 22 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,7 +1312,7 @@ struct to_ir {
13121312

13131313
// if the FOR iters and targets are present, emit FOR target assignments
13141314
if (iter_val != nullptr && targets) {
1315-
Value* cur_elem = iter_val->getelem(range, method, trip_count);
1315+
Value* cur_elem = iter_val->getitem(range, method, trip_count);
13161316
SugaredValuePtr sv = std::make_shared<SimpleValue>(cur_elem);
13171317
List<Expr> target_exprs = targets.value();
13181318
validateAssignLhsExpr(target_exprs, range);
@@ -1651,7 +1651,7 @@ struct to_ir {
16511651
NamedValue(stmt.rhs().range(), "value", emitExpr(stmt.rhs()));
16521652

16531653
const auto getItem =
1654-
graph->insert(aten::select, {listArg, idxArg}, {}, stmt.range());
1654+
graph->insert(aten::__getitem__, {listArg, idxArg}, {}, stmt.range());
16551655
const auto augmentedItem = graph->insert(
16561656
getAugOp(stmt, elementType), {getItem, valueArg}, {}, stmt.range());
16571657
graph->insert(
@@ -2808,24 +2808,6 @@ struct to_ir {
28082808
->output();
28092809
}
28102810

2811-
Value* emitDictIndex(
2812-
const SourceRange& loc,
2813-
Value* dict_val,
2814-
Value* key_val) {
2815-
auto dict_type = dict_val->type()->cast<DictType>();
2816-
2817-
if (!key_val->type()->isSubtypeOf(dict_type->getKeyType())) {
2818-
throw ErrorReport(loc)
2819-
<< "Expected key type '" << key_val->type()->python_str()
2820-
<< "' to subtype the key type '"
2821-
<< dict_type->getKeyType()->python_str() << "' of the dict '"
2822-
<< dict_type->python_str() << "'";
2823-
}
2824-
2825-
return graph->insertNode(graph->createDictIndex(dict_val, key_val))
2826-
->output();
2827-
}
2828-
28292811
int64_t getSliceInd(Value* idx_val, const SourceRange& loc) {
28302812
auto ivalue = toIValue(idx_val);
28312813
if (ivalue && ivalue->isInt()) {
@@ -2864,60 +2846,30 @@ struct to_ir {
28642846
}
28652847

28662848
Value* emitSubscript(const Subscript& subscript) {
2867-
return emitSubscript(
2868-
subscript.range(),
2869-
emitExpr(subscript.value()),
2870-
subscript.subscript_exprs());
2871-
}
2872-
2873-
Value* emitSubscript(
2874-
const SourceRange& loc,
2875-
Value* sliceable,
2876-
const List<Expr>& subscript_exprs) {
2849+
const SugaredValuePtr sv = emitSugaredExpr(subscript.value(), 1);
2850+
const List<Expr>& subscript_exprs = subscript.subscript_exprs();
2851+
const SourceRange& range = subscript.range();
2852+
const SourceRange& val_range = subscript.value().range();
28772853
if (subscript_exprs.size() != 1) {
2878-
return emitMultidimSlicing(loc, sliceable, subscript_exprs);
2854+
return emitMultidimSlicing(
2855+
range, sv->asValue(val_range, method), subscript_exprs);
28792856
}
28802857
if (subscript_exprs[0].kind() == TK_SLICE_EXPR) {
2881-
return emitBasicSlice(loc, sliceable, subscript_exprs);
2858+
return emitBasicSlice(
2859+
range, sv->asValue(val_range, method), subscript_exprs);
28822860
} else {
2883-
return emitBasicGather(loc, sliceable, subscript_exprs);
2884-
}
2885-
}
2886-
2887-
// Desugars gather syntactic sugar foo[i]
2888-
Value* emitBasicGather(
2889-
const SourceRange& loc,
2890-
Value* gatherable,
2891-
const List<Expr>& subscript_exprs) {
2892-
AT_ASSERT(subscript_exprs.size() == 1);
2893-
2894-
if (gatherable->type()->kind() == TypeKind::ListType) {
2895-
// if it's a list, emit a regular index selection op
2896-
auto* idx = emitExpr(subscript_exprs[0]);
2897-
return emitBuiltinCall(
2898-
loc, *graph, aten::select, c10::nullopt, {gatherable, idx}, {}, true);
2899-
} else if (gatherable->type()->isSubtypeOf(TensorType::get())) {
2900-
return emitMultidimSlicing(loc, gatherable, subscript_exprs);
2901-
} else if (auto tuple_type = gatherable->type()->cast<TupleType>()) {
2902-
auto* idx = emitExpr(subscript_exprs[0]);
2903-
return emitTupleIndex(loc, gatherable, idx);
2904-
} else if (auto dict_type = gatherable->type()->cast<DictType>()) {
2905-
auto* idx = emitExpr(subscript_exprs[0]);
2906-
return emitDictIndex(loc, gatherable, idx);
2907-
} else if (auto string_type = gatherable->type()->cast<StringType>()) {
2908-
auto* idx = emitExpr(subscript_exprs[0]);
2909-
return emitBuiltinCall(
2910-
loc,
2911-
*graph,
2912-
prim::StringIndex,
2913-
c10::nullopt,
2914-
{gatherable, idx},
2915-
{},
2916-
true);
2917-
} else {
2918-
throw ErrorReport(loc) << "Indexing only supported on List, Dict, "
2919-
"Tensor, Tuple, and str but got type '"
2920-
<< gatherable->type()->python_str() << "'";
2861+
// Desugars gather syntactic sugar foo[i]
2862+
Value* idx = emitExpr(subscript_exprs[0]);
2863+
Value* val = sv->asValue(val_range, method);
2864+
AT_ASSERT(subscript_exprs.size() == 1);
2865+
2866+
if (val->type()->cast<TupleType>()) {
2867+
return emitTupleIndex(range, sv->asValue(val_range, method), idx);
2868+
} else if (val->type()->isSubtypeOf(TensorType::get())) {
2869+
return emitMultidimSlicing(range, val, subscript_exprs);
2870+
} else {
2871+
return sv->getitem(range, method, idx);
2872+
}
29212873
}
29222874
}
29232875
};

torch/csrc/jit/script/sugared_value.cpp

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -275,20 +275,29 @@ Value* SimpleValue::len(const SourceRange& loc, Function& m) {
275275
}
276276
}
277277

278-
Value* SimpleValue::getelem(const SourceRange&loc, Function& m, Value* i) {
278+
Value* SimpleValue::getitem(const SourceRange& loc, Function& m, Value* idx) {
279279
Value* val = getValue();
280280
TypePtr val_type = val->type();
281281
Graph& g = *m.graph();
282282
Value* cur_elem = nullptr;
283-
if (val_type->cast<ListType>()) {
284-
cur_elem = g.insert(aten::select, {val, i}, {}, loc);
285-
} else if (val_type->cast<StringType>()) {
286-
cur_elem = g.insert(prim::StringIndex, {val, i}, {}, loc);
283+
284+
// if it's a List/String/Dict, emit a regular __getitem__ op
285+
if (val_type->cast<ListType>() || val_type->cast<StringType>()) {
286+
cur_elem = g.insert(aten::__getitem__, {val, idx}, {}, loc);
287+
} else if (auto dict_type = val_type->cast<DictType>()) {
288+
if (!idx->type()->isSubtypeOf(dict_type->getKeyType())) {
289+
throw ErrorReport(loc)
290+
<< "Expected key type '" << idx->type()->python_str()
291+
<< "' to subtype the key type '"
292+
<< dict_type->getKeyType()->python_str() << "' of the dict '"
293+
<< dict_type->python_str() << "'";
294+
}
295+
cur_elem = g.insert(aten::__getitem__, {val, idx}, {}, loc);
287296
} else if (val_type->isSubtypeOf(TensorType::get())) {
288-
cur_elem = g.insert(aten::select, {val, 0, i}, {}, loc);
297+
cur_elem = g.insert(aten::select, {val, 0, idx}, {}, loc);
289298
} else {
290-
throw ErrorReport(loc)
291-
<< "cannot get element of the value type " << val_type->python_str();
299+
throw ErrorReport(loc) << "'" << val_type->python_str() << "'"
300+
<< " object is not subscriptable";
292301
}
293302
return cur_elem;
294303
}
@@ -327,12 +336,12 @@ Value* RangeValue::len(const SourceRange& loc, Function& m) {
327336
}
328337
}
329338

330-
Value* RangeValue::getelem(const SourceRange&loc, Function& m, Value* i) {
339+
Value* RangeValue::getitem(const SourceRange& loc, Function& m, Value* idx) {
331340
if (has_only_end_) {
332-
return i;
341+
return idx;
333342
} else {
334343
auto& g = *m.graph();
335-
return g.insert(aten::__derive_index, {i, start_, step_}, {}, loc);
344+
return g.insert(aten::__derive_index, {idx, start_, step_}, {}, loc);
336345
}
337346
}
338347

@@ -368,12 +377,12 @@ Value* IterableTree::len(const SourceRange& loc, Function& m) {
368377
return g.insert(prim::min, {list_node->output()}, {}, loc);
369378
}
370379

371-
Value* IterableTree::getelem(const SourceRange&loc, Function& m, Value* i) {
380+
Value* IterableTree::getitem(const SourceRange& loc, Function& m, Value* idx) {
372381
std::vector<Value*> child_items;
373382
for(const SugaredValuePtr& child: children_) {
374-
child_items.emplace_back(child->getelem(loc, m, i));
383+
child_items.emplace_back(child->getitem(loc, m, idx));
375384
}
376-
// If you call getelem() on a IterableTree sugared value, we will create Tuple
385+
// If you call getitem() on a IterableTree sugared value, we will create Tuple
377386
// from the children items, and make the Tuple value as the element
378387
Graph& g = *m.graph();
379388
return g.insertNode(g.createTuple(child_items))->output();

0 commit comments

Comments
 (0)