Skip to content

Commit d6a958b

Browse files
committed
[dtensor][5/N] change to a better/safer op registration
This PR changes the op registration to a better mechanism, now we require the directly overload registration instead of the op key str, this have several benefits: 1. We ensure that the op registration registers the correct op, which means it would be faild if the op registration become wrong (this PR already fixing several op registration errors as we use direct OpOverload registration 2. If the overload name get changed/deleted, we immediately know it at the source code compilation level, which is safer 3. This also keep it consistents with the op registration mechanism with other tensor subclasses within PyTorch ghstack-source-id: 3c4d812 Pull Request resolved: #90735
1 parent 9a8632b commit d6a958b

File tree

7 files changed

+419
-420
lines changed

7 files changed

+419
-420
lines changed

torch/distributed/_tensor/ops/math_ops.py

Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates
22
from typing import cast, Optional, Sequence
33

4+
import torch
5+
46
from torch.distributed._tensor.op_schema import OpSchema, OutputSharding
57
from torch.distributed._tensor.ops.common_rules import pointwise_rule, reduction_rule
68
from torch.distributed._tensor.ops.utils import (
@@ -11,6 +13,9 @@
1113
from torch.distributed._tensor.placement_types import DTensorSpec
1214

1315

16+
aten = torch.ops.aten
17+
18+
1419
def _infer_reduction_dims(dims_arg: object, ndim: int) -> Optional[Sequence[int]]:
1520
if dims_arg is None:
1621
return None
@@ -22,11 +27,17 @@ def _infer_reduction_dims(dims_arg: object, ndim: int) -> Optional[Sequence[int]
2227
return dims
2328

2429

25-
@register_prop_rule("aten.all.default")
30+
@register_prop_rule(aten.all.default)
2631
def default_reduction_rule(op_schema: OpSchema) -> OutputSharding:
2732
return reduction_rule(op_schema, reduction_linear=True)
2833

2934

35+
@register_prop_rule(
36+
[
37+
aten.sum.default,
38+
aten.sum.dim_IntList,
39+
]
40+
)
3041
def sum_rule(op_schema: OpSchema) -> OutputSharding:
3142
args_schema = op_schema.args_schema
3243
input_spec = cast(DTensorSpec, args_schema[0])
@@ -40,15 +51,7 @@ def sum_rule(op_schema: OpSchema) -> OutputSharding:
4051
)
4152

4253

43-
sum_ops = [
44-
"aten.sum.default",
45-
"aten.sum.dim_IntList",
46-
]
47-
for sum_op in sum_ops:
48-
register_prop_rule(sum_op)(sum_rule)
49-
50-
51-
@register_prop_rule("aten._softmax.default")
54+
@register_prop_rule(aten._softmax.default)
5255
def softmax_rule(op_schema: OpSchema) -> OutputSharding:
5356
input_spec, softmax_dim, _ = op_schema.args_schema
5457
input_spec = cast(DTensorSpec, input_spec)
@@ -59,7 +62,7 @@ def softmax_rule(op_schema: OpSchema) -> OutputSharding:
5962
return OutputSharding(input_spec)
6063

6164

62-
@register_prop_rule("aten._softmax_backward_data.default")
65+
@register_prop_rule(aten._softmax_backward_data.default)
6366
def softmax_bwd_rule(op_schema: OpSchema) -> OutputSharding:
6467
grad_out_spec, out_spec, softmax_dim, _ = op_schema.args_schema
6568
grad_out_spec = cast(DTensorSpec, grad_out_spec)
@@ -74,6 +77,7 @@ def softmax_bwd_rule(op_schema: OpSchema) -> OutputSharding:
7477
return pointwise_rule(op_schema)
7578

7679

80+
@register_prop_rule([aten.mean.default, aten.mean.dim, aten.mean.out])
7781
def mean_rule(op_schema: OpSchema) -> OutputSharding:
7882
args_schema = op_schema.args_schema
7983
input_spec = cast(DTensorSpec, args_schema[0])
@@ -88,16 +92,13 @@ def mean_rule(op_schema: OpSchema) -> OutputSharding:
8892
)
8993

9094

91-
mean_ops = [
92-
"aten.mean.default",
93-
"aten.mean.dim",
94-
"aten.mean.out",
95-
]
96-
97-
for mean_op in mean_ops:
98-
register_prop_rule(mean_op)(mean_rule)
99-
100-
95+
@register_prop_rule(
96+
[
97+
aten.var.default,
98+
aten.var.dim,
99+
aten.var.out,
100+
]
101+
)
101102
def var_rule(op_schema: OpSchema) -> OutputSharding:
102103
args_schema = op_schema.args_schema
103104
input_spec = cast(DTensorSpec, args_schema[0])
@@ -114,18 +115,7 @@ def var_rule(op_schema: OpSchema) -> OutputSharding:
114115
)
115116

116117

117-
var_ops = [
118-
"aten.var.default",
119-
"aten.var.dim",
120-
"aten.var.out",
121-
]
122-
123-
for var_op in var_ops:
124-
register_prop_rule(var_op)(var_rule)
125-
126-
127-
@register_prop_rule("aten.var.correction")
128-
@register_prop_rule("aten.var.correction_out")
118+
@register_prop_rule([aten.var.correction, aten.var.correction_out])
129119
def var_correction_rule(op_schema: OpSchema) -> OutputSharding:
130120
args_schema = op_schema.args_schema
131121
input_spec = cast(DTensorSpec, args_schema[0])

torch/distributed/_tensor/ops/matrix_ops.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates
22
# implement matrix related ops for distributed tensor
3+
4+
import torch
5+
36
from torch.distributed._tensor.op_schema import OpSchema, OutputSharding
47
from torch.distributed._tensor.ops.common_rules import einop_rule, pointwise_rule
58
from torch.distributed._tensor.ops.utils import register_prop_rule
69

10+
aten = torch.ops.aten
11+
712

813
def _update_schema_suggestion_for_addmm(
914
output_sharding: OutputSharding,
@@ -41,12 +46,12 @@ def _update_schema_suggestion_for_addmm(
4146
return output_sharding
4247

4348

44-
@register_prop_rule("aten.mm.default")
49+
@register_prop_rule(aten.mm.default)
4550
def mm_rules(op_schema: OpSchema) -> OutputSharding:
4651
return einop_rule("mk,kn->mn", op_schema, linearity=False)
4752

4853

49-
@register_prop_rule("aten.addmm.default")
54+
@register_prop_rule(aten.addmm.default)
5055
def addmm_rules(op_schema: OpSchema) -> OutputSharding:
5156
input_spec, mat1_spec, mat2_spec = op_schema.args_spec
5257
mm_out_sharding = mm_rules(
@@ -80,17 +85,17 @@ def addmm_rules(op_schema: OpSchema) -> OutputSharding:
8085
return output_sharding
8186

8287

83-
@register_prop_rule("aten.t.default")
88+
@register_prop_rule(aten.t.default)
8489
def transpose_rule(op_schema: OpSchema) -> OutputSharding:
8590
return einop_rule("ij->ji", op_schema, linearity=True)
8691

8792

88-
@register_prop_rule("aten.bmm.default")
93+
@register_prop_rule(aten.bmm.default)
8994
def bmm_rules(op_schema: OpSchema) -> OutputSharding:
9095
return einop_rule("bmk,bkn->bmn", op_schema, linearity=False)
9196

9297

93-
@register_prop_rule("aten.baddbmm.default")
98+
@register_prop_rule(aten.baddbmm.default)
9499
def baddbmm_rules(op_schema: OpSchema) -> OutputSharding:
95100
input_spec, mat1_spec, mat2_spec = op_schema.args_spec
96101
bmm_output_sharding = bmm_rules(

0 commit comments

Comments
 (0)