Skip to content

Commit fab06a5

Browse files
committed
Made a += b for lists do an in place add
1 parent b858f42 commit fab06a5

File tree

4 files changed

+51
-9
lines changed

4 files changed

+51
-9
lines changed

aten/src/ATen/core/List.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,12 @@ class ListPtr final {
349349
*/
350350
void push_back(T&& value) const;
351351

352+
/**
353+
* Appends the given list to the end of the container. Uses at most one memory allocation.
354+
* May invalidate any references, pointers, or iterators referring to contained elements. Any past-the-end iterators may also be invalidated.
355+
*/
356+
void append(ListPtr<T> lst) const;
357+
352358
/**
353359
* Appends the given element value to the end of the container.
354360
* The new element is constructed with the given arguments.

aten/src/ATen/core/List_inl.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,21 @@ void ListPtr<T>::push_back(T&& value) const {
178178
impl_->list.push_back(detail::list_element_from<T, StorageT>(std::move(value)));
179179
}
180180

181+
template<class T>
182+
void ListPtr<T>::append(ListPtr<T> b) const {
183+
size_type neededSize = this->size() + b.size();
184+
if (impl_->list.capacity() < neededSize) {
185+
this->reserve(std::max(impl_->list.capacity() * 2, neededSize));
186+
}
187+
if (b.use_count() == 1) {
188+
std::move(b.begin(), b.end(), this->begin());
189+
} else {
190+
for (const auto& el: b) {
191+
this->push_back(el);
192+
}
193+
}
194+
}
195+
181196
template<class T>
182197
template<class... Args>
183198
void ListPtr<T>::emplace_back(Args&&... args) const {

torch/csrc/jit/register_prim_ops.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1554,6 +1554,15 @@ int listAdd(Stack& stack) {
15541554
push(stack, std::move(ret));
15551555
return 0;
15561556
}
1557+
template <class T>
1558+
int listInplaceAdd(Stack& stack) {
1559+
c10::ListPtr<T> a = c10::make_list<T>();
1560+
c10::ListPtr<T> b = c10::make_list<T>();
1561+
pop(stack, a, b);
1562+
a.append(b);
1563+
push(stack, std::move(a));
1564+
return 0;
1565+
}
15571566

15581567
template <class T>
15591568
int listMulIntLeft(Stack& stack) {
@@ -1928,6 +1937,10 @@ RegisterOperators reg2({
19281937
"aten::add(" decl_type "[] a, " decl_type "[] b) -> " decl_type \
19291938
"[]", \
19301939
listAdd<c_type::value_type>), \
1940+
Operator( \
1941+
"aten::add_(" decl_type "[](a!) self, " decl_type "[] b) -> " decl_type \
1942+
"[]", \
1943+
listInplaceAdd<c_type::value_type>), \
19311944
Operator( \
19321945
"aten::slice(" decl_type \
19331946
"[] l, int start, int end=9223372036854775807, int step=1) -> " decl_type \

torch/csrc/jit/script/compiler.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1632,7 +1632,15 @@ struct to_ir {
16321632
// Get the appropriate builtin op for this augmented assignment
16331633
// If the RHS is a tensor, return the corresponding ATen in-place op
16341634
// If it's a list of scalars, then return the corresponding list augment op
1635-
Symbol getAugOp(const AugAssign& stmt, bool isTensor) {
1635+
Symbol getAugOp(const AugAssign& stmt, TypePtr type) {
1636+
if (type->cast<ListType>()) { // Lists also have in-place ops.
1637+
std::cout<<"generating list op"<<std::endl;
1638+
switch (stmt.aug_op()) {
1639+
case '+':
1640+
return aten::add_;
1641+
}
1642+
}
1643+
bool isTensor = type->isSubtypeOf(TensorType::get());
16361644
switch (stmt.aug_op()) {
16371645
case '+':
16381646
return isTensor ? aten::add_ : aten::add;
@@ -1696,7 +1704,7 @@ struct to_ir {
16961704
emitBuiltinCall(
16971705
stmt.range(),
16981706
*method.graph(),
1699-
getAugOp(stmt, /*isTensor=*/true),
1707+
getAugOp(stmt, lhsValue->type()),
17001708
self,
17011709
{rhs},
17021710
{},
@@ -1713,14 +1721,15 @@ struct to_ir {
17131721
const auto lhs = Var(stmt.lhs());
17141722
const auto lhsValue = environment_stack->getSugaredVar(lhs.name())
17151723
->asValue(lhs.range(), method);
1716-
if (lhsValue->type()->isSubtypeOf(TensorType::get())) {
1724+
auto lhsType = lhsValue->type();
1725+
if (lhsType->isSubtypeOf(TensorType::get()) || lhsType->cast<c10::ListType>()) {
17171726
// for tensors, emit the corresponding in-place op
17181727
const auto rhs = NamedValue(stmt.rhs().range(), emitExpr(stmt.rhs()));
17191728
const auto self = NamedValue(stmt.lhs().range(), "self", lhsValue);
17201729
const auto output = emitBuiltinCall(
17211730
stmt.range(),
17221731
*method.graph(),
1723-
getAugOp(stmt, /*isTensor=*/true),
1732+
getAugOp(stmt, lhsValue->type()),
17241733
self,
17251734
{rhs},
17261735
{},
@@ -1761,7 +1770,7 @@ struct to_ir {
17611770
emitBuiltinCall(
17621771
stmt.range(),
17631772
*method.graph(),
1764-
getAugOp(stmt, /*isTensor=*/true),
1773+
getAugOp(stmt, sliceable->type()),
17651774
slicedArg,
17661775
{rhs},
17671776
{},
@@ -1778,7 +1787,7 @@ struct to_ir {
17781787
const auto augmented = emitBuiltinCall(
17791788
stmt.range(),
17801789
*method.graph(),
1781-
getAugOp(stmt, /*isTensor=*/true),
1790+
getAugOp(stmt, sliceable->type()),
17821791
indexed,
17831792
{rhs},
17841793
{},
@@ -1796,8 +1805,7 @@ struct to_ir {
17961805
const auto listType = sliceable->type()->cast<ListType>();
17971806
AT_ASSERT(listType != nullptr);
17981807

1799-
bool isTensorList =
1800-
listType->getElementType()->isSubtypeOf(TensorType::get());
1808+
auto elementType = listType->getElementType();
18011809

18021810
// Get the idx to augment
18031811
const auto subscriptExprs = lhs.subscript_exprs();
@@ -1817,7 +1825,7 @@ struct to_ir {
18171825
const auto getItem =
18181826
graph->insert(aten::select, {listArg, idxArg}, {}, stmt.range());
18191827
const auto augmentedItem = graph->insert(
1820-
getAugOp(stmt, isTensorList), {getItem, valueArg}, {}, stmt.range());
1828+
getAugOp(stmt, elementType), {getItem, valueArg}, {}, stmt.range());
18211829
graph->insert(
18221830
aten::_set_item, {listArg, idxArg, augmentedItem}, {}, stmt.range());
18231831
}

0 commit comments

Comments
 (0)