Skip to content

Commit 00df49c

Browse files
Elias Ellisonfacebook-github-bot
authored andcommitted
Fix Trace inlining of graphs with optional inputs (#22686)
Summary: Previously in tracing when we called a script function we would inline the graph and set the graph inputs equal to the types the graph was invoked with. This breaks for optional arguments invoked with None since we rely on None being set to Optional[T] in schema matching. Pull Request resolved: #22686 Differential Revision: D16186372 Pulled By: eellison fbshipit-source-id: e25c807c63527bf442eb8b31122d50689c7822f5
1 parent 3e3e6ee commit 00df49c

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

test/test_jit.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8325,6 +8325,32 @@ def forward(self, x):
83258325
# constants not baked in
83268326
self.assertEqual(g(x), f(x))
83278327

8328+
def test_trace_optional(self):
8329+
@torch.jit.script
8330+
def test(x):
8331+
# type: (Optional[Tensor])
8332+
if x is None:
8333+
return torch.zeros(1)
8334+
else:
8335+
return x
8336+
8337+
def test_none():
8338+
return test(None)
8339+
8340+
def test_tensor():
8341+
return test(torch.zeros(2))
8342+
8343+
f_none = torch.jit.trace(test_none, ())
8344+
self.assertEqual(f_none(), torch.zeros(1))
8345+
8346+
f_tensor = torch.jit.trace(test_tensor, ())
8347+
self.assertEqual(f_tensor(), torch.zeros(2))
8348+
8349+
graph = f_tensor.graph
8350+
f = str(graph)
8351+
# tensor type correctly set as graph input
8352+
FileCheck().check("Double(2) = prim:").run(f)
8353+
83288354
def test_trace_nested_datatypes(self):
83298355
@torch.jit.script
83308356
def foo(x):

torch/csrc/jit/graph_executor_impl.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,17 @@ struct GraphExecutorImplBase {
8989
// been set.
9090
auto local_graph = this->graph->copy();
9191
for (size_t i = 0; i < input_values.size(); ++i) {
92-
local_graph->inputs().at(i)->setType(input_values.at(i)->type());
92+
// propagate tensor types
93+
if (input_values.at(i)->type()->cast<TensorType>()) {
94+
local_graph->inputs().at(i)->setType(input_values.at(i)->type());
95+
}
96+
97+
// None does not subtype Optional[T], schema matching relies on
98+
// None values being set to Optional[T] so update type here
99+
// see logic for Nones in tryConvertToType
100+
if (input_values.at(i)->type() == NoneType::get()) {
101+
input_values.at(i)->setType(local_graph->inputs().at(i)->type());
102+
}
93103
}
94104
PropagateInputShapes(local_graph);
95105
auto output_values =

0 commit comments

Comments
 (0)