@@ -1632,7 +1632,15 @@ struct to_ir {
16321632 // Get the appropriate builtin op for this augmented assignment
16331633 // If the RHS is a tensor, return the corresponding ATen in-place op
16341634 // If it's a list of scalars, then return the corresponding list augment op
1635- Symbol getAugOp (const AugAssign& stmt, bool isTensor) {
1635+ Symbol getAugOp (const AugAssign& stmt, TypePtr type) {
1636+ if (type->cast <ListType>()) { // Lists also have in-place ops.
1637+ std::cout<<" generating list op" <<std::endl;
1638+ switch (stmt.aug_op ()) {
1639+ case ' +' :
1640+ return aten::add_;
1641+ }
1642+ }
1643+ bool isTensor = type->isSubtypeOf (TensorType::get ());
16361644 switch (stmt.aug_op ()) {
16371645 case ' +' :
16381646 return isTensor ? aten::add_ : aten::add;
@@ -1696,7 +1704,7 @@ struct to_ir {
16961704 emitBuiltinCall (
16971705 stmt.range (),
16981706 *method.graph (),
1699- getAugOp (stmt, /* isTensor= */ true ),
1707+ getAugOp (stmt, lhsValue-> type () ),
17001708 self,
17011709 {rhs},
17021710 {},
@@ -1713,14 +1721,15 @@ struct to_ir {
17131721 const auto lhs = Var (stmt.lhs ());
17141722 const auto lhsValue = environment_stack->getSugaredVar (lhs.name ())
17151723 ->asValue (lhs.range (), method);
1716- if (lhsValue->type ()->isSubtypeOf (TensorType::get ())) {
1724+ auto lhsType = lhsValue->type ();
1725+ if (lhsType->isSubtypeOf (TensorType::get ()) || lhsType->cast <c10::ListType>()) {
17171726 // for tensors, emit the corresponding in-place op
17181727 const auto rhs = NamedValue (stmt.rhs ().range (), emitExpr (stmt.rhs ()));
17191728 const auto self = NamedValue (stmt.lhs ().range (), " self" , lhsValue);
17201729 const auto output = emitBuiltinCall (
17211730 stmt.range (),
17221731 *method.graph (),
1723- getAugOp (stmt, /* isTensor= */ true ),
1732+ getAugOp (stmt, lhsValue-> type () ),
17241733 self,
17251734 {rhs},
17261735 {},
@@ -1761,7 +1770,7 @@ struct to_ir {
17611770 emitBuiltinCall (
17621771 stmt.range (),
17631772 *method.graph (),
1764- getAugOp (stmt, /* isTensor= */ true ),
1773+ getAugOp (stmt, sliceable-> type () ),
17651774 slicedArg,
17661775 {rhs},
17671776 {},
@@ -1778,7 +1787,7 @@ struct to_ir {
17781787 const auto augmented = emitBuiltinCall (
17791788 stmt.range (),
17801789 *method.graph (),
1781- getAugOp (stmt, /* isTensor= */ true ),
1790+ getAugOp (stmt, sliceable-> type () ),
17821791 indexed,
17831792 {rhs},
17841793 {},
@@ -1796,8 +1805,7 @@ struct to_ir {
17961805 const auto listType = sliceable->type ()->cast <ListType>();
17971806 AT_ASSERT (listType != nullptr );
17981807
1799- bool isTensorList =
1800- listType->getElementType ()->isSubtypeOf (TensorType::get ());
1808+ auto elementType = listType->getElementType ();
18011809
18021810 // Get the idx to augment
18031811 const auto subscriptExprs = lhs.subscript_exprs ();
@@ -1817,7 +1825,7 @@ struct to_ir {
18171825 const auto getItem =
18181826 graph->insert (aten::select, {listArg, idxArg}, {}, stmt.range ());
18191827 const auto augmentedItem = graph->insert (
1820- getAugOp (stmt, isTensorList ), {getItem, valueArg}, {}, stmt.range ());
1828+ getAugOp (stmt, elementType ), {getItem, valueArg}, {}, stmt.range ());
18211829 graph->insert (
18221830 aten::_set_item, {listArg, idxArg, augmentedItem}, {}, stmt.range ());
18231831 }
0 commit comments