33import torch
44from torch .distributed ._tensor import DeviceMesh , DTensor , Replicate
55from torch .distributed .tensor .parallel ._utils import _create_1d_device_mesh
6- from torch .distributed .tensor .parallel .api import _parallelize_linear , _parallelize_mlp
6+ from torch .distributed .tensor .parallel .api import parallelize_module , _parallelize_linear , _parallelize_mlp
77from torch .distributed .tensor .parallel .style import (
88 ColwiseParallel ,
99 make_input_replicate_1d ,
@@ -77,6 +77,7 @@ def _compare_params(
7777 self ,
7878 local_module ,
7979 dist_module ,
80+ rank0_only ,
8081 skip_rowwise_bias = False ,
8182 compare_grad = False ,
8283 ):
@@ -85,7 +86,7 @@ def _compare_params(
8586 dist_param = dist_module .get_parameter (name )
8687 param = param .grad if compare_grad else param
8788 dist_param = dist_param .grad if compare_grad else dist_param
88- if self .rank == 0 or (
89+ if ( not rank0_only ) or ( self .rank == 0 ) or (
8990 name not in ["net2.bias" ]
9091 and not skip_rowwise_bias
9192 or name not in ["bias" , "net2.bias" ]
@@ -95,15 +96,16 @@ def _compare_params(
9596 dist_param .redistribute (
9697 device_mesh = dist_param .device_mesh , placements = replicate
9798 ).to_local (),
99+ f"{ name } not equal between dist and non-dist"
98100 )
99101
100- def _compare_module (self , local_module , dist_module , inp_size , rowwise = False ):
102+ def _compare_module (self , local_module , dist_module , inp_size , rank0_only = True , rowwise = False ):
101103 LR = 0.25 # the learning rate we use for testing
102104 local_optim = torch .optim .SGD (local_module .parameters (), lr = LR )
103105 dist_optim = torch .optim .SGD (dist_module .parameters (), lr = LR )
104106 torch .manual_seed (0 )
105107 inp = torch .rand (* inp_size , device = self .device_type )
106- self ._compare_params (local_module , dist_module )
108+ self ._compare_params (local_module , dist_module , rank0_only )
107109
108110 # check forward correctness
109111 local_output = local_module (inp )
@@ -118,11 +120,11 @@ def _compare_module(self, local_module, dist_module, inp_size, rowwise=False):
118120 dist_output .sum ().backward ()
119121
120122 # check backward and ensure gradients are same
121- self ._compare_params (local_module , dist_module , rowwise , True )
123+ self ._compare_params (local_module , dist_module , rank0_only , rowwise , True )
122124
123125 local_optim .step ()
124126 dist_optim .step ()
125- self ._compare_params (local_module , dist_module , rowwise )
127+ self ._compare_params (local_module , dist_module , rank0_only , rowwise )
126128
127129 @with_comms
128130 def test_parallelize_mlp (self ):
@@ -141,6 +143,23 @@ def test_parallelize_mlp(self):
141143 model_tp = _parallelize_mlp (model_tp , device_mesh , PairwiseParallel ())
142144 self ._compare_module (model , model_tp , inp_size )
143145
146+ @with_comms
147+ def test_parallelize_mlp_with_module_api (self ):
148+ inp_size = [12 , 10 ]
149+ model = MLPModule (self .device_type )
150+ model_tp = MLPModule (self .device_type )
151+
152+ # Ensure model are initialized the same way.
153+ self .assertEqual (model .net1 .weight , model_tp .net1 .weight )
154+ self .assertEqual (model .net1 .bias , model_tp .net1 .bias )
155+ self .assertEqual (model .net2 .weight , model_tp .net2 .weight )
156+ self .assertEqual (model .net2 .bias , model_tp .net2 .bias )
157+
158+ # Parallelize module.
159+ device_mesh = DeviceMesh (self .device_type , torch .arange (self .world_size ))
160+ model_tp = parallelize_module (model_tp , device_mesh , {"net1" : ColwiseParallel (), "net2" : ColwiseParallel ()})
161+ self ._compare_module (model , model_tp , inp_size , rank0_only = False )
162+
144163 @with_comms
145164 def test_parallelize_mlp_error (self ):
146165 class DummyParallel (ParallelStyle ):
@@ -177,7 +196,7 @@ def test_linear_row_wise_parallel(self):
177196
178197 # let each rank generate unique local input
179198 torch .manual_seed (self .rank )
180- self ._compare_module (model , model_tp , inp_size , True )
199+ self ._compare_module (model , model_tp , inp_size , rowwise = True )
181200
182201 @with_comms
183202 def test_linear_col_wise_parallel (self ):
0 commit comments