Skip to content

Commit ff49256

Browse files
committed
[PT-D][TP] Fix TP API for FQN path based parallelization
ghstack-source-id: 992954f Pull Request resolved: #93029
1 parent 5441f2c commit ff49256

File tree

2 files changed

+30
-10
lines changed

2 files changed

+30
-10
lines changed

test/distributed/tensor/parallel/test_parallelize_api.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate
55
from torch.distributed.tensor.parallel._utils import _create_1d_device_mesh
6-
from torch.distributed.tensor.parallel.api import _parallelize_linear, _parallelize_mlp
6+
from torch.distributed.tensor.parallel.api import parallelize_module, _parallelize_linear, _parallelize_mlp
77
from torch.distributed.tensor.parallel.style import (
88
ColwiseParallel,
99
make_input_replicate_1d,
@@ -77,6 +77,7 @@ def _compare_params(
7777
self,
7878
local_module,
7979
dist_module,
80+
rank0_only,
8081
skip_rowwise_bias=False,
8182
compare_grad=False,
8283
):
@@ -85,7 +86,7 @@ def _compare_params(
8586
dist_param = dist_module.get_parameter(name)
8687
param = param.grad if compare_grad else param
8788
dist_param = dist_param.grad if compare_grad else dist_param
88-
if self.rank == 0 or (
89+
if (not rank0_only) or (self.rank == 0) or (
8990
name not in ["net2.bias"]
9091
and not skip_rowwise_bias
9192
or name not in ["bias", "net2.bias"]
@@ -95,15 +96,16 @@ def _compare_params(
9596
dist_param.redistribute(
9697
device_mesh=dist_param.device_mesh, placements=replicate
9798
).to_local(),
99+
f"{name} not equal between dist and non-dist"
98100
)
99101

100-
def _compare_module(self, local_module, dist_module, inp_size, rowwise=False):
102+
def _compare_module(self, local_module, dist_module, inp_size, rank0_only=True, rowwise=False):
101103
LR = 0.25 # the learning rate we use for testing
102104
local_optim = torch.optim.SGD(local_module.parameters(), lr=LR)
103105
dist_optim = torch.optim.SGD(dist_module.parameters(), lr=LR)
104106
torch.manual_seed(0)
105107
inp = torch.rand(*inp_size, device=self.device_type)
106-
self._compare_params(local_module, dist_module)
108+
self._compare_params(local_module, dist_module, rank0_only)
107109

108110
# check forward correctness
109111
local_output = local_module(inp)
@@ -118,11 +120,11 @@ def _compare_module(self, local_module, dist_module, inp_size, rowwise=False):
118120
dist_output.sum().backward()
119121

120122
# check backward and ensure gradients are same
121-
self._compare_params(local_module, dist_module, rowwise, True)
123+
self._compare_params(local_module, dist_module, rank0_only, rowwise, True)
122124

123125
local_optim.step()
124126
dist_optim.step()
125-
self._compare_params(local_module, dist_module, rowwise)
127+
self._compare_params(local_module, dist_module, rank0_only, rowwise)
126128

127129
@with_comms
128130
def test_parallelize_mlp(self):
@@ -141,6 +143,23 @@ def test_parallelize_mlp(self):
141143
model_tp = _parallelize_mlp(model_tp, device_mesh, PairwiseParallel())
142144
self._compare_module(model, model_tp, inp_size)
143145

146+
@with_comms
147+
def test_parallelize_mlp_with_module_api(self):
148+
inp_size = [12, 10]
149+
model = MLPModule(self.device_type)
150+
model_tp = MLPModule(self.device_type)
151+
152+
# Ensure model are initialized the same way.
153+
self.assertEqual(model.net1.weight, model_tp.net1.weight)
154+
self.assertEqual(model.net1.bias, model_tp.net1.bias)
155+
self.assertEqual(model.net2.weight, model_tp.net2.weight)
156+
self.assertEqual(model.net2.bias, model_tp.net2.bias)
157+
158+
# Parallelize module.
159+
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
160+
model_tp = parallelize_module(model_tp, device_mesh, {"net1": ColwiseParallel(), "net2": ColwiseParallel()})
161+
self._compare_module(model, model_tp, inp_size, rank0_only=False)
162+
144163
@with_comms
145164
def test_parallelize_mlp_error(self):
146165
class DummyParallel(ParallelStyle):
@@ -177,7 +196,7 @@ def test_linear_row_wise_parallel(self):
177196

178197
# let each rank generate unique local input
179198
torch.manual_seed(self.rank)
180-
self._compare_module(model, model_tp, inp_size, True)
199+
self._compare_module(model, model_tp, inp_size, rowwise=True)
181200

182201
@with_comms
183202
def test_linear_col_wise_parallel(self):

torch/distributed/tensor/parallel/api.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,12 @@ def parallelize_module( # type: ignore[return]
9797
for module_path, parallelize_style in parallelize_plan.items():
9898
sub_module = module.get_submodule(module_path)
9999
module.register_module( # type: ignore[call-arg] # pyre-ignore[20]
100+
module_path,
100101
parallelize_module( # type: ignore[arg-type]
101-
module_path, sub_module, device_mesh, parallelize_style # type: ignore[arg-type] # pyre-ignore[6]
102-
)
102+
sub_module, device_mesh, parallelize_style # type: ignore[arg-type] # pyre-ignore[6]
103+
),
103104
)
104-
return module
105+
return module
105106
else:
106107
raise RuntimeError( # pyre-ignore[7]
107108
"Expect Union[ParallelStyle, Dict[str, ParallelStyle]] for"

0 commit comments

Comments
 (0)