Skip to content

Commit de1ad0a

Browse files
committed
[PT-D] Enable Meta Tensor Support for DTensor
ghstack-source-id: 7b2b65c Pull Request resolved: #92652
1 parent 58e24d6 commit de1ad0a

File tree

4 files changed

+31
-16
lines changed

4 files changed

+31
-16
lines changed

test/distributed/_tensor/test_dtensor.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,6 @@
1313

1414

1515
class DTensorTest(DTensorTestBase):
16-
# @with_comms
17-
# def test_tensor_constructor(self):
18-
# import torch.distributed._tensor as dist_tensor
19-
# shard_spec = PlacementSpec(device_mesh, strategies=[Shard(0)])
20-
# empty_tensor = dist_tensor.empty((12, 10), placement_spec=shard_spec)
21-
# zero_tensor = dist_tensor.zeros((12, 10), placement_spec=shard_spec)
22-
# one_tensor = dist_tensor.ones((12, 10), placement_spec=shard_spec)
23-
24-
# zero_cuda_tensor = dist_tensor.zeros((12, 10), device="cuda", placement_spec=shard_spec)
25-
26-
# dist_tensor.empty_like(empty_tensor)
27-
# dist_tensor.zero_like(empty_tensor)
28-
# dist_tensor.one_like(empty_tensor)
29-
3016
@with_comms
3117
def test_dtensor_constructor(self):
3218
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
@@ -55,6 +41,21 @@ def test_dtensor_constructor(self):
5541
requires_grad=True,
5642
)
5743

44+
@with_comms
45+
def test_meta_dtensor(self):
46+
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
47+
dist_specs = [[Shard(0)], [Replicate()]]
48+
meta_tensor = torch.randn(1024, 2048, device="meta")
49+
for dist_spec in dist_specs:
50+
# Test distribute_tensor on meta tensor
51+
meta_dtensor = distribute_tensor(meta_tensor, device_mesh, dist_spec)
52+
torch.nn.init.constant_(meta_dtensor, 1.2)
53+
self.assertEqual(meta_dtensor.device.type, self.device_type)
54+
# Test from_local on meta tensor
55+
meta_dtensor = DTensor.from_local(meta_tensor, device_mesh, dist_spec)
56+
torch.nn.init.constant_(meta_dtensor, 1.5)
57+
self.assertEqual(meta_dtensor.device.type, self.device_type)
58+
5859
@with_comms
5960
def test_dtensor_stride(self):
6061
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))

torch/distributed/_tensor/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ def distribute_tensor(
4040
# get default device mesh if there's nothing specified
4141
device_mesh = get_global_device_mesh() if device_mesh is None else device_mesh
4242
# convert tensor to the correponding device type if it's not in that device type
43-
tensor = tensor.to(device_mesh.device_type)
43+
if not tensor.is_meta:
44+
tensor = tensor.to(device_mesh.device_type)
4445
# set default placements to replicated if not specified
4546
if placements is None:
4647
placements = [Replicate() for _ in range(device_mesh.ndim)]

torch/distributed/_tensor/api.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,8 @@ def from_local(
277277
# in the mesh dimension
278278
device_mesh = get_global_device_mesh() if device_mesh is None else device_mesh
279279
# convert the local tensor to desired device base on device mesh's device_type
280-
local_tensor = local_tensor.to(device_mesh.device_type)
280+
if not local_tensor.is_meta:
281+
local_tensor = local_tensor.to(device_mesh.device_type)
281282

282283
# set default placements to replicated if not specified
283284
if placements is None:

torch/distributed/_tensor/device_mesh.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,12 @@ def scatter(
283283
Returns:
284284
A :class:`Work` object
285285
"""
286+
# TODO: Ideally we should use the meta tensor way
287+
# (to register a meta kernel for the collective op)
288+
# so that it would avoid the communication. Need to
289+
# remove the check below once that is done.
290+
if output.is_meta:
291+
return None
286292
dim_group = self._dim_groups[mesh_dim]
287293
# src need to be global rank
288294
src_for_dim = 0
@@ -330,6 +336,12 @@ def broadcast(
330336
Returns:
331337
A :class:`Work` object
332338
"""
339+
# TODO: Ideally we should use the meta tensor way
340+
# (to register a meta kernel for the collective op)
341+
# so that it would avoid the communication. Need to
342+
# remove the check below once that is done.
343+
if tensor.is_meta:
344+
return None
333345
dim_group = self._dim_groups[mesh_dim]
334346
# src need to be global rank
335347
src_for_dim = 0

0 commit comments

Comments
 (0)