@@ -595,7 +595,7 @@ def combine_scopes(left, right):
595595 """
596596
597597 def _parse_symbolic_shape_expressions (self , tensor_check_names , tensor_check_ids ):
598- # Output
598+ # Pre join output
599599 finished_expressions = []
600600
601601 # A mapping of tensor_ids to tensor names
@@ -604,7 +604,7 @@ def _parse_symbolic_shape_expressions(self, tensor_check_names, tensor_check_ids
604604 # We should not have a shape env, or guards if we are not in config.dynamic shapes
605605 # But check it anyway.
606606 if not config .dynamic_shapes :
607- return finished_expressions
607+ return None
608608
609609 expr_to_tensor_ref = {}
610610 guard_printer = DynamoGuardPrinter (expr_to_tensor_ref , id_to_name_map )
@@ -625,24 +625,14 @@ def _parse_symbolic_shape_expressions(self, tensor_check_names, tensor_check_ids
625625 expr_to_tensor_ref [obj_expr ] = {}
626626 expr_to_tensor_ref [obj_expr ][tensor_ref ] = ""
627627 finished_expressions .append (f"isinstance({ name } , torch.Tensor)" )
628- # Extract all the guard elements out of guards
629- # The guard format, atm, uses tuple position 0 for the expression
630- # and tuple position 1 for a negation. Eventually, these will be collapsed together.
631- expression_and_evaluation = [
632- (guard [0 ], guard [1 ]) for guard in self .output_graph .shape_env .guards
633- ]
634- for expression , evaluation in expression_and_evaluation :
635- expr_as_str = guard_printer .doprint (expression )
636- # We may get into a state where symbolic shape keys (all should be found in replacements)
637- # Have not been removed from the expression. This is a serious enough error state that we need to assert.
638- for key in self .output_graph .shape_env .var_to_val .keys ():
639- assert str (key ) not in expr_as_str , f"Unknown shape symbol { key } . "
640-
641- # Certain expressions are negated in their guards.
642- if not evaluation :
643- expr_as_str = f"not ({ expr_as_str } )"
644-
645- finished_expressions .append (expr_as_str )
628+
629+ guard_expression = self .output_graph .shape_env .get_guard_expr ()
630+ expr_as_str = guard_printer .doprint (guard_expression )
631+ # We may get into a state where symbolic shape keys (all should be found in replacements)
632+ # Have not been removed from the expression. This is a serious enough error state that we need to assert.
633+ for key in self .output_graph .shape_env .var_to_val .keys ():
634+ assert str (key ) not in expr_as_str , f"Unknown shape symbol { key } . "
635+ finished_expressions .append (expr_as_str )
646636
647637 for expr in expr_to_tensor_ref .keys ():
648638 tensor_refs = expr_to_tensor_ref [expr ].keys ()
@@ -653,9 +643,15 @@ def _parse_symbolic_shape_expressions(self, tensor_check_names, tensor_check_ids
653643
654644 if len (equality_candidates ) > 1 :
655645 equality_expr = " == " .join (equality_candidates )
646+ # breakpoint()
656647 finished_expressions .append (equality_expr )
657648
658- return finished_expressions
649+ # Redundant with code_parts, but allows us to wrap it with parens nicely.
650+ if len (finished_expressions ) == 0 :
651+ return None
652+
653+ expression = " and " .join (finished_expressions )
654+ return f"({ expression } )"
659655
660656 def compile_check_fn (self , local_builder , global_builder ):
661657 assert not (set (local_builder .argnames ) & set (global_builder .argnames ))
@@ -683,12 +679,12 @@ def compile_check_fn(self, local_builder, global_builder):
683679 check_tensors_fn = None
684680 check_tensors_verbose_fn = None
685681 if tensor_check_names :
686- finished_expressions = self ._parse_symbolic_shape_expressions (
682+ symbolic_shape_expression = self ._parse_symbolic_shape_expressions (
687683 tensor_check_names , tensor_check_ids
688684 )
689- for expression in finished_expressions :
690- code_parts .append (expression )
691- verbose_code_parts .append (expression )
685+ if symbolic_shape_expression :
686+ code_parts .append (symbolic_shape_expression )
687+ verbose_code_parts .append (symbolic_shape_expression )
692688
693689 tensor_check_examples = (
694690 local_builder .tensor_check_examples
@@ -708,6 +704,9 @@ def compile_check_fn(self, local_builder, global_builder):
708704 def direct_equality (a , b ):
709705 return a == b
710706
707+ def direct_negation (a , b ):
708+ return not direct_equality (a , b )
709+
711710 code = " and " .join (unique (code_parts ))
712711 closure_vars = collections .OrderedDict (
713712 [
@@ -716,6 +715,7 @@ def direct_equality(a, b):
716715 ("___check_tensors_verbose" , check_tensors_verbose_fn ),
717716 ("tensor_check_names" , tensor_check_names ),
718717 ("Eq" , direct_equality ),
718+ ("Ne" , direct_negation ),
719719 ("Mod" , sympy .Mod ),
720720 ("FloorDiv" , FloorDiv ),
721721 ]
0 commit comments