Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion functorch/experimental/_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class UnsupportedAliasMutationException(RuntimeError):
We're going to define a `cond` operation.
In order to do this, we need implementations for each of the dispatch keys.
"""
cond = HigherOrderOperator("cond", _deprecated_global_ns=True)
cond = HigherOrderOperator("cond")

def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
assert isinstance(operands, (list, tuple)), "Cond operands must be a list or tuple of tensors"
Expand Down
2 changes: 1 addition & 1 deletion test/dynamo/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -3209,7 +3209,7 @@ def forward(self, pred, x):
ones_3 = torch.ones(6, 4)
cond_true_0 = self.cond_true_0
cond_false_0 = self.cond_false_0
cond = torch.ops.cond(arg0, cond_true_0, cond_false_0, [arg1, ones, ones_1, ones_3, ones, ones_1, ones_2]); arg0 = cond_true_0 = cond_false_0 = arg1 = ones = ones_1 = ones_3 = ones_2 = None
cond = torch.ops.higher_order.cond(arg0, cond_true_0, cond_false_0, [arg1, ones, ones_1, ones_3, ones, ones_1, ones_2]); arg0 = cond_true_0 = cond_false_0 = arg1 = ones = ones_1 = ones_3 = ones_2 = None
return pytree.tree_unflatten([cond], self._out_spec)""", # noqa: B950,E122
)

Expand Down
8 changes: 4 additions & 4 deletions test/functorch/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,11 +613,11 @@ def f(x, pred, pred2):
def forward(self, x_1, pred_1, pred2_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.cond(pred_1, true_graph_0, false_graph_0, [x_1]);
conditional = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]);
pred_1 = true_graph_0 = false_graph_0 = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
conditional_1 = torch.ops.cond(pred2_1, true_graph_1, false_graph_1, [x_1, x_1]);
conditional_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1, x_1]);
pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
add = torch.ops.aten.add.Tensor(conditional, conditional_1); conditional = conditional_1 = None
return add
Expand Down Expand Up @@ -780,11 +780,11 @@ def f(x, pred, pred2):
def forward(self, x_1, pred_1, pred2_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.cond(pred_1, true_graph_0, false_graph_0, [x_1]);
conditional = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]);
pred_1 = true_graph_0 = false_graph_0 = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
conditional_1 = torch.ops.cond(pred2_1, true_graph_1, false_graph_1, [x_1, x_1]);
conditional_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1, x_1]);
pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
add = torch.ops.aten.add.Tensor(conditional, conditional_1); conditional = conditional_1 = None
return add
Expand Down
17 changes: 16 additions & 1 deletion torch/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,14 @@ def __init__(self, name, *, _deprecated_global_ns=False):
else:
_higher_order_ops[name] = self
self._ns = "higher_order"

# For a normal HigherOrderOperator instance, we will change its __module__ from torch._ops to
# torch._ops.higher_order.
# For an instance of subclass of HigherOrderOperator (e.g. customized higher order op),
# the __module__ attribute will be kept unchanged.
if self.__class__ is HigherOrderOperator:
self_name_space = "." + self.namespace if self.namespace else ""
self.__module__ = self.__module__ + self_name_space
self.non_fallthrough_keys = torch._C._dispatch_keyset_full()

@property
Expand Down Expand Up @@ -768,7 +776,14 @@ def __init__(self, name, ops):
self._ops = ops

def __getattr__(self, name):
return self._ops[name]
# Following _OpNamespace.__getattr__, we cache the op on the _PyOpNamespace object.
op = self._ops.get(name, None)
if op is None:
raise AttributeError(
f"'_PyOpNamespace' '{self.name}' object has no attribute '{name}'"
)
setattr(self, name, op)
return op
Comment on lines +779 to +786
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary for this PR? Or is it just an optimization?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, it's just an optimization. I was looking at how other opearator's __module__ is set and found this logic and the error message seems pretty good. So I "copied" it over.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, thanks for explaining



class _Ops(types.ModuleType):
Expand Down