22# Owner(s): ["oncall: distributed"]
33
44import torch
5- from torch .testing ._internal .common_utils import run_tests
6- from torchgen .model import FunctionSchema
5+ from torch .distributed ._tensor import DeviceMesh
76from torch .distributed ._tensor .dispatch import OpSchema
87
98from torch .distributed ._tensor .ops .common_rules import (
109 einop_rule ,
11- reduction_rule ,
1210 pointwise_rule ,
11+ reduction_rule ,
1312)
1413from torch .distributed ._tensor .placement_types import DTensorSpec
14+ from torch .testing ._internal .common_utils import run_tests
1515from torch .testing ._internal .distributed ._tensor .common_dtensor import (
1616 DTensorTestBase ,
1717 with_comms ,
1818)
19- from torch . distributed . _tensor import DeviceMesh
19+ from torchgen . model import FunctionSchema
2020
2121
2222class CommonRulesTest (DTensorTestBase ):
@@ -34,17 +34,11 @@ def test_einop_basic_propagation(self):
3434 # plain einsum, mm
3535 mesh = DeviceMesh (self .device_type , torch .arange (self .world_size ))
3636
37- func_schema = self .parse_schema (
38- "aten::mm(Tensor self, Tensor mat2) -> Tensor"
39- )
37+ func_schema = self .parse_schema ("aten::mm(Tensor self, Tensor mat2) -> Tensor" )
4038 # propagate col-wise sharding
4139 mat1 , mat2 = [- 1 , - 1 ], [- 1 , 0 ]
42- mat1_spec = DTensorSpec .from_dim_map (
43- mesh , mat1 , [], shape = torch .Size ([8 , 4 ])
44- )
45- mat2_spec = DTensorSpec .from_dim_map (
46- mesh , mat2 , [], shape = torch .Size ([4 , 8 ])
47- )
40+ mat1_spec = DTensorSpec .from_dim_map (mesh , mat1 , [], shape = torch .Size ([8 , 4 ]))
41+ mat2_spec = DTensorSpec .from_dim_map (mesh , mat2 , [], shape = torch .Size ([4 , 8 ]))
4842 output_sharding = einop_rule (
4943 "mk,kn->mn" , OpSchema (func_schema , (mat1_spec , mat2_spec ), {})
5044 )
@@ -55,12 +49,8 @@ def test_einop_basic_propagation(self):
5549
5650 # propagate row-wise sharding
5751 mat1 , mat2 = [0 , - 1 ], [- 1 , - 1 ]
58- mat1_spec = DTensorSpec .from_dim_map (
59- mesh , mat1 , [], shape = torch .Size ([8 , 4 ])
60- )
61- mat2_spec = DTensorSpec .from_dim_map (
62- mesh , mat2 , [], shape = torch .Size ([4 , 8 ])
63- )
52+ mat1_spec = DTensorSpec .from_dim_map (mesh , mat1 , [], shape = torch .Size ([8 , 4 ]))
53+ mat2_spec = DTensorSpec .from_dim_map (mesh , mat2 , [], shape = torch .Size ([4 , 8 ]))
6454 output_sharding = einop_rule (
6555 "mk,kn->mn" , OpSchema (func_schema , (mat1_spec , mat2_spec ), {})
6656 )
@@ -71,12 +61,8 @@ def test_einop_basic_propagation(self):
7161
7262 # generate partial
7363 mat1 , mat2 = [- 1 , 0 ], [0 , - 1 ]
74- mat1_spec = DTensorSpec .from_dim_map (
75- mesh , mat1 , [], shape = torch .Size ([8 , 4 ])
76- )
77- mat2_spec = DTensorSpec .from_dim_map (
78- mesh , mat2 , [], shape = torch .Size ([4 , 8 ])
79- )
64+ mat1_spec = DTensorSpec .from_dim_map (mesh , mat1 , [], shape = torch .Size ([8 , 4 ]))
65+ mat2_spec = DTensorSpec .from_dim_map (mesh , mat2 , [], shape = torch .Size ([4 , 8 ]))
8066 output_sharding = einop_rule (
8167 "mk,kn->mn" , OpSchema (func_schema , (mat1_spec , mat2_spec ), {})
8268 )
@@ -94,9 +80,7 @@ def test_einop_pointwise_propagation(self):
9480 )
9581 # addition
9682 mat1 = [0 , - 1 ]
97- mat1_spec = DTensorSpec .from_dim_map (
98- mesh , mat1 , [], shape = torch .Size ([8 , 8 ])
99- )
83+ mat1_spec = DTensorSpec .from_dim_map (mesh , mat1 , [], shape = torch .Size ([8 , 8 ]))
10084 output_sharding = einop_rule (
10185 "ij,ij->ij" , OpSchema (func_schema , (mat1_spec , mat1_spec ), {})
10286 )
@@ -110,9 +94,7 @@ def test_einop_pointwise_propagation(self):
11094 mat1_spec = DTensorSpec .from_dim_map (
11195 mesh , mat1 , [], shape = torch .Size ([8 , 4 , 2 ])
11296 )
113- mat2_spec = DTensorSpec .from_dim_map (
114- mesh , [- 1 ], [], shape = torch .Size ([2 ])
115- )
97+ mat2_spec = DTensorSpec .from_dim_map (mesh , [- 1 ], [], shape = torch .Size ([2 ]))
11698 output_sharding = einop_rule (
11799 "ijk,k->ijk" , OpSchema (func_schema , (mat1_spec , mat2_spec ), {})
118100 )
@@ -144,17 +126,11 @@ def test_einop_merge_sharding(self):
144126 )
145127 mesh = DeviceMesh (self .device_type , mesh_shape )
146128
147- func_schema = self .parse_schema (
148- "aten::mm(Tensor self, Tensor mat2) -> Tensor"
149- )
129+ func_schema = self .parse_schema ("aten::mm(Tensor self, Tensor mat2) -> Tensor" )
150130
151131 mat1 , mat2 = [0 , - 1 ], [- 1 , 1 ]
152- mat1_spec = DTensorSpec .from_dim_map (
153- mesh , mat1 , [], shape = torch .Size ([8 , 4 ])
154- )
155- mat2_spec = DTensorSpec .from_dim_map (
156- mesh , mat2 , [], shape = torch .Size ([4 , 8 ])
157- )
132+ mat1_spec = DTensorSpec .from_dim_map (mesh , mat1 , [], shape = torch .Size ([8 , 4 ]))
133+ mat2_spec = DTensorSpec .from_dim_map (mesh , mat2 , [], shape = torch .Size ([4 , 8 ]))
158134 output_sharding = einop_rule (
159135 "mk,kn->mn" , OpSchema (func_schema , (mat1_spec , mat2_spec ), {})
160136 )
@@ -175,12 +151,8 @@ def test_einop_linearity(self):
175151 )
176152
177153 mat1 , mat2 = [0 , - 1 ], [- 1 , - 1 ]
178- mat1_spec = DTensorSpec .from_dim_map (
179- mesh , mat1 , [1 ], shape = torch .Size ([8 , 4 ])
180- )
181- mat2_spec = DTensorSpec .from_dim_map (
182- mesh , mat2 , [], shape = torch .Size ([4 , 8 ])
183- )
154+ mat1_spec = DTensorSpec .from_dim_map (mesh , mat1 , [1 ], shape = torch .Size ([8 , 4 ]))
155+ mat2_spec = DTensorSpec .from_dim_map (mesh , mat2 , [], shape = torch .Size ([4 , 8 ]))
184156 # if not turn on linearity, partial sum is not eligible to propagate, we return
185157 # suggestion to reshard inputs with no partial sum (i.e. all_reduce one input)
186158 output_sharding = einop_rule (
@@ -212,12 +184,8 @@ def test_einop_linearity(self):
212184 "aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"
213185 )
214186 mat1 , mat2 = [0 , - 1 ], [0 , - 1 ]
215- mat1_spec = DTensorSpec .from_dim_map (
216- mesh , mat1 , [1 ], shape = torch .Size ([8 , 6 ])
217- )
218- mat2_spec = DTensorSpec .from_dim_map (
219- mesh , mat2 , [], shape = torch .Size ([8 , 6 ])
220- )
187+ mat1_spec = DTensorSpec .from_dim_map (mesh , mat1 , [1 ], shape = torch .Size ([8 , 6 ]))
188+ mat2_spec = DTensorSpec .from_dim_map (mesh , mat2 , [], shape = torch .Size ([8 , 6 ]))
221189
222190 output_sharding = einop_rule (
223191 "ij,ij->ij" ,
@@ -237,16 +205,10 @@ def test_einop_multi_sharding_on_mesh_dim(self):
237205 mesh_shape = torch .arange (self .world_size )
238206 mesh = DeviceMesh (self .device_type , mesh_shape )
239207
240- func_schema = self .parse_schema (
241- "aten::mm(Tensor self, Tensor mat2) -> Tensor"
242- )
208+ func_schema = self .parse_schema ("aten::mm(Tensor self, Tensor mat2) -> Tensor" )
243209 mat1 , mat2 = [0 , - 1 ], [0 , - 1 ]
244- mat1_spec = DTensorSpec .from_dim_map (
245- mesh , mat1 , [], shape = torch .Size ([8 , 12 ])
246- )
247- mat2_spec = DTensorSpec .from_dim_map (
248- mesh , mat2 , [], shape = torch .Size ([12 , 4 ])
249- )
210+ mat1_spec = DTensorSpec .from_dim_map (mesh , mat1 , [], shape = torch .Size ([8 , 12 ]))
211+ mat2_spec = DTensorSpec .from_dim_map (mesh , mat2 , [], shape = torch .Size ([12 , 4 ]))
250212 output_sharding = einop_rule (
251213 "mk,kn->mn" , OpSchema (func_schema , (mat1_spec , mat2_spec ), {})
252214 )
@@ -271,19 +233,11 @@ def test_einop_errors(self):
271233 "aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"
272234 )
273235 mat1 , mat2 = [0 , - 1 ], [1 , - 1 ]
274- mat1_spec = DTensorSpec .from_dim_map (
275- mesh , mat1 , [], shape = torch .Size ([8 , 4 ])
276- )
277- mat2_spec = DTensorSpec .from_dim_map (
278- mesh , mat2 , [], shape = torch .Size ([8 , 4 ])
279- )
236+ mat1_spec = DTensorSpec .from_dim_map (mesh , mat1 , [], shape = torch .Size ([8 , 4 ]))
237+ mat2_spec = DTensorSpec .from_dim_map (mesh , mat2 , [], shape = torch .Size ([8 , 4 ]))
280238
281- with self .assertRaisesRegex (
282- RuntimeError , "sharded two different ways:"
283- ):
284- einop_rule (
285- "ij,ij->ij" , OpSchema (func_schema , (mat1_spec , mat2_spec ), {})
286- )
239+ with self .assertRaisesRegex (RuntimeError , "sharded two different ways:" ):
240+ einop_rule ("ij,ij->ij" , OpSchema (func_schema , (mat1_spec , mat2_spec ), {}))
287241
288242 @with_comms
289243 def test_pointwise_rules_broadcasting (self ):
@@ -293,12 +247,8 @@ def test_pointwise_rules_broadcasting(self):
293247 "where.self(Tensor condition, Tensor self, Tensor other) -> Tensor"
294248 )
295249 inp1 , inp2 , inp3 = [0 ], [], [- 1 , - 1 ]
296- condition = DTensorSpec .from_dim_map (
297- mesh , inp1 , [], shape = torch .Size ([8 ])
298- )
299- self_tensor = DTensorSpec .from_dim_map (
300- mesh , inp2 , [], shape = torch .Size ([])
301- )
250+ condition = DTensorSpec .from_dim_map (mesh , inp1 , [], shape = torch .Size ([8 ]))
251+ self_tensor = DTensorSpec .from_dim_map (mesh , inp2 , [], shape = torch .Size ([]))
302252 other_tensor = DTensorSpec .from_dim_map (
303253 mesh , inp3 , [], shape = torch .Size ([1 , 1 ])
304254 )
@@ -320,12 +270,8 @@ def test_pointwise_rules_suggestion(self):
320270 )
321271 # propagate point-wise sharding
322272 inp1 , inp2 = [- 1 , - 1 ], [- 1 , 0 ]
323- mat1_spec = DTensorSpec .from_dim_map (
324- mesh , inp1 , [], shape = torch .Size ([8 , 4 ])
325- )
326- mat2_spec = DTensorSpec .from_dim_map (
327- mesh , inp2 , [], shape = torch .Size ([8 , 4 ])
328- )
273+ mat1_spec = DTensorSpec .from_dim_map (mesh , inp1 , [], shape = torch .Size ([8 , 4 ]))
274+ mat2_spec = DTensorSpec .from_dim_map (mesh , inp2 , [], shape = torch .Size ([8 , 4 ]))
329275 # adding a positional argument -1 to arg schema
330276 output_sharding = pointwise_rule (
331277 OpSchema (func_schema , (mat1_spec , mat2_spec , - 1 ), {})
@@ -353,12 +299,8 @@ def test_pointwise_multi_sharding_on_mesh_dim(self):
353299
354300 # basic case to test implicit broadcasting shape alignment
355301 mat1 , mat2 = [- 1 , 0 ], [0 ]
356- mat1_spec = DTensorSpec .from_dim_map (
357- mesh , mat1 , [], shape = torch .Size ([20 , 6 ])
358- )
359- mat2_spec = DTensorSpec .from_dim_map (
360- mesh , mat2 , [], shape = torch .Size ([6 ])
361- )
302+ mat1_spec = DTensorSpec .from_dim_map (mesh , mat1 , [], shape = torch .Size ([20 , 6 ]))
303+ mat2_spec = DTensorSpec .from_dim_map (mesh , mat2 , [], shape = torch .Size ([6 ]))
362304 output_sharding = pointwise_rule (
363305 OpSchema (func_schema , (mat1_spec , mat2_spec ), {})
364306 )
@@ -384,9 +326,7 @@ def test_pointwise_multi_sharding_on_mesh_dim(self):
384326 # ensure that the suggestion is to reshard the first
385327 # arg by all_gather first tensor dim sharding
386328 schema_suggestion = output_sharding .schema_suggestions [0 ]
387- self .assertEqual (
388- schema_suggestion .args_schema [0 ].dim_map , [- 1 , - 1 , - 1 , 1 ]
389- )
329+ self .assertEqual (schema_suggestion .args_schema [0 ].dim_map , [- 1 , - 1 , - 1 , 1 ])
390330 self .assertEqual (schema_suggestion .args_schema [1 ].dim_map , mat2 )
391331
392332 @with_comms
@@ -431,9 +371,7 @@ def test_reduction_rule(self):
431371 )
432372 # reduction on a 2d mat
433373 mat1 = [0 , - 1 ]
434- mat1_spec = DTensorSpec .from_dim_map (
435- mesh , mat1 , [], shape = torch .Size ([8 , 4 ])
436- )
374+ mat1_spec = DTensorSpec .from_dim_map (mesh , mat1 , [], shape = torch .Size ([8 , 4 ]))
437375 # reduction on dim 0
438376 output_sharding_0 = reduction_rule (
439377 OpSchema (func_schema , (mat1_spec , 0 ), {}),
@@ -467,9 +405,7 @@ def test_reduction_rule(self):
467405 self .assertEqual (output_sharding_all_dim .output_spec .dim_map , [])
468406 # pending sum on mesh
469407 self .assertEqual (output_sharding_all_dim .output_spec .sums , [0 ])
470- self .assertEqual (
471- output_sharding_all_dim .output_spec .shape , torch .Size ([])
472- )
408+ self .assertEqual (output_sharding_all_dim .output_spec .shape , torch .Size ([]))
473409
474410
475411if __name__ == "__main__" :
0 commit comments