Skip to content

Commit 1de44a6

Browse files
Elias Ellisonfacebook-github-bot
authored andcommitted
fix specialized list from dict keys (#23267)
Summary: Previously we weren't specializing the list returned from `dict.keys()` Pull Request resolved: #23267 Differential Revision: D16448512 Pulled By: eellison fbshipit-source-id: fcd2a37ac680bdf90219b099a94aa36a80f4067c
1 parent a6ccd62 commit 1de44a6

File tree

2 files changed

+31
-23
lines changed

2 files changed

+31
-23
lines changed

test/test_jit.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16453,6 +16453,14 @@ def keys(x):
1645316453

1645416454
self.assertEqual(set(keys(self.dict())), set(self.dict().keys()))
1645516455

16456+
@torch.jit.script
16457+
def specialized_list():
16458+
li = {1: 1, 2: 2}.keys()
16459+
li.append(3)
16460+
return li
16461+
16462+
self.assertTrue(set(specialized_list()) == set([1, 2, 3]))
16463+
1645616464
def test_values(self):
1645716465
@torch.jit.script
1645816466
def values(x):

torch/csrc/jit/register_prim_ops.cpp

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1691,58 +1691,58 @@ int dictLen(Stack& stack) {
16911691
return 0;
16921692
}
16931693

1694-
int dictKeys(Stack& stack) {
1695-
auto dict = pop(stack).toGenericDict();
1696-
auto keys = c10::impl::GenericList(c10::impl::deprecatedUntypedList());
1697-
keys.reserve(dict.size());
1698-
for (auto& item : dict) {
1699-
keys.push_back(item.key());
1700-
}
1701-
push(stack, IValue(keys));
1702-
return 0;
1703-
}
1704-
1705-
template <typename Elem>
1706-
c10::List<Elem> makeListForDictValues(
1694+
template <unsigned int Index, typename Elem>
1695+
c10::List<Elem> makeListForDictKeysOrValues(
17071696
const std::vector<std::pair<IValue, IValue>>& order) {
17081697
c10::List<Elem> values;
17091698
values.reserve(order.size());
17101699
for (const auto& item : order) {
1711-
values.push_back(item.second.to<Elem>());
1700+
values.push_back(std::get<Index>(item).template to<Elem>());
17121701
}
17131702
return values;
17141703
}
17151704

1716-
template <>
1717-
c10::impl::GenericList makeListForDictValues<IValue>(
1705+
template <unsigned int Index>
1706+
c10::impl::GenericList makeGenericListForDictKeysOrValues(
17181707
const std::vector<std::pair<IValue, IValue>>& order) {
17191708
auto values = c10::impl::GenericList(c10::impl::deprecatedUntypedList());
17201709
values.reserve(order.size());
17211710
for (const auto& item : order) {
1722-
values.push_back(item.second);
1711+
values.push_back(std::get<Index>(item));
17231712
}
17241713
return values;
17251714
}
17261715

1727-
Operation dictValues(const Node* n) {
1716+
template <unsigned int Index>
1717+
Operation dictKeysOrValues(const Node* n) {
17281718
auto outputType = n->output()->type()->expect<ListType>();
17291719
return [=](Stack& stack) -> int {
17301720
const auto& order = iterationOrder(pop(stack).toGenericDict());
17311721
if (outputType->getElementType()->isSubtypeOf(TensorType::get())) {
1732-
push(stack, makeListForDictValues<at::Tensor>(order));
1722+
push(stack, makeListForDictKeysOrValues<Index, at::Tensor>(order));
17331723
} else if (outputType->getElementType() == IntType::get()) {
1734-
push(stack, makeListForDictValues<int64_t>(order));
1724+
push(stack, makeListForDictKeysOrValues<Index, int64_t>(order));
17351725
} else if (outputType->getElementType() == FloatType::get()) {
1736-
push(stack, makeListForDictValues<double>(order));
1726+
push(stack, makeListForDictKeysOrValues<Index, double>(order));
17371727
} else if (outputType->getElementType() == BoolType::get()) {
1738-
push(stack, makeListForDictValues<bool>(order));
1728+
push(stack, makeListForDictKeysOrValues<Index, bool>(order));
17391729
} else {
1740-
push(stack, makeListForDictValues<IValue>(order));
1730+
push(stack, makeGenericListForDictKeysOrValues<Index>(order));
17411731
}
17421732
return 0;
17431733
};
17441734
}
17451735

1736+
Operation dictKeys(const Node* n) {
1737+
// getting first dict pair
1738+
return dictKeysOrValues<0>(n);
1739+
}
1740+
1741+
Operation dictValues(const Node* n) {
1742+
// getting second of dict pair
1743+
return dictKeysOrValues<1>(n);
1744+
}
1745+
17461746
int dictIndex(Stack& stack) {
17471747
auto key = pop(stack);
17481748
auto dict = pop(stack).toGenericDict();

0 commit comments

Comments
 (0)