Skip to content

Raise error for 1D (size > 1) -> 0D parameter loads#166335

Closed
dsashidh wants to merge 2 commits intopytorch:mainfrom
dsashidh:fix_load_state_dict_shape_mismatch
Closed

Raise error for 1D (size > 1) -> 0D parameter loads#166335
dsashidh wants to merge 2 commits intopytorch:mainfrom
dsashidh:fix_load_state_dict_shape_mismatch

Conversation

@dsashidh
Copy link
Contributor

@dsashidh dsashidh commented Oct 27, 2025

Fixes #165873

Title

Fix load_state_dict: raise error for 1D (size > 1) -> 0D parameter loads

Summary

This PR fixes a bug where loading a 1D tensor (size > 1) into a scalar (0D) parameter would silently take the first element instead of raising an error. The fix preserves backward compatibility for 1D tensors of size 1 while catching genuine shape mismatches.

Motivation

Previously, loading a 1D tensor like torch.randn(32000) into a 0D scalar parameter would silently slice the first element, leading to silent data loss and potential bugs. This change ensures users get a clear error when there's a genuine shape mismatch.

Behavior change

Before:
1D tensor (any length) -> 0D scalar -> silently coerced using input_param[0]

After:

  • 1D tensor (size == 1) -> 0D scalar -> allowed (backward compatibility)
  • 1D tensor (size > 1) -> 0D scalar -> raises RuntimeError with size mismatch message

In torch/nn/modules/module.py, _load_from_state_dict, added input_param.shape[0] == 1 check to the backward compatibility condition to only allow single-element 1D tensors.

Tests

Added test_scalar_param_1d_tensor_raises to verify that loading 1D tensors of size > 1 raises an error, while size 1 loads successfully.

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 27, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/166335

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 8d2fa97 with merge base ed4aa44 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@dsashidh
Copy link
Contributor Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Oct 27, 2025
@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 28, 2025
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dsashidh It looks like this case was here for backwards compatibility, but from a long time ago.

If there is a decision to no longer support this backward compatibility, the cleaner fix would be to remove this if statement and let if fall back to if not is_param_lazy and input_param.shape != param.shape, where it will be caught in this case and currently has the same body.

Copy link
Contributor Author

@dsashidh dsashidh Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback, I commented below.

Copy link
Contributor

@mikaylagawarecki mikaylagawarecki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think my comment below might be an acceptable fix

@@ -2436,7 +2436,11 @@ def _load_from_state_dict(
and len(param.shape) == 0
and len(input_param.shape) == 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps add and input_param.shape[0] == 1, so the unexpected case will not fall into this if statement

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback, I commented below.

@dsashidh
Copy link
Contributor Author

dsashidh commented Nov 3, 2025

Thank you for the feedback!
I see two possible approaches:
Option 1: Remove the entire backward compatibility block (@morrison-turnansky 's approach)
Option 2: Add and input_param.shape[0] == 1 to preserve backward compatibility for [1] -> scalar (@mikaylagawarecki's approach)

I wrote my test with the expectation that [1] -> scalar should raise an error (strict shape matching), which passes with approach 1 but not approach 2. Since this backward compatibility is from PyTorch 0.3 (2017), I'm inclined toward approach 1 for stricter shape matching.
@mikaylagawarecki : Is there a strong reason to preserve [1] -> scalar compatibility?

@albanD albanD removed their request for review November 3, 2025 22:05
@mikaylagawarecki
Copy link
Contributor

mikaylagawarecki commented Nov 3, 2025

I think it should be fine to load a 1d tensor of size 1 into a scalar tensor, iiuc that's what adding the check I suggested would do, though do correct me if I'm wrong

@morrison-turnansky
Copy link
Collaborator

@dsashidh Go with what @mikaylagawarecki is suggesting.

@dsashidh dsashidh changed the title Raise error on 1D - > 0D parameter loads in load_state_dict Raise error for 1D (size > 1) -> 0D parameter loads Nov 4, 2025
@dsashidh
Copy link
Contributor Author

dsashidh commented Nov 4, 2025

I think it should be fine to load a 1d tensor of size 1 into a scalar tensor, iiuc that's what adding the check I suggested would do, though do correct me if I'm wrong

Hi @mikaylagawarecki thanks for clarifying! I've implemented your suggested check (and input_param.shape[0] == 1) and updated my test accordingly.
The fix now:
Allows [1] -> scalar (backward compatibility)
Raises error for [2+] -> scalar
Tests pass with this approach.

@dsashidh dsashidh force-pushed the fix_load_state_dict_shape_mismatch branch from 6b097c0 to 8d2fa97 Compare November 6, 2025 15:09
@dsashidh
Copy link
Contributor Author

dsashidh commented Nov 6, 2025

Hi @mikaylagawarecki I was seeing a MYPY lintrunner failure unrelated to my changes. I’ve rebased my branch on upstream/viable/strict which should hopefully align it with a clean CI baseline

@mikaylagawarecki
Copy link
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 7, 2025
@mikaylagawarecki mikaylagawarecki added release notes: nn release notes category topic: bug fixes topic category and removed topic: not user facing topic category labels Nov 7, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

hvarfner pushed a commit to hvarfner/botorch that referenced this pull request Nov 12, 2025
Summary: Fix in one OSS notebook where the state_dict was naively expanded. This [pytorch PR](pytorch/pytorch#166335) caused an error that was previously silently ignored.

Differential Revision: D86884189
meta-codesync bot pushed a commit to meta-pytorch/botorch that referenced this pull request Nov 12, 2025
Summary:
Pull Request resolved: #3079

Fix in one OSS notebook where the state_dict was naively expanded. This [pytorch PR](pytorch/pytorch#166335) caused an error that was previously silently ignored.

Reviewed By: sdaulton

Differential Revision: D86884189

fbshipit-source-id: 2c5ded01b17800a64da53b0722cfcc1ccac5e6eb
Silv3S pushed a commit to Silv3S/pytorch that referenced this pull request Nov 18, 2025
Fixes pytorch#165873

# Title
Fix load_state_dict: raise error for 1D (size > 1) -> 0D parameter loads

## Summary
This PR fixes a bug where loading a 1D tensor (size > 1) into a scalar (0D) parameter would silently take the first element instead of raising an error. The fix preserves backward compatibility for 1D tensors of size 1 while catching genuine shape mismatches.

## Motivation
Previously, loading a 1D tensor like torch.randn(32000) into a 0D scalar parameter would silently slice the first element, leading to silent data loss and potential bugs. This change ensures users get a clear error when there's a genuine shape mismatch.

## Behavior change

Before:
1D tensor (any length) -> 0D scalar -> silently coerced using input_param[0]

After:
- 1D tensor (size == 1) -> 0D scalar -> allowed (backward compatibility)
- 1D tensor (size > 1) -> 0D scalar -> raises RuntimeError with size mismatch message

In torch/nn/modules/module.py, _load_from_state_dict, added input_param.shape[0] == 1 check to the backward compatibility condition to only allow single-element 1D tensors.

## Tests
Added test_scalar_param_1d_tensor_raises to verify that loading 1D tensors of size > 1 raises an error, while size 1 loads successfully.
Pull Request resolved: pytorch#166335
Approved by: https://github.com/mikaylagawarecki
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: nn release notes category topic: bug fixes topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

No error raised despite shape mismatch in load_state_dict

6 participants