Skip to content

Commit b8b7480

Browse files
wz337pytorchmergebot
authored andcommitted
[Checkpoint][2D][6/N] Add optimizer and update default_planner to core distributed (#90212)
This is the last PR for integrating 2D into core distributed. This PR does the following: 1. Add optimizer.py: this adds ability to load a state_dict in conjunction with FSDP sharded optimzer state. 2. Update default_planner.py to support 2D checkpoint. 3. Add test_fsdp_optim_state.py as a unit test for No. 1. 4. Fix bug in torch/testing/_internal/distributed/checkpoint_utils.py 5. Rename the filename for the APIs that should be private. Will organize and cleanup further in following PRs. #90328 Docstring and integration test will be added in the following PRs. Pull Request resolved: #90212 Approved by: https://github.com/wanchaol
1 parent 36ac095 commit b8b7480

File tree

11 files changed

+663
-39
lines changed

11 files changed

+663
-39
lines changed

test/distributed/checkpoint/test_dedup_tensors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import dataclasses
44
import torch
5-
from torch.distributed.checkpoint.dedup_tensors import dedup_tensors
5+
from torch.distributed.checkpoint._dedup_tensors import dedup_tensors
66
from torch.distributed.checkpoint.planner import SavePlan, WriteItemType
77
from torch.distributed.checkpoint.planner_helpers import (
88
_create_write_item_for_tensor,
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Owner(s): ["oncall: distributed"]
2+
3+
import torch
4+
5+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
6+
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
7+
import torch.distributed.checkpoint as dist_cp
8+
import torch.distributed as dist
9+
10+
from torch.distributed.checkpoint.default_planner import (
11+
DefaultSavePlanner,
12+
DefaultLoadPlanner,
13+
)
14+
from torch.distributed.checkpoint.optimizer import (
15+
load_sharded_optimizer_state_dict,
16+
)
17+
18+
from torch.testing._internal.distributed._tensor.common_dtensor import (
19+
DTensorTestBase,
20+
with_comms,
21+
)
22+
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
23+
from torch.testing._internal.common_utils import run_tests
24+
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
25+
26+
27+
class FsdpOptimStateCheckpoint(DTensorTestBase):
28+
@with_comms
29+
@skip_if_lt_x_gpu(4)
30+
@with_temp_dir
31+
def test_distributed_tensor_planner(self) -> None:
32+
CHECKPOINT_DIR = self.temp_dir
33+
34+
model = FSDP(torch.nn.Linear(8, 8, device="meta"))
35+
optim = torch.optim.Adam(model.parameters(), lr=0.1)
36+
37+
model(torch.rand(8, 8, device=dist.get_rank())).sum().backward()
38+
optim.step()
39+
40+
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
41+
state_dict = {
42+
"model": model.state_dict(),
43+
"optim": FSDP.sharded_optim_state_dict(model, optim),
44+
}
45+
46+
dist_cp.save_state_dict(
47+
state_dict=state_dict,
48+
storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
49+
planner=DefaultSavePlanner(
50+
flatten_state_dict=True,
51+
flatten_sharded_tensors=True,
52+
),
53+
)
54+
55+
# now load the model and ensure the values are the same
56+
model_2 = FSDP(torch.nn.Linear(8, 8, device="meta"))
57+
optim_2 = torch.optim.Adam(model_2.parameters(), lr=0.1)
58+
59+
with FSDP.summon_full_params(model):
60+
with FSDP.summon_full_params(model_2):
61+
self.assertNotEqual(model.weight, model_2.weight)
62+
self.assertNotEqual(model.bias, model_2.bias)
63+
64+
# Adam lazily creates its state
65+
self.assertEqual(0, len(optim_2.state))
66+
67+
with FSDP.state_dict_type(model_2, StateDictType.SHARDED_STATE_DICT):
68+
state_dict = {
69+
"model": model_2.state_dict(),
70+
# cannot load the optimizer together with the model
71+
}
72+
73+
dist_cp.load_state_dict(
74+
state_dict=state_dict,
75+
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
76+
planner=DefaultLoadPlanner(
77+
flatten_state_dict=True,
78+
flatten_sharded_tensors=True,
79+
),
80+
)
81+
model_2.load_state_dict(state_dict["model"])
82+
83+
optim_state = load_sharded_optimizer_state_dict(
84+
model_state_dict=state_dict["model"],
85+
optimizer_key="optim",
86+
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
87+
)
88+
89+
flattened_osd = FSDP.flatten_sharded_optim_state_dict(
90+
optim_state["optim"], model_2
91+
)
92+
optim_2.load_state_dict(flattened_osd)
93+
94+
with FSDP.summon_full_params(model):
95+
with FSDP.summon_full_params(model_2):
96+
self.assertEqual(model.weight, model_2.weight)
97+
self.assertEqual(model.bias, model_2.bias)
98+
99+
def opt_at(opt, idx):
100+
return list(iter(opt.state.values()))[idx]
101+
102+
# Adam lazily creates its state
103+
self.assertEqual(
104+
opt_at(optim, 0)["exp_avg"], opt_at(optim_2, 0)["exp_avg"]
105+
)
106+
self.assertEqual(
107+
opt_at(optim, 0)["exp_avg_sq"], opt_at(optim_2, 0)["exp_avg_sq"]
108+
)
109+
110+
111+
if __name__ == "__main__":
112+
run_tests()

test/distributed/checkpoint/test_nested_dict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44
from torch.testing._internal.common_utils import run_tests, TestCase
5-
from torch.distributed.checkpoint.nested_dict import (
5+
from torch.distributed.checkpoint._nested_dict import (
66
flatten_state_dict,
77
unflatten_state_dict,
88
)

test/distributed/checkpoint/test_traverse.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections import OrderedDict
44
import torch
55

6-
import torch.distributed.checkpoint.traverse as traverse
6+
import torch.distributed.checkpoint._traverse as _traverse
77
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
88
from torch.testing._internal.common_utils import run_tests, TestCase
99

@@ -24,7 +24,7 @@ def collect_data(path, value):
2424
nonlocal data
2525
data[path] = value
2626

27-
traverse.traverse_state_dict(state_dict, collect_data)
27+
_traverse.traverse_state_dict(state_dict, collect_data)
2828

2929
self.assertIn(("key0",), data)
3030
self.assertEqual(data[("key0",)], 1)
@@ -53,7 +53,7 @@ def collect_data(path, value):
5353
nonlocal data
5454
data[path] = value
5555

56-
traverse.traverse_state_dict(state_dict, collect_data)
56+
_traverse.traverse_state_dict(state_dict, collect_data)
5757

5858
self.assertNotIn(("key1"), data)
5959

@@ -84,7 +84,7 @@ def collect_data(path, value):
8484
nonlocal data
8585
data[path] = value
8686

87-
traverse.traverse_state_dict(state_dict, collect_data)
87+
_traverse.traverse_state_dict(state_dict, collect_data)
8888

8989
self.assertNotIn(("key0",), data)
9090

@@ -105,7 +105,7 @@ def collect_data(path, value):
105105
nonlocal data
106106
data[path] = value
107107

108-
traverse.traverse_state_dict(state_dict, collect_data)
108+
_traverse.traverse_state_dict(state_dict, collect_data)
109109

110110
self.assertIn(("key0", 0, "key1", "key2"), data)
111111
self.assertEqual(
@@ -129,7 +129,7 @@ def collect_data(path, value):
129129
nonlocal data
130130
data[path] = value
131131

132-
traverse.traverse_state_dict(state_dict, collect_data)
132+
_traverse.traverse_state_dict(state_dict, collect_data)
133133

134134
self.assertIn(("key0", 0), data)
135135
self.assertEqual(data[("key0", 0)], 99)
@@ -140,36 +140,36 @@ def collect_data(path, value):
140140
def test_set_element(self) -> None:
141141
state_dict: STATE_DICT_TYPE = {}
142142

143-
traverse.set_element(state_dict, ("k",), 10)
143+
_traverse.set_element(state_dict, ("k",), 10)
144144
self.assertEqual(state_dict["k"], 10)
145145

146-
traverse.set_element(state_dict, ("k1", 2), 1)
146+
_traverse.set_element(state_dict, ("k1", 2), 1)
147147
self.assertEqual(state_dict["k1"], [None, None, 1])
148148

149-
traverse.set_element(state_dict, ("k1", 1), 99)
149+
_traverse.set_element(state_dict, ("k1", 1), 99)
150150
self.assertEqual(state_dict["k1"], [None, 99, 1])
151151

152-
traverse.set_element(state_dict, ("k1", 3), 88)
152+
_traverse.set_element(state_dict, ("k1", 3), 88)
153153
self.assertEqual(state_dict["k1"], [None, 99, 1, 88])
154154

155-
traverse.set_element(state_dict, ("k2", "k3"), 3)
155+
_traverse.set_element(state_dict, ("k2", "k3"), 3)
156156
self.assertEqual(state_dict["k2"], {"k3": 3})
157157

158-
traverse.set_element(state_dict, ("k2", "k4", 0, 0), 99)
158+
_traverse.set_element(state_dict, ("k2", "k4", 0, 0), 99)
159159
self.assertEqual(state_dict["k2"]["k4"][0], [99])
160160

161161
def test_get_element(self) -> None:
162162
state_dict = {"a": [0, 1], "b": [2, {"c": "d"}]}
163-
self.assertEqual(traverse.get_element(state_dict, ("a",)), [0, 1])
164-
self.assertEqual(traverse.get_element(state_dict, ("b", 0)), 2)
165-
self.assertEqual(traverse.get_element(state_dict, ("b", 1, "c")), "d")
166-
167-
self.assertIsNone(traverse.get_element(state_dict, ("c",)))
168-
self.assertIsNone(traverse.get_element(state_dict, ("a", 33)))
169-
self.assertIsNone(traverse.get_element(state_dict, ("b", 88)))
170-
self.assertIsNone(traverse.get_element(state_dict, ("b", 0, 2)))
171-
self.assertIsNone(traverse.get_element(state_dict, ("b", 1, 2)))
172-
self.assertIsNone(traverse.get_element(state_dict, ("b", 1, "d")))
163+
self.assertEqual(_traverse.get_element(state_dict, ("a",)), [0, 1])
164+
self.assertEqual(_traverse.get_element(state_dict, ("b", 0)), 2)
165+
self.assertEqual(_traverse.get_element(state_dict, ("b", 1, "c")), "d")
166+
167+
self.assertIsNone(_traverse.get_element(state_dict, ("c",)))
168+
self.assertIsNone(_traverse.get_element(state_dict, ("a", 33)))
169+
self.assertIsNone(_traverse.get_element(state_dict, ("b", 88)))
170+
self.assertIsNone(_traverse.get_element(state_dict, ("b", 0, 2)))
171+
self.assertIsNone(_traverse.get_element(state_dict, ("b", 1, 2)))
172+
self.assertIsNone(_traverse.get_element(state_dict, ("b", 1, "d")))
173173

174174

175175
if __name__ == "__main__":
File renamed without changes.

torch/distributed/checkpoint/nested_dict.py renamed to torch/distributed/checkpoint/_nested_dict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
STATE_DICT_TYPE,
66
)
77

8-
from .traverse import (
8+
from ._traverse import (
99
traverse_state_dict,
1010
set_element,
1111
OBJ_PATH,

torch/distributed/checkpoint/nested_tensor.py renamed to torch/distributed/checkpoint/_nested_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
)
2020

2121

22-
from .traverse import (
22+
from ._traverse import (
2323
OBJ_PATH,
2424
traverse_state_dict,
2525
set_element,

0 commit comments

Comments
 (0)