Skip to content

Commit 9ecc33d

Browse files
Elias Ellisonfacebook-github-bot
authored andcommitted
metacompile isinstance checks (#23885)
Summary: Pull Request resolved: #23885 This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in #23266. This PR only compiles one if branch if the condition is an isinstance check. This is consistent with what mypy does; it does not report errors if a branch can be determined statically to be unreachable. ``` def foo(x): # type: (int) -> int if isinstance(x, str): return x["1"] return x + 1 reveal_type(foo) # no error, shows int -> int ``` Test Plan: Imported from OSS Differential Revision: D16697092 Pulled By: eellison fbshipit-source-id: d3eb4925cd16d551515ac6ff620a69897dbec130
1 parent 33a1c30 commit 9ecc33d

File tree

2 files changed

+61
-6
lines changed

2 files changed

+61
-6
lines changed

test/test_jit.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12960,6 +12960,35 @@ def test({arg_str}):
1296012960

1296112961
FileCheck().check_not("prim::PythonOp").run(cu.test.graph)
1296212962

12963+
def test_isinstance_metacompile(self):
12964+
@torch.jit.script
12965+
def test_primitive_type(x):
12966+
# type: (int) -> int
12967+
if isinstance(x, int):
12968+
return x + 1
12969+
else:
12970+
return x - 1
12971+
12972+
self.assertEqual(test_primitive_type(1), 2)
12973+
with self.assertRaisesRegex(Exception, "Expected a value of type"):
12974+
test_primitive_type(1.5)
12975+
12976+
_MyNamedTuple = namedtuple('_MyNamedTuple', ['value'])
12977+
12978+
@torch.jit.script
12979+
def test_non_primitive_types(x):
12980+
# type: (_MyNamedTuple) -> Tensor
12981+
if isinstance(1, _MyNamedTuple):
12982+
return 10
12983+
12984+
if isinstance(x, _MyNamedTuple):
12985+
return x.value + 1
12986+
else:
12987+
return 1
12988+
12989+
out = test_non_primitive_types(_MyNamedTuple(value=torch.tensor(5.0)))
12990+
self.assertEqual(out, torch.tensor(6.0))
12991+
1296312992
@unittest.skipIf(True, "Removing weak script")
1296412993
def test_overloading(self):
1296512994
@torch._jit_internal.weak_module

torch/csrc/jit/script/compiler.cpp

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,18 +1280,44 @@ struct to_ir {
12801280
}
12811281
}
12821282

1283+
bool isIsInstanceCall(const Expr& expr) {
1284+
if (expr.kind() != TK_APPLY) {
1285+
return false;
1286+
}
1287+
auto callee = Apply(expr).callee();
1288+
return callee.kind() == TK_VAR && Var(callee).name().name() == "isinstance";
1289+
}
1290+
1291+
bool isPotentialNoneCheck(const Expr& expr) {
1292+
return expr.kind() == TK_IS || expr.kind() == TK_ISNOT;
1293+
}
1294+
12831295
void emitIf(const If& stmt) {
1284-
// NOTE: emitIf checks on If stmt condition to see if the cond AST kind ==
1285-
// is/is not, for such cases we do meta programming and disable emitting the
1296+
// NOTE: emitIf checks on If stmt condition to see if the cond AST is
1297+
// a potential none check with is/is not, or an isinstance check.
1298+
// for such cases we do meta programming and disable emitting the
12861299
// corresponding branches
12871300
Expr cond = stmt.cond();
1301+
bool isinstance_call = isIsInstanceCall(cond);
1302+
bool potential_none_check = !isinstance_call && isPotentialNoneCheck(cond);
12881303

1289-
if (cond.kind() != TK_IS && cond.kind() != TK_ISNOT) {
1290-
// emit normal IF stmt for cases except TK_IS and TK_ISNOT
1304+
if (!isinstance_call && !potential_none_check) {
1305+
// emit normal IF stmt for cases except isinstance & none checks
12911306
Value* cond_value = emitCond(cond);
1292-
emitIfElseBlocks(cond_value, stmt);
1293-
return;
1307+
return emitIfElseBlocks(cond_value, stmt);
1308+
}
1309+
1310+
if (isinstance_call) {
1311+
auto is_instance_result = emitSugaredExpr(cond, 1);
1312+
auto ivalue = toIValue(is_instance_result->asValue(cond.range(), method));
1313+
TORCH_INTERNAL_ASSERT(ivalue); // no support for runtime checks
1314+
if (ivalue->toBool()) {
1315+
return emitStatements(stmt.trueBranch());
1316+
} else {
1317+
return emitStatements(stmt.falseBranch());
1318+
}
12941319
}
1320+
12951321
// meta programming on AST for is/is not cases and emit branches base on the
12961322
// possible output of cond
12971323
auto cond_op = BinOp(cond);

0 commit comments

Comments
 (0)