Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9363,6 +9363,27 @@ def list_iterables(x):
return x
''')

def test_for_in_string(self):
def test_strings(x):
# type: (str) -> str
reverse = ""
for c in x:
reverse = c + reverse
return reverse

self.checkScript(test_strings, ("hello",))
self.checkScript(test_strings, ("",))

def test_list_strings(x):
# type: (List[str]) -> str
result = ""
for sub_str in x:
result += sub_str
return result

self.checkScript(test_list_strings, (["hello", "world"],))
self.checkScript(test_list_strings, (["hello", " ", "world", ""],))

def test_for_tuple_unpack(self):
def for_tuple_unpack(x, y):
for i, j in [[3, 4], [5, 6], [7, 8]]:
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/script/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1344,7 +1344,7 @@ struct to_ir {
auto sv = emitSugaredExpr(itrs[0], 1);

// We will get IterableTree for builtinFunctions zip() and enumerate(),
// RangeValue for range(), and SimpleValue for types like List, Tensor, Dict.
// RangeValue for range(), and SimpleValue for types like List/Tensor/Dict/String.
auto range_val = std::dynamic_pointer_cast<RangeValue>(sv);
auto siv = std::dynamic_pointer_cast<SimpleValue>(sv);
auto iterable_tree = std::dynamic_pointer_cast<IterableTree>(sv);
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/jit/script/sugared_value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ Value* SimpleValue::len(const SourceRange& loc, Function& m) {
TypePtr val_type = val->type();
Graph& g = *m.graph();
if (val_type->cast<ListType>() ||
val_type->cast<StringType>() ||
val_type->isSubtypeOf(TensorType::get())) {
return g.insert(aten::len, {val}, {}, loc);
} else {
Expand All @@ -280,6 +281,8 @@ Value* SimpleValue::getelem(const SourceRange&loc, Function& m, Value* i) {
Value* cur_elem = nullptr;
if (val_type->cast<ListType>()) {
cur_elem = g.insert(aten::select, {val, i}, {}, loc);
} else if (val_type->cast<StringType>()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to standardize a builtin like "getelem" to be used here rather than have a big if branch that emits different code, similar to our other magic methods. Furthermore, we should make serialization emit this as foo[i] so that we do not end up getting committed to preserving these operator names.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with this, will make a separate PR to define it. We will need to follow up on design the deprecation strategy of those prim operators that should be deleted.

cur_elem = g.insert(prim::StringIndex, {val, i}, {}, loc);
} else if (val_type->isSubtypeOf(TensorType::get())) {
cur_elem = g.insert(aten::select, {val, 0, i}, {}, loc);
} else {
Expand Down