Skip to content

Commit 2ed6607

Browse files
committed
[jit] register __getitem__ builtin
Summary: Follow up of #21990, I am switching index select operations to a standard __getitem__ builtin, rather than bunch of different builtins according to the type, such as prim::DictIndex, prim::ListIndex, etc. This will also aligned with the some other magic methods that we already use
1 parent 687a6ca commit 2ed6607

File tree

9 files changed

+145
-109
lines changed

9 files changed

+145
-109
lines changed

aten/src/ATen/core/interned_strings.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ namespace c10 {
132132
_(aten, ne_) \
133133
_(aten, transpose_) \
134134
_(aten, unsqueeze_) \
135+
_(aten, __getitem__) \
135136
_(aten, _set_item) \
136137
_(aten, manual_seed) \
137138
_(aten, set_) \
@@ -142,7 +143,7 @@ namespace c10 {
142143
_(aten, list) \
143144
_(aten, wait) \
144145
_(aten, save) \
145-
_(aten, keys) \
146+
_(aten, keys) \
146147
_(aten, ord) \
147148
_(aten, chr) \
148149
_(aten, hex) \
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
def empty_int_list_test(y: Tensor) -> int:
22
x = annotate(List[int], [])
3-
return torch.select(x, 0)
3+
return x[0]

test/test_jit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11612,7 +11612,7 @@ def foo(a):
1161211612
i += 1
1161311613
return b
1161411614

11615-
FileCheck().check("prim::Loop").check_not("aten::rand").check("aten::select") \
11615+
FileCheck().check("prim::Loop").check_not("aten::rand").check("aten::__getitem__") \
1161611616
.check_count("aten::rand", 1, exactly=True).run(str(foo.graph))
1161711617

1161811618
def test_mutable_dce_wildcards(self):

torch/csrc/jit/ir.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1394,7 +1394,7 @@ Node* Graph::createDictIndex(Value* dict, Value* index) {
13941394
auto dict_type = dict->type()->expect<DictType>();
13951395
AT_ASSERT(index->type()->isSubtypeOf(dict_type->getKeyType()));
13961396

1397-
auto n = create(prim::DictIndex, {dict, index});
1397+
auto n = create(aten::__getitem__, {dict, index});
13981398
n->output()->setType(dict_type->getValueType());
13991399
return n;
14001400
}

torch/csrc/jit/passes/python_print.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,13 @@ struct PythonPrintPass {
463463
stmt << end;
464464
}
465465

466+
void printValueIndex(std::ostream& stmt, at::ArrayRef<Value*> inputs) {
467+
stmt << useOf(inputs[0]);
468+
stmt << "[";
469+
stmt << useOf(inputs[1]);
470+
stmt << "]";
471+
}
472+
466473
void printDict(
467474
std::ostream& stmt,
468475
at::ArrayRef<Value*> key_value_pairs,
@@ -889,6 +896,9 @@ struct PythonPrintPass {
889896
case aten::Str: {
890897
printValueList(stmt, node->inputs(), "str(", ")");
891898
} break;
899+
case aten::__getitem__: {
900+
printValueIndex(stmt, node->inputs());
901+
} break;
892902
case prim::Print: {
893903
printValueList(stmt, node->inputs(), "print(", ")");
894904
} break;

torch/csrc/jit/register_prim_ops.cpp

Lines changed: 108 additions & 87 deletions
Large diffs are not rendered by default.

torch/csrc/jit/script/compiler.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,7 +1312,7 @@ struct to_ir {
13121312

13131313
// if the FOR iters and targets are present, emit FOR target assignments
13141314
if (iter_val != nullptr && targets) {
1315-
Value* cur_elem = iter_val->getelem(range, method, trip_count);
1315+
Value* cur_elem = iter_val->getitem(range, method, trip_count);
13161316
SugaredValuePtr sv = std::make_shared<SimpleValue>(cur_elem);
13171317
List<Expr> target_exprs = targets.value();
13181318
validateAssignLhsExpr(target_exprs, range);
@@ -1644,7 +1644,7 @@ struct to_ir {
16441644
NamedValue(stmt.rhs().range(), "value", emitExpr(stmt.rhs()));
16451645

16461646
const auto getItem =
1647-
graph->insert(aten::select, {listArg, idxArg}, {}, stmt.range());
1647+
graph->insert(aten::__getitem__, {listArg, idxArg}, {}, stmt.range());
16481648
const auto augmentedItem = graph->insert(
16491649
getAugOp(stmt, isTensorList), {getItem, valueArg}, {}, stmt.range());
16501650
graph->insert(
@@ -2880,7 +2880,13 @@ struct to_ir {
28802880
// if it's a list, emit a regular index selection op
28812881
auto* idx = emitExpr(subscript_exprs[0]);
28822882
return emitBuiltinCall(
2883-
loc, *graph, aten::select, c10::nullopt, {gatherable, idx}, {}, true);
2883+
loc,
2884+
*graph,
2885+
aten::__getitem__,
2886+
c10::nullopt,
2887+
{gatherable, idx},
2888+
{},
2889+
true);
28842890
} else if (gatherable->type()->isSubtypeOf(TensorType::get())) {
28852891
return emitMultidimSlicing(loc, gatherable, subscript_exprs);
28862892
} else if (auto tuple_type = gatherable->type()->cast<TupleType>()) {
@@ -2894,7 +2900,7 @@ struct to_ir {
28942900
return emitBuiltinCall(
28952901
loc,
28962902
*graph,
2897-
prim::StringIndex,
2903+
aten::__getitem__,
28982904
c10::nullopt,
28992905
{gatherable, idx},
29002906
{},

torch/csrc/jit/script/sugared_value.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -274,15 +274,13 @@ Value* SimpleValue::len(const SourceRange& loc, Function& m) {
274274
}
275275
}
276276

277-
Value* SimpleValue::getelem(const SourceRange&loc, Function& m, Value* i) {
277+
Value* SimpleValue::getitem(const SourceRange&loc, Function& m, Value* i) {
278278
Value* val = getValue();
279279
TypePtr val_type = val->type();
280280
Graph& g = *m.graph();
281281
Value* cur_elem = nullptr;
282-
if (val_type->cast<ListType>()) {
283-
cur_elem = g.insert(aten::select, {val, i}, {}, loc);
284-
} else if (val_type->cast<StringType>()) {
285-
cur_elem = g.insert(prim::StringIndex, {val, i}, {}, loc);
282+
if (val_type->cast<ListType>() || val_type->cast<StringType>()) {
283+
cur_elem = g.insert(aten::__getitem__, {val, i}, {}, loc);
286284
} else if (val_type->isSubtypeOf(TensorType::get())) {
287285
cur_elem = g.insert(aten::select, {val, 0, i}, {}, loc);
288286
} else {
@@ -326,7 +324,7 @@ Value* RangeValue::len(const SourceRange& loc, Function& m) {
326324
}
327325
}
328326

329-
Value* RangeValue::getelem(const SourceRange&loc, Function& m, Value* i) {
327+
Value* RangeValue::getitem(const SourceRange&loc, Function& m, Value* i) {
330328
if (has_only_end_) {
331329
return i;
332330
} else {
@@ -367,12 +365,12 @@ Value* IterableTree::len(const SourceRange& loc, Function& m) {
367365
return g.insert(prim::min, {list_node->output()}, {}, loc);
368366
}
369367

370-
Value* IterableTree::getelem(const SourceRange&loc, Function& m, Value* i) {
368+
Value* IterableTree::getitem(const SourceRange&loc, Function& m, Value* i) {
371369
std::vector<Value*> child_items;
372370
for(const SugaredValuePtr& child: children_) {
373-
child_items.emplace_back(child->getelem(loc, m, i));
371+
child_items.emplace_back(child->getitem(loc, m, i));
374372
}
375-
// If you call getelem() on a IterableTree sugared value, we will create Tuple
373+
// If you call getitem() on a IterableTree sugared value, we will create Tuple
376374
// from the children items, and make the Tuple value as the element
377375
Graph& g = *m.graph();
378376
return g.insertNode(g.createTuple(child_items))->output();

torch/csrc/jit/script/sugared_value.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ struct TORCH_API SugaredValue
100100
<< " object is not iterable";
101101
}
102102
// expression for ith elemement for iterable value
103-
virtual Value* getelem(const SourceRange&loc, Function& m, Value* i) {
103+
virtual Value* getitem(const SourceRange&loc, Function& m, Value* i) {
104104
throw ErrorReport(loc) << " cannot get the element of value " << kind();
105105
}
106106

@@ -153,7 +153,7 @@ struct TORCH_API SimpleValue : public SugaredValue {
153153
}
154154

155155
Value* len(const SourceRange& loc, Function& m) override;
156-
Value* getelem(const SourceRange&loc, Function& m, Value* i) override;
156+
Value* getitem(const SourceRange&loc, Function& m, Value* i) override;
157157

158158
private:
159159
Value* value_;
@@ -444,7 +444,7 @@ struct TORCH_API RangeValue : SugaredValue {
444444
return "range";
445445
}
446446
Value* len(const SourceRange& loc, Function& m) override;
447-
Value* getelem(const SourceRange&loc, Function& m, Value* i) override;
447+
Value* getitem(const SourceRange&loc, Function& m, Value* i) override;
448448

449449
private:
450450
Value* start_;
@@ -492,11 +492,11 @@ struct TORCH_API IterableTree : SugaredValue {
492492
// given a IterableTree node, get all the base iterables/leaves under the
493493
// IterableTree node, which are either SimpleValue or RangeValue. This enable
494494
// us to get all the basic SugaredValues that contains valid loop information
495-
// with len() and getelem()
495+
// with len() and getitem()
496496
std::vector<SugaredValuePtr> get_base_iterables();
497497

498498
Value* len(const SourceRange& loc, Function& m) override;
499-
Value* getelem(const SourceRange&loc, Function& m, Value* i) override;
499+
Value* getitem(const SourceRange&loc, Function& m, Value* i) override;
500500

501501
private:
502502
std::vector<SugaredValuePtr> children_;

0 commit comments

Comments
 (0)