Skip to content

Commit f3c97a4

Browse files
authored
Raise error on 3.11 dynamo export (#95088) (#95396)
For #94914. Realized that `dynamo.export` doesn't immediately raise an error when dynamo is trying to run on 3.11/windows. Pull Request resolved: #95088 Approved by: https://github.com/weiwangmeta
1 parent 30cf0e7 commit f3c97a4

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

torch/_dynamo/eval_frame.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,13 @@ def __call__(self, fn):
370370
return fn
371371

372372

373+
def check_if_dynamo_supported():
374+
if sys.platform == "win32":
375+
raise RuntimeError("Windows not yet supported for torch.compile")
376+
if sys.version_info >= (3, 11):
377+
raise RuntimeError("Python 3.11+ not yet supported for torch.compile")
378+
379+
373380
def optimize(
374381
backend="inductor",
375382
*,
@@ -403,6 +410,7 @@ def optimize(
403410
def toy_example(a, b):
404411
...
405412
"""
413+
check_if_dynamo_supported()
406414
# Note: The hooks object could be global instead of passed around, *however* that would make
407415
# for a confusing API usage and plumbing story wherein we nest multiple .optimize calls.
408416
# There is some prior art around this, w/r/t nesting backend calls are enforced to be the same
@@ -412,10 +420,6 @@ def toy_example(a, b):
412420
torch._C._log_api_usage_once("torch._dynamo.optimize")
413421
if disable or os.environ.get("TORCHDYNAMO_DISABLE", "") == "1":
414422
return _NullDecorator()
415-
if sys.platform == "win32":
416-
raise RuntimeError("Windows not yet supported for torch.compile")
417-
if sys.version_info >= (3, 11):
418-
raise RuntimeError("Python 3.11+ not yet supported for torch.compile")
419423

420424
backend = get_compiler_fn(backend)
421425

@@ -517,6 +521,7 @@ def guard_export_print(guards):
517521
def export(
518522
f, *args, aten_graph=False, decomposition_table=None, tracing_mode="real", **kwargs
519523
):
524+
check_if_dynamo_supported()
520525
torch._C._log_api_usage_once("torch._dynamo.export")
521526
if decomposition_table is not None or tracing_mode != "real":
522527
assert (

0 commit comments

Comments
 (0)