Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 48 additions & 41 deletions torch/_dynamo/variables/ctx_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@
import torch._C
from torch._guards import Guard

from .. import variables
from .. import graph_break_hints, variables
from ..bytecode_transformation import (
create_call_function,
create_instruction,
create_setup_with,
)
from ..device_interface import get_interface_for_device
from ..exc import unimplemented, Unsupported
from ..exc import unimplemented_v2
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, GlobalStateSource
from .base import VariableTracker
Expand Down Expand Up @@ -173,40 +173,27 @@ def fn_name(self):

def enter(self, tx):
source = None if self.source is None else AttrSource(self.source, "__enter__")
try:
return variables.UserMethodVariable(
self.cm_obj.__enter__.__func__,
self,
source=source,
).call_function(tx, [], {})
except Unsupported as e:
unimplemented(
f"Unsupported context manager {self.cm_obj}'s __enter__ function",
from_exc=e,
)
return variables.UserMethodVariable(
self.cm_obj.__enter__.__func__,
self,
source=source,
).call_function(tx, [], {})

def exit(self, tx: "InstructionTranslator", *args):
source = None if self.source is None else AttrSource(self.source, "__exit__")
try:
x = variables.UserMethodVariable(
self.cm_obj.__exit__.__func__,
self,
source=source,
).call_function(
tx,
[
variables.ConstantVariable.create(None),
variables.ConstantVariable.create(None),
variables.ConstantVariable.create(None),
],
{},
)
except Unsupported as e:
unimplemented(
f"Unsupported context manager {self.cm_obj}'s __exit__ function",
from_exc=e,
)

x = variables.UserMethodVariable(
self.cm_obj.__exit__.__func__,
self,
source=source,
).call_function(
tx,
[
variables.ConstantVariable.create(None),
variables.ConstantVariable.create(None),
variables.ConstantVariable.create(None),
],
{},
)
tx.active_generic_context_managers.pop()
return x

Expand Down Expand Up @@ -921,11 +908,13 @@ def fn_name(self):
return "nullcontext"

def reconstruct(self, cg):
unimplemented(
"""
Dynamo doesn't support compiling a region that leaks torch profiler context
objects which will be used outside the region
"""
unimplemented_v2(
gb_type="torch.profiler object escaped from compiled region",
context=str(self),
explanation="Dynamo doesn't support compiling a region that returns a torch.profiler context manager.",
hints=[
*graph_break_hints.SUPPORTABLE,
],
)


Expand Down Expand Up @@ -1043,8 +1032,16 @@ def exit(self, tx: "InstructionTranslator", *args):
).call_function(tx, [self.tensors, self.prev_versions], {})

def reconstruct(self, codegen):
unimplemented(
"torch.autograd._unsafe_preserve_version_counter with graph break"
unimplemented_v2(
gb_type="torch.autograd._unsafe_preserve_version_counter escaped from compiled region",
context=str(self),
explanation=(
"Dynamo doesn't support compiling a region that returns "
"a torch.autograd._unsafe_preserve_version_counter context manager."
),
hints=[
*graph_break_hints.SUPPORTABLE,
],
)


Expand Down Expand Up @@ -1292,7 +1289,17 @@ def call_method(
),
)
else:
unimplemented(f"event method {name} unsupported")
unimplemented_v2(
gb_type="Unsupported torch.cuda.Event method",
context=str(name),
explanation=(
f"Dynamo doesn't support tracing the torch.cuda.Event.{name} method. "
f"We currently support wait, record, synchronize, and query.",
),
hints=[
*graph_break_hints.SUPPORTABLE,
],
)

def as_proxy(self):
return self.proxy
Expand Down
Loading