Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion torch/distributed/_tensor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
ArgKwargsType = Union[Tuple[object, ...], Dict[str, object]]
# ATen op schemas could have Tensor, Tuple[Tensor] and List[Tensor], so output type sould
# be the same set of possiblities.
OutputSpecType = Optional[Union[DTensorSpec, Sequence[DTensorSpec]]]
OutputSpecType = Optional[Union[DTensorSpec, Sequence[Optional[DTensorSpec]]]]


def unwrap_local_tensor(e: "dtensor.DTensor") -> torch.Tensor:
Expand Down Expand Up @@ -45,8 +45,13 @@ def wrap(res: object, spec: OutputSpecType) -> object:
assert spec is not None and isinstance(
spec, tuple
), f"output spec does not match with output! Expected tuple, got {spec}"

# NOTE: local results might return Optional Tensor from ATen op, so we need to
# handle that case and make sure we don't wrap None with DTensor.
# (i.e. native_layer_norm.backward)
return tuple(
dtensor.DTensor(e, s.mesh, s.placements, size=s.shape)
if e is not None and s is not None else None
for e, s in zip(res, spec)
)
else:
Expand Down