Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
9cd43a9
[dtensor][3/N] refactor dispatching logic and add propagator
wanchaol Dec 13, 2022
53035d4
Update on "[dtensor][3/N] refactor dispatching logic and add propagator"
wanchaol Dec 20, 2022
376c2f4
Update on "[dtensor][3/N] refactor dispatching logic and add propagator"
wanchaol Dec 29, 2022
131e617
Update on "[dtensor][3/N] refactor dispatching logic and add propagator"
wanchaol Dec 29, 2022
669a306
Update on "[dtensor][3/N] refactor dispatching logic and add propagator"
wanchaol Jan 3, 2023
555c148
Update on "[dtensor][3/N] refactor dispatching logic and add propagator"
wanchaol Jan 3, 2023
6127aac
Update on "[dtensor][3/N] refactor dispatching logic and add propagator"
wanchaol Jan 4, 2023
1e181d6
Update on "[dtensor][3/N] refactor dispatching logic and add propagator"
wanchaol Jan 6, 2023
c940e37
Update on "[dtensor][4/N] refactor dispatching logic and add propagator"
wanchaol Jan 6, 2023
33893ca
Update on "[dtensor][4/N] refactor dispatching logic and add propagator"
wanchaol Jan 13, 2023
a838732
Update on "[dtensor][4/N] refactor dispatching logic and add propagator"
wanchaol Jan 15, 2023
1d892b6
Update on "[dtensor][4/N] refactor dispatching logic and add propagator"
wanchaol Jan 18, 2023
05326d1
Update on "[dtensor][4/N] refactor dispatching logic and add propagator"
wanchaol Jan 19, 2023
3bbd7c6
Update on "[dtensor][4/N] refactor dispatching logic and add propagator"
wanchaol Jan 24, 2023
55d02e8
Update on "[dtensor][4/N] refactor dispatching logic and add propagator"
wanchaol Jan 24, 2023
b261071
Update on "[dtensor][4/N] refactor dispatching logic and add propagator"
wanchaol Jan 26, 2023
1ca969c
Update on "[dtensor][4/N] refactor dispatching logic and add propagator"
wanchaol Jan 26, 2023
f287016
Update on "[dtensor][4/N] refactor dispatching logic and add propagator"
wanchaol Jan 26, 2023
2e0ad74
Update on "[dtensor][4/N] refactor dispatching logic and add propagator"
wanchaol Jan 26, 2023
64b23cc
Update on "[dtensor][4/N] refactor dispatching logic and add propagator"
wanchaol Jan 30, 2023
d961da8
Update on "[dtensor][4/N] refactor dispatching logic and add propagator"
wanchaol Jan 30, 2023
3b2426b
Update on "[dtensor][4/N] refactor dispatching logic and add propagator"
wanchaol Jan 31, 2023
f4f917e
Update on "[dtensor][4/N] refactor dispatching logic and add propagator"
wanchaol Jan 31, 2023
9b50e00
Update on "[dtensor][4/N] refactor dispatching logic and add propagator"
wanchaol Jan 31, 2023
213b785
Update on "[dtensor][4/N] refactor dispatching logic and add propagator"
wanchaol Jan 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/distributed/_tensor/test_common_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from torch._C import parse_schema
from torch.distributed._tensor import DeviceMesh
from torch.distributed._tensor.dispatch import OpSchema
from torch.distributed._tensor.op_schema import OpSchema

from torch.distributed._tensor.ops.common_rules import (
einop_rule,
Expand Down
7 changes: 3 additions & 4 deletions torch/distributed/_tensor/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Replicate,
Shard,
)
from torch.distributed._tensor.sharding_prop import ShardingPropagator
from torch.distributed._tensor.redistribute import Redistribute
from torch.utils._pytree import tree_flatten

Expand Down Expand Up @@ -133,9 +134,7 @@ class DTensor(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__

# class attribute that handles operator placements propagation
# rules, keyed by aten op name, value is propagation func
_op_to_rules: Dict[
str, Callable[["op_dispatch.OpSchema"], "op_dispatch.OutputSharding"]
] = {}
_propagator: ShardingPropagator = ShardingPropagator()

# class attribute that handles custom registered ops, all handled
# custom ops should appear in this table, and overriding the default
Expand Down Expand Up @@ -233,7 +232,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
func,
args,
kwargs,
DTensor._op_to_rules,
DTensor._propagator,
DTensor._custom_dispatch_ops,
)

Expand Down
135 changes: 24 additions & 111 deletions torch/distributed/_tensor/dispatch.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import Callable, cast, Dict, Optional, Tuple, Union
from typing import Callable, cast, Dict, Tuple, Union, Optional

import torch

import torch.distributed._tensor.api as dtensor
from torch.distributed._tensor.op_schema import (
ArgsType,
KwargsType,
OpSchema,
OutputSharding,
OutputSpecType,
)
from torch.distributed._tensor.placement_types import DTensorSpec
from torch.distributed._tensor.sharding_prop import ShardingPropagator
from torch.distributed._tensor.redistribute import redistribute_dtensor
from torch.distributed._tensor.utils import unwrap_local_tensor
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
from torch.utils._pytree import tree_flatten, tree_unflatten


"""
Expand All @@ -24,15 +22,6 @@
_ENABLE_FALLBACK = False


"""
Print information on ops input shape and sharding for debugging purposes.
"""
_DEBUG_VERBOSE = False

def unwrap_schema(e: object) -> object:
return e._spec if isinstance(e, dtensor.DTensor) else e


def wrap(res: object, spec: OutputSpecType) -> object:
if isinstance(res, torch.Tensor):
assert spec is not None and isinstance(
Expand Down Expand Up @@ -105,133 +94,57 @@ def _reshape_alias(
}


def propagate_input_sharding(
op_call: torch._ops.OpOverload,
args: Tuple[object, ...],
kwargs: Dict[str, object],
op_to_rules: Dict[str, Callable[[OpSchema], OutputSharding]],
) -> Tuple[OpSchema, bool, Optional[OutputSharding]]:
# unwrap the args/kwargs schema
args_schema = tree_map(unwrap_schema, args)
kwargs_schema = tree_map(unwrap_schema, kwargs)

op_schema = OpSchema(op_call._schema, args_schema, kwargs_schema)

if _DEBUG_VERBOSE and torch.distributed.get_rank() == 0:
print(f"{op_call}({op_schema})")
local_shapes = tree_map(
lambda t: t.to_local().shape if isinstance(t, dtensor.DTensor) else None,
args,
)
print(f" local shapes: {local_shapes}")

op_key = str(op_call)
sharding_prop_func = op_to_rules.get(op_key, None)

if sharding_prop_func is None:
# step 1. If there's not even one sharding rule
# implemented for the operator, we fall back to
# local tensor compute, this is wront currently
# we will change the behavior to reshard to full
# replicate and do the computatation
if not _ENABLE_FALLBACK:
raise NotImplementedError(
f"Operator {op_key} does not have a DistributedTensor rule registered."
)
else:
return op_schema, False, None

# step 2. there's sharding propagation rule, run
# sharding propagation to get output sharding
try:
output_sharding = sharding_prop_func(op_schema)
except Exception as e:
raise RuntimeError(
f"Sharding propagation failed on op {op_key}.\n"
f"Input schema: {op_schema}.\n"
f"Error: {e}"
) from e

# step 3. if can't get output_spec from sharding
# propagation (i.e. no rules apply for input
# placements), we do auto redistribute on inputs
# to get an eligble input, which we will pick a
# target schema base on the redistribute cost
# TODO: implement full auto distribute with a
# simple cost estimation model
if output_sharding.output_spec is None:
# do auto distributed/boxing here
if output_sharding.schema_suggestions is not None:
# pick the first suggestion for now,
target_schema = output_sharding.schema_suggestions[0]
# run sharding propagation again with target schema
output_sharding = sharding_prop_func(target_schema)

return target_schema, True, output_sharding

else:
raise RuntimeError(
f"Sharding propagation failed on op {op_key}!"
f"Input schema: {op_schema}."
f"Failed reason: {output_sharding.failed_reason}"
)
else:
return op_schema, False, output_sharding


def operator_dispatch(
op_call: torch._ops.OpOverload,
args: Tuple[object, ...],
kwargs: Dict[str, object],
op_to_rules: Dict[str, Callable[[OpSchema], OutputSharding]],
custom_dispatch_ops: Dict[str, Callable[..., object]],
sharding_propagator: ShardingPropagator,
custom_dispatch_ops: Optional[Dict[str, Callable[..., object]]] = None,
) -> object:
# first we need to lift some private aten aliases to public calls
if op_call in _CURRENT_DECOMPOSITION_TABLE:
return _CURRENT_DECOMPOSITION_TABLE[op_call](*args, **kwargs)

# STEP 0. See if threre're user defined custom aten operator
# STEP 0. See if there's a user defined custom aten operator
# implementations. Custom operators take the highest priority
if str(op_call) in custom_dispatch_ops:
if custom_dispatch_ops is not None and str(op_call) in custom_dispatch_ops:
# dispatch to user defined custom distributed tensor ops
return custom_dispatch_ops[str(op_call)](*args, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: will this custom_dispatch_ops be deprecated once register_impl is no longer needed? I assume that eventually we want to get rid of register_impl and fully adopt propagation rules.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we should deprecate this once we move all ops to use propagation rules


target_schema, redistribute, output_sharding = propagate_input_sharding(
op_call, args, kwargs, op_to_rules
)
# unwrap the args/kwargs schema
op_schema = sharding_propagator.prepare_op_schema(op_call, args, kwargs)

output_sharding = sharding_propagator.propagate_op_sharding(op_call, op_schema)

if output_sharding is None:
# default to local tensor ops, this is wrong
# but we use it now to enable more tensor point-wise ops
# TODO: delete this and use replicate (all_gather) as
# the default fallback.
tensor_args = tree_map(unwrap_local_tensor, args)
tensor_kwargs = tree_map(unwrap_local_tensor, kwargs)
local_results = op_call(*tensor_args, **tensor_kwargs)
return wrap(local_results, target_schema.args_spec[0])
# if the schema suggestion from sharding prop is not the same instance as the
# input op_schema, it indicates a reshard, we need to redistribute the input
# tensors before calling the local op
assert output_sharding.schema_suggestions is not None
needs_redistribute = output_sharding.schema_suggestions[0] is not op_schema
suggested_input_schema = output_sharding.schema_suggestions[0]

local_tensor_args = pack_args_kwargs_with_local_tensor(
args,
target_schema.args_schema,
redistribute_with_schema=redistribute,
suggested_input_schema.args_schema,
redistribute_with_schema=needs_redistribute,
)
local_tensor_kwargs = pack_args_kwargs_with_local_tensor(
kwargs,
target_schema.kwargs_schema,
redistribute_with_schema=redistribute,
suggested_input_schema.kwargs_schema,
redistribute_with_schema=needs_redistribute,
)

# run local op computation with potentially modified args/kwargs
local_tensor_args = cast(Tuple[object, ...], local_tensor_args)
local_tensor_kwargs = cast(Dict[str, object], local_tensor_kwargs)
local_results = op_call(*local_tensor_args, **local_tensor_kwargs)

if target_schema.is_inplace:
if suggested_input_schema.is_inplace:
# inplace op should return self instead of re-wrapping
self = cast(dtensor.DTensor, args[0])
self._spec = cast(DTensorSpec, output_sharding.output_spec)
return self
elif target_schema.is_out_variant:
elif suggested_input_schema.is_out_variant:
# out variant could possibly have multiple out args (i.e. lu_unpack.out)
output_specs = (
(output_sharding.output_spec,)
Expand All @@ -240,7 +153,7 @@ def operator_dispatch(
)
out_dts = []
spec_idx = 0
for arg in target_schema.func_schema.arguments:
for arg in suggested_input_schema.func_schema.arguments:
if arg.is_out:
out_dt = cast(dtensor.DTensor, kwargs[arg.name])
out_dt._spec = cast(DTensorSpec, output_specs[spec_idx])
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/_tensor/ops/common_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import cast, Dict, List, Optional, Sequence, Tuple

import torch
from torch.distributed._tensor.dispatch import OpSchema, OutputSharding
from torch.distributed._tensor.op_schema import OpSchema, OutputSharding
from torch.distributed._tensor.ops.utils import prod
from torch.distributed._tensor.placement_types import DTensorSpec

Expand Down
9 changes: 4 additions & 5 deletions torch/distributed/_tensor/ops/math_ops.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import cast, Optional, Sequence

from torch.distributed._tensor.api import DTensor
from torch.distributed._tensor.dispatch import OpSchema, OutputSharding
from torch.distributed._tensor.op_schema import OpSchema, OutputSharding
from torch.distributed._tensor.ops.common_rules import pointwise_rule, reduction_rule
from torch.distributed._tensor.ops.utils import (
as_list,
Expand Down Expand Up @@ -46,7 +45,7 @@ def sum_rule(op_schema: OpSchema) -> OutputSharding:
"aten.sum.dim_IntList",
]
for sum_op in sum_ops:
DTensor._op_to_rules[sum_op] = sum_rule
register_prop_rule(sum_op)(sum_rule)


@register_prop_rule("aten._softmax.default")
Expand Down Expand Up @@ -96,7 +95,7 @@ def mean_rule(op_schema: OpSchema) -> OutputSharding:
]

for mean_op in mean_ops:
DTensor._op_to_rules[mean_op] = mean_rule
register_prop_rule(mean_op)(mean_rule)


def var_rule(op_schema: OpSchema) -> OutputSharding:
Expand All @@ -122,7 +121,7 @@ def var_rule(op_schema: OpSchema) -> OutputSharding:
]

for var_op in var_ops:
DTensor._op_to_rules[var_op] = var_rule
register_prop_rule(var_op)(var_rule)


@register_prop_rule("aten.var.correction")
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/_tensor/ops/matrix_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# implement matrix related ops for distributed tensor
from torch.distributed._tensor.dispatch import OpSchema, OutputSharding
from torch.distributed._tensor.op_schema import OpSchema, OutputSharding
from torch.distributed._tensor.ops.common_rules import einop_rule, pointwise_rule
from torch.distributed._tensor.ops.utils import register_prop_rule

Expand Down
7 changes: 3 additions & 4 deletions torch/distributed/_tensor/ops/pointwise_ops.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import cast

from torch.distributed._tensor.api import DTensor
from torch.distributed._tensor.dispatch import OpSchema, OutputSharding
from torch.distributed._tensor.op_schema import OpSchema, OutputSharding
from torch.distributed._tensor.ops.common_rules import (
linear_pointwise_rule,
pointwise_rule,
Expand Down Expand Up @@ -370,11 +369,11 @@


for op in linear_pointwise_ops:
DTensor._op_to_rules[op] = linear_pointwise_rule
register_prop_rule(op)(linear_pointwise_rule)


for op in pointwise_ops:
DTensor._op_to_rules[op] = pointwise_rule
register_prop_rule(op)(pointwise_rule)


def _register_non_deterministic_op(op):
Expand Down
11 changes: 5 additions & 6 deletions torch/distributed/_tensor/ops/tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
import torch
from torch.distributed._tensor.api import (
_Partial,
DTensor,
DTensorSpec,
Placement,
Replicate,
Shard,
)
from torch.distributed._tensor.dispatch import OpSchema, OutputSharding
from torch.distributed._tensor.op_schema import OpSchema, OutputSharding
from torch.distributed._tensor.ops.common_rules import einop_rule, pointwise_rule
from torch.distributed._tensor.ops.utils import register_prop_rule, normalize_dim

Expand Down Expand Up @@ -105,16 +104,16 @@ def new_factory_rule(op_schema: OpSchema) -> OutputSharding:
no_shard_prop_ops = ["aten._local_scalar_dense.default"]

for op in default_prop_ops:
DTensor._op_to_rules[op] = default_prop_rule
register_prop_rule(op)(default_prop_rule)

for op in create_like_ops:
DTensor._op_to_rules[op] = prop_create_like
register_prop_rule(op)(prop_create_like)

for op in no_shard_prop_ops:
DTensor._op_to_rules[op] = no_shard_prop_rule
register_prop_rule(op)(no_shard_prop_rule)

for op in new_factory_ops:
DTensor._op_to_rules[op] = new_factory_rule
register_prop_rule(op)(new_factory_rule)


@register_prop_rule("aten.bucketize.Tensor")
Expand Down
6 changes: 5 additions & 1 deletion torch/distributed/_tensor/ops/tp_sharding_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
# implement matrix related ops for distributed tensor
from typing import List

import torch
import torch.utils._pytree as pytree
from torch.distributed._tensor.api import DTensor
from torch.distributed._tensor.ops.utils import register_impl, unwrap_single_placement
from torch.distributed._tensor.utils import unwrap_local_tensor

"""
The ops below were quickly hacked and needed to be polished down the road.
Expand All @@ -15,6 +15,10 @@
"""


def unwrap_local_tensor(e: DTensor) -> torch.Tensor:
return e._local_tensor if isinstance(e, DTensor) else e


@register_impl("aten.split.Tensor")
# pyre-fixme[2]: Parameter must be annotated.
def dist_split(self: DTensor, split_size_or_sections, dim=0) -> List[DTensor]:
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/_tensor/ops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def register_prop_rule(func):
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def wrapper(impl):
DTensor._op_to_rules[func] = impl
DTensor._propagator.register_sharding_prop_rule(func, impl)
return impl

return wrapper
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/_tensor/ops/view_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from torch import Tensor
from torch.distributed._tensor.api import Shard
from torch.distributed._tensor.dispatch import OpSchema, OutputSharding
from torch.distributed._tensor.op_schema import OpSchema, OutputSharding
from torch.distributed._tensor.ops.utils import (
normalize_dim,
normalize_dims,
Expand Down
Loading