Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions test/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,12 +277,25 @@ def test_adam(self):
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3)
)
self._test_basic_cases(
lambda weight, bias: optim.Adam([weight, bias], lr=1e-3,
amsgrad=True)
)
self._test_basic_cases(
lambda weight, bias: optim.Adam(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3, amsgrad=True)
)

def test_sparse_adam(self):
self._test_rosenbrock_sparse(
lambda params: optim.SparseAdam(params, lr=4e-2),
True
)
self._test_rosenbrock_sparse(
lambda params: optim.SparseAdam(params, lr=4e-2, amsgrad=True),
True
)

def test_adadelta(self):
self._test_rosenbrock(
Expand Down
1 change: 1 addition & 0 deletions torch/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
del adadelta
del adagrad
del adam
del sparse_adam
del adamax
del asgd
del sgd
Expand Down
23 changes: 19 additions & 4 deletions torch/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,19 @@ class Adam(Optimizer):
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_

.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""

def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0):
weight_decay=0, amsgrad=False):
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay)
weight_decay=weight_decay, amsgrad=amsgrad)
super(Adam, self).__init__(params, defaults)

def step(self, closure=None):
Expand All @@ -46,6 +50,7 @@ def step(self, closure=None):
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
amsgrad = group['amsgrad']

state = self.state[p]

Expand All @@ -56,8 +61,13 @@ def step(self, closure=None):
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p.data)

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsgrad:
max_exp_avg_sq = state['max_exp_avg_sq']
beta1, beta2 = group['betas']

state['step'] += 1
Expand All @@ -68,8 +78,13 @@ def step(self, closure=None):
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)

denom = exp_avg_sq.sqrt().add_(group['eps'])
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
else:
denom = exp_avg_sq.sqrt().add_(group['eps'])

bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
Expand Down
25 changes: 22 additions & 3 deletions torch/optim/sparse_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,18 @@ class SparseAdam(Optimizer):
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_

.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""

def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8):
defaults = dict(lr=lr, betas=betas, eps=eps)
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
amsgrad=False):
defaults = dict(lr=lr, betas=betas, eps=eps, amsgrad=amsgrad)
super(SparseAdam, self).__init__(params, defaults)

def step(self, closure=None):
Expand All @@ -44,6 +49,7 @@ def step(self, closure=None):
grad = p.grad.data
if not grad.is_sparse:
raise RuntimeError('SparseAdam does not support dense gradients, please consider Adam instead')
amsgrad = group['amsgrad']

state = self.state[p]

Expand All @@ -54,6 +60,9 @@ def step(self, closure=None):
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p.data)

state['step'] += 1

Expand All @@ -69,6 +78,9 @@ def make_sparse(values):
return constructor(grad_indices, values, size)

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsgrad:
max_exp_avg_sq = state['max_exp_avg_sq']
old_max_exp_avg_sq_values = max_exp_avg_sq._sparse_mask(grad)._values()
beta1, beta2 = group['betas']

# Decay the first and second moment running average coefficient
Expand All @@ -83,7 +95,14 @@ def make_sparse(values):

# Dense addition again is intended, avoiding another _sparse_mask
numer = exp_avg_update_values.add_(old_exp_avg_values)
denom = exp_avg_sq_update_values.add_(old_exp_avg_sq_values).sqrt_().add_(group['eps'])
exp_avg_sq_update_values.add_(old_exp_avg_sq_values)
if amsgrad:
torch.max(old_max_exp_avg_sq_values, exp_avg_sq_update_values, out=old_max_exp_avg_sq_values)
denom = old_max_exp_avg_sq_values.sqrt_().add_(group['eps'])
max_exp_avg_sq = make_sparse(old_max_exp_avg_sq_values)

This comment was marked as off-topic.

This comment was marked as off-topic.

del old_max_exp_avg_sq_values
else:
denom = exp_avg_sq_update_values.sqrt_().add_(group['eps'])
del exp_avg_update_values, exp_avg_sq_update_values

bias_correction1 = 1 - beta1 ** state['step']
Expand Down