Skip to content
85 changes: 72 additions & 13 deletions test/distributed/_tensor/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
# Owner(s): ["oncall: distributed"]

import torch
import torch.nn.functional as F
from torch.distributed.tensor.parallel import (
PairwiseParallel,
parallelize_module,
)
from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor
from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard

Expand All @@ -11,22 +16,24 @@
with_comms,
)

class DummyMLP(torch.nn.Module):
def __init__(self, device):
super(DummyMLP, self).__init__()
self.net1 = torch.nn.Linear(5, 1024, device=device)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(1024, 4, device=device)

class DTensorTest(DTensorTestBase):
# @with_comms
# def test_tensor_constructor(self):
# import torch.distributed._tensor as dist_tensor
# shard_spec = PlacementSpec(device_mesh, strategies=[Shard(0)])
# empty_tensor = dist_tensor.empty((12, 10), placement_spec=shard_spec)
# zero_tensor = dist_tensor.zeros((12, 10), placement_spec=shard_spec)
# one_tensor = dist_tensor.ones((12, 10), placement_spec=shard_spec)

# zero_cuda_tensor = dist_tensor.zeros((12, 10), device="cuda", placement_spec=shard_spec)
def forward(self, x):
return self.net2(F.relu(self.net1(x)))

# dist_tensor.empty_like(empty_tensor)
# dist_tensor.zero_like(empty_tensor)
# dist_tensor.one_like(empty_tensor)
def reset_parameters(self, *args, **kwargs):
with torch.no_grad():
self.net1.weight.fill_(0.5)
self.net2.weight.fill_(1)
self.net1.bias.fill_(1.5)
self.net2.bias.fill_(1.2)

class DTensorTest(DTensorTestBase):
@with_comms
def test_dtensor_constructor(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
Expand Down Expand Up @@ -55,6 +62,58 @@ def test_dtensor_constructor(self):
requires_grad=True,
)

@with_comms
def test_meta_dtensor(self):
device_mesh = self.build_device_mesh()
dist_specs = [[Shard(0)], [Replicate()]]
meta_tensor = torch.randn(1024, 2048, device="meta")
for dist_spec in dist_specs:
# Test distribute_tensor on meta tensor
meta_dtensor = distribute_tensor(meta_tensor, device_mesh, dist_spec)
self.assertTrue(meta_dtensor.is_meta)
meta_dtensor = torch.empty_like(meta_dtensor, device=self.device_type)
torch.nn.init.constant_(meta_dtensor, 1.2)
value_tensor = torch.empty_like(meta_dtensor.to_local()).fill_(1.2)
self.assertFalse(meta_dtensor.is_meta)
self.assertEqual(meta_dtensor.device.type, self.device_type)
self.assertEqual(meta_dtensor.to_local(), value_tensor)
# Test from_local on meta tensor
meta_dtensor = DTensor.from_local(meta_tensor, device_mesh, dist_spec)
meta_dtensor = torch.empty_like(meta_dtensor, device=self.device_type)
torch.nn.init.constant_(meta_dtensor, 1.5)
self.assertEqual(meta_dtensor.device.type, self.device_type)
value_tensor = torch.empty_like(meta_dtensor.to_local()).fill_(1.5)
self.assertEqual(meta_dtensor.to_local(), value_tensor)

@with_comms
def test_modules_w_meta_dtensor(self):
model = DummyMLP("meta")
device_mesh = self.build_device_mesh()
model_tp = parallelize_module(model, device_mesh, PairwiseParallel())
model_tp.to_empty(device=self.device_type)
model_tp.reset_parameters()
optim = torch.optim.SGD(model_tp.parameters(), lr=0.1)
model_regular = DummyMLP(self.device_type)
model_regular_tp = parallelize_module(model_regular, device_mesh, PairwiseParallel())
optim_regular = torch.optim.SGD(model_regular_tp.parameters(), lr=0.1)
model_regular_tp.reset_parameters()
torch.manual_seed(0)
inp = torch.randn(20, 5, device=self.device_type)

output = model_tp(inp)
output_regular = model_regular_tp(inp)
self.assertEqual(output, output_regular)

output.sum().backward()
output_regular.sum().backward()

optim.step()
optim_regular.step()

torch.manual_seed(1)
inp = torch.randn(20, 5, device=self.device_type)
self.assertEqual(model_tp(inp), model_regular_tp(inp))

@with_comms
def test_dtensor_stride(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
Expand Down
3 changes: 2 additions & 1 deletion torch/distributed/_tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def distribute_tensor(
# get default device mesh if there's nothing specified
device_mesh = get_global_device_mesh() if device_mesh is None else device_mesh
# convert tensor to the correponding device type if it's not in that device type
tensor = tensor.to(device_mesh.device_type)
if not tensor.is_meta:
tensor = tensor.to(device_mesh.device_type)
# set default placements to replicated if not specified
if placements is None:
placements = [Replicate() for _ in range(device_mesh.ndim)]
Expand Down
3 changes: 2 additions & 1 deletion torch/distributed/_tensor/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,8 @@ def from_local(
# in the mesh dimension
device_mesh = get_global_device_mesh() if device_mesh is None else device_mesh
# convert the local tensor to desired device base on device mesh's device_type
local_tensor = local_tensor.to(device_mesh.device_type)
if not local_tensor.is_meta:
local_tensor = local_tensor.to(device_mesh.device_type)

# set default placements to replicated if not specified
if placements is None:
Expand Down
12 changes: 12 additions & 0 deletions torch/distributed/_tensor/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,12 @@ def scatter(
Returns:
A :class:`Work` object
"""
# TODO: Ideally we should use the meta tensor way
# (to register a meta kernel for the collective op)
# so that it would avoid the communication. Need to
# remove the check below once that is done.
if output.is_meta:
return None
dim_group = self._dim_groups[mesh_dim]
# src need to be global rank
src_for_dim = 0
Expand Down Expand Up @@ -369,6 +375,12 @@ def broadcast(
Returns:
A :class:`Work` object
"""
# TODO: Ideally we should use the meta tensor way
# (to register a meta kernel for the collective op)
# so that it would avoid the communication. Need to
# remove the check below once that is done.
if tensor.is_meta:
return None
dim_group = self._dim_groups[mesh_dim]
# src need to be global rank
src_for_dim = 0
Expand Down