Skip to content

Commit c9344fc

Browse files
wanchaolfacebook-github-bot
authored andcommitted
add for in string support (#21990)
Summary: Pull Request resolved: #21990 ghimport-source-id: 69b4882 Test Plan: Imported from OSS Differential Revision: D15948547 Pulled By: wanchaol fbshipit-source-id: 057e7f4fb67c6dca98458ceb14414368e1a86260
1 parent eab3575 commit c9344fc

File tree

3 files changed

+25
-1
lines changed

3 files changed

+25
-1
lines changed

test/test_jit.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9363,6 +9363,27 @@ def list_iterables(x):
93639363
return x
93649364
''')
93659365

9366+
def test_for_in_string(self):
9367+
def test_strings(x):
9368+
# type: (str) -> str
9369+
reverse = ""
9370+
for c in x:
9371+
reverse = c + reverse
9372+
return reverse
9373+
9374+
self.checkScript(test_strings, ("hello",))
9375+
self.checkScript(test_strings, ("",))
9376+
9377+
def test_list_strings(x):
9378+
# type: (List[str]) -> str
9379+
result = ""
9380+
for sub_str in x:
9381+
result += sub_str
9382+
return result
9383+
9384+
self.checkScript(test_list_strings, (["hello", "world"],))
9385+
self.checkScript(test_list_strings, (["hello", " ", "world", ""],))
9386+
93669387
def test_for_tuple_unpack(self):
93679388
def for_tuple_unpack(x, y):
93689389
for i, j in [[3, 4], [5, 6], [7, 8]]:

torch/csrc/jit/script/compiler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1344,7 +1344,7 @@ struct to_ir {
13441344
auto sv = emitSugaredExpr(itrs[0], 1);
13451345

13461346
// We will get IterableTree for builtinFunctions zip() and enumerate(),
1347-
// RangeValue for range(), and SimpleValue for types like List, Tensor, Dict.
1347+
// RangeValue for range(), and SimpleValue for types like List/Tensor/Dict/String.
13481348
auto range_val = std::dynamic_pointer_cast<RangeValue>(sv);
13491349
auto siv = std::dynamic_pointer_cast<SimpleValue>(sv);
13501350
auto iterable_tree = std::dynamic_pointer_cast<IterableTree>(sv);

torch/csrc/jit/script/sugared_value.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ Value* SimpleValue::len(const SourceRange& loc, Function& m) {
265265
TypePtr val_type = val->type();
266266
Graph& g = *m.graph();
267267
if (val_type->cast<ListType>() ||
268+
val_type->cast<StringType>() ||
268269
val_type->isSubtypeOf(TensorType::get())) {
269270
return g.insert(aten::len, {val}, {}, loc);
270271
} else {
@@ -280,6 +281,8 @@ Value* SimpleValue::getelem(const SourceRange&loc, Function& m, Value* i) {
280281
Value* cur_elem = nullptr;
281282
if (val_type->cast<ListType>()) {
282283
cur_elem = g.insert(aten::select, {val, i}, {}, loc);
284+
} else if (val_type->cast<StringType>()) {
285+
cur_elem = g.insert(prim::StringIndex, {val, i}, {}, loc);
283286
} else if (val_type->isSubtypeOf(TensorType::get())) {
284287
cur_elem = g.insert(aten::select, {val, 0, i}, {}, loc);
285288
} else {

0 commit comments

Comments
 (0)