Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
DefaultSavePlanner,
DefaultLoadPlanner,
)
from torch.distributed.checkpoint.optimizer import (
load_sharded_optimizer_state_dict,
)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
from torch.distributed.tensor.parallel import (
Expand Down Expand Up @@ -68,8 +71,27 @@ def _distribute_and_fsdp_wrap_module(
return FSDP(module, process_group=pg, use_orig_params=use_orig_params)


def create_new_dist_group():
world_size = dist.get_world_size()
group1 = [i for i in range(world_size) if i % 2 == 0]
group2 = [i for i in range(world_size) if i % 2 != 0]

# create new fsdp group for resharding
fsdp_0 = dist.new_group(ranks=group1)
fsdp_1 = dist.new_group(ranks=group2)
if dist.get_rank() % 2 == 0:
my_fsdp = fsdp_0
else:
my_fsdp = fsdp_1

return my_fsdp


def init_model(
model_parallel_size=TP_DEGREE, use_orig_params=False, fsdp_nested=False
model_parallel_size=TP_DEGREE,
use_orig_params=False,
fsdp_nested=False,
fsdp_pg=None,
):
rank = dist.get_rank()
torch.cuda.set_device(rank)
Expand All @@ -83,7 +105,10 @@ def init_model(
mesh=torch.arange(0, world_size).view(model_parallel_size, -1),
)

fsdp_pg = twod_mesh.get_dim_groups()[0]
if not fsdp_pg:
fsdp_pg = twod_mesh.get_dim_groups()[0]
else:
fsdp_pg = create_new_dist_group()

# Create Input
model = _distribute_and_fsdp_wrap_module(
Expand All @@ -93,28 +118,28 @@ def init_model(
return model, fsdp_pg


class Test2dModelStateCheckpoint(DTensorTestBase):
@with_comms
@skip_if_lt_x_gpu(4)
@with_temp_dir
def test_2d_model_state_checkpoint(self) -> None:
class Test2dFsdpDtCheckpoint(DTensorTestBase):
def _test_fsdp_dt_checkpoint(self, fsdp_pg=None) -> None:
if not is_available():
self.skipTest("FSDP 2d parallel integration not available")

CHECKPOINT_DIR = self.temp_dir

model = init_model()[0]
optim = torch.optim.Adam(model.parameters(), lr=0.1)

# Create Input
input_seed = self.rank
torch.manual_seed(input_seed + 1)
input = torch.rand(4, 5).cuda(self.rank)

model(input).sum().backward()
optim.step()

with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
state_dict = {
"model": model.state_dict(),
"optim": FSDP.sharded_optim_state_dict(model, optim),
}

dist_cp.save_state_dict(
Expand All @@ -127,7 +152,8 @@ def test_2d_model_state_checkpoint(self) -> None:
),
)

model_2 = init_model()[0]
model_2 = init_model(fsdp_pg=fsdp_pg)[0]
optim_2 = torch.optim.Adam(model_2.parameters(), lr=0.1)

# Ensure the parameters are different before loading
with FSDP.summon_full_params(model):
Expand Down Expand Up @@ -157,6 +183,17 @@ def test_2d_model_state_checkpoint(self) -> None:
)
model_2.load_state_dict(state_dict["model"])

optim_state = load_sharded_optimizer_state_dict(
model_state_dict=state_dict["model"],
optimizer_key="optim",
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
)

flattened_osd = FSDP.flatten_sharded_optim_state_dict(
optim_state["optim"], model_2, optim_2
)
optim_2.load_state_dict(flattened_osd)

# Ensure the parameters are the same after loading
with FSDP.summon_full_params(model):
with FSDP.summon_full_params(model_2):
Expand All @@ -171,6 +208,29 @@ def test_2d_model_state_checkpoint(self) -> None:
else:
self.assertEqual(n_p1[1], n_p2[1])

def opt_at(opt, idx):
return list(opt.state.values())[idx]

# Adam lazily creates its state
self.assertEqual(
opt_at(optim, 0)["exp_avg"], opt_at(optim_2, 0)["exp_avg"]
)
self.assertEqual(
opt_at(optim, 0)["exp_avg_sq"], opt_at(optim_2, 0)["exp_avg_sq"]
)

@with_comms
@skip_if_lt_x_gpu(4)
@with_temp_dir
def test_2d_fsdp_dt_checkpoint_no_resharding(self) -> None:
self._test_fsdp_dt_checkpoint()

@with_comms
@skip_if_lt_x_gpu(4)
@with_temp_dir
def test_2d_fsdp_dt_checkpoint_resharding(self) -> None:
self._test_fsdp_dt_checkpoint(fsdp_pg=create_new_dist_group())


if __name__ == "__main__":
run_tests()
6 changes: 4 additions & 2 deletions torch/distributed/checkpoint/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,15 @@ def _get_state_dict_2d_layout(
assert (
len(value.local_shards()) == 1
), "Cannot handle ST with multiple shards"
assert isinstance(ShardedTensor, value)
assert isinstance(
value, ShardedTensor
), "Can only handle nested ShardedTensor"
shard = value.local_shards()[0]
specs[key] = (
shard.metadata.shard_offsets,
shard.metadata.shard_sizes,
)
dp_pg = shard.tensor._process_group
dp_pg = shard.tensor._process_group # type: ignore[attr-defined]

return (
specs,
Expand Down