[optim] prevent problematic tensor aliasing in lr_scheduler#163098
[optim] prevent problematic tensor aliasing in lr_scheduler#163098filipviz wants to merge 4 commits intopytorch:mainfrom
Conversation
Prevents edge cases in SequentialLR and ReduceLROnPlateau which could corrupt learning rates or trigger recompilation. Supersedes pytorch#162360 Fixes pytorch#162359 Fixes pytorch#163093
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163098
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 9f60631 with merge base e900a27 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot merge thanks for your detailed work! |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Add a helper to address several lr_scheduler aliasing issues Fix inaccurate type annotations throughout lr_scheduler.py Fixes pytorch#163103
Merge failedReason: New commits were pushed while merging. Please rerun the merge command. Details for Dev Infra teamRaised by workflow job |
|
Sorry! I accidentally pushed 39c342e (which was supposed to be a follow-up PR) to this branch. Let me know how you'd like me to handle this @janeyx99. That commit fixes a number of type annotations and an issue where self._last_lr: list[float] = [group["lr"] for group in self.optimizer.param_groups]It adds another helper to address this: def _param_groups_val_list(optimizer: Optimizer, key: str) -> list[Any]:
"""Create a list containing group[key] for each optimizer param_group.
Prevents aliasing when group[key] could be a Tensor.
Raises a KeyError when group[key] does not exist.
"""
return [
group[key].clone() if isinstance(group[key], Tensor) else group[key]
for group in optimizer.param_groups
]The same helper fixes a bug where |
Also adds a test for unintended tensor aliasing in LR schedulers Depends on pytorch#163098 Fixes pytorch#163105
|
Let's undo that change and land the original first! I can review the followup PR separately. It's good to keep changes minimal per PR. |
|
Same page @janeyx99. Reverted that commit but added |
janeyx99
left a comment
There was a problem hiding this comment.
This commit also fixes CosineAnnealingWarmRestarts, yes?
|
@pytorchbot merge no need for other changes as long as CI is green! |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…163098) Prevents edge cases in SequentialLR and ReduceLROnPlateau which could corrupt learning rates or trigger recompilation. Supersedes pytorch#162360 Fixes pytorch#162359 Fixes pytorch#163093 While putting pytorch#162360 together, I noticed the class of issue I was fixing (i.e. unintended aliasing in lr_schedulers when using Tensor lrs) appeared in several other places. @janeyx99 suggested I put together a follow-up PR. There are several bugs resembling the one fixed in pytorch#162360. I added a helper to fix these: ```python def _update_param_group_val(param_group: dict[str, Any], key: str, val: float | Tensor): """Set param_group[key] to val without aliasing or assignment when they're both tensors. Raises a KeyError if param_group[key] does not exist. """ if isinstance(param_group[key], Tensor): param_group[key].fill_(_to_scalar(val)) else: param_group[key] = val ``` And applied it to fix bugs in `SequentialLR.__init__` and `LRScheduler._update_lr`. I also added it to `CyclicLR.__init__` which was using an equivalent pattern, and `CosineAnnealingWarmRestarts.step` which *should* have had a similar issue: ```python for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): param_group["lr"] = lr ``` But did not, because `get_lr()` actually returns tensors when using a tensor lr (despite its `list[float]` return type annotation). Relying on this propagation seems fragile, so I conservatively added the method here as well. I'll be fixing the type annotations and several related issues in followup PRs built off of this one. Pull Request resolved: pytorch#163098 Approved by: https://github.com/janeyx99
…163098) Prevents edge cases in SequentialLR and ReduceLROnPlateau which could corrupt learning rates or trigger recompilation. Supersedes pytorch#162360 Fixes pytorch#162359 Fixes pytorch#163093 While putting pytorch#162360 together, I noticed the class of issue I was fixing (i.e. unintended aliasing in lr_schedulers when using Tensor lrs) appeared in several other places. @janeyx99 suggested I put together a follow-up PR. There are several bugs resembling the one fixed in pytorch#162360. I added a helper to fix these: ```python def _update_param_group_val(param_group: dict[str, Any], key: str, val: float | Tensor): """Set param_group[key] to val without aliasing or assignment when they're both tensors. Raises a KeyError if param_group[key] does not exist. """ if isinstance(param_group[key], Tensor): param_group[key].fill_(_to_scalar(val)) else: param_group[key] = val ``` And applied it to fix bugs in `SequentialLR.__init__` and `LRScheduler._update_lr`. I also added it to `CyclicLR.__init__` which was using an equivalent pattern, and `CosineAnnealingWarmRestarts.step` which *should* have had a similar issue: ```python for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): param_group["lr"] = lr ``` But did not, because `get_lr()` actually returns tensors when using a tensor lr (despite its `list[float]` return type annotation). Relying on this propagation seems fragile, so I conservatively added the method here as well. I'll be fixing the type annotations and several related issues in followup PRs built off of this one. Pull Request resolved: pytorch#163098 Approved by: https://github.com/janeyx99
…163098) Prevents edge cases in SequentialLR and ReduceLROnPlateau which could corrupt learning rates or trigger recompilation. Supersedes pytorch#162360 Fixes pytorch#162359 Fixes pytorch#163093 While putting pytorch#162360 together, I noticed the class of issue I was fixing (i.e. unintended aliasing in lr_schedulers when using Tensor lrs) appeared in several other places. @janeyx99 suggested I put together a follow-up PR. There are several bugs resembling the one fixed in pytorch#162360. I added a helper to fix these: ```python def _update_param_group_val(param_group: dict[str, Any], key: str, val: float | Tensor): """Set param_group[key] to val without aliasing or assignment when they're both tensors. Raises a KeyError if param_group[key] does not exist. """ if isinstance(param_group[key], Tensor): param_group[key].fill_(_to_scalar(val)) else: param_group[key] = val ``` And applied it to fix bugs in `SequentialLR.__init__` and `LRScheduler._update_lr`. I also added it to `CyclicLR.__init__` which was using an equivalent pattern, and `CosineAnnealingWarmRestarts.step` which *should* have had a similar issue: ```python for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): param_group["lr"] = lr ``` But did not, because `get_lr()` actually returns tensors when using a tensor lr (despite its `list[float]` return type annotation). Relying on this propagation seems fragile, so I conservatively added the method here as well. I'll be fixing the type annotations and several related issues in followup PRs built off of this one. Pull Request resolved: pytorch#163098 Approved by: https://github.com/janeyx99
Prevents edge cases in SequentialLR and ReduceLROnPlateau which could corrupt learning rates or trigger recompilation.
Supersedes #162360
Fixes #162359
Fixes #163093
While putting #162360 together, I noticed the class of issue I was fixing (i.e. unintended aliasing in lr_schedulers when using Tensor lrs) appeared in several other places. @janeyx99 suggested I put together a follow-up PR.
There are several bugs resembling the one fixed in #162360. I added a helper to fix these:
And applied it to fix bugs in
SequentialLR.__init__andLRScheduler._update_lr. I also added it toCyclicLR.__init__which was using an equivalent pattern, andCosineAnnealingWarmRestarts.stepwhich should have had a similar issue:But did not, because
get_lr()actually returns tensors when using a tensor lr (despite itslist[float]return type annotation). Relying on this propagation seems fragile, so I conservatively added the method here as well. I'll be fixing the type annotations and several related issues in followup PRs built off of this one.