Skip to content

Commit 17fd1ff

Browse files
committed
[WIP][dist_optim] introduce distributed functional optimizer
This PR introduces a distributed functional optimizer, so that distributed optimizer can reuse the functional optimizer APIs and maintain their own states. This could enable the torchscript compatible functional optimizer when using distributed optimizer, helps getting rid of GIL and improve overall performance of training, especially distributed model parallel training ghstack-source-id: 0c75d84 Pull Request resolved: #45221
1 parent e935d1e commit 17fd1ff

File tree

4 files changed

+197
-14
lines changed

4 files changed

+197
-14
lines changed

torch/distributed/optim/adagrad.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from typing import List, Dict, Optional
2+
import torch
3+
import torch.optim.functional as F
4+
5+
from torch import Tensor
6+
7+
# Define a TorchScript compatible Functional Adagrad Optimizer
8+
# where we use these optimizer in a functional way.
9+
# Instead of using the `param.grad` when updating parameters,
10+
# we explicitly let the user pass gradients to the `step` function
11+
# this is so that we could separate the gradients and parameters
12+
# and allow multithreaded trainer to update the parameters
13+
# without data traces on accumulating to the same .grad.
14+
# NOTE: This should be only used by distributed optimizer internals
15+
# and not meant to expose to the user.
16+
@torch.jit.script
17+
class FunctionalAdagrad(object):
18+
def __init__(
19+
self,
20+
params: List[Tensor],
21+
lr: float = 1e-2,
22+
lr_decay: float = 0.0,
23+
weight_decay: float = 0.0,
24+
initial_accumulator_value: float = 0.0,
25+
warmup_lr_multiplier: float = 1.0,
26+
warmup_num_iters: float = 0.0,
27+
eps: float = 1e-10,
28+
coalesce_grad: bool = True,
29+
):
30+
self.defaults = {
31+
"lr": lr,
32+
"lr_decay": lr_decay,
33+
"eps": eps,
34+
"weight_decay": weight_decay,
35+
"initial_accumulator_value": initial_accumulator_value,
36+
"warmup_lr_multiplier": warmup_lr_multiplier,
37+
"warmup_num_iters": warmup_num_iters,
38+
}
39+
self.coalesce_grad = coalesce_grad
40+
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
41+
42+
if len(params) == 0:
43+
raise ValueError("optimizer got an empty parameter list")
44+
45+
# NOTE: we only have one param_group and don't allow user to add additional
46+
# param group as it's not a common use case.
47+
self.param_group = {"params": params}
48+
49+
# TODO: no union or any types in TorchScript, make step a scalar tensor instead
50+
# This is also needed by if we want to share_memory on the step across processes
51+
for p in self.param_group["params"]:
52+
self.state[p] = {
53+
"sum": torch.full_like(p.data, initial_accumulator_value),
54+
"step": torch.tensor(0.0),
55+
}
56+
57+
def step(self, gradients: List[Optional[Tensor]]):
58+
params = self.param_group['params']
59+
params_with_grad = []
60+
grads = []
61+
state_sums = []
62+
state_steps: List[int] = []
63+
64+
if len(params) != len(gradients):
65+
raise ValueError(
66+
"the gradients passed in does not equal to the size of the parameters!"
67+
+ f"Params length: {len(params)}. "
68+
+ f"Gradients length: {len(gradients)}"
69+
)
70+
71+
for param, gradient in zip(self.param_group['params'], gradients):
72+
if gradient is not None:
73+
params_with_grad.append(param)
74+
grads.append(gradient)
75+
state = self.state[param]
76+
state_sums.append(state['sum'])
77+
# update the steps for each param group update
78+
state['step'] += 1
79+
# record the step after step update
80+
state_steps.append(state['step'].item())
81+
82+
with torch.no_grad():
83+
F.adagrad(params,
84+
grads,
85+
state_sums,
86+
state_steps,
87+
self.defaults['lr'],
88+
self.defaults['weight_decay'],
89+
self.defaults['lr_decay'],
90+
self.defaults['eps'])

torch/distributed/optim/optimizer.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
1+
from typing import List, Optional
2+
13
import torch.distributed.rpc as rpc
4+
import torch.optim as optim
5+
from .adagrad import FunctionalAdagrad
26
import torch.distributed.autograd as dist_autograd
37

8+
49
from collections import defaultdict
510
from threading import Lock
611

712

8-
class _LocalOptimizer:
13+
class _LocalOptimizer(object):
914
# Ideally we would only need to share a lock for instances of
1015
# _LocalOptimizer that deal with the same parameters. We are
1116
# making a simplifying assumption here that if there is more
@@ -14,20 +19,36 @@ class _LocalOptimizer:
1419
# trainer will create its own instance of _LocalOptimizer but
1520
# they will all optimize the same parameters on each worker)
1621
global_lock = Lock()
22+
functional_optim_map = {
23+
optim.Adagrad: FunctionalAdagrad,
24+
# torch.optim.Adam: torch.distributed.optim.Adam
25+
}
1726

1827
def __init__(self, optim_cls, local_params_rref, *args, **kwargs):
19-
self.optim = optim_cls(
20-
[rref.local_value() for rref in local_params_rref],
28+
optim_ctor = _LocalOptimizer.functional_optim_map.get(optim_cls, optim_cls)
29+
self.is_functional_optim = (optim_ctor != optim_cls)
30+
self._local_params = [rref.local_value() for rref in local_params_rref]
31+
self.optim = optim_ctor(
32+
self._local_params,
2133
*args,
2234
**kwargs)
2335

2436
def step(self, autograd_ctx_id):
2537
all_local_grads = dist_autograd.get_gradients(autograd_ctx_id)
2638

27-
with _LocalOptimizer.global_lock:
28-
for param, grad in all_local_grads.items():
29-
param.grad = grad
30-
self.optim.step()
39+
if self.is_functional_optim:
40+
# apply functional optimizer step with a list of gradients
41+
grads: List[Optional[torch.Tensor]] = [
42+
all_local_grads[p] if p in all_local_grads else None
43+
for p in self._local_params
44+
]
45+
46+
self.optim.step(grads)
47+
else:
48+
with _LocalOptimizer.global_lock:
49+
for param, grad in all_local_grads.items():
50+
param.grad = grad
51+
self.optim.step()
3152

3253

3354
def _new_local_optimizer(optim_cls, local_params_rref, *args, **kwargs):

torch/optim/functional.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,16 @@
66

77
# TODO: use foreach API in optim.functional to do all the computation
88

9+
def _make_sparse(grad, grad_indices, values):
10+
size = grad.size()
11+
if grad_indices.numel() == 0 or values.numel() == 0:
12+
return torch.empty_like(grad)
13+
return torch.sparse_coo_tensor(grad_indices, values, size)
14+
# constructor = grad.new
15+
# if grad_indices.dim() == 0 or values.dim() == 0:
16+
# return constructor().resize_as_(grad)
17+
# return constructor(grad_indices, values, size)
18+
919
def adagrad(params: List[Tensor],
1020
grads: List[Tensor],
1121
state_sums: List[Tensor],
@@ -33,15 +43,10 @@ def adagrad(params: List[Tensor],
3343
grad_values = grad._values()
3444
size = grad.size()
3545

36-
def make_sparse(values):
37-
constructor = grad.new
38-
if grad_indices.dim() == 0 or values.dim() == 0:
39-
return constructor().resize_as_(grad)
40-
return constructor(grad_indices, values, size)
41-
state_sum.add_(make_sparse(grad_values.pow(2)))
46+
state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2)))
4247
std = state_sum.sparse_mask(grad)
4348
std_values = std._values().sqrt_().add_(eps)
44-
param.add_(make_sparse(grad_values / std_values), alpha=-clr)
49+
param.add_(_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr)
4550
else:
4651
state_sum.addcmul_(grad, grad, value=1)
4752
std = state_sum.sqrt().add_(eps)

torch/testing/_internal/distributed/rpc/dist_optimizer_test.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,70 @@ def test_dist_optim(self):
198198
# ensure local equals remote
199199
self.assertEqual(new_w1, module1.get_w())
200200
self.assertEqual(new_w2, module2.get_w())
201+
202+
203+
@dist_init
204+
def test_dist_optim_functional(self):
205+
# local version
206+
module1 = MyModule()
207+
module2 = MyModule()
208+
params = [module1.get_w(), module2.get_w()]
209+
local_optim = optim.Adagrad(params, lr=0.05)
210+
211+
old_w1 = module1.w.clone().detach()
212+
old_w2 = module2.w.clone().detach()
213+
214+
g_cpu = torch.Generator()
215+
g_cpu.manual_seed(0)
216+
t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
217+
t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
218+
output1 = module1.forward(t2)
219+
output2 = module2.forward(output1)
220+
loss = torch.add(output2, t1).sum()
221+
222+
loss.backward()
223+
local_optim.step()
224+
225+
# distributed version
226+
owner1 = "worker%d" % ((self.rank + 1) % self.world_size)
227+
owner2 = "worker%d" % ((self.rank + 2) % self.world_size)
228+
229+
remote_module1 = rpc.remote(owner1, MyModule)
230+
remote_module2 = rpc.remote(owner2, MyModule)
231+
remote_param1 = remote_method(MyModule.get_w, remote_module1)
232+
remote_param2 = remote_method(MyModule.get_w, remote_module2)
233+
234+
old_w1_remote = remote_param1.to_here()
235+
236+
# sanity check: local and remote initial weights should match
237+
self.assertEqual(old_w1, remote_param1.to_here())
238+
self.assertEqual(old_w2, remote_param2.to_here())
239+
240+
dist_optim = DistributedOptimizer(
241+
optim.Adagrad, [remote_param1, remote_param2], lr=0.05
242+
)
243+
244+
with dist_autograd.context() as context_id:
245+
g_cpu.manual_seed(0)
246+
t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
247+
t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
248+
output1 = rpc_async_method(MyModule.forward, remote_module1, t2)
249+
output2 = rpc_async_method(MyModule.forward, remote_module2, output1.wait())
250+
loss = torch.add(output2.wait(), t1)
251+
252+
dist_autograd.backward(context_id, [loss.sum()])
253+
dist_optim.step(context_id)
254+
255+
new_w1 = rpc_async_method(MyModule.get_w, remote_module1).wait()
256+
new_w2 = rpc_async_method(MyModule.get_w, remote_module2).wait()
257+
print("old w1: ")
258+
print(old_w1)
259+
print("new w1: ")
260+
print(new_w1)
261+
262+
# ensure optimizer changed weights
263+
self.assertNotEqual(old_w1, new_w1)
264+
self.assertNotEqual(old_w2, new_w2)
265+
# ensure local equals remote
266+
self.assertEqual(new_w1, module1.get_w())
267+
self.assertEqual(new_w2, module2.get_w())

0 commit comments

Comments
 (0)