Skip to content

Commit 0c0ffcc

Browse files
smessmerfacebook-github-bot
authored andcommitted
deepCopy also copies type information of lists (#23271)
Summary: Pull Request resolved: #23271 ghstack-source-id: 87088503 Differential Revision: D16449220 fbshipit-source-id: 551b7cef8f6d0d2d5a56b24ddbe2e0bb2c0c3dbe
1 parent 895e79a commit 0c0ffcc

File tree

3 files changed

+15
-3
lines changed

3 files changed

+15
-3
lines changed

aten/src/ATen/core/List.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,10 @@ class List final {
435435
// TODO Test use_count
436436
size_t use_count() const;
437437

438+
// private API for now because the return type will change to TypePtr
439+
// instead of optional<TypePtr> once types are mandatory.
440+
optional<TypePtr> _elementType() const;
441+
438442
private:
439443
explicit List(c10::intrusive_ptr<detail::ListImpl<StorageT>>&& elements);
440444
friend struct IValue;

aten/src/ATen/core/List_inl.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,4 +289,9 @@ size_t List<T>::use_count() const {
289289
return impl_.use_count();
290290
}
291291

292+
template<class T>
293+
optional<TypePtr> List<T>::_elementType() const {
294+
return impl_->elementType;
295+
}
296+
292297
}

torch/csrc/jit/passes/utils/check_alias_annotation.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,12 @@ IValue deepCopy(const IValue& self) {
2525

2626
// Lists of ivalues should recursively deep copy their contents
2727
if (self.isGenericList()) {
28-
auto newList = c10::impl::GenericList(c10::impl::deprecatedUntypedList());
29-
newList.reserve(self.toGenericListRef().size());
30-
for (const IValue& value : self.toGenericListRef()) {
28+
auto source = std::move(self).toGenericList();
29+
auto newList = source._elementType().has_value()
30+
? c10::impl::GenericList(*source._elementType())
31+
: c10::impl::GenericList(c10::impl::deprecatedUntypedList());
32+
newList.reserve(source.size());
33+
for (const IValue& value : source) {
3134
newList.push_back(deepCopy(value));
3235
}
3336
return newList;

0 commit comments

Comments
 (0)