Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 18 additions & 36 deletions test/distributed/_tensor/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,26 @@

import torch
import torch.nn as nn
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import DTensorTestBase, with_comms
from torch.distributed._tensor import (
distribute_tensor,
distribute_module,
DeviceMesh,
distribute_module,
distribute_tensor,
DTensor,
Shard,
Replicate,
Shard,
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)


class MyModel(nn.Module):
def __init__(self, n_features, n_layers, device):
super().__init__()
self.seq = nn.Sequential(
*[
nn.Linear(n_features, n_features, device=device)
for _ in range(n_layers)
]
*[nn.Linear(n_features, n_features, device=device) for _ in range(n_layers)]
)

def forward(self, x):
Expand Down Expand Up @@ -50,12 +50,8 @@ def test_distribute_tensor(self):
tensor_to_shard = torch.randn(
3 * self.world_size, 3, requires_grad=requires_grad
)
dist_tensor = distribute_tensor(
tensor_to_shard, device_mesh, shard_spec
)
self.assertEqual(
dist_tensor.size(), torch.Size([3 * self.world_size, 3])
)
dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_spec)
self.assertEqual(dist_tensor.size(), torch.Size([3 * self.world_size, 3]))
local_tensor = dist_tensor.to_local()
self.assertEqual(local_tensor.size(), torch.Size([3, 3]))
if requires_grad:
Expand All @@ -78,9 +74,7 @@ def test_distribute_tensor_errors(self):
dtensor = distribute_tensor(tensor_to_distribute, device_mesh, spec)

with self.assertRaisesRegex(ValueError, "to a different device mesh"):
new_mesh = DeviceMesh(
self.device_type, torch.arange(self.world_size)
)
new_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
distribute_tensor(dtensor, new_mesh, [Shard(0)])

with self.assertRaisesRegex(ValueError, "to a different placements"):
Expand All @@ -104,9 +98,7 @@ def test_distribute_tensor_uneven_sharding(self):
splitted_tensor_list = tensor_to_shard.tensor_split(
self.world_size, dim=shard_dim
)
dist_tensor = distribute_tensor(
tensor_to_shard, device_mesh, shard_spec
)
dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_spec)
self.assertEqual(dist_tensor.size(), torch.Size(input_size))
local_tensor = dist_tensor.to_local()
self.assertEqual(local_tensor, splitted_tensor_list[self.rank])
Expand All @@ -115,9 +107,7 @@ def test_distribute_tensor_uneven_sharding(self):
def test_distribute_module(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
# fully shard all linear modules on dim 0
module_to_shard = MyModel(
5 * self.world_size, 20, device=self.device_type
)
module_to_shard = MyModel(5 * self.world_size, 20, device=self.device_type)
shard_spec = [Shard(0)]

def shard_fn(name, module, device_mesh):
Expand All @@ -128,9 +118,7 @@ def shard_fn(name, module, device_mesh):
)
module.register_parameter(name, dist_param)

sharded_module = distribute_module(
module_to_shard, device_mesh, shard_fn
)
sharded_module = distribute_module(module_to_shard, device_mesh, shard_fn)
for param in sharded_module.parameters():
self.assertIsInstance(param, DTensor)
self.assertEqual(param.placements, shard_spec)
Expand Down Expand Up @@ -162,21 +150,15 @@ def replicate_fn(name, module, device_mesh):

# only shard part of module, and rest of module should be replicate
def shard_fn(name, module, device_mesh):
if isinstance(module, nn.Linear) and (
name == "seq.0" or name == "seq.8"
):
if isinstance(module, nn.Linear) and (name == "seq.0" or name == "seq.8"):
for name, param in module.named_parameters():
dist_param = torch.nn.Parameter(
distribute_tensor(param, device_mesh, shard_spec)
)
module.register_parameter(name, dist_param)

module_to_distribute = MyModel(
5 * self.world_size, 20, device=self.device_type
)
dist_module = distribute_module(
module_to_distribute, device_mesh, shard_fn
)
module_to_distribute = MyModel(5 * self.world_size, 20, device=self.device_type)
dist_module = distribute_module(module_to_distribute, device_mesh, shard_fn)
for name, param in dist_module.named_parameters():
self.assertIsInstance(param, DTensor)
if name.startswith("seq.0") or name.startswith("seq.8"):
Expand Down
136 changes: 36 additions & 100 deletions test/distributed/_tensor/test_common_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@
# Owner(s): ["oncall: distributed"]

import torch
from torch.testing._internal.common_utils import run_tests
from torchgen.model import FunctionSchema
from torch.distributed._tensor import DeviceMesh
from torch.distributed._tensor.dispatch import OpSchema

from torch.distributed._tensor.ops.common_rules import (
einop_rule,
reduction_rule,
pointwise_rule,
reduction_rule,
)
from torch.distributed._tensor.placement_types import DTensorSpec
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)
from torch.distributed._tensor import DeviceMesh
from torchgen.model import FunctionSchema


class CommonRulesTest(DTensorTestBase):
Expand All @@ -34,17 +34,11 @@ def test_einop_basic_propagation(self):
# plain einsum, mm
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))

func_schema = self.parse_schema(
"aten::mm(Tensor self, Tensor mat2) -> Tensor"
)
func_schema = self.parse_schema("aten::mm(Tensor self, Tensor mat2) -> Tensor")
# propagate col-wise sharding
mat1, mat2 = [-1, -1], [-1, 0]
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], shape=torch.Size([8, 4])
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, mat2, [], shape=torch.Size([4, 8])
)
mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [], shape=torch.Size([8, 4]))
mat2_spec = DTensorSpec.from_dim_map(mesh, mat2, [], shape=torch.Size([4, 8]))
output_sharding = einop_rule(
"mk,kn->mn", OpSchema(func_schema, (mat1_spec, mat2_spec), {})
)
Expand All @@ -55,12 +49,8 @@ def test_einop_basic_propagation(self):

# propagate row-wise sharding
mat1, mat2 = [0, -1], [-1, -1]
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], shape=torch.Size([8, 4])
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, mat2, [], shape=torch.Size([4, 8])
)
mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [], shape=torch.Size([8, 4]))
mat2_spec = DTensorSpec.from_dim_map(mesh, mat2, [], shape=torch.Size([4, 8]))
output_sharding = einop_rule(
"mk,kn->mn", OpSchema(func_schema, (mat1_spec, mat2_spec), {})
)
Expand All @@ -71,12 +61,8 @@ def test_einop_basic_propagation(self):

# generate partial
mat1, mat2 = [-1, 0], [0, -1]
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], shape=torch.Size([8, 4])
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, mat2, [], shape=torch.Size([4, 8])
)
mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [], shape=torch.Size([8, 4]))
mat2_spec = DTensorSpec.from_dim_map(mesh, mat2, [], shape=torch.Size([4, 8]))
output_sharding = einop_rule(
"mk,kn->mn", OpSchema(func_schema, (mat1_spec, mat2_spec), {})
)
Expand All @@ -94,9 +80,7 @@ def test_einop_pointwise_propagation(self):
)
# addition
mat1 = [0, -1]
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], shape=torch.Size([8, 8])
)
mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [], shape=torch.Size([8, 8]))
output_sharding = einop_rule(
"ij,ij->ij", OpSchema(func_schema, (mat1_spec, mat1_spec), {})
)
Expand All @@ -110,9 +94,7 @@ def test_einop_pointwise_propagation(self):
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], shape=torch.Size([8, 4, 2])
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, [-1], [], shape=torch.Size([2])
)
mat2_spec = DTensorSpec.from_dim_map(mesh, [-1], [], shape=torch.Size([2]))
output_sharding = einop_rule(
"ijk,k->ijk", OpSchema(func_schema, (mat1_spec, mat2_spec), {})
)
Expand Down Expand Up @@ -144,17 +126,11 @@ def test_einop_merge_sharding(self):
)
mesh = DeviceMesh(self.device_type, mesh_shape)

func_schema = self.parse_schema(
"aten::mm(Tensor self, Tensor mat2) -> Tensor"
)
func_schema = self.parse_schema("aten::mm(Tensor self, Tensor mat2) -> Tensor")

mat1, mat2 = [0, -1], [-1, 1]
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], shape=torch.Size([8, 4])
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, mat2, [], shape=torch.Size([4, 8])
)
mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [], shape=torch.Size([8, 4]))
mat2_spec = DTensorSpec.from_dim_map(mesh, mat2, [], shape=torch.Size([4, 8]))
output_sharding = einop_rule(
"mk,kn->mn", OpSchema(func_schema, (mat1_spec, mat2_spec), {})
)
Expand All @@ -175,12 +151,8 @@ def test_einop_linearity(self):
)

mat1, mat2 = [0, -1], [-1, -1]
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [1], shape=torch.Size([8, 4])
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, mat2, [], shape=torch.Size([4, 8])
)
mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [1], shape=torch.Size([8, 4]))
mat2_spec = DTensorSpec.from_dim_map(mesh, mat2, [], shape=torch.Size([4, 8]))
# if not turn on linearity, partial sum is not eligible to propagate, we return
# suggestion to reshard inputs with no partial sum (i.e. all_reduce one input)
output_sharding = einop_rule(
Expand Down Expand Up @@ -212,12 +184,8 @@ def test_einop_linearity(self):
"aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"
)
mat1, mat2 = [0, -1], [0, -1]
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [1], shape=torch.Size([8, 6])
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, mat2, [], shape=torch.Size([8, 6])
)
mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [1], shape=torch.Size([8, 6]))
mat2_spec = DTensorSpec.from_dim_map(mesh, mat2, [], shape=torch.Size([8, 6]))

output_sharding = einop_rule(
"ij,ij->ij",
Expand All @@ -237,16 +205,10 @@ def test_einop_multi_sharding_on_mesh_dim(self):
mesh_shape = torch.arange(self.world_size)
mesh = DeviceMesh(self.device_type, mesh_shape)

func_schema = self.parse_schema(
"aten::mm(Tensor self, Tensor mat2) -> Tensor"
)
func_schema = self.parse_schema("aten::mm(Tensor self, Tensor mat2) -> Tensor")
mat1, mat2 = [0, -1], [0, -1]
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], shape=torch.Size([8, 12])
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, mat2, [], shape=torch.Size([12, 4])
)
mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [], shape=torch.Size([8, 12]))
mat2_spec = DTensorSpec.from_dim_map(mesh, mat2, [], shape=torch.Size([12, 4]))
output_sharding = einop_rule(
"mk,kn->mn", OpSchema(func_schema, (mat1_spec, mat2_spec), {})
)
Expand All @@ -271,19 +233,11 @@ def test_einop_errors(self):
"aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"
)
mat1, mat2 = [0, -1], [1, -1]
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], shape=torch.Size([8, 4])
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, mat2, [], shape=torch.Size([8, 4])
)
mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [], shape=torch.Size([8, 4]))
mat2_spec = DTensorSpec.from_dim_map(mesh, mat2, [], shape=torch.Size([8, 4]))

with self.assertRaisesRegex(
RuntimeError, "sharded two different ways:"
):
einop_rule(
"ij,ij->ij", OpSchema(func_schema, (mat1_spec, mat2_spec), {})
)
with self.assertRaisesRegex(RuntimeError, "sharded two different ways:"):
einop_rule("ij,ij->ij", OpSchema(func_schema, (mat1_spec, mat2_spec), {}))

@with_comms
def test_pointwise_rules_broadcasting(self):
Expand All @@ -293,12 +247,8 @@ def test_pointwise_rules_broadcasting(self):
"where.self(Tensor condition, Tensor self, Tensor other) -> Tensor"
)
inp1, inp2, inp3 = [0], [], [-1, -1]
condition = DTensorSpec.from_dim_map(
mesh, inp1, [], shape=torch.Size([8])
)
self_tensor = DTensorSpec.from_dim_map(
mesh, inp2, [], shape=torch.Size([])
)
condition = DTensorSpec.from_dim_map(mesh, inp1, [], shape=torch.Size([8]))
self_tensor = DTensorSpec.from_dim_map(mesh, inp2, [], shape=torch.Size([]))
other_tensor = DTensorSpec.from_dim_map(
mesh, inp3, [], shape=torch.Size([1, 1])
)
Expand All @@ -320,12 +270,8 @@ def test_pointwise_rules_suggestion(self):
)
# propagate point-wise sharding
inp1, inp2 = [-1, -1], [-1, 0]
mat1_spec = DTensorSpec.from_dim_map(
mesh, inp1, [], shape=torch.Size([8, 4])
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, inp2, [], shape=torch.Size([8, 4])
)
mat1_spec = DTensorSpec.from_dim_map(mesh, inp1, [], shape=torch.Size([8, 4]))
mat2_spec = DTensorSpec.from_dim_map(mesh, inp2, [], shape=torch.Size([8, 4]))
# adding a positional argument -1 to arg schema
output_sharding = pointwise_rule(
OpSchema(func_schema, (mat1_spec, mat2_spec, -1), {})
Expand Down Expand Up @@ -353,12 +299,8 @@ def test_pointwise_multi_sharding_on_mesh_dim(self):

# basic case to test implicit broadcasting shape alignment
mat1, mat2 = [-1, 0], [0]
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], shape=torch.Size([20, 6])
)
mat2_spec = DTensorSpec.from_dim_map(
mesh, mat2, [], shape=torch.Size([6])
)
mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [], shape=torch.Size([20, 6]))
mat2_spec = DTensorSpec.from_dim_map(mesh, mat2, [], shape=torch.Size([6]))
output_sharding = pointwise_rule(
OpSchema(func_schema, (mat1_spec, mat2_spec), {})
)
Expand All @@ -384,9 +326,7 @@ def test_pointwise_multi_sharding_on_mesh_dim(self):
# ensure that the suggestion is to reshard the first
# arg by all_gather first tensor dim sharding
schema_suggestion = output_sharding.schema_suggestions[0]
self.assertEqual(
schema_suggestion.args_schema[0].dim_map, [-1, -1, -1, 1]
)
self.assertEqual(schema_suggestion.args_schema[0].dim_map, [-1, -1, -1, 1])
self.assertEqual(schema_suggestion.args_schema[1].dim_map, mat2)

@with_comms
Expand Down Expand Up @@ -431,9 +371,7 @@ def test_reduction_rule(self):
)
# reduction on a 2d mat
mat1 = [0, -1]
mat1_spec = DTensorSpec.from_dim_map(
mesh, mat1, [], shape=torch.Size([8, 4])
)
mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [], shape=torch.Size([8, 4]))
# reduction on dim 0
output_sharding_0 = reduction_rule(
OpSchema(func_schema, (mat1_spec, 0), {}),
Expand Down Expand Up @@ -467,9 +405,7 @@ def test_reduction_rule(self):
self.assertEqual(output_sharding_all_dim.output_spec.dim_map, [])
# pending sum on mesh
self.assertEqual(output_sharding_all_dim.output_spec.sums, [0])
self.assertEqual(
output_sharding_all_dim.output_spec.shape, torch.Size([])
)
self.assertEqual(output_sharding_all_dim.output_spec.shape, torch.Size([]))


if __name__ == "__main__":
Expand Down
Loading