Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 0 additions & 17 deletions torch/distributed/_tensor/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,23 +214,6 @@ def __repr__(self):
# TODO: consider all_gather the local tensors for better debugging
return f"DTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})"

@classmethod
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
# if we find nn.functional name in dispatch op, dispatch to it instead,
# this allow us to override some python level behaviors that wouldn't be
# possible in __torch_dispatch__ level.
if func.__name__ in DTensor._custom_dispatch_ops:
# dispatch to the same table as the name should be different between
# torch_function and torch_dispatch
return DTensor._custom_dispatch_ops[func.__name__](*args, **kwargs)
else:
# if not, just do nothing here
return super().__torch_function__(func, types, args, kwargs)

@classmethod
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
Expand Down