Skip to content

Commit ad0f43f

Browse files
committed
Use new megaguard, optimize some checks
1 parent a12c61d commit ad0f43f

File tree

1 file changed

+25
-25
lines changed

1 file changed

+25
-25
lines changed

torch/_dynamo/guards.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,7 @@ def combine_scopes(left, right):
595595
"""
596596

597597
def _parse_symbolic_shape_expressions(self, tensor_check_names, tensor_check_ids):
598-
# Output
598+
# Pre join output
599599
finished_expressions = []
600600

601601
# A mapping of tensor_ids to tensor names
@@ -604,7 +604,7 @@ def _parse_symbolic_shape_expressions(self, tensor_check_names, tensor_check_ids
604604
# We should not have a shape env, or guards if we are not in config.dynamic shapes
605605
# But check it anyway.
606606
if not config.dynamic_shapes:
607-
return finished_expressions
607+
return None
608608

609609
expr_to_tensor_ref = {}
610610
guard_printer = DynamoGuardPrinter(expr_to_tensor_ref, id_to_name_map)
@@ -625,24 +625,14 @@ def _parse_symbolic_shape_expressions(self, tensor_check_names, tensor_check_ids
625625
expr_to_tensor_ref[obj_expr] = {}
626626
expr_to_tensor_ref[obj_expr][tensor_ref] = ""
627627
finished_expressions.append(f"isinstance({name}, torch.Tensor)")
628-
# Extract all the guard elements out of guards
629-
# The guard format, atm, uses tuple position 0 for the expression
630-
# and tuple position 1 for a negation. Eventually, these will be collapsed together.
631-
expression_and_evaluation = [
632-
(guard[0], guard[1]) for guard in self.output_graph.shape_env.guards
633-
]
634-
for expression, evaluation in expression_and_evaluation:
635-
expr_as_str = guard_printer.doprint(expression)
636-
# We may get into a state where symbolic shape keys (all should be found in replacements)
637-
# Have not been removed from the expression. This is a serious enough error state that we need to assert.
638-
for key in self.output_graph.shape_env.var_to_val.keys():
639-
assert str(key) not in expr_as_str, f"Unknown shape symbol {key}. "
640-
641-
# Certain expressions are negated in their guards.
642-
if not evaluation:
643-
expr_as_str = f"not ({expr_as_str})"
644-
645-
finished_expressions.append(expr_as_str)
628+
629+
guard_expression = self.output_graph.shape_env.get_guard_expr()
630+
expr_as_str = guard_printer.doprint(guard_expression)
631+
# We may get into a state where symbolic shape keys (all should be found in replacements)
632+
# Have not been removed from the expression. This is a serious enough error state that we need to assert.
633+
for key in self.output_graph.shape_env.var_to_val.keys():
634+
assert str(key) not in expr_as_str, f"Unknown shape symbol {key}. "
635+
finished_expressions.append(expr_as_str)
646636

647637
for expr in expr_to_tensor_ref.keys():
648638
tensor_refs = expr_to_tensor_ref[expr].keys()
@@ -653,9 +643,15 @@ def _parse_symbolic_shape_expressions(self, tensor_check_names, tensor_check_ids
653643

654644
if len(equality_candidates) > 1:
655645
equality_expr = " == ".join(equality_candidates)
646+
# breakpoint()
656647
finished_expressions.append(equality_expr)
657648

658-
return finished_expressions
649+
# Redundant with code_parts, but allows us to wrap it with parens nicely.
650+
if len(finished_expressions) == 0:
651+
return None
652+
653+
expression = " and ".join(finished_expressions)
654+
return f"({expression})"
659655

660656
def compile_check_fn(self, local_builder, global_builder):
661657
assert not (set(local_builder.argnames) & set(global_builder.argnames))
@@ -683,12 +679,12 @@ def compile_check_fn(self, local_builder, global_builder):
683679
check_tensors_fn = None
684680
check_tensors_verbose_fn = None
685681
if tensor_check_names:
686-
finished_expressions = self._parse_symbolic_shape_expressions(
682+
symbolic_shape_expression = self._parse_symbolic_shape_expressions(
687683
tensor_check_names, tensor_check_ids
688684
)
689-
for expression in finished_expressions:
690-
code_parts.append(expression)
691-
verbose_code_parts.append(expression)
685+
if symbolic_shape_expression:
686+
code_parts.append(symbolic_shape_expression)
687+
verbose_code_parts.append(symbolic_shape_expression)
692688

693689
tensor_check_examples = (
694690
local_builder.tensor_check_examples
@@ -708,6 +704,9 @@ def compile_check_fn(self, local_builder, global_builder):
708704
def direct_equality(a, b):
709705
return a == b
710706

707+
def direct_negation(a, b):
708+
return not direct_equality(a, b)
709+
711710
code = " and ".join(unique(code_parts))
712711
closure_vars = collections.OrderedDict(
713712
[
@@ -716,6 +715,7 @@ def direct_equality(a, b):
716715
("___check_tensors_verbose", check_tensors_verbose_fn),
717716
("tensor_check_names", tensor_check_names),
718717
("Eq", direct_equality),
718+
("Ne", direct_negation),
719719
("Mod", sympy.Mod),
720720
("FloorDiv", FloorDiv),
721721
]

0 commit comments

Comments
 (0)