Skip to content

Commit daf1eb4

Browse files
wanchaolpytorchmergebot
authored andcommitted
try to fix the warning in distribute_tensor (#125476)
Pull Request resolved: #125476 Approved by: https://github.com/albanD, https://github.com/awgu ghstack dependencies: #125475
1 parent 7ffa555 commit daf1eb4

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

torch/distributed/_tensor/api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,7 @@ def distribute_tensor(
573573
# OffsetBasedRNGTracker to perform random operators.
574574
# TODO: the value assignment to global variable is not the ideal solution
575575
# we can replace it in future.
576-
if is_rng_supported_mesh(device_mesh) and not random._rng_tracker:
576+
if not random._rng_tracker and is_rng_supported_mesh(device_mesh):
577577
random._rng_tracker = OffsetBasedRNGTracker(device_type)
578578

579579
if not tensor.is_leaf:
@@ -612,7 +612,7 @@ def distribute_tensor(
612612
)
613613
return tensor
614614

615-
local_tensor = tensor
615+
local_tensor = tensor.detach()
616616

617617
# distribute the tensor according to the placements.
618618
placements = list(placements)
@@ -637,7 +637,7 @@ def distribute_tensor(
637637
# detach the local tensor passed to DTensor since after the construction
638638
# of DTensor, autograd would work on top of DTensor instead of local tensor
639639
return DTensor(
640-
local_tensor.detach().requires_grad_(tensor.requires_grad),
640+
local_tensor.requires_grad_(tensor.requires_grad),
641641
device_mesh,
642642
placements,
643643
shape=tensor.size(),

torch/distributed/_tensor/dispatch.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates
2+
import contextlib
23
import functools
34
import operator
45
from typing import cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING
@@ -181,15 +182,15 @@ def default_tensor(spec: DTensorSpec) -> torch.Tensor:
181182

182183
# run local op computation with potentially modified args/kwargs
183184
local_tensor_args = cast(Tuple[object, ...], local_tensor_args)
184-
if op_call in self._random_ops and is_rng_supported_mesh(mesh):
185-
if not random._rng_tracker:
185+
if op_call in self._random_ops:
186+
if not random._rng_tracker and is_rng_supported_mesh(mesh):
186187
# Default to `OffsetBasedRNGTracker` if the parallelism API
187188
# did not already construct one
188189
random._rng_tracker = random.OffsetBasedRNGTracker(mesh.device_type)
189190
# For DTensor random operator, run it within a distribute region
190191
with random._rng_tracker._distribute_region(
191192
cast(dtensor.DTensor, args[0])._spec
192-
):
193+
) if random._rng_tracker else contextlib.nullcontext():
193194
local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
194195
else:
195196
local_results = op_call(*local_tensor_args, **op_info.local_kwargs)

0 commit comments

Comments
 (0)