|
2 | 2 | # Owner(s): ["oncall: distributed"] |
3 | 3 |
|
4 | 4 | import torch |
| 5 | +import torch.nn.functional as F |
| 6 | +from torch.distributed.tensor.parallel import ( |
| 7 | + PairwiseParallel, |
| 8 | + parallelize_module, |
| 9 | +) |
5 | 10 | from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor |
6 | 11 | from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard |
7 | 12 |
|
|
11 | 16 | with_comms, |
12 | 17 | ) |
13 | 18 |
|
| 19 | +class DummyMLP(torch.nn.Module): |
| 20 | + def __init__(self, device): |
| 21 | + super(DummyMLP, self).__init__() |
| 22 | + self.net1 = torch.nn.Linear(5, 1024, device=device) |
| 23 | + self.relu = torch.nn.ReLU() |
| 24 | + self.net2 = torch.nn.Linear(1024, 4, device=device) |
14 | 25 |
|
15 | | -class DTensorTest(DTensorTestBase): |
16 | | - # @with_comms |
17 | | - # def test_tensor_constructor(self): |
18 | | - # import torch.distributed._tensor as dist_tensor |
19 | | - # shard_spec = PlacementSpec(device_mesh, strategies=[Shard(0)]) |
20 | | - # empty_tensor = dist_tensor.empty((12, 10), placement_spec=shard_spec) |
21 | | - # zero_tensor = dist_tensor.zeros((12, 10), placement_spec=shard_spec) |
22 | | - # one_tensor = dist_tensor.ones((12, 10), placement_spec=shard_spec) |
23 | | - |
24 | | - # zero_cuda_tensor = dist_tensor.zeros((12, 10), device="cuda", placement_spec=shard_spec) |
| 26 | + def forward(self, x): |
| 27 | + return self.net2(F.relu(self.net1(x))) |
25 | 28 |
|
26 | | - # dist_tensor.empty_like(empty_tensor) |
27 | | - # dist_tensor.zero_like(empty_tensor) |
28 | | - # dist_tensor.one_like(empty_tensor) |
| 29 | + def reset_parameters(self, *args, **kwargs): |
| 30 | + with torch.no_grad(): |
| 31 | + self.net1.weight.fill_(0.5) |
| 32 | + self.net2.weight.fill_(1) |
| 33 | + self.net1.bias.fill_(1.5) |
| 34 | + self.net2.bias.fill_(1.2) |
29 | 35 |
|
| 36 | +class DTensorTest(DTensorTestBase): |
30 | 37 | @with_comms |
31 | 38 | def test_dtensor_constructor(self): |
32 | 39 | device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) |
@@ -55,6 +62,58 @@ def test_dtensor_constructor(self): |
55 | 62 | requires_grad=True, |
56 | 63 | ) |
57 | 64 |
|
| 65 | + @with_comms |
| 66 | + def test_meta_dtensor(self): |
| 67 | + device_mesh = self.build_device_mesh() |
| 68 | + dist_specs = [[Shard(0)], [Replicate()]] |
| 69 | + meta_tensor = torch.randn(1024, 2048, device="meta") |
| 70 | + for dist_spec in dist_specs: |
| 71 | + # Test distribute_tensor on meta tensor |
| 72 | + meta_dtensor = distribute_tensor(meta_tensor, device_mesh, dist_spec) |
| 73 | + self.assertTrue(meta_dtensor.is_meta) |
| 74 | + meta_dtensor = torch.empty_like(meta_dtensor, device=self.device_type) |
| 75 | + torch.nn.init.constant_(meta_dtensor, 1.2) |
| 76 | + value_tensor = torch.empty_like(meta_dtensor.to_local()).fill_(1.2) |
| 77 | + self.assertFalse(meta_dtensor.is_meta) |
| 78 | + self.assertEqual(meta_dtensor.device.type, self.device_type) |
| 79 | + self.assertEqual(meta_dtensor.to_local(), value_tensor) |
| 80 | + # Test from_local on meta tensor |
| 81 | + meta_dtensor = DTensor.from_local(meta_tensor, device_mesh, dist_spec) |
| 82 | + meta_dtensor = torch.empty_like(meta_dtensor, device=self.device_type) |
| 83 | + torch.nn.init.constant_(meta_dtensor, 1.5) |
| 84 | + self.assertEqual(meta_dtensor.device.type, self.device_type) |
| 85 | + value_tensor = torch.empty_like(meta_dtensor.to_local()).fill_(1.5) |
| 86 | + self.assertEqual(meta_dtensor.to_local(), value_tensor) |
| 87 | + |
| 88 | + @with_comms |
| 89 | + def test_modules_w_meta_dtensor(self): |
| 90 | + model = DummyMLP("meta") |
| 91 | + device_mesh = self.build_device_mesh() |
| 92 | + model_tp = parallelize_module(model, device_mesh, PairwiseParallel()) |
| 93 | + model_tp.to_empty(device=self.device_type) |
| 94 | + model_tp.reset_parameters() |
| 95 | + optim = torch.optim.SGD(model_tp.parameters(), lr=0.1) |
| 96 | + model_regular = DummyMLP(self.device_type) |
| 97 | + model_regular_tp = parallelize_module(model_regular, device_mesh, PairwiseParallel()) |
| 98 | + optim_regular = torch.optim.SGD(model_regular_tp.parameters(), lr=0.1) |
| 99 | + model_regular_tp.reset_parameters() |
| 100 | + torch.manual_seed(0) |
| 101 | + inp = torch.randn(20, 5, device=self.device_type) |
| 102 | + |
| 103 | + output = model_tp(inp) |
| 104 | + output_regular = model_regular_tp(inp) |
| 105 | + self.assertEqual(output, output_regular) |
| 106 | + |
| 107 | + output.sum().backward() |
| 108 | + output_regular.sum().backward() |
| 109 | + |
| 110 | + optim.step() |
| 111 | + optim_regular.step() |
| 112 | + |
| 113 | + torch.manual_seed(1) |
| 114 | + inp = torch.randn(20, 5, device=self.device_type) |
| 115 | + self.assertEqual(model_tp(inp), model_regular_tp(inp)) |
| 116 | + |
58 | 117 | @with_comms |
59 | 118 | def test_dtensor_stride(self): |
60 | 119 | device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) |
|
0 commit comments