Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 105 additions & 7 deletions test/distributed/fsdp/test_fsdp_pure_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@

import sys

import torch
import torch.distributed.fsdp._traversal_utils as traversal_utils
from torch import distributed as dist
from torch.distributed.fsdp import CPUOffload
from torch.distributed.fsdp import (
CPUOffload,
FullyShardedDataParallel as FSDP,
MixedPrecision,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
CUDAInitMode,
Expand All @@ -13,7 +19,6 @@
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TEST_WITH_DEV_DBG_ASAN,
)
Expand All @@ -37,13 +42,20 @@ def world_size(self):
return min(4, super().world_size)

@skip_if_lt_x_gpu(2)
@parametrize(
"cpu_offload",
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)],
)
def test_pure_fp16(self, cpu_offload: CPUOffload):
def test_pure_fp16_training(self):
"""Tests pure FP16 training, including when the parameter's dtype is
changed after FSDP initialization and before training."""
self.run_subtests(
{
"cpu_offload": [
CPUOffload(offload_params=True),
CPUOffload(offload_params=False),
]
},
self._test_pure_fp16_training,
)

def _test_pure_fp16_training(self, cpu_offload: CPUOffload):
self._test_fsdp_parity(
NestedWrappedModule,
FSDPInitMode.RECURSIVE,
Expand All @@ -54,6 +66,92 @@ def test_pure_fp16(self, cpu_offload: CPUOffload):
use_pure_fp16=True,
)

@skip_if_lt_x_gpu(2)
def test_fp16_dtypes(self):
"""
Tests that both user-facing parameter/gradient dtypes and internal
saved dtype attributes are as expected when using an FP16 model
possibly with explicit mixed precision enabled.
"""
self.run_subtests(
{
"to_half_before_fsdp_init": [False, True],
"use_orig_params": [False, True],
"mixed_precision": [
MixedPrecision(),
MixedPrecision(
param_dtype=torch.float16,
reduce_dtype=torch.float32,
),
MixedPrecision(
param_dtype=torch.float32,
),
],
},
self._test_fp16_dtypes,
)

def _test_fp16_dtypes(
self,
to_half_before_fsdp_init: bool,
use_orig_params: bool,
mixed_precision: MixedPrecision,
):
model = NestedWrappedModule.init(
self.process_group,
FSDPInitMode.NO_FSDP,
CUDAInitMode.CUDA_NEVER,
{},
)
fsdp_kwargs = {
"use_orig_params": use_orig_params,
"device_id": torch.cuda.current_device(),
"mixed_precision": mixed_precision,
}
if to_half_before_fsdp_init:
model = model.half()
fsdp_model = FSDP(model, **fsdp_kwargs)
if not to_half_before_fsdp_init:
fsdp_model = fsdp_model.half()
for param in fsdp_model.parameters():
self.assertEqual(param.dtype, torch.float16)
inp = tuple(
t.half() if torch.is_tensor(t) else t
for t in fsdp_model.module.get_input(torch.device("cuda"))
)
out = fsdp_model(*inp)
out.sum().backward()

# Check handle dtype attributes
for handle in traversal_utils._get_fsdp_handles(fsdp_model):
self.assertEqual(handle.flat_param.dtype, torch.float16)
self.assertEqual(handle.flat_param.grad.dtype, torch.float16)
self.assertEqual(handle._orig_param_dtype, torch.float16)
# Specifying `mixed_precision` takes precedence over the model
# dtype for both `param_dtype` and `reduce_dtype`
if mixed_precision.param_dtype is not None:
self.assertEqual(
handle._fwd_bwd_param_dtype, mixed_precision.param_dtype
)
else:
self.assertEqual(handle._fwd_bwd_param_dtype, torch.float16)
if mixed_precision.reduce_dtype is not None:
self.assertEqual(handle._reduce_dtype, mixed_precision.reduce_dtype)
elif (
mixed_precision.reduce_dtype is None
and mixed_precision.param_dtype is not None
):
# Special case: infer reduce dtype from parameter dtype
self.assertEqual(handle._reduce_dtype, mixed_precision.param_dtype)
else:
self.assertEqual(handle._reduce_dtype, torch.float16)

# Check parameter/gradient dtypes
for param in fsdp_model.parameters():
self.assertEqual(param.dtype, torch.float16)
if param.grad is not None:
self.assertEqual(param.grad.dtype, torch.float16)


instantiate_parametrized_tests(TestPureFP16)

Expand Down
28 changes: 24 additions & 4 deletions torch/distributed/fsdp/flat_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,9 +522,14 @@ def _init_param_reduce_dtypes(
is ``None``, in which case we assume the gradient reduction dtype
matches the forward/backward parameter dtype.
"""
low_prec_param_dtype_specified = mp_param_dtype is not None
low_prec_reduce_dtype_specified = mp_reduce_dtype is not None
if low_prec_param_dtype_specified and not low_prec_reduce_dtype_specified:
# Save whether these dtypes were specified so that we permit the
# parameter dtype to change up until the lazy initialization
self._low_prec_param_dtype_specified = mp_param_dtype is not None
self._low_prec_reduce_dtype_specified = mp_reduce_dtype is not None
if (
self._low_prec_param_dtype_specified
and not self._low_prec_reduce_dtype_specified
):
# Special case: infer gradient reduction mixed precision
self._fwd_bwd_param_dtype = mp_param_dtype
self._reduce_dtype = self._fwd_bwd_param_dtype
Expand Down Expand Up @@ -770,6 +775,21 @@ def init_flat_param_attributes(self) -> None:
reshard methods in this class for the allocation and free pattern.
"""
flat_param = self.flat_param
if flat_param.dtype != self._orig_param_dtype:
# Entering this branch means that the user changed the parameter
# dtype after FSDP initialization, in which case we may need to
# refresh some saved dtype attributes (dtypes specified as a part
# of mixed precision take precedence).
if not self._low_prec_param_dtype_specified:
self._fwd_bwd_param_dtype = flat_param.dtype
# For `reduce_dtype`, require `param_dtype` was not specified since
# then we infer the `reduce_dtype` from the specified `param_dtype`
if (
not self._low_prec_reduce_dtype_specified
and not self._low_prec_param_dtype_specified
):
self._reduce_dtype = flat_param.dtype
self._orig_param_dtype = flat_param.dtype
cpu_device = torch.device("cpu")
if self._offload_params:
p_assert(
Expand Down Expand Up @@ -1552,7 +1572,7 @@ def _use_sharded_views(self) -> None:
# Allow the original data to be freed via garbage collection
param.data = torch.empty(
0,
dtype=param.dtype,
dtype=self.flat_param.dtype, # in case `flat_param` changed dtype
device=self.flat_param.device,
requires_grad=False,
)
Expand Down