@@ -1312,7 +1312,7 @@ struct to_ir {
13121312
13131313 // if the FOR iters and targets are present, emit FOR target assignments
13141314 if (iter_val != nullptr && targets) {
1315- Value* cur_elem = iter_val->getelem (range, method, trip_count);
1315+ Value* cur_elem = iter_val->getitem (range, method, trip_count);
13161316 SugaredValuePtr sv = std::make_shared<SimpleValue>(cur_elem);
13171317 List<Expr> target_exprs = targets.value ();
13181318 validateAssignLhsExpr (target_exprs, range);
@@ -1651,7 +1651,7 @@ struct to_ir {
16511651 NamedValue (stmt.rhs ().range (), " value" , emitExpr (stmt.rhs ()));
16521652
16531653 const auto getItem =
1654- graph->insert (aten::select , {listArg, idxArg}, {}, stmt.range ());
1654+ graph->insert (aten::__getitem__ , {listArg, idxArg}, {}, stmt.range ());
16551655 const auto augmentedItem = graph->insert (
16561656 getAugOp (stmt, elementType), {getItem, valueArg}, {}, stmt.range ());
16571657 graph->insert (
@@ -2808,24 +2808,6 @@ struct to_ir {
28082808 ->output ();
28092809 }
28102810
2811- Value* emitDictIndex (
2812- const SourceRange& loc,
2813- Value* dict_val,
2814- Value* key_val) {
2815- auto dict_type = dict_val->type ()->cast <DictType>();
2816-
2817- if (!key_val->type ()->isSubtypeOf (dict_type->getKeyType ())) {
2818- throw ErrorReport (loc)
2819- << " Expected key type '" << key_val->type ()->python_str ()
2820- << " ' to subtype the key type '"
2821- << dict_type->getKeyType ()->python_str () << " ' of the dict '"
2822- << dict_type->python_str () << " '" ;
2823- }
2824-
2825- return graph->insertNode (graph->createDictIndex (dict_val, key_val))
2826- ->output ();
2827- }
2828-
28292811 int64_t getSliceInd (Value* idx_val, const SourceRange& loc) {
28302812 auto ivalue = toIValue (idx_val);
28312813 if (ivalue && ivalue->isInt ()) {
@@ -2864,60 +2846,30 @@ struct to_ir {
28642846 }
28652847
28662848 Value* emitSubscript (const Subscript& subscript) {
2867- return emitSubscript (
2868- subscript.range (),
2869- emitExpr (subscript.value ()),
2870- subscript.subscript_exprs ());
2871- }
2872-
2873- Value* emitSubscript (
2874- const SourceRange& loc,
2875- Value* sliceable,
2876- const List<Expr>& subscript_exprs) {
2849+ const SugaredValuePtr sv = emitSugaredExpr (subscript.value (), 1 );
2850+ const List<Expr>& subscript_exprs = subscript.subscript_exprs ();
2851+ const SourceRange& range = subscript.range ();
2852+ const SourceRange& val_range = subscript.value ().range ();
28772853 if (subscript_exprs.size () != 1 ) {
2878- return emitMultidimSlicing (loc, sliceable, subscript_exprs);
2854+ return emitMultidimSlicing (
2855+ range, sv->asValue (val_range, method), subscript_exprs);
28792856 }
28802857 if (subscript_exprs[0 ].kind () == TK_SLICE_EXPR) {
2881- return emitBasicSlice (loc, sliceable, subscript_exprs);
2858+ return emitBasicSlice (
2859+ range, sv->asValue (val_range, method), subscript_exprs);
28822860 } else {
2883- return emitBasicGather (loc, sliceable, subscript_exprs);
2884- }
2885- }
2886-
2887- // Desugars gather syntactic sugar foo[i]
2888- Value* emitBasicGather (
2889- const SourceRange& loc,
2890- Value* gatherable,
2891- const List<Expr>& subscript_exprs) {
2892- AT_ASSERT (subscript_exprs.size () == 1 );
2893-
2894- if (gatherable->type ()->kind () == TypeKind::ListType) {
2895- // if it's a list, emit a regular index selection op
2896- auto * idx = emitExpr (subscript_exprs[0 ]);
2897- return emitBuiltinCall (
2898- loc, *graph, aten::select, c10::nullopt , {gatherable, idx}, {}, true );
2899- } else if (gatherable->type ()->isSubtypeOf (TensorType::get ())) {
2900- return emitMultidimSlicing (loc, gatherable, subscript_exprs);
2901- } else if (auto tuple_type = gatherable->type ()->cast <TupleType>()) {
2902- auto * idx = emitExpr (subscript_exprs[0 ]);
2903- return emitTupleIndex (loc, gatherable, idx);
2904- } else if (auto dict_type = gatherable->type ()->cast <DictType>()) {
2905- auto * idx = emitExpr (subscript_exprs[0 ]);
2906- return emitDictIndex (loc, gatherable, idx);
2907- } else if (auto string_type = gatherable->type ()->cast <StringType>()) {
2908- auto * idx = emitExpr (subscript_exprs[0 ]);
2909- return emitBuiltinCall (
2910- loc,
2911- *graph,
2912- prim::StringIndex,
2913- c10::nullopt ,
2914- {gatherable, idx},
2915- {},
2916- true );
2917- } else {
2918- throw ErrorReport (loc) << " Indexing only supported on List, Dict, "
2919- " Tensor, Tuple, and str but got type '"
2920- << gatherable->type ()->python_str () << " '" ;
2861+ // Desugars gather syntactic sugar foo[i]
2862+ Value* idx = emitExpr (subscript_exprs[0 ]);
2863+ Value* val = sv->asValue (val_range, method);
2864+ AT_ASSERT (subscript_exprs.size () == 1 );
2865+
2866+ if (val->type ()->cast <TupleType>()) {
2867+ return emitTupleIndex (range, sv->asValue (val_range, method), idx);
2868+ } else if (val->type ()->isSubtypeOf (TensorType::get ())) {
2869+ return emitMultidimSlicing (range, val, subscript_exprs);
2870+ } else {
2871+ return sv->getitem (range, method, idx);
2872+ }
29212873 }
29222874 }
29232875};
0 commit comments