Skip to content

Commit c37c516

Browse files
wanchaolpytorchmergebot
authored andcommitted
[dtensor] ufmt test/distributed/_tensor (#89968)
cmd: `ufmt format test/distributed/_tensor` Pull Request resolved: #89968 Approved by: https://github.com/fduwjj
1 parent bf23e0b commit c37c516

12 files changed

+217
-445
lines changed

test/distributed/_tensor/test_api.py

Lines changed: 18 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,26 +3,26 @@
33

44
import torch
55
import torch.nn as nn
6-
from torch.testing._internal.common_utils import run_tests
7-
from torch.testing._internal.distributed._tensor.common_dtensor import DTensorTestBase, with_comms
86
from torch.distributed._tensor import (
9-
distribute_tensor,
10-
distribute_module,
117
DeviceMesh,
8+
distribute_module,
9+
distribute_tensor,
1210
DTensor,
13-
Shard,
1411
Replicate,
12+
Shard,
13+
)
14+
from torch.testing._internal.common_utils import run_tests
15+
from torch.testing._internal.distributed._tensor.common_dtensor import (
16+
DTensorTestBase,
17+
with_comms,
1518
)
1619

1720

1821
class MyModel(nn.Module):
1922
def __init__(self, n_features, n_layers, device):
2023
super().__init__()
2124
self.seq = nn.Sequential(
22-
*[
23-
nn.Linear(n_features, n_features, device=device)
24-
for _ in range(n_layers)
25-
]
25+
*[nn.Linear(n_features, n_features, device=device) for _ in range(n_layers)]
2626
)
2727

2828
def forward(self, x):
@@ -50,12 +50,8 @@ def test_distribute_tensor(self):
5050
tensor_to_shard = torch.randn(
5151
3 * self.world_size, 3, requires_grad=requires_grad
5252
)
53-
dist_tensor = distribute_tensor(
54-
tensor_to_shard, device_mesh, shard_spec
55-
)
56-
self.assertEqual(
57-
dist_tensor.size(), torch.Size([3 * self.world_size, 3])
58-
)
53+
dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_spec)
54+
self.assertEqual(dist_tensor.size(), torch.Size([3 * self.world_size, 3]))
5955
local_tensor = dist_tensor.to_local()
6056
self.assertEqual(local_tensor.size(), torch.Size([3, 3]))
6157
if requires_grad:
@@ -78,9 +74,7 @@ def test_distribute_tensor_errors(self):
7874
dtensor = distribute_tensor(tensor_to_distribute, device_mesh, spec)
7975

8076
with self.assertRaisesRegex(ValueError, "to a different device mesh"):
81-
new_mesh = DeviceMesh(
82-
self.device_type, torch.arange(self.world_size)
83-
)
77+
new_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
8478
distribute_tensor(dtensor, new_mesh, [Shard(0)])
8579

8680
with self.assertRaisesRegex(ValueError, "to a different placements"):
@@ -104,9 +98,7 @@ def test_distribute_tensor_uneven_sharding(self):
10498
splitted_tensor_list = tensor_to_shard.tensor_split(
10599
self.world_size, dim=shard_dim
106100
)
107-
dist_tensor = distribute_tensor(
108-
tensor_to_shard, device_mesh, shard_spec
109-
)
101+
dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_spec)
110102
self.assertEqual(dist_tensor.size(), torch.Size(input_size))
111103
local_tensor = dist_tensor.to_local()
112104
self.assertEqual(local_tensor, splitted_tensor_list[self.rank])
@@ -115,9 +107,7 @@ def test_distribute_tensor_uneven_sharding(self):
115107
def test_distribute_module(self):
116108
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
117109
# fully shard all linear modules on dim 0
118-
module_to_shard = MyModel(
119-
5 * self.world_size, 20, device=self.device_type
120-
)
110+
module_to_shard = MyModel(5 * self.world_size, 20, device=self.device_type)
121111
shard_spec = [Shard(0)]
122112

123113
def shard_fn(name, module, device_mesh):
@@ -128,9 +118,7 @@ def shard_fn(name, module, device_mesh):
128118
)
129119
module.register_parameter(name, dist_param)
130120

131-
sharded_module = distribute_module(
132-
module_to_shard, device_mesh, shard_fn
133-
)
121+
sharded_module = distribute_module(module_to_shard, device_mesh, shard_fn)
134122
for param in sharded_module.parameters():
135123
self.assertIsInstance(param, DTensor)
136124
self.assertEqual(param.placements, shard_spec)
@@ -162,21 +150,15 @@ def replicate_fn(name, module, device_mesh):
162150

163151
# only shard part of module, and rest of module should be replicate
164152
def shard_fn(name, module, device_mesh):
165-
if isinstance(module, nn.Linear) and (
166-
name == "seq.0" or name == "seq.8"
167-
):
153+
if isinstance(module, nn.Linear) and (name == "seq.0" or name == "seq.8"):
168154
for name, param in module.named_parameters():
169155
dist_param = torch.nn.Parameter(
170156
distribute_tensor(param, device_mesh, shard_spec)
171157
)
172158
module.register_parameter(name, dist_param)
173159

174-
module_to_distribute = MyModel(
175-
5 * self.world_size, 20, device=self.device_type
176-
)
177-
dist_module = distribute_module(
178-
module_to_distribute, device_mesh, shard_fn
179-
)
160+
module_to_distribute = MyModel(5 * self.world_size, 20, device=self.device_type)
161+
dist_module = distribute_module(module_to_distribute, device_mesh, shard_fn)
180162
for name, param in dist_module.named_parameters():
181163
self.assertIsInstance(param, DTensor)
182164
if name.startswith("seq.0") or name.startswith("seq.8"):

test/distributed/_tensor/test_common_rules.py

Lines changed: 36 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,21 @@
22
# Owner(s): ["oncall: distributed"]
33

44
import torch
5-
from torch.testing._internal.common_utils import run_tests
6-
from torchgen.model import FunctionSchema
5+
from torch.distributed._tensor import DeviceMesh
76
from torch.distributed._tensor.dispatch import OpSchema
87

98
from torch.distributed._tensor.ops.common_rules import (
109
einop_rule,
11-
reduction_rule,
1210
pointwise_rule,
11+
reduction_rule,
1312
)
1413
from torch.distributed._tensor.placement_types import DTensorSpec
14+
from torch.testing._internal.common_utils import run_tests
1515
from torch.testing._internal.distributed._tensor.common_dtensor import (
1616
DTensorTestBase,
1717
with_comms,
1818
)
19-
from torch.distributed._tensor import DeviceMesh
19+
from torchgen.model import FunctionSchema
2020

2121

2222
class CommonRulesTest(DTensorTestBase):
@@ -34,17 +34,11 @@ def test_einop_basic_propagation(self):
3434
# plain einsum, mm
3535
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
3636

37-
func_schema = self.parse_schema(
38-
"aten::mm(Tensor self, Tensor mat2) -> Tensor"
39-
)
37+
func_schema = self.parse_schema("aten::mm(Tensor self, Tensor mat2) -> Tensor")
4038
# propagate col-wise sharding
4139
mat1, mat2 = [-1, -1], [-1, 0]
42-
mat1_spec = DTensorSpec.from_dim_map(
43-
mesh, mat1, [], shape=torch.Size([8, 4])
44-
)
45-
mat2_spec = DTensorSpec.from_dim_map(
46-
mesh, mat2, [], shape=torch.Size([4, 8])
47-
)
40+
mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [], shape=torch.Size([8, 4]))
41+
mat2_spec = DTensorSpec.from_dim_map(mesh, mat2, [], shape=torch.Size([4, 8]))
4842
output_sharding = einop_rule(
4943
"mk,kn->mn", OpSchema(func_schema, (mat1_spec, mat2_spec), {})
5044
)
@@ -55,12 +49,8 @@ def test_einop_basic_propagation(self):
5549

5650
# propagate row-wise sharding
5751
mat1, mat2 = [0, -1], [-1, -1]
58-
mat1_spec = DTensorSpec.from_dim_map(
59-
mesh, mat1, [], shape=torch.Size([8, 4])
60-
)
61-
mat2_spec = DTensorSpec.from_dim_map(
62-
mesh, mat2, [], shape=torch.Size([4, 8])
63-
)
52+
mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [], shape=torch.Size([8, 4]))
53+
mat2_spec = DTensorSpec.from_dim_map(mesh, mat2, [], shape=torch.Size([4, 8]))
6454
output_sharding = einop_rule(
6555
"mk,kn->mn", OpSchema(func_schema, (mat1_spec, mat2_spec), {})
6656
)
@@ -71,12 +61,8 @@ def test_einop_basic_propagation(self):
7161

7262
# generate partial
7363
mat1, mat2 = [-1, 0], [0, -1]
74-
mat1_spec = DTensorSpec.from_dim_map(
75-
mesh, mat1, [], shape=torch.Size([8, 4])
76-
)
77-
mat2_spec = DTensorSpec.from_dim_map(
78-
mesh, mat2, [], shape=torch.Size([4, 8])
79-
)
64+
mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [], shape=torch.Size([8, 4]))
65+
mat2_spec = DTensorSpec.from_dim_map(mesh, mat2, [], shape=torch.Size([4, 8]))
8066
output_sharding = einop_rule(
8167
"mk,kn->mn", OpSchema(func_schema, (mat1_spec, mat2_spec), {})
8268
)
@@ -94,9 +80,7 @@ def test_einop_pointwise_propagation(self):
9480
)
9581
# addition
9682
mat1 = [0, -1]
97-
mat1_spec = DTensorSpec.from_dim_map(
98-
mesh, mat1, [], shape=torch.Size([8, 8])
99-
)
83+
mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [], shape=torch.Size([8, 8]))
10084
output_sharding = einop_rule(
10185
"ij,ij->ij", OpSchema(func_schema, (mat1_spec, mat1_spec), {})
10286
)
@@ -110,9 +94,7 @@ def test_einop_pointwise_propagation(self):
11094
mat1_spec = DTensorSpec.from_dim_map(
11195
mesh, mat1, [], shape=torch.Size([8, 4, 2])
11296
)
113-
mat2_spec = DTensorSpec.from_dim_map(
114-
mesh, [-1], [], shape=torch.Size([2])
115-
)
97+
mat2_spec = DTensorSpec.from_dim_map(mesh, [-1], [], shape=torch.Size([2]))
11698
output_sharding = einop_rule(
11799
"ijk,k->ijk", OpSchema(func_schema, (mat1_spec, mat2_spec), {})
118100
)
@@ -144,17 +126,11 @@ def test_einop_merge_sharding(self):
144126
)
145127
mesh = DeviceMesh(self.device_type, mesh_shape)
146128

147-
func_schema = self.parse_schema(
148-
"aten::mm(Tensor self, Tensor mat2) -> Tensor"
149-
)
129+
func_schema = self.parse_schema("aten::mm(Tensor self, Tensor mat2) -> Tensor")
150130

151131
mat1, mat2 = [0, -1], [-1, 1]
152-
mat1_spec = DTensorSpec.from_dim_map(
153-
mesh, mat1, [], shape=torch.Size([8, 4])
154-
)
155-
mat2_spec = DTensorSpec.from_dim_map(
156-
mesh, mat2, [], shape=torch.Size([4, 8])
157-
)
132+
mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [], shape=torch.Size([8, 4]))
133+
mat2_spec = DTensorSpec.from_dim_map(mesh, mat2, [], shape=torch.Size([4, 8]))
158134
output_sharding = einop_rule(
159135
"mk,kn->mn", OpSchema(func_schema, (mat1_spec, mat2_spec), {})
160136
)
@@ -175,12 +151,8 @@ def test_einop_linearity(self):
175151
)
176152

177153
mat1, mat2 = [0, -1], [-1, -1]
178-
mat1_spec = DTensorSpec.from_dim_map(
179-
mesh, mat1, [1], shape=torch.Size([8, 4])
180-
)
181-
mat2_spec = DTensorSpec.from_dim_map(
182-
mesh, mat2, [], shape=torch.Size([4, 8])
183-
)
154+
mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [1], shape=torch.Size([8, 4]))
155+
mat2_spec = DTensorSpec.from_dim_map(mesh, mat2, [], shape=torch.Size([4, 8]))
184156
# if not turn on linearity, partial sum is not eligible to propagate, we return
185157
# suggestion to reshard inputs with no partial sum (i.e. all_reduce one input)
186158
output_sharding = einop_rule(
@@ -212,12 +184,8 @@ def test_einop_linearity(self):
212184
"aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"
213185
)
214186
mat1, mat2 = [0, -1], [0, -1]
215-
mat1_spec = DTensorSpec.from_dim_map(
216-
mesh, mat1, [1], shape=torch.Size([8, 6])
217-
)
218-
mat2_spec = DTensorSpec.from_dim_map(
219-
mesh, mat2, [], shape=torch.Size([8, 6])
220-
)
187+
mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [1], shape=torch.Size([8, 6]))
188+
mat2_spec = DTensorSpec.from_dim_map(mesh, mat2, [], shape=torch.Size([8, 6]))
221189

222190
output_sharding = einop_rule(
223191
"ij,ij->ij",
@@ -237,16 +205,10 @@ def test_einop_multi_sharding_on_mesh_dim(self):
237205
mesh_shape = torch.arange(self.world_size)
238206
mesh = DeviceMesh(self.device_type, mesh_shape)
239207

240-
func_schema = self.parse_schema(
241-
"aten::mm(Tensor self, Tensor mat2) -> Tensor"
242-
)
208+
func_schema = self.parse_schema("aten::mm(Tensor self, Tensor mat2) -> Tensor")
243209
mat1, mat2 = [0, -1], [0, -1]
244-
mat1_spec = DTensorSpec.from_dim_map(
245-
mesh, mat1, [], shape=torch.Size([8, 12])
246-
)
247-
mat2_spec = DTensorSpec.from_dim_map(
248-
mesh, mat2, [], shape=torch.Size([12, 4])
249-
)
210+
mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [], shape=torch.Size([8, 12]))
211+
mat2_spec = DTensorSpec.from_dim_map(mesh, mat2, [], shape=torch.Size([12, 4]))
250212
output_sharding = einop_rule(
251213
"mk,kn->mn", OpSchema(func_schema, (mat1_spec, mat2_spec), {})
252214
)
@@ -271,19 +233,11 @@ def test_einop_errors(self):
271233
"aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"
272234
)
273235
mat1, mat2 = [0, -1], [1, -1]
274-
mat1_spec = DTensorSpec.from_dim_map(
275-
mesh, mat1, [], shape=torch.Size([8, 4])
276-
)
277-
mat2_spec = DTensorSpec.from_dim_map(
278-
mesh, mat2, [], shape=torch.Size([8, 4])
279-
)
236+
mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [], shape=torch.Size([8, 4]))
237+
mat2_spec = DTensorSpec.from_dim_map(mesh, mat2, [], shape=torch.Size([8, 4]))
280238

281-
with self.assertRaisesRegex(
282-
RuntimeError, "sharded two different ways:"
283-
):
284-
einop_rule(
285-
"ij,ij->ij", OpSchema(func_schema, (mat1_spec, mat2_spec), {})
286-
)
239+
with self.assertRaisesRegex(RuntimeError, "sharded two different ways:"):
240+
einop_rule("ij,ij->ij", OpSchema(func_schema, (mat1_spec, mat2_spec), {}))
287241

288242
@with_comms
289243
def test_pointwise_rules_broadcasting(self):
@@ -293,12 +247,8 @@ def test_pointwise_rules_broadcasting(self):
293247
"where.self(Tensor condition, Tensor self, Tensor other) -> Tensor"
294248
)
295249
inp1, inp2, inp3 = [0], [], [-1, -1]
296-
condition = DTensorSpec.from_dim_map(
297-
mesh, inp1, [], shape=torch.Size([8])
298-
)
299-
self_tensor = DTensorSpec.from_dim_map(
300-
mesh, inp2, [], shape=torch.Size([])
301-
)
250+
condition = DTensorSpec.from_dim_map(mesh, inp1, [], shape=torch.Size([8]))
251+
self_tensor = DTensorSpec.from_dim_map(mesh, inp2, [], shape=torch.Size([]))
302252
other_tensor = DTensorSpec.from_dim_map(
303253
mesh, inp3, [], shape=torch.Size([1, 1])
304254
)
@@ -320,12 +270,8 @@ def test_pointwise_rules_suggestion(self):
320270
)
321271
# propagate point-wise sharding
322272
inp1, inp2 = [-1, -1], [-1, 0]
323-
mat1_spec = DTensorSpec.from_dim_map(
324-
mesh, inp1, [], shape=torch.Size([8, 4])
325-
)
326-
mat2_spec = DTensorSpec.from_dim_map(
327-
mesh, inp2, [], shape=torch.Size([8, 4])
328-
)
273+
mat1_spec = DTensorSpec.from_dim_map(mesh, inp1, [], shape=torch.Size([8, 4]))
274+
mat2_spec = DTensorSpec.from_dim_map(mesh, inp2, [], shape=torch.Size([8, 4]))
329275
# adding a positional argument -1 to arg schema
330276
output_sharding = pointwise_rule(
331277
OpSchema(func_schema, (mat1_spec, mat2_spec, -1), {})
@@ -353,12 +299,8 @@ def test_pointwise_multi_sharding_on_mesh_dim(self):
353299

354300
# basic case to test implicit broadcasting shape alignment
355301
mat1, mat2 = [-1, 0], [0]
356-
mat1_spec = DTensorSpec.from_dim_map(
357-
mesh, mat1, [], shape=torch.Size([20, 6])
358-
)
359-
mat2_spec = DTensorSpec.from_dim_map(
360-
mesh, mat2, [], shape=torch.Size([6])
361-
)
302+
mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [], shape=torch.Size([20, 6]))
303+
mat2_spec = DTensorSpec.from_dim_map(mesh, mat2, [], shape=torch.Size([6]))
362304
output_sharding = pointwise_rule(
363305
OpSchema(func_schema, (mat1_spec, mat2_spec), {})
364306
)
@@ -384,9 +326,7 @@ def test_pointwise_multi_sharding_on_mesh_dim(self):
384326
# ensure that the suggestion is to reshard the first
385327
# arg by all_gather first tensor dim sharding
386328
schema_suggestion = output_sharding.schema_suggestions[0]
387-
self.assertEqual(
388-
schema_suggestion.args_schema[0].dim_map, [-1, -1, -1, 1]
389-
)
329+
self.assertEqual(schema_suggestion.args_schema[0].dim_map, [-1, -1, -1, 1])
390330
self.assertEqual(schema_suggestion.args_schema[1].dim_map, mat2)
391331

392332
@with_comms
@@ -431,9 +371,7 @@ def test_reduction_rule(self):
431371
)
432372
# reduction on a 2d mat
433373
mat1 = [0, -1]
434-
mat1_spec = DTensorSpec.from_dim_map(
435-
mesh, mat1, [], shape=torch.Size([8, 4])
436-
)
374+
mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [], shape=torch.Size([8, 4]))
437375
# reduction on dim 0
438376
output_sharding_0 = reduction_rule(
439377
OpSchema(func_schema, (mat1_spec, 0), {}),
@@ -467,9 +405,7 @@ def test_reduction_rule(self):
467405
self.assertEqual(output_sharding_all_dim.output_spec.dim_map, [])
468406
# pending sum on mesh
469407
self.assertEqual(output_sharding_all_dim.output_spec.sums, [0])
470-
self.assertEqual(
471-
output_sharding_all_dim.output_spec.shape, torch.Size([])
472-
)
408+
self.assertEqual(output_sharding_all_dim.output_spec.shape, torch.Size([]))
473409

474410

475411
if __name__ == "__main__":

0 commit comments

Comments
 (0)