Skip to content

Commit 6a3d5f1

Browse files
ydwu4pytorchmergebot
authored andcommitted
[HigherOrderOp] Remove _deprecated_global_ns from cond (#104380)
Remove _deprecated_global_ns from cond following #104105. We change the module attribute of HigherOrderOperator instances in the constructor from torch.ops to torch.ops.higher_order when self.namespace is "higher_order". For subclasses (e.g. customized higher order operator), we leave their \_\_module\_\_ unchanged. Will import this PR to fix internal tests. Pull Request resolved: #104380 Approved by: https://github.com/zhxchen17, https://github.com/zou3519
1 parent d5a83a5 commit 6a3d5f1

File tree

4 files changed

+22
-7
lines changed

4 files changed

+22
-7
lines changed

functorch/experimental/_cond.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class UnsupportedAliasMutationException(RuntimeError):
3232
We're going to define a `cond` operation.
3333
In order to do this, we need implementations for each of the dispatch keys.
3434
"""
35-
cond = HigherOrderOperator("cond", _deprecated_global_ns=True)
35+
cond = HigherOrderOperator("cond")
3636

3737
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
3838
assert isinstance(operands, (list, tuple)), "Cond operands must be a list or tuple of tensors"

test/dynamo/test_export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3209,7 +3209,7 @@ def forward(self, pred, x):
32093209
ones_3 = torch.ones(6, 4)
32103210
cond_true_0 = self.cond_true_0
32113211
cond_false_0 = self.cond_false_0
3212-
cond = torch.ops.cond(arg0, cond_true_0, cond_false_0, [arg1, ones, ones_1, ones_3, ones, ones_1, ones_2]); arg0 = cond_true_0 = cond_false_0 = arg1 = ones = ones_1 = ones_3 = ones_2 = None
3212+
cond = torch.ops.higher_order.cond(arg0, cond_true_0, cond_false_0, [arg1, ones, ones_1, ones_3, ones, ones_1, ones_2]); arg0 = cond_true_0 = cond_false_0 = arg1 = ones = ones_1 = ones_3 = ones_2 = None
32133213
return pytree.tree_unflatten([cond], self._out_spec)""", # noqa: B950,E122
32143214
)
32153215

test/functorch/test_control_flow.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -613,11 +613,11 @@ def f(x, pred, pred2):
613613
def forward(self, x_1, pred_1, pred2_1):
614614
true_graph_0 = self.true_graph_0
615615
false_graph_0 = self.false_graph_0
616-
conditional = torch.ops.cond(pred_1, true_graph_0, false_graph_0, [x_1]);
616+
conditional = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]);
617617
pred_1 = true_graph_0 = false_graph_0 = None
618618
true_graph_1 = self.true_graph_1
619619
false_graph_1 = self.false_graph_1
620-
conditional_1 = torch.ops.cond(pred2_1, true_graph_1, false_graph_1, [x_1, x_1]);
620+
conditional_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1, x_1]);
621621
pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
622622
add = torch.ops.aten.add.Tensor(conditional, conditional_1); conditional = conditional_1 = None
623623
return add
@@ -780,11 +780,11 @@ def f(x, pred, pred2):
780780
def forward(self, x_1, pred_1, pred2_1):
781781
true_graph_0 = self.true_graph_0
782782
false_graph_0 = self.false_graph_0
783-
conditional = torch.ops.cond(pred_1, true_graph_0, false_graph_0, [x_1]);
783+
conditional = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]);
784784
pred_1 = true_graph_0 = false_graph_0 = None
785785
true_graph_1 = self.true_graph_1
786786
false_graph_1 = self.false_graph_1
787-
conditional_1 = torch.ops.cond(pred2_1, true_graph_1, false_graph_1, [x_1, x_1]);
787+
conditional_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1, x_1]);
788788
pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
789789
add = torch.ops.aten.add.Tensor(conditional, conditional_1); conditional = conditional_1 = None
790790
return add

torch/_ops.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,14 @@ def __init__(self, name, *, _deprecated_global_ns=False):
214214
else:
215215
_higher_order_ops[name] = self
216216
self._ns = "higher_order"
217+
218+
# For a normal HigherOrderOperator instance, we will change its __module__ from torch._ops to
219+
# torch._ops.higher_order.
220+
# For an instance of subclass of HigherOrderOperator (e.g. customized higher order op),
221+
# the __module__ attribute will be kept unchanged.
222+
if self.__class__ is HigherOrderOperator:
223+
self_name_space = "." + self.namespace if self.namespace else ""
224+
self.__module__ = self.__module__ + self_name_space
217225
self.non_fallthrough_keys = torch._C._dispatch_keyset_full()
218226

219227
@property
@@ -768,7 +776,14 @@ def __init__(self, name, ops):
768776
self._ops = ops
769777

770778
def __getattr__(self, name):
771-
return self._ops[name]
779+
# Following _OpNamespace.__getattr__, we cache the op on the _PyOpNamespace object.
780+
op = self._ops.get(name, None)
781+
if op is None:
782+
raise AttributeError(
783+
f"'_PyOpNamespace' '{self.name}' object has no attribute '{name}'"
784+
)
785+
setattr(self, name, op)
786+
return op
772787

773788

774789
class _Ops(types.ModuleType):

0 commit comments

Comments
 (0)