Skip to content

Commit fb529c2

Browse files
anijain2305pytorchmergebot
authored andcommitted
[dynamo] skip_guard_eval_unsafe stance for power users (#140251)
Pull Request resolved: #140251 Approved by: https://github.com/jansel ghstack dependencies: #140223, #140250
1 parent 7392e88 commit fb529c2

File tree

12 files changed

+275
-11
lines changed

12 files changed

+275
-11
lines changed
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Owner(s): ["module: dynamo"]
2+
3+
import torch
4+
import torch._dynamo.test_case
5+
import torch._dynamo.testing
6+
7+
8+
def my_custom_function(x):
9+
return x + 1
10+
11+
12+
class RunDiffGuardTests(torch._dynamo.test_case.TestCase):
13+
def test_bool_recompile(self):
14+
def fn(x, y, c):
15+
if c:
16+
return x * y
17+
else:
18+
return x + y
19+
20+
opt_fn = torch.compile(fn, backend="inductor")
21+
x = 2 * torch.ones(4)
22+
y = 3 * torch.ones(4)
23+
24+
ref1 = opt_fn(x, y, True)
25+
ref2 = opt_fn(x, y, False)
26+
27+
with torch.compiler.set_stance(skip_guard_eval_unsafe=True):
28+
res2 = opt_fn(x, y, False)
29+
res1 = opt_fn(x, y, True)
30+
31+
self.assertEqual(ref1, res1)
32+
self.assertEqual(ref2, res2)
33+
34+
def test_tensor_recompile(self):
35+
def fn(x, y):
36+
return x * y
37+
38+
opt_fn = torch.compile(fn, backend="eager")
39+
x = torch.randn(4, dtype=torch.float32)
40+
y = torch.randn(4, dtype=torch.float32)
41+
42+
ref1 = opt_fn(x, y)
43+
44+
x64 = torch.randn(4, dtype=torch.float64)
45+
y64 = torch.randn(4, dtype=torch.float64)
46+
ref2 = opt_fn(x64, y64)
47+
48+
with torch.compiler.set_stance(skip_guard_eval_unsafe=True):
49+
res1 = opt_fn(x, y)
50+
res2 = opt_fn(x64, y64)
51+
52+
self.assertEqual(ref1, res1)
53+
self.assertEqual(ref2, res2)
54+
55+
def test_post_recompile(self):
56+
class Foo:
57+
a = 4
58+
b = 5
59+
60+
foo = Foo()
61+
62+
def fn(x):
63+
return x + foo.a + foo.b
64+
65+
cnts = torch._dynamo.testing.CompileCounter()
66+
opt_fn = torch.compile(fn, backend=cnts)
67+
68+
x = torch.randn(4)
69+
ref = fn(x)
70+
res = opt_fn(x)
71+
self.assertEqual(ref, res)
72+
self.assertEqual(cnts.frame_count, 1)
73+
74+
foo.a = 11
75+
ref = fn(x)
76+
res = opt_fn(x)
77+
self.assertEqual(ref, res)
78+
self.assertEqual(cnts.frame_count, 2)
79+
80+
with torch.compiler.set_stance(skip_guard_eval_unsafe=True):
81+
# Set it back to original value
82+
foo.a = 4
83+
ref = fn(x)
84+
res = opt_fn(x)
85+
self.assertEqual(ref, res)
86+
87+
foo.a = 11
88+
ref = fn(x)
89+
res = opt_fn(x)
90+
self.assertEqual(ref, res)
91+
92+
# Check that we are back to original behavior
93+
foo.b = 8
94+
ref = fn(x)
95+
res = opt_fn(x)
96+
self.assertEqual(ref, res)
97+
self.assertEqual(cnts.frame_count, 3)
98+
99+
def test_fail_on_tensor_shape_change(self):
100+
def fn(dt):
101+
return dt["x"] + 1
102+
103+
x = torch.randn(4)
104+
dt = {}
105+
dt["x"] = x
106+
opt_fn = torch.compile(fn, backend="eager")
107+
opt_fn(dt)
108+
109+
with self.assertRaisesRegex(
110+
RuntimeError, "Recompilation triggered with skip_guard_eval_unsafe stance"
111+
):
112+
with torch.compiler.set_stance(skip_guard_eval_unsafe=True):
113+
x = torch.randn(4, 4)
114+
dt["x"] = x
115+
opt_fn(dt)
116+
117+
def test_cache_line_pickup(self):
118+
def fn(x, a=None, b=None):
119+
x = x * 3
120+
if a:
121+
x = x * 5
122+
if b:
123+
x = x * 7
124+
return x
125+
126+
opt_fn = torch.compile(fn, backend="eager")
127+
x = torch.ones(4)
128+
129+
ref1 = opt_fn(x, a=None, b=None)
130+
ref2 = opt_fn(x, a=1, b=None)
131+
ref3 = opt_fn(x, a=1, b=1)
132+
133+
with torch.compiler.set_stance(skip_guard_eval_unsafe=True):
134+
res1 = opt_fn(x, a=None, b=None)
135+
res2 = opt_fn(x, a=1, b=None)
136+
res3 = opt_fn(x, a=1, b=1)
137+
138+
self.assertEqual(ref1, res1)
139+
self.assertEqual(ref2, res2)
140+
self.assertEqual(ref3, res3)
141+
142+
143+
if __name__ == "__main__":
144+
from torch._dynamo.test_case import run_tests
145+
146+
run_tests()

torch/_C/_dynamo/eval_frame.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ skip_code_recursive_flag: SkipCodeRecursiveFlag
1212
cache_limit_hit_flag: CacheLimitHitFlag
1313

1414
def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ...
15+
def set_skip_guard_eval_unsafe(value: bool) -> bool: ...
1516
def get_eval_frame_callback() -> DynamoCallback: ...
1617
def reset_code(code: types.CodeType) -> None: ...
1718
def unsupported(obj1: object, obj2: object) -> object: ...

torch/_dynamo/decorators.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,17 @@ class set_stance(_DecoratorContextManager):
9797

9898
_dynamo_forbidden = True
9999

100-
def __init__(self, stance: str, force_backend=None) -> None:
100+
def __init__(
101+
self,
102+
stance: str = "default",
103+
*,
104+
skip_guard_eval_unsafe: bool = False,
105+
force_backend=None,
106+
) -> None:
101107
if force_backend is not None and stance != "default":
102108
raise RuntimeError("non-default stance cannot have force_backend set")
103109

104-
self.stance = DynamoStance(stance, force_backend)
110+
self.stance = DynamoStance(stance, skip_guard_eval_unsafe, force_backend)
105111
self.prev = _set_stance(self.stance)
106112

107113
def __call__(self, fn):

torch/_dynamo/eval_frame.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from torch._C._dynamo.eval_frame import ( # noqa: F401
5353
reset_code,
5454
set_guard_error_hook,
55+
set_skip_guard_eval_unsafe,
5556
skip_code,
5657
unsupported,
5758
)
@@ -122,6 +123,7 @@ def _maybe_set_eval_frame(callback: DynamoCallback):
122123
@dataclass
123124
class DynamoStance:
124125
stance: str = "default"
126+
skip_guard_eval_unsafe: bool = False
125127
backend: Union[str, Callable[..., Any], None] = None
126128

127129

@@ -183,6 +185,10 @@ def fail_callback(*args, **kwargs):
183185
raise RuntimeError(f"invalid torch.compile stance '{_stance}'")
184186

185187

188+
def _is_skip_guard_eval_unsafe_stance():
189+
return _stance.skip_guard_eval_unsafe
190+
191+
186192
def _reset_guarded_backend_cache():
187193
global cached_backends
188194
for backend in cached_backends.values():
@@ -446,10 +452,14 @@ def __enter__(self):
446452
)
447453
self.cleanup_fns = [enter() for enter in self.enter_exit_hooks]
448454
self.prior = _maybe_set_eval_frame(_callback_from_stance(self.callback))
455+
self.prior_skip_guard_eval_unsafe = set_skip_guard_eval_unsafe(
456+
_is_skip_guard_eval_unsafe_stance()
457+
)
449458

450459
def __exit__(self, exc_type, exc_val, exc_tb):
451460
assert self.prior is not unset
452461
_maybe_set_eval_frame(self.prior)
462+
set_skip_guard_eval_unsafe(self.prior_skip_guard_eval_unsafe)
453463
self.prior = unset
454464
for cleanup in self.cleanup_fns:
455465
cleanup()
@@ -541,6 +551,9 @@ def _fn(*args, **kwargs):
541551

542552
cleanups = [enter() for enter in self.enter_exit_hooks]
543553
prior = _maybe_set_eval_frame(_callback_from_stance(callback))
554+
prior_skip_guard_eval_unsafe = set_skip_guard_eval_unsafe(
555+
_is_skip_guard_eval_unsafe_stance()
556+
)
544557

545558
# Ensure that if an assertion occurs after graph pushes
546559
# something onto the DynamicLayerStack then we pop it off (the
@@ -561,6 +574,7 @@ def _fn(*args, **kwargs):
561574
)
562575

563576
_maybe_set_eval_frame(prior)
577+
set_skip_guard_eval_unsafe(prior_skip_guard_eval_unsafe)
564578
for cleanup in cleanups:
565579
cleanup()
566580

@@ -717,10 +731,14 @@ def __call__(self, fn):
717731
@functools.wraps(fn)
718732
def _fn(*args, **kwargs):
719733
prior = _maybe_set_eval_frame(_callback_from_stance(self.callback))
734+
prior_skip_guard_eval_unsafe = set_skip_guard_eval_unsafe(
735+
_is_skip_guard_eval_unsafe_stance()
736+
)
720737
try:
721738
return fn(*args, **kwargs)
722739
finally:
723740
_maybe_set_eval_frame(prior)
741+
set_skip_guard_eval_unsafe(prior_skip_guard_eval_unsafe)
724742

725743
_fn._torchdynamo_disable = True # type: ignore[attr-defined]
726744

torch/_dynamo/guards.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,16 @@ def finalize(self):
253253
def populate_diff_guard_manager(self):
254254
self.diff_guard_root = self.clone_with_chosen_sources(self.diff_guard_sources)
255255

256+
# Ensure that that C++ side points to the updated diff guard manager.
257+
# When a new GuardManagerWrapper is created, it does not have a
258+
# cache_entry attribute, so it relies on the CacheEntry constructor to
259+
# set the diff_guard_root in C++. But once it is saved in the Dynamo
260+
# cache, C++ side adds a cache_entry attribute. On recompiles, this
261+
# cache_entry is visible, so we update the C++ side to point to the
262+
# update guard manager.
263+
if self.cache_entry:
264+
self.cache_entry.update_diff_guard_root_manager()
265+
256266
def clone_with_chosen_sources(self, chosen_sources):
257267
def filter_fn(node_mgr):
258268
return node_mgr.get_source() in chosen_sources
@@ -2205,6 +2215,9 @@ def __init__(self, reason):
22052215
super().__init__()
22062216
self.invalidation_reason = reason
22072217

2218+
def populate_diff_guard_manager(self):
2219+
self.diff_guard_root = None
2220+
22082221

22092222
# NB: Naively, you'd expect this to only be a function that produces
22102223
# the callable that constitutes the guard. However, there is some

torch/compiler/__init__.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,9 @@ def disable(fn=None, recursive=True):
230230
return torch._dynamo.disable(fn, recursive)
231231

232232

233-
def set_stance(stance: str, force_backend=None):
233+
def set_stance(
234+
stance: str = "default", *, skip_guard_eval_unsafe=False, force_backend=None
235+
):
234236
"""
235237
Set the current stance of the compiler.
236238
Can be used as a function, context manager, or decorator.
@@ -270,12 +272,31 @@ def bar():
270272
If there is cached compiled code valid for the input, it will still be used.
271273
- "fail_on_recompile": Raise an error when recompiling a function.
272274
275+
skip_guard_eval_unsafe: A flag to run only differentiating guards.
276+
CAUTION - This flag is unsafe and should only be used if your setup
277+
meets the following conditions.
278+
279+
torch.compile uses a guard system to support recompilations and
280+
choose which compiled artifact to run at runtime. These guards,
281+
though efficient, add some overhead, which may impact performance in
282+
scenarios where you need to optimize for minimal guard processing
283+
time. This API enables you to disable guard evaluation, assuming
284+
that you have warmed up the compiled model with a sufficient variety
285+
of inputs. This assumption means that, after the warmup phase, no
286+
further recompilations will be necessary. If this assumption fails,
287+
there is a risk of silently producing incorrect results (hence the
288+
term "unsafe" in the API name).
289+
273290
force_backend: If `stance` is "default", this argument can be used to force `torch.compile`
274291
to use a specific backend. Otherwise, an error is raised.
275292
"""
276293
import torch._dynamo
277294

278-
return torch._dynamo.set_stance(stance, force_backend=force_backend)
295+
return torch._dynamo.set_stance(
296+
stance,
297+
skip_guard_eval_unsafe=skip_guard_eval_unsafe,
298+
force_backend=force_backend,
299+
)
279300

280301

281302
# forbid in graph

torch/csrc/dynamo/cache_entry.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend)
1818
}
1919
this->root_mgr = torch::dynamo::convert_to_root_guard_manager(
2020
this->guard_manager.attr("root"));
21+
this->diff_guard_root_mgr = torch::dynamo::convert_to_root_guard_manager(
22+
this->guard_manager.attr("diff_guard_root"));
2123
}
2224

2325
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED(
@@ -52,6 +54,11 @@ void CacheEntry::invalidate(py::object deleted_guard_manager) {
5254
this->trace_annotation = "Invalidated";
5355
}
5456

57+
void CacheEntry::update_diff_guard_root_manager() {
58+
this->diff_guard_root_mgr = torch::dynamo::convert_to_root_guard_manager(
59+
this->guard_manager.attr("diff_guard_root"));
60+
}
61+
5562
PyCodeObject* CacheEntry_get_code(CacheEntry* e) {
5663
return (PyCodeObject*)e->code.ptr();
5764
}

torch/csrc/dynamo/cache_entry.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ typedef struct VISIBILITY_HIDDEN CacheEntry {
5050
py::object compile_id;
5151
// root guard manager if exists
5252
void* root_mgr{nullptr};
53+
// diff guard root guard manager if exists
54+
void* diff_guard_root_mgr{nullptr};
5355
// backend used to create this cache entry
5456
PyObject* backend{nullptr};
5557
// Reference to owning ExtraState
@@ -70,6 +72,8 @@ typedef struct VISIBILITY_HIDDEN CacheEntry {
7072
py::object next();
7173

7274
void invalidate(py::object deleted_guard_manager);
75+
// Called from the python side to update the diff guard root manager
76+
void update_diff_guard_root_manager();
7377
} CacheEntry;
7478
C10_DIAGNOSTIC_POP()
7579
C10_DIAGNOSTIC_POP()

0 commit comments

Comments
 (0)