Skip to content

Commit 2bcf863

Browse files
janeyx99pytorchmergebot
authored andcommitted
[optim] include nn.Parameter as foreach supported (#95811)
This PR is a result of a realization that models are NOT subscribed to the foreach defaulting as have been claimed on our documentation for months now. BIG OOPS. Pull Request resolved: #95811 Approved by: https://github.com/albanD
1 parent 45fd1b3 commit 2bcf863

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

torch/optim/optimizer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
__all__ = ['Optimizer', 'register_optimizer_step_pre_hook', 'register_optimizer_step_post_hook']
1616
_global_optimizer_pre_hooks: Dict[int, Callable] = OrderedDict()
1717
_global_optimizer_post_hooks: Dict[int, Callable] = OrderedDict()
18+
_foreach_supported_types = [torch.Tensor, torch.nn.parameter.Parameter]
1819

1920
class _RequiredParameter:
2021
"""Singleton class representing a required parameter for an Optimizer."""
@@ -69,10 +70,10 @@ def _default_to_fused_or_foreach(tensorlists: List[List[torch.Tensor]],
6970
for tensorlist in tensorlists:
7071
all_tensors.extend(tensorlist)
7172
fused = use_fused and all(
72-
p is None or (type(p) == torch.Tensor and p.is_cuda and torch.is_floating_point(p)) for p in all_tensors
73+
p is None or (type(p) in _foreach_supported_types and p.is_cuda and torch.is_floating_point(p)) for p in all_tensors
7374
)
7475
foreach = not fused and all(
75-
p is None or (type(p) == torch.Tensor and p.is_cuda) for p in all_tensors
76+
p is None or (type(p) in _foreach_supported_types and p.is_cuda) for p in all_tensors
7677
)
7778
return fused, foreach
7879

0 commit comments

Comments
 (0)