Skip to content

Commit d96ce9b

Browse files
wanchaolfacebook-github-bot
authored andcommitted
add for in dict support (#22006)
Summary: Pull Request resolved: #22006 ghimport-source-id: d9686c0 Test Plan: Imported from OSS Differential Revision: D15948548 Pulled By: wanchaol fbshipit-source-id: 4227502ca050099085ad481aef725ac2cab06d74
1 parent c9344fc commit d96ce9b

File tree

3 files changed

+29
-1
lines changed

3 files changed

+29
-1
lines changed

aten/src/ATen/core/interned_strings.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ namespace c10 {
142142
_(aten, list) \
143143
_(aten, wait) \
144144
_(aten, save) \
145+
_(aten, keys) \
145146
_(aten, ord) \
146147
_(aten, chr) \
147148
_(aten, hex) \

test/test_jit.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9384,6 +9384,28 @@ def test_list_strings(x):
93849384
self.checkScript(test_list_strings, (["hello", "world"],))
93859385
self.checkScript(test_list_strings, (["hello", " ", "world", ""],))
93869386

9387+
def test_for_in_dict(self):
9388+
def test_dicts(x):
9389+
# type: (Dict[str, int]) -> int
9390+
sum = 0
9391+
for key in x:
9392+
sum += x[key]
9393+
return sum
9394+
9395+
self.checkScript(test_dicts, ({"a": 1, "b": 2, "c": 3},))
9396+
9397+
def test_dict_keys_values(x):
9398+
# type: (Dict[str, int]) -> Tuple[str, int]
9399+
key_str = ""
9400+
sum = 0
9401+
for key in x.keys():
9402+
key_str += key
9403+
for val in x.values():
9404+
sum += val
9405+
return key_str, sum
9406+
9407+
self.checkScript(test_dicts, ({"a": 1, "b": 2, "c": 3},))
9408+
93879409
def test_for_tuple_unpack(self):
93889410
def for_tuple_unpack(x, y):
93899411
for i, j in [[3, 4], [5, 6], [7, 8]]:

torch/csrc/jit/script/compiler.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1341,7 +1341,7 @@ struct to_ir {
13411341
}
13421342
// Emit loop information for builtinFunction values like range(), zip(),
13431343
// enumerate() or SimpleValue like List, Tensor, Dict, etc.
1344-
auto sv = emitSugaredExpr(itrs[0], 1);
1344+
SugaredValuePtr sv = emitSugaredExpr(itrs[0], 1);
13451345

13461346
// We will get IterableTree for builtinFunctions zip() and enumerate(),
13471347
// RangeValue for range(), and SimpleValue for types like List/Tensor/Dict/String.
@@ -1352,6 +1352,11 @@ struct to_ir {
13521352
// For SimpleValue(except Tuple) or RanveValue/IterableTree, emit common loop
13531353
if ((siv && !siv->getValue()->type()->cast<TupleType>())
13541354
|| range_val || iterable_tree) {
1355+
// looping over a dict defaults to looping over the keys in python
1356+
if (siv && siv->getValue()->type()->cast<DictType>()) {
1357+
sv = std::make_shared<SimpleValue>(
1358+
graph->insert(aten::keys, {siv->getValue()}, {}, stmt.range()));
1359+
}
13551360
emitLoopCommon(stmt.range(), body, sv, targets, {});
13561361
return;
13571362
}

0 commit comments

Comments
 (0)