Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/dynamo/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def func(x):
for guard in out_guards:
if guard.source == GuardSource.SHAPE_ENV:
hit = True
self.assertTrue("x.size()[0] <= 10" in guard.code_list[0])
self.assertTrue("x.size()[0] <= 10" in guard.code_list)

self.assertTrue(hit)

Expand Down
33 changes: 33 additions & 0 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3282,6 +3282,39 @@ def guard_failures(failure):
self.assertTrue(guard_failure is not None)
self.assertEqual(guard_failure[0], "k == 3")

@patch.object(torch._dynamo.config, "dynamic_shapes", True)
def test_guard_failure_fn_shape_control(self):
def fn(x, y):
if x.shape[0] < 3:
if y.shape[0] < 3:
return x * y
else:
return x + y
else:
return -1

x = torch.randn([2, 2])
y = torch.randn([2, 2])

guard_failure = None

def guard_failures(failure):
nonlocal guard_failure
guard_failure = failure

opt_fn = torch._dynamo.optimize(
"eager", nopython=True, guard_fail_fn=guard_failures
)(fn)

x2 = torch.randn([5, 5])
y2 = torch.randn([5, 5])

opt_fn(x, y)
opt_fn(x2, y2)

self.assertTrue(guard_failure is not None)
self.assertEqual(guard_failure[0], "x.size()[0] < 3")

def test_guard_failure_fn2(self):
def fn(x, y):
x = x + 1
Expand Down
4 changes: 2 additions & 2 deletions test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,8 +1081,8 @@ def f(a, b):
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(16), torch.randn(8))
from torch._dynamo.source import LocalSource
self.assertExpectedInline(
fx_g.shape_env.codegen_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")]),
"""a.size()[0] == 2*b.size()[0] and a.stride()[0] == 1 and a.storage_offset() == 0 and b.stride()[0] == 1 and b.storage_offset() == 0 and b.size()[0] != 0 and b.size()[0] != 1""" # noqa: B950
str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")])),
"""['a.size()[0] == 2*b.size()[0]', 'a.stride()[0] == 1', 'a.storage_offset() == 0', 'b.stride()[0] == 1', 'b.storage_offset() == 0', 'b.size()[0] != 0 and b.size()[0] != 1']""" # noqa: B950
)

def test_sym_storage_offset(self):
Expand Down
6 changes: 3 additions & 3 deletions torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,13 +401,13 @@ def SHAPE_ENV(self, guard: Guard):
output_graph = self.guarded_code.output_graph
# NB: self.output_graph can be None in the debug_nops tests
fs = output_graph.tracked_fakes
code = output_graph.shape_env.codegen_guards(
guards = output_graph.shape_env.produce_guards(
[a.fake for a in fs],
[a.source for a in fs],
source_ref=self.source_ref,
)
if code != "True":
self._produce_guard_code(guard, [code], shape_env=True)
for shape_guard in guards:
self._produce_guard_code(guard, [shape_guard], shape_env=True)

def TENSOR_MATCH(self, guard: Guard):
if guard.is_nn_module():
Expand Down
20 changes: 10 additions & 10 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,13 +816,13 @@ def duck_int(self, val):
)
return self.val_to_var[val]

# Generates a Python string which, when evaluated in a context that
# Generates a list of guards strings which, when evaluated in a context that
# defines tensors for all the sources, returns True or False depending
# on if the guards evaluated to True or not. Primarily used by Dynamo,
# on if the guards in the list evaluated to True or not. Primarily used by Dynamo,
# but this is also helpful for manual testing of guards (see
# evaluate_guards_for_args)
def codegen_guards(self, placeholders, sources,
source_ref=lambda n: n.name()):
def produce_guards(self, placeholders, sources,
source_ref=lambda n: n.name()) -> List[str]:
# It took a lot of sweat to figure out the algorithm here. Let's
# explain how it works.
#
Expand Down Expand Up @@ -963,16 +963,16 @@ def track_symint(source, val):
# negative inferences on shape variables
exprs.append(f"{source_ref(sources[0])} != 0 and {source_ref(sources[0])} != 1")

if exprs:
return " and ".join(exprs)
else:
return "True"
return exprs

def evaluate_guards_for_args(self, placeholders, args):
from torch._dynamo.source import GlobalSource
arg_names = [f"t{i}" for i in range(len(args))]
code = self.codegen_guards(placeholders, [GlobalSource(a) for a in arg_names])
return eval(code, {}, dict(zip(arg_names, args)))
guards = self.produce_guards(placeholders, [GlobalSource(a) for a in arg_names])
if guards:
code = " and ".join(guards)
return eval(code, {}, dict(zip(arg_names, args)))
return True

def bind_symbols(self, placeholders, args):
# Given a paired list of placeholders (fake tensors with
Expand Down