Skip to content

Commit 77f3366

Browse files
fduwjjpytorchmergebot
authored andcommitted
[PT-D] Enable Meta Tensor Support for DTensor (#92652)
Pull Request resolved: #92652 Approved by: https://github.com/XilunWu, https://github.com/wanchaol
1 parent e714e37 commit 77f3366

File tree

4 files changed

+88
-15
lines changed

4 files changed

+88
-15
lines changed

test/distributed/_tensor/test_dtensor.py

Lines changed: 72 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
# Owner(s): ["oncall: distributed"]
33

44
import torch
5+
import torch.nn.functional as F
6+
from torch.distributed.tensor.parallel import (
7+
PairwiseParallel,
8+
parallelize_module,
9+
)
510
from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor
611
from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard
712

@@ -11,22 +16,24 @@
1116
with_comms,
1217
)
1318

19+
class DummyMLP(torch.nn.Module):
20+
def __init__(self, device):
21+
super(DummyMLP, self).__init__()
22+
self.net1 = torch.nn.Linear(5, 1024, device=device)
23+
self.relu = torch.nn.ReLU()
24+
self.net2 = torch.nn.Linear(1024, 4, device=device)
1425

15-
class DTensorTest(DTensorTestBase):
16-
# @with_comms
17-
# def test_tensor_constructor(self):
18-
# import torch.distributed._tensor as dist_tensor
19-
# shard_spec = PlacementSpec(device_mesh, strategies=[Shard(0)])
20-
# empty_tensor = dist_tensor.empty((12, 10), placement_spec=shard_spec)
21-
# zero_tensor = dist_tensor.zeros((12, 10), placement_spec=shard_spec)
22-
# one_tensor = dist_tensor.ones((12, 10), placement_spec=shard_spec)
23-
24-
# zero_cuda_tensor = dist_tensor.zeros((12, 10), device="cuda", placement_spec=shard_spec)
26+
def forward(self, x):
27+
return self.net2(F.relu(self.net1(x)))
2528

26-
# dist_tensor.empty_like(empty_tensor)
27-
# dist_tensor.zero_like(empty_tensor)
28-
# dist_tensor.one_like(empty_tensor)
29+
def reset_parameters(self, *args, **kwargs):
30+
with torch.no_grad():
31+
self.net1.weight.fill_(0.5)
32+
self.net2.weight.fill_(1)
33+
self.net1.bias.fill_(1.5)
34+
self.net2.bias.fill_(1.2)
2935

36+
class DTensorTest(DTensorTestBase):
3037
@with_comms
3138
def test_dtensor_constructor(self):
3239
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
@@ -55,6 +62,58 @@ def test_dtensor_constructor(self):
5562
requires_grad=True,
5663
)
5764

65+
@with_comms
66+
def test_meta_dtensor(self):
67+
device_mesh = self.build_device_mesh()
68+
dist_specs = [[Shard(0)], [Replicate()]]
69+
meta_tensor = torch.randn(1024, 2048, device="meta")
70+
for dist_spec in dist_specs:
71+
# Test distribute_tensor on meta tensor
72+
meta_dtensor = distribute_tensor(meta_tensor, device_mesh, dist_spec)
73+
self.assertTrue(meta_dtensor.is_meta)
74+
meta_dtensor = torch.empty_like(meta_dtensor, device=self.device_type)
75+
torch.nn.init.constant_(meta_dtensor, 1.2)
76+
value_tensor = torch.empty_like(meta_dtensor.to_local()).fill_(1.2)
77+
self.assertFalse(meta_dtensor.is_meta)
78+
self.assertEqual(meta_dtensor.device.type, self.device_type)
79+
self.assertEqual(meta_dtensor.to_local(), value_tensor)
80+
# Test from_local on meta tensor
81+
meta_dtensor = DTensor.from_local(meta_tensor, device_mesh, dist_spec)
82+
meta_dtensor = torch.empty_like(meta_dtensor, device=self.device_type)
83+
torch.nn.init.constant_(meta_dtensor, 1.5)
84+
self.assertEqual(meta_dtensor.device.type, self.device_type)
85+
value_tensor = torch.empty_like(meta_dtensor.to_local()).fill_(1.5)
86+
self.assertEqual(meta_dtensor.to_local(), value_tensor)
87+
88+
@with_comms
89+
def test_modules_w_meta_dtensor(self):
90+
model = DummyMLP("meta")
91+
device_mesh = self.build_device_mesh()
92+
model_tp = parallelize_module(model, device_mesh, PairwiseParallel())
93+
model_tp.to_empty(device=self.device_type)
94+
model_tp.reset_parameters()
95+
optim = torch.optim.SGD(model_tp.parameters(), lr=0.1)
96+
model_regular = DummyMLP(self.device_type)
97+
model_regular_tp = parallelize_module(model_regular, device_mesh, PairwiseParallel())
98+
optim_regular = torch.optim.SGD(model_regular_tp.parameters(), lr=0.1)
99+
model_regular_tp.reset_parameters()
100+
torch.manual_seed(0)
101+
inp = torch.randn(20, 5, device=self.device_type)
102+
103+
output = model_tp(inp)
104+
output_regular = model_regular_tp(inp)
105+
self.assertEqual(output, output_regular)
106+
107+
output.sum().backward()
108+
output_regular.sum().backward()
109+
110+
optim.step()
111+
optim_regular.step()
112+
113+
torch.manual_seed(1)
114+
inp = torch.randn(20, 5, device=self.device_type)
115+
self.assertEqual(model_tp(inp), model_regular_tp(inp))
116+
58117
@with_comms
59118
def test_dtensor_stride(self):
60119
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))

torch/distributed/_tensor/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ def distribute_tensor(
4040
# get default device mesh if there's nothing specified
4141
device_mesh = get_global_device_mesh() if device_mesh is None else device_mesh
4242
# convert tensor to the correponding device type if it's not in that device type
43-
tensor = tensor.to(device_mesh.device_type)
43+
if not tensor.is_meta:
44+
tensor = tensor.to(device_mesh.device_type)
4445
# set default placements to replicated if not specified
4546
if placements is None:
4647
placements = [Replicate() for _ in range(device_mesh.ndim)]

torch/distributed/_tensor/api.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,8 @@ def from_local(
277277
# in the mesh dimension
278278
device_mesh = get_global_device_mesh() if device_mesh is None else device_mesh
279279
# convert the local tensor to desired device base on device mesh's device_type
280-
local_tensor = local_tensor.to(device_mesh.device_type)
280+
if not local_tensor.is_meta:
281+
local_tensor = local_tensor.to(device_mesh.device_type)
281282

282283
# set default placements to replicated if not specified
283284
if placements is None:

torch/distributed/_tensor/device_mesh.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,12 @@ def scatter(
322322
Returns:
323323
A :class:`Work` object
324324
"""
325+
# TODO: Ideally we should use the meta tensor way
326+
# (to register a meta kernel for the collective op)
327+
# so that it would avoid the communication. Need to
328+
# remove the check below once that is done.
329+
if output.is_meta:
330+
return None
325331
dim_group = self._dim_groups[mesh_dim]
326332
# src need to be global rank
327333
src_for_dim = 0
@@ -369,6 +375,12 @@ def broadcast(
369375
Returns:
370376
A :class:`Work` object
371377
"""
378+
# TODO: Ideally we should use the meta tensor way
379+
# (to register a meta kernel for the collective op)
380+
# so that it would avoid the communication. Need to
381+
# remove the check below once that is done.
382+
if tensor.is_meta:
383+
return None
372384
dim_group = self._dim_groups[mesh_dim]
373385
# src need to be global rank
374386
src_for_dim = 0

0 commit comments

Comments
 (0)