Skip to content

Commit 82e8cae

Browse files
committed
deepCopy also copies type information of lists
Differential Revision: [D16449220](https://our.internmc.facebook.com/intern/diff/D16449220/) ghstack-source-id: 87041700 Pull Request resolved: #23271
1 parent f112c52 commit 82e8cae

File tree

3 files changed

+24
-5
lines changed

3 files changed

+24
-5
lines changed

aten/src/ATen/core/List.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,10 @@ class List final {
435435
// TODO Test use_count
436436
size_t use_count() const;
437437

438+
// private API for now because the return type will change to TypePtr
439+
// instead of optional<TypePtr> once types are mandatory.
440+
optional<TypePtr> _elementType() const;
441+
438442
private:
439443
explicit List(c10::intrusive_ptr<detail::ListImpl<StorageT>>&& elements);
440444
friend struct IValue;

aten/src/ATen/core/List_inl.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,4 +289,9 @@ size_t List<T>::use_count() const {
289289
return impl_.use_count();
290290
}
291291

292+
template<class T>
293+
optional<TypePtr> List<T>::_elementType() const {
294+
return impl_->elementType;
295+
}
296+
292297
}

torch/csrc/jit/passes/utils/check_alias_annotation.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,22 @@ IValue deepCopy(const IValue& self) {
2525

2626
// Lists of ivalues should recursively deep copy their contents
2727
if (self.isGenericList()) {
28-
auto newList = c10::impl::GenericList(c10::impl::deprecatedUntypedList());
29-
newList.reserve(self.toGenericListRef().size());
30-
for (const IValue& value : self.toGenericListRef()) {
31-
newList.push_back(deepCopy(value));
28+
auto source = std::move(self).toGenericList();
29+
auto deepCopyGenericList = [&] (c10::impl::GenericList& dest) {
30+
dest.reserve(source.size());
31+
for (const IValue& value : source) {
32+
dest.push_back(deepCopy(value));
33+
}
34+
};
35+
if (source._elementType().has_value()) {
36+
auto newList = c10::impl::GenericList(*source._elementType());
37+
deepCopyGenericList(newList);
38+
return newList;
39+
} else {
40+
auto newList = c10::impl::GenericList(c10::impl::deprecatedUntypedList());
41+
deepCopyGenericList(newList);
42+
return newList;
3243
}
33-
return newList;
3444
}
3545

3646
// Regular lists can copy assign

0 commit comments

Comments
 (0)