Skip to content

Commit 64d973e

Browse files
author
Meghan Lele
committed
[JIT] Cast return values of functions returning Any
**Summary** This commit modifies IR generation to insert explicit cast that cast each return value to `Any` when a function is annotated as returning `Any`. This precludes the failure in type unification (see below) that caused this issue. Issue #41962 reported that the use of an `Any` return type in combination with different code paths returning values of different types causes a segmentation fault. This is because the exit transform pass tries to unify the different return types, fails, but silently sets the type of the if node to c10::nullopt. This causes problems later in shape analysis when that type object is dereferenced. **Test Plan** This commit adds a unit test that checks that a function similar to the one in #41962 can be scripted and executed. **Fixes** This commit fixes #41962. ghstack-source-id: 77d7dbf Pull Request resolved: #42259
1 parent 4e964f3 commit 64d973e

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

test/test_jit_py3.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,20 @@ def forward(self, x):
542542
# Check that ignored method is still intact.
543543
self.assertEqual(inp, n.ignored_method(inp))
544544

545+
def test_if_returning_any(self):
546+
"""
547+
Check that an if statement can return different
548+
types early from each branch when the return
549+
type of the function is Any.
550+
"""
551+
def if_function(inp: torch.Tensor) -> Any:
552+
if inp.shape[0] == 1:
553+
return inp * inp
554+
else:
555+
return "str"
556+
557+
self.checkScript(if_function, (torch.randn(5),))
558+
545559
def test_export_opnames_interface(self):
546560
global OneTwoModule
547561

torch/csrc/jit/frontend/ir_emitter.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -996,7 +996,13 @@ struct to_ir {
996996
result_type = merged_result_type.value();
997997
}
998998
AT_ASSERT(result_type);
999+
9991000
def_stack_.back().merged_return_type_ = result_type;
1001+
1002+
if (result_type == AnyType::get() && result->type() != AnyType::get()) {
1003+
result = graph->insertUncheckedCast(result, result_type);
1004+
}
1005+
10001006
graph->insertNode(graph->create(prim::ReturnStmt, {result}, 0));
10011007
exit_blocks.insert(environment_stack->block());
10021008
}

0 commit comments

Comments
 (0)