Skip to content

Commit babaac3

Browse files
zou3519facebook-github-bot
authored andcommitted
Fix bug with named tensors and (no) tracer support (#26106)
Summary: Pull Request resolved: #26106 Previously, in the named tensors build, an operator is marked as non-traceable if ANY of its overloads are named tensor overloads. This breaks the tracer for things like torch.full (has a names= overload for named tensor) and tensor.sum (has a Dimname overload for named tensor). This PR fixes the problem by putting the "no tracer support" logic into the location where the tracer attempts to construct a graph by adding a Dimname/DimnameList argument to a node. Test Plan: - new test in test_jit.py to check if torch.full is traceable - new test in test_namedtensor.py to check what happens when someone tries to trace a function that uses named tensor APIs. - [namedtensor ci] Differential Revision: D17353452 Pulled By: zou3519 fbshipit-source-id: b0b843c8357ffe54baee6e8df86db914f0b1ece4
1 parent 33221b1 commit babaac3

File tree

5 files changed

+39
-6
lines changed

5 files changed

+39
-6
lines changed

test/test_jit.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,6 +1650,17 @@ def test_trace_arange(self):
16501650
def test_trace_arange_with_grad(self):
16511651
self.do_trace_arange(True)
16521652

1653+
# Test that a trace of torch.full(x.shape) doesn't store the shape as a constant
1654+
def test_trace_full_dynamic_shape(self):
1655+
def full_with_shape_like(x):
1656+
return torch.full(x.shape, 2)
1657+
1658+
x = torch.randn(3, 4)
1659+
ge = torch.jit.trace(full_with_shape_like, example_inputs=x)
1660+
y = torch.randn(2, 7)
1661+
self.assertEqual(ge(y).shape, y.shape)
1662+
self.assertEqual(ge(x).shape, x.shape)
1663+
16531664
def test_trace_casts(self):
16541665
casts = [
16551666
lambda x: x.byte(),

test/test_namedtensor.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1088,7 +1088,22 @@ def test_as_strided(self):
10881088
def test_as_strided_cuda(self):
10891089
self._test_as_strided('cuda')
10901090

1091-
def test_no_jit_support(self):
1091+
def test_no_jit_tracer_support(self):
1092+
def foo(x):
1093+
return torch.full(x.shape, 2, names=('N',))
1094+
1095+
with self.assertRaisesRegex(RuntimeError, 'not supported with the tracer'):
1096+
x = torch.randn(3)
1097+
torch.jit.trace(foo, example_inputs=x)
1098+
1099+
def bar(x):
1100+
return x.select('N', 1)
1101+
1102+
with self.assertRaisesRegex(RuntimeError, 'not supported with the tracer'):
1103+
x = torch.randn(3)
1104+
torch.jit.trace(bar, example_inputs=x)
1105+
1106+
def test_no_jit_script_support(self):
10921107
@torch.jit.script
10931108
def foo(x):
10941109
return x + 1

tools/autograd/gen_variable_type.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -281,11 +281,6 @@ def find_factory_functions(declarations):
281281

282282

283283
def should_trace(declaration):
284-
# Short-term plan: Don't support tracing Dimname.
285-
# Long-term plan: Add Dimname as a first-class type to the JIT.
286-
if any('Dimname' in arg['simple_type'] for arg in declaration['arguments']):
287-
return False
288-
289284
# Operations involving Storage or Type are not traceable at the moment
290285
if any(arg['simple_type'] in {'Storage', 'Type', 'ConstQuantizerPtr'} for arg in declaration['arguments']):
291286
return False

torch/csrc/jit/tracer.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,14 @@ void addInputs(
477477
n->addInput(none);
478478
}
479479
}
480+
#ifdef BUILD_NAMEDTENSOR
481+
void addInputs(
482+
Node* n,
483+
const char* name,
484+
c10::optional<at::DimnameList> value) {
485+
TORCH_CHECK(false, "NYI: Named tensors are not supported with the tracer");
486+
}
487+
#endif
480488
void addInputs(
481489
Node* n,
482490
const char* name,

torch/csrc/jit/tracer.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <c10/util/Exception.h>
55
#include <torch/csrc/WindowsTorchApiMacro.h>
66
#include <ATen/core/jit_type.h>
7+
#include <ATen/core/Dimname.h>
78

89
#include <torch/csrc/utils/variadic.h>
910

@@ -280,6 +281,9 @@ TORCH_API void addInputs(
280281
const char* name,
281282
const c10::optional<at::ScalarType>& value);
282283
TORCH_API void addInputs(Node* n, const char* name, at::MemoryFormat value);
284+
#ifdef BUILD_NAMEDTENSOR
285+
TORCH_API void addInputs(Node* n, const char* name, c10::optional<at::DimnameList> value);
286+
#endif
283287
TORCH_API void addInputs(
284288
Node* n,
285289
const char* name,

0 commit comments

Comments
 (0)