Skip to content

Commit 377a728

Browse files
committed
dictKeys and dictItems ops on typed dicts return typed lists
Differential Revision: [D16448942](https://our.internmc.facebook.com/intern/diff/D16448942/) ghstack-source-id: 87040601 Pull Request resolved: #23270
1 parent f112c52 commit 377a728

File tree

3 files changed

+53
-10
lines changed

3 files changed

+53
-10
lines changed

aten/src/ATen/core/Dict.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,12 @@ class Dict final {
363363
* having to reallocate or rehash.
364364
*/
365365
void reserve(size_type count);
366+
367+
368+
// private API for now because the return type will change to TypePtr
369+
// instead of optional<TypePtr> once types are mandatory.
370+
optional<TypePtr> _keyType() const;
371+
optional<TypePtr> _valueType() const;
366372
};
367373

368374
namespace impl {

aten/src/ATen/core/Dict_inl.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,4 +212,20 @@ void Dict<Key, Value>::reserve(size_type count) {
212212
impl_->dict.reserve(count);
213213
}
214214

215+
template<class Key, class Value>
216+
optional<TypePtr> Dict<Key, Value>::_keyType() const {
217+
if (!impl_->elementTypes.has_value()) {
218+
return c10::nullopt;
219+
}
220+
return impl_->elementTypes->keyType;
221+
}
222+
223+
template<class Key, class Value>
224+
optional<TypePtr> Dict<Key, Value>::_valueType() const {
225+
if (!impl_->elementTypes.has_value()) {
226+
return c10::nullopt;
227+
}
228+
return impl_->elementTypes->valueType;
229+
}
230+
215231
}

torch/csrc/jit/register_prim_ops.cpp

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1693,12 +1693,22 @@ int dictLen(Stack& stack) {
16931693

16941694
int dictKeys(Stack& stack) {
16951695
auto dict = pop(stack).toGenericDict();
1696-
auto keys = c10::impl::GenericList(c10::impl::deprecatedUntypedList());
1697-
keys.reserve(dict.size());
1698-
for (auto& item : dict) {
1699-
keys.push_back(item.key());
1696+
auto addKeysFromDict = [&] (c10::impl::GenericList& keys) {
1697+
keys.reserve(dict.size());
1698+
for (auto& item : dict) {
1699+
keys.push_back(item.key());
1700+
}
1701+
};
1702+
auto key_type = dict._keyType();
1703+
if (key_type.has_value()) {
1704+
auto keys = c10::impl::GenericList(*key_type);
1705+
addKeysFromDict(keys);
1706+
push(stack, std::move(keys));
1707+
} else {
1708+
auto keys = c10::impl::GenericList(c10::impl::deprecatedUntypedList());
1709+
addKeysFromDict(keys);
1710+
push(stack, std::move(keys));
17001711
}
1701-
push(stack, IValue(keys));
17021712
return 0;
17031713
}
17041714

@@ -1851,12 +1861,23 @@ int dictUpdate(Stack& stack) {
18511861

18521862
int dictItems(Stack& stack) {
18531863
auto dict = pop(stack).toGenericDict();
1854-
auto items = c10::impl::GenericList(c10::impl::deprecatedUntypedList());
1855-
items.reserve(dict.size());
1856-
for (const auto& item : iterationOrder(dict)) {
1857-
items.emplace_back(c10::ivalue::Tuple::create({item.first, item.second}));
1864+
auto addItemsFromDict = [&] (c10::impl::GenericList& items) {
1865+
items.reserve(dict.size());
1866+
for (const auto& item : iterationOrder(dict)) {
1867+
items.emplace_back(c10::ivalue::Tuple::create({item.first, item.second}));
1868+
}
1869+
};
1870+
auto key_type = dict._keyType();
1871+
auto value_type = dict._valueType();
1872+
if (key_type.has_value() && value_type.has_value()) {
1873+
auto items = c10::impl::GenericList(TupleType::create({*key_type, *value_type}));
1874+
addItemsFromDict(items);
1875+
push(stack, std::move(items));
1876+
} else {
1877+
auto items = c10::impl::GenericList(c10::impl::deprecatedUntypedList());
1878+
addItemsFromDict(items);
1879+
push(stack, std::move(items));
18581880
}
1859-
push(stack, std::move(items));
18601881
return 0;
18611882
}
18621883

0 commit comments

Comments
 (0)