Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ namespace c10 {
_(aten, list) \
_(aten, wait) \
_(aten, save) \
_(aten, keys) \
_(aten, ord) \
_(aten, chr) \
_(aten, hex) \
Expand Down
22 changes: 22 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9384,6 +9384,28 @@ def test_list_strings(x):
self.checkScript(test_list_strings, (["hello", "world"],))
self.checkScript(test_list_strings, (["hello", " ", "world", ""],))

def test_for_in_dict(self):
def test_dicts(x):
# type: (Dict[str, int]) -> int
sum = 0
for key in x:
sum += x[key]
return sum

self.checkScript(test_dicts, ({"a": 1, "b": 2, "c": 3},))

def test_dict_keys_values(x):
# type: (Dict[str, int]) -> Tuple[str, int]
key_str = ""
sum = 0
for key in x.keys():
key_str += key
for val in x.values():
sum += val
return key_str, sum

self.checkScript(test_dicts, ({"a": 1, "b": 2, "c": 3},))

def test_for_tuple_unpack(self):
def for_tuple_unpack(x, y):
for i, j in [[3, 4], [5, 6], [7, 8]]:
Expand Down
7 changes: 6 additions & 1 deletion torch/csrc/jit/script/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1341,7 +1341,7 @@ struct to_ir {
}
// Emit loop information for builtinFunction values like range(), zip(),
// enumerate() or SimpleValue like List, Tensor, Dict, etc.
auto sv = emitSugaredExpr(itrs[0], 1);
SugaredValuePtr sv = emitSugaredExpr(itrs[0], 1);

// We will get IterableTree for builtinFunctions zip() and enumerate(),
// RangeValue for range(), and SimpleValue for types like List/Tensor/Dict/String.
Expand All @@ -1352,6 +1352,11 @@ struct to_ir {
// For SimpleValue(except Tuple) or RanveValue/IterableTree, emit common loop
if ((siv && !siv->getValue()->type()->cast<TupleType>())
|| range_val || iterable_tree) {
// looping over a dict defaults to looping over the keys in python
if (siv && siv->getValue()->type()->cast<DictType>()) {
sv = std::make_shared<SimpleValue>(
graph->insert(aten::keys, {siv->getValue()}, {}, stmt.range()));
}
emitLoopCommon(stmt.range(), body, sv, targets, {});
return;
}
Expand Down