Skip to content

Commit 280df5d

Browse files
zou3519pytorchmergebot
authored andcommitted
[HigherOrderOp] Remove _deprecated_global_ns from some ops (#104105)
The remaining ops after this PR are: - cond - map - anything that is out of tree. These are a bit more difficult to remove. Test Plan: - existing tests Pull Request resolved: #104105 Approved by: https://github.com/ydwu4
1 parent de7b6e5 commit 280df5d

File tree

2 files changed

+5
-9
lines changed

2 files changed

+5
-9
lines changed

torch/_higher_order_ops/wrap.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# Used for testing the HigherOrderOperator mechanism
77
class Wrap(HigherOrderOperator):
88
def __init__(self):
9-
super().__init__("wrap", _deprecated_global_ns=True)
9+
super().__init__("wrap")
1010

1111
def __call__(self, func, *args):
1212
# Dynamo already traces the body of HigherOrderOp beforehand when it
@@ -41,7 +41,7 @@ class WrapActivationCheckpoint(HigherOrderOperator):
4141
partitioners. See TagActivationCheckpoint for more information.
4242
"""
4343
def __init__(self):
44-
super().__init__("wrap_activation_checkpoint", _deprecated_global_ns=True)
44+
super().__init__("wrap_activation_checkpoint")
4545

4646
def __call__(self, function, *args, **kwargs):
4747
# use_reentrant is set to False because this op is going to be traced.
@@ -72,7 +72,7 @@ class TagActivationCheckpoint(HigherOrderOperator):
7272
"""
7373

7474
def __init__(self):
75-
super().__init__("tag_activation_checkpoint", _deprecated_global_ns=True)
75+
super().__init__("tag_activation_checkpoint")
7676

7777
def tag_nodes(self, gmod):
7878
# TODO - This needs major investigation. Currently, we are tagging all

torch/_prims/rng_prims.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,7 @@ def get_device(args, kwargs):
164164

165165

166166
def register_run_and_save_rng_state_op():
167-
run_and_save_rng_state = HigherOrderOperator(
168-
"run_and_save_rng_state", _deprecated_global_ns=True
169-
)
167+
run_and_save_rng_state = HigherOrderOperator("run_and_save_rng_state")
170168

171169
run_and_save_rng_state.fallthrough(DispatchKey.ADInplaceOrView)
172170
run_and_save_rng_state.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined]
@@ -220,9 +218,7 @@ def impl_proxy_dispatch_mode(op, *args, **kwargs):
220218

221219

222220
def register_run_with_rng_state_op():
223-
run_with_rng_state = HigherOrderOperator(
224-
"run_with_rng_state", _deprecated_global_ns=True
225-
)
221+
run_with_rng_state = HigherOrderOperator("run_with_rng_state")
226222

227223
run_with_rng_state.fallthrough(DispatchKey.ADInplaceOrView)
228224
run_with_rng_state.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined]

0 commit comments

Comments
 (0)