@@ -49,10 +49,8 @@ def test_colwise_parallel_style(self):
4949 model = nn .Linear (16 , 16 , device = self .device_type )
5050
5151 default_col_parallel = ColwiseParallel ()
52+ colwise_mod = parallelize_module (deepcopy (model ), mesh , default_col_parallel )
5253 with comm_mode :
53- colwise_mod = parallelize_module (
54- deepcopy (model ), mesh , default_col_parallel
55- )
5654 out = colwise_mod (tensor )
5755 # ensure output shard on the last dim
5856 self .assertEqual (out .shape , (8 , 16 // self .world_size ))
@@ -65,10 +63,8 @@ def test_colwise_parallel_style(self):
6563 self .assertEqual (comm_mode .get_total_counts (), 1 )
6664
6765 sharded_col_parallel = ColwiseParallel (input_layouts = Shard (0 ))
66+ colwise_mod = parallelize_module (deepcopy (model ), mesh , sharded_col_parallel )
6867 with comm_mode :
69- colwise_mod = parallelize_module (
70- deepcopy (model ), mesh , sharded_col_parallel
71- )
7268 out = colwise_mod (tensor )
7369 # ensure output shard on the last dim
7470 self .assertEqual (out .shape , (8 * self .world_size , 16 // self .world_size ))
@@ -94,10 +90,8 @@ def test_colwise_parallel_embedding(self):
9490 model = nn .Embedding (16 , 16 , device = self .device_type )
9591
9692 default_col_parallel = ColwiseParallel ()
93+ colwise_mod = parallelize_module (deepcopy (model ), mesh , default_col_parallel )
9794 with comm_mode :
98- colwise_mod = parallelize_module (
99- deepcopy (model ), mesh , default_col_parallel
100- )
10195 out = colwise_mod (tensor )
10296 # ensure output shard on the last dim
10397 self .assertEqual (out .shape , (4 , 2 , 16 // self .world_size ))
@@ -119,10 +113,8 @@ def test_rowwise_parallel_style(self):
119113 model = nn .Linear (16 , 16 , device = self .device_type )
120114
121115 default_row_parallel = RowwiseParallel ()
116+ rowwise_mod = parallelize_module (deepcopy (model ), mesh , default_row_parallel )
122117 with comm_mode :
123- rowwise_mod = parallelize_module (
124- deepcopy (model ), mesh , default_row_parallel
125- )
126118 out = rowwise_mod (tensor )
127119 # ensure output replicated
128120 self .assertEqual (out .shape , (8 , 16 ))
@@ -135,10 +127,8 @@ def test_rowwise_parallel_style(self):
135127 self .assertEqual (comm_mode .get_total_counts (), 1 )
136128
137129 sharded_row_parallel = RowwiseParallel (output_layouts = Shard (0 ))
130+ rowwise_mod = parallelize_module (deepcopy (model ), mesh , sharded_row_parallel )
138131 with comm_mode :
139- rowwise_mod = parallelize_module (
140- deepcopy (model ), mesh , sharded_row_parallel
141- )
142132 out = rowwise_mod (tensor )
143133 # ensure output replicated
144134 self .assertEqual (out .shape , (8 // self .world_size , 16 ))
@@ -163,10 +153,10 @@ def test_rowwise_parallel_embedding(self):
163153 tensor = torch .arange (8 , device = self .device_type ).reshape (4 , 2 )
164154 model = nn .Embedding (16 , 16 , device = self .device_type )
165155
156+ rowwise_mod = parallelize_module (
157+ deepcopy (model ), mesh , RowwiseParallel (input_layouts = Replicate ())
158+ )
166159 with comm_mode :
167- rowwise_mod = parallelize_module (
168- deepcopy (model ), mesh , RowwiseParallel (input_layouts = Replicate ())
169- )
170160 out = rowwise_mod (tensor )
171161 # ensure output shard on the last dim
172162 self .assertEqual (out .shape , (4 , 2 , 16 ))
0 commit comments