Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions torch/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,13 +316,16 @@ def adam(params: List[Tensor],
# and when differentiable=False.
# We still respect when the user inputs False for fused.
if fused is None:
if not differentiable and all(
p.is_cuda and torch.is_floating_point(p)
for p in params + grads + exp_avgs + exp_avg_sqs + max_exp_avg_sqs + state_steps
):
fused = True
else:
fused = False
all_tensors = []
all_tensors.extend(params)
all_tensors.extend(grads)
all_tensors.extend(exp_avgs)
all_tensors.extend(exp_avg_sqs)
all_tensors.extend(max_exp_avg_sqs)
all_tensors.extend(state_steps)
Comment on lines +319 to +325
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't you refactor into a nice single itertools.chain, that way you don't need any temporary list at all.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I started with that, but it does not play well with JIT bindings: #91896 (comment)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, which is annoying because jit bindings shouldn't be necessary since this path isn't part possible under JIT.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The seven lines can be changed into just 1 line, no?
all_tensors = [*params, *grads, *exp_avgs, *exp_avg_sqs, *max_exp_avg_sqs, *state_steps]

fused = not torch.jit.is_scripting() and not differentiable and all(
p.is_cuda and torch.is_floating_point(p) for p in all_tensors
)

if not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors")
Expand Down