Skip to content

[optim] prevent problematic tensor aliasing in lr_scheduler#163098

Closed
filipviz wants to merge 4 commits intopytorch:mainfrom
filipviz:lr-scheduler-patches
Closed

[optim] prevent problematic tensor aliasing in lr_scheduler#163098
filipviz wants to merge 4 commits intopytorch:mainfrom
filipviz:lr-scheduler-patches

Conversation

@filipviz
Copy link
Contributor

Prevents edge cases in SequentialLR and ReduceLROnPlateau which could corrupt learning rates or trigger recompilation.

Supersedes #162360
Fixes #162359
Fixes #163093

While putting #162360 together, I noticed the class of issue I was fixing (i.e. unintended aliasing in lr_schedulers when using Tensor lrs) appeared in several other places. @janeyx99 suggested I put together a follow-up PR.

There are several bugs resembling the one fixed in #162360. I added a helper to fix these:

def _update_param_group_val(param_group: dict[str, Any], key: str, val: float | Tensor):
    """Set param_group[key] to val without aliasing or assignment when they're both tensors.
    Raises a KeyError if param_group[key] does not exist.
    """
    if isinstance(param_group[key], Tensor):
        param_group[key].fill_(_to_scalar(val))
    else:
        param_group[key] = val

And applied it to fix bugs in SequentialLR.__init__ and LRScheduler._update_lr. I also added it to CyclicLR.__init__ which was using an equivalent pattern, and CosineAnnealingWarmRestarts.step which should have had a similar issue:

for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
    param_group["lr"] = lr

But did not, because get_lr() actually returns tensors when using a tensor lr (despite its list[float] return type annotation). Relying on this propagation seems fragile, so I conservatively added the method here as well. I'll be fixing the type annotations and several related issues in followup PRs built off of this one.

Prevents edge cases in SequentialLR and ReduceLROnPlateau which could
corrupt learning rates or trigger recompilation.

Supersedes pytorch#162360
Fixes pytorch#162359
Fixes pytorch#163093
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 16, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163098

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 9f60631 with merge base e900a27 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@janeyx99
Copy link
Contributor

@pytorchbot merge

thanks for your detailed work!

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 16, 2025
@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 16, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Add a helper to address several lr_scheduler aliasing issues
Fix inaccurate type annotations throughout lr_scheduler.py

Fixes pytorch#163103
@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Sep 16, 2025
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: New commits were pushed while merging. Please rerun the merge command.

Details for Dev Infra team Raised by workflow job

@filipviz
Copy link
Contributor Author

filipviz commented Sep 16, 2025

Sorry! I accidentally pushed 39c342e (which was supposed to be a follow-up PR) to this branch. Let me know how you'd like me to handle this @janeyx99. That commit fixes a number of type annotations and an issue where LRScheduler._update_lr, ChainedScheduler.__init__, ChainedScheduler.step, ReduceLROnPlateau.__init__, ReduceLROOnPlateau.step, and CosineAnnealingWarmRestarts.step previously aliased self._last_lr with group["lr"] if group["lr"] was a tensor, due to patterns like this:

self._last_lr: list[float] = [group["lr"] for group in self.optimizer.param_groups]

It adds another helper to address this:

def _param_groups_val_list(optimizer: Optimizer, key: str) -> list[Any]:
    """Create a list containing group[key] for each optimizer param_group.
    Prevents aliasing when group[key] could be a Tensor.
    Raises a KeyError when group[key] does not exist.
    """
    return [
        group[key].clone() if isinstance(group[key], Tensor) else group[key]
        for group in optimizer.param_groups
    ]

The same helper fixes a bug where LRScheduler.__init__ makes group["initial_lr"] a Tensor if group["lr"] is one, meaning that self.base_lrs alias with the group["initial_lr"]s.

filipviz added a commit to filipviz/pytorch that referenced this pull request Sep 16, 2025
Also adds a test for unintended tensor aliasing in LR schedulers

Depends on pytorch#163098
Fixes pytorch#163105
@janeyx99
Copy link
Contributor

Let's undo that change and land the original first!

I can review the followup PR separately. It's good to keep changes minimal per PR.

@filipviz
Copy link
Contributor Author

Same page @janeyx99. Reverted that commit but added test_reduce_lr_on_plateau_preserves_lr_type back in (since this PR fixes that bug). Let me know if you'd like any other changes.

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commit also fixes CosineAnnealingWarmRestarts, yes?

@janeyx99
Copy link
Contributor

@pytorchbot merge

no need for other changes as long as CI is green!

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 16, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@janeyx99
Copy link
Contributor

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
…163098)

Prevents edge cases in SequentialLR and ReduceLROnPlateau which could corrupt learning rates or trigger recompilation.

Supersedes pytorch#162360
Fixes pytorch#162359
Fixes pytorch#163093

While putting pytorch#162360 together, I noticed the class of issue I was fixing (i.e. unintended aliasing in lr_schedulers when using Tensor lrs) appeared in several other places. @janeyx99 suggested I put together a follow-up PR.

There are several bugs resembling the one fixed in pytorch#162360. I added a helper to fix these:
```python
def _update_param_group_val(param_group: dict[str, Any], key: str, val: float | Tensor):
    """Set param_group[key] to val without aliasing or assignment when they're both tensors.
    Raises a KeyError if param_group[key] does not exist.
    """
    if isinstance(param_group[key], Tensor):
        param_group[key].fill_(_to_scalar(val))
    else:
        param_group[key] = val
```

And applied it to fix bugs in `SequentialLR.__init__` and `LRScheduler._update_lr`. I also added it to `CyclicLR.__init__` which was using an equivalent pattern, and `CosineAnnealingWarmRestarts.step` which *should* have had a similar issue:
```python
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
    param_group["lr"] = lr
```

But did not, because `get_lr()` actually returns tensors when using a tensor lr (despite its `list[float]` return type annotation). Relying on this propagation seems fragile, so I conservatively added the method here as well. I'll be fixing the type annotations and several related issues in followup PRs built off of this one.
Pull Request resolved: pytorch#163098
Approved by: https://github.com/janeyx99
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
…163098)

Prevents edge cases in SequentialLR and ReduceLROnPlateau which could corrupt learning rates or trigger recompilation.

Supersedes pytorch#162360
Fixes pytorch#162359
Fixes pytorch#163093

While putting pytorch#162360 together, I noticed the class of issue I was fixing (i.e. unintended aliasing in lr_schedulers when using Tensor lrs) appeared in several other places. @janeyx99 suggested I put together a follow-up PR.

There are several bugs resembling the one fixed in pytorch#162360. I added a helper to fix these:
```python
def _update_param_group_val(param_group: dict[str, Any], key: str, val: float | Tensor):
    """Set param_group[key] to val without aliasing or assignment when they're both tensors.
    Raises a KeyError if param_group[key] does not exist.
    """
    if isinstance(param_group[key], Tensor):
        param_group[key].fill_(_to_scalar(val))
    else:
        param_group[key] = val
```

And applied it to fix bugs in `SequentialLR.__init__` and `LRScheduler._update_lr`. I also added it to `CyclicLR.__init__` which was using an equivalent pattern, and `CosineAnnealingWarmRestarts.step` which *should* have had a similar issue:
```python
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
    param_group["lr"] = lr
```

But did not, because `get_lr()` actually returns tensors when using a tensor lr (despite its `list[float]` return type annotation). Relying on this propagation seems fragile, so I conservatively added the method here as well. I'll be fixing the type annotations and several related issues in followup PRs built off of this one.
Pull Request resolved: pytorch#163098
Approved by: https://github.com/janeyx99
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
…163098)

Prevents edge cases in SequentialLR and ReduceLROnPlateau which could corrupt learning rates or trigger recompilation.

Supersedes pytorch#162360
Fixes pytorch#162359
Fixes pytorch#163093

While putting pytorch#162360 together, I noticed the class of issue I was fixing (i.e. unintended aliasing in lr_schedulers when using Tensor lrs) appeared in several other places. @janeyx99 suggested I put together a follow-up PR.

There are several bugs resembling the one fixed in pytorch#162360. I added a helper to fix these:
```python
def _update_param_group_val(param_group: dict[str, Any], key: str, val: float | Tensor):
    """Set param_group[key] to val without aliasing or assignment when they're both tensors.
    Raises a KeyError if param_group[key] does not exist.
    """
    if isinstance(param_group[key], Tensor):
        param_group[key].fill_(_to_scalar(val))
    else:
        param_group[key] = val
```

And applied it to fix bugs in `SequentialLR.__init__` and `LRScheduler._update_lr`. I also added it to `CyclicLR.__init__` which was using an equivalent pattern, and `CosineAnnealingWarmRestarts.step` which *should* have had a similar issue:
```python
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
    param_group["lr"] = lr
```

But did not, because `get_lr()` actually returns tensors when using a tensor lr (despite its `list[float]` return type annotation). Relying on this propagation seems fragile, so I conservatively added the method here as well. I'll be fixing the type annotations and several related issues in followup PRs built off of this one.
Pull Request resolved: pytorch#163098
Approved by: https://github.com/janeyx99
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

4 participants