Skip to content

Commit f368f26

Browse files
author
Meghan Lele
committed
[JIT] Check mergeability of return types for functions that return Any
**Summary** This commit modifies the type checking performed during IR generation so that it checks that all return types for a function marked as returning `Any` can be unified. Issue #41962 reported that the use of an `Any` return type in combination with different code paths returning values of different types causes a segmentation fault. During IR generation, the type of each return is checked against the annotated return type of the function, or unified with all other return types if there is no annotation. Because every type is a subtype of `Any`, this commit modifies the IR generation to do the latter (i.e. check if all return types can be unified) if the annotated return type is `Any`. **Test Plan** This commit adds a unit test that checks that an exception with an appropriate error message is thrown when a function and module with an annotated return type of `Any` is compiled and the possible return types cannot be unified. **Fixes** This commit fixes #41962. ghstack-source-id: bc3ca78 Pull Request resolved: #42259
1 parent 27b03d6 commit f368f26

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

test/test_jit_py3.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch.testing._internal.jit_utils
1212
import torch.nn as nn
1313
import types
14+
import inspect
1415

1516
class TestScriptPy3(JitTestCase):
1617
def test_joined_str(self):
@@ -533,6 +534,38 @@ def forward(self, x):
533534
# Check that ignored method is still intact.
534535
self.assertEqual(inp, n.ignored_method(inp))
535536

537+
def test_do_not_use_any_as_return_supertype(self):
538+
"""
539+
Check that if a function is annotated as returning any,
540+
all return paths are checked against each other (to see if
541+
they have a common supertype) instead of against the
542+
return type.
543+
"""
544+
class Mod(nn.Module):
545+
def __init__(self):
546+
super().__init__()
547+
548+
def forward(self, inp: torch.Tensor) -> Any:
549+
if inp.shape[0] == 1:
550+
return inp * inp
551+
else:
552+
return 3
553+
554+
def any_function(inp: torch.Tensor) -> Any:
555+
if inp.shape[0] == 1:
556+
return inp * inp
557+
else:
558+
return 3
559+
560+
with self.assertRaisesRegex(RuntimeError, "Previous return statement returned a value of type Tensor but this return statement returns a value of type int"):
561+
torch.jit.script(Mod())
562+
563+
with self.assertRaisesRegex(RuntimeError, "Previous return statement returned a value of type Tensor but this return statement returns a value of type int"):
564+
torch.jit.script(any_function)
565+
566+
with self.assertRaisesRegex(RuntimeError, "Previous return statement returned a value of type Tensor but this return statement returns a value of type int"):
567+
cu = torch.jit.CompilationUnit(inspect.getsource(any_function))
568+
536569
def test_export_opnames_interface(self):
537570
global OneTwoModule
538571

torch/csrc/jit/frontend/ir_emitter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -956,7 +956,7 @@ struct to_ir {
956956
Value* result = emitExpr(stmt.expr());
957957
TypePtr result_type = def_stack_.back().declared_return_type_;
958958
// result type is annotated, every return must convert to that type
959-
if (result_type) {
959+
if (result_type && result_type != AnyType::get()) {
960960
// this guard skips implicit conversion from None -> Tensor for the return
961961
// type. otherwise forgetting a return a function returning a tensor will
962962
// cause a None to be converted to a tensor.

0 commit comments

Comments
 (0)