Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions torch/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
__all__ = ['Optimizer', 'register_optimizer_step_pre_hook', 'register_optimizer_step_post_hook']
_global_optimizer_pre_hooks: Dict[int, Callable] = OrderedDict()
_global_optimizer_post_hooks: Dict[int, Callable] = OrderedDict()
_foreach_supported_types = [torch.Tensor, torch.nn.parameter.Parameter]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: make it a tuple instead so it's not mutated accidentally mutated?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually intentional! This PR was a realization that the defaulting was way too conservative and did not allow models to be included for faster foreach. We want to allow users to be able to add their own subclasses here if they know that foreach will not break on those tensor subclasses.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@janeyx99 What is the recommended way to add to this list? Should we directly append like the following in trainer code:

torch.optim.optimizer._foreach_supported_types.append(...)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@janeyx99 Have the same qq @awgu raised. Do you know how users are adding their own subclasses to the list?
Due to circular dependency, instead of adding DTensor here, I am just trying out what Andrew suggested first. Seems appending tensor subclass to the list has no effect at later stage so I am wondering how users are doing this if you are aware.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may be the first user to try this--what Andrew suggests seems okay for now barring the fact that it's a private API. Once we want to productionize this, we should work together to make this an exposed public API.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sounds like appending is not working?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sounds like appending is not working?

My bad! False information. Just double check append seems working fine on the unit test locally!


class _RequiredParameter:
"""Singleton class representing a required parameter for an Optimizer."""
Expand Down Expand Up @@ -69,10 +70,10 @@ def _default_to_fused_or_foreach(tensorlists: List[List[torch.Tensor]],
for tensorlist in tensorlists:
all_tensors.extend(tensorlist)
fused = use_fused and all(
p is None or (type(p) == torch.Tensor and p.is_cuda and torch.is_floating_point(p)) for p in all_tensors
p is None or (type(p) in _foreach_supported_types and p.is_cuda and torch.is_floating_point(p)) for p in all_tensors
)
foreach = not fused and all(
p is None or (type(p) == torch.Tensor and p.is_cuda) for p in all_tensors
p is None or (type(p) in _foreach_supported_types and p.is_cuda) for p in all_tensors
)
return fused, foreach

Expand Down