@@ -1691,58 +1691,58 @@ int dictLen(Stack& stack) {
16911691 return 0 ;
16921692}
16931693
1694- int dictKeys (Stack& stack) {
1695- auto dict = pop (stack).toGenericDict ();
1696- auto keys = c10::impl::GenericList (c10::impl::deprecatedUntypedList ());
1697- keys.reserve (dict.size ());
1698- for (auto & item : dict) {
1699- keys.push_back (item.key ());
1700- }
1701- push (stack, IValue (keys));
1702- return 0 ;
1703- }
1704-
1705- template <typename Elem>
1706- c10::List<Elem> makeListForDictValues (
1694+ template <unsigned int Index, typename Elem>
1695+ c10::List<Elem> makeListForDictKeysOrValues (
17071696 const std::vector<std::pair<IValue, IValue>>& order) {
17081697 c10::List<Elem> values;
17091698 values.reserve (order.size ());
17101699 for (const auto & item : order) {
1711- values.push_back (item. second . to <Elem>());
1700+ values.push_back (std::get<Index>( item). template to <Elem>());
17121701 }
17131702 return values;
17141703}
17151704
1716- template <>
1717- c10::impl::GenericList makeListForDictValues<IValue> (
1705+ template <unsigned int Index >
1706+ c10::impl::GenericList makeGenericListForDictKeysOrValues (
17181707 const std::vector<std::pair<IValue, IValue>>& order) {
17191708 auto values = c10::impl::GenericList (c10::impl::deprecatedUntypedList ());
17201709 values.reserve (order.size ());
17211710 for (const auto & item : order) {
1722- values.push_back (item. second );
1711+ values.push_back (std::get<Index>( item) );
17231712 }
17241713 return values;
17251714}
17261715
1727- Operation dictValues (const Node* n) {
1716+ template <unsigned int Index>
1717+ Operation dictKeysOrValues (const Node* n) {
17281718 auto outputType = n->output ()->type ()->expect <ListType>();
17291719 return [=](Stack& stack) -> int {
17301720 const auto & order = iterationOrder (pop (stack).toGenericDict ());
17311721 if (outputType->getElementType ()->isSubtypeOf (TensorType::get ())) {
1732- push (stack, makeListForDictValues< at::Tensor>(order));
1722+ push (stack, makeListForDictKeysOrValues<Index, at::Tensor>(order));
17331723 } else if (outputType->getElementType () == IntType::get ()) {
1734- push (stack, makeListForDictValues< int64_t >(order));
1724+ push (stack, makeListForDictKeysOrValues<Index, int64_t >(order));
17351725 } else if (outputType->getElementType () == FloatType::get ()) {
1736- push (stack, makeListForDictValues< double >(order));
1726+ push (stack, makeListForDictKeysOrValues<Index, double >(order));
17371727 } else if (outputType->getElementType () == BoolType::get ()) {
1738- push (stack, makeListForDictValues< bool >(order));
1728+ push (stack, makeListForDictKeysOrValues<Index, bool >(order));
17391729 } else {
1740- push (stack, makeListForDictValues<IValue >(order));
1730+ push (stack, makeGenericListForDictKeysOrValues<Index >(order));
17411731 }
17421732 return 0 ;
17431733 };
17441734}
17451735
1736+ Operation dictKeys (const Node* n) {
1737+ // getting first dict pair
1738+ return dictKeysOrValues<0 >(n);
1739+ }
1740+
1741+ Operation dictValues (const Node* n) {
1742+ // getting second of dict pair
1743+ return dictKeysOrValues<1 >(n);
1744+ }
1745+
17461746int dictIndex (Stack& stack) {
17471747 auto key = pop (stack);
17481748 auto dict = pop (stack).toGenericDict ();
0 commit comments