Skip to content

Commit 8556cf2

Browse files
ezyangpytorchmergebot
authored andcommitted
Make backend_accuracy_fails suppress errors in same_two_models (#100324)
The basic idea is that if we're trying to match for an accuracy error, we don't want to switch to a compile/runtime error, because that's probably us breaking things in a different way. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: #100324 Approved by: https://github.com/voznesenskym
1 parent 054a254 commit 8556cf2

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

torch/_dynamo/debug_utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def run_fwd_maybe_bwd(gm, args, only_fwd=False):
309309
return collect_results(gm, out, None, args)
310310

311311

312-
def same_two_models(gm, opt_gm, example_inputs, only_fwd=False):
312+
def same_two_models(gm, opt_gm, example_inputs, only_fwd=False, *, require_fp64=False):
313313
"""
314314
Check two models have same accuracy.
315315
"""
@@ -336,6 +336,8 @@ def same_two_models(gm, opt_gm, example_inputs, only_fwd=False):
336336
)
337337
fp64_ref = run_fwd_maybe_bwd(fp64_model, fp64_examples, only_fwd)
338338
except Exception:
339+
if require_fp64:
340+
raise RuntimeError("Could not generate fp64 outputs")
339341
log.warning("Could not generate fp64 outputs")
340342
fp64_ref = None
341343

@@ -393,11 +395,16 @@ def cast_to_fp64(model, inputs):
393395
return cast_to(torch.float64, model, inputs)
394396

395397

396-
def backend_accuracy_fails(gm, example_inputs, compiler_fn, only_fwd=False):
398+
def backend_accuracy_fails(
399+
gm, example_inputs, compiler_fn, only_fwd=False, *, require_fp64=False
400+
):
397401
try:
398402
compiled_gm = compiler_fn(
399403
copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs)
400404
)
405+
return not same_two_models(
406+
gm, compiled_gm, example_inputs, only_fwd, require_fp64=require_fp64
407+
)
401408
except Exception as e:
402409
# This means that the the minified graph is bad/exposes a different problem.
403410
# As we are checking accuracy here, lets log the exception and return False.
@@ -409,5 +416,3 @@ def backend_accuracy_fails(gm, example_inputs, compiler_fn, only_fwd=False):
409416
)
410417
)
411418
return False
412-
413-
return not same_two_models(gm, compiled_gm, example_inputs, only_fwd)

0 commit comments

Comments
 (0)