-
Notifications
You must be signed in to change notification settings - Fork 31k
[torchao] fix safetensors and enable loading from sharded files #41998
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[torchao] fix safetensors and enable loading from sharded files #41998
Conversation
8b6b802 to
eeb8451
Compare
a431b9a to
5a62843
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, looks good mostly, had one more inline comment
5a62843 to
1a020ed
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your work ! Left a couple of comments. Btw, we will soon refactor how quantization is applied as we move to dynamic weights loading like vllm. This should help getting support for features like TP
| if TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(self.metadata): | ||
| updated_state_dict = unflatten_tensor_state_dict(model.state_dict(), metadata) | ||
|
|
||
| weights_to_register = set(updated_state_dict.keys()) | ||
|
|
||
| for name, param in list(model.named_parameters()): | ||
| module_fqn, weight_name = name.rsplit(".", 1) | ||
| module = model.get_submodule(module_fqn) | ||
| weight = getattr(module, weight_name) | ||
|
|
||
| device = weight.device | ||
| requires_grad = weight.requires_grad | ||
|
|
||
| if "_weight_" in weight_name: | ||
| delattr(module, weight_name) | ||
|
|
||
| if name in weights_to_register: | ||
| new_param_value = updated_state_dict[name] | ||
| new_param = torch.nn.Parameter(new_param_value.to(device), requires_grad=requires_grad) | ||
| module.register_parameter(weight_name, new_param) | ||
|
|
||
| weights_to_register.remove(name) | ||
|
|
||
| model.load_state_dict(updated_state_dict, strict=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so instead of performing unflatten_tensor_state_dict in create_quantized_param, we do it here at the very end and we just store the flattened weights in the module?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, we don't want to do it in create_quantized_param since at most, we'd only have access to one shard file, and we want to consider the case where tensor subclass attributes are split up over multiple files
we call unflatten_tensor_state_dict at the very end to get the recovered state dict, and then iterate through the model and replace the weights that represent the tensor attributes with the entire tensor subclass.
1a020ed to
7cdb0c6
Compare
7cdb0c6 to
e4773a5
Compare
|
[For maintainers] Suggested jobs to run (before merge) run-slow: torchao_integration |
0e3c0d5 to
a2df2ec
Compare
a2df2ec to
3b60f7c
Compare
Context
This PR is a followup to #40735 and #41138. Previously, we enabled safetensors in torchao for one shard file. This PR fixes some errors introduced in #41138 and handles the case when checkpoints are sharded onto more than one file, including the edge case where a single quantized tensor (ie
Float8Tensor) is sharded onto two different files (ieqdataon one andscaleon another).Summary
If we are loading in a component of a tensor subclass in
create_quantized_param()called by_load_state_dict_into_meta_model(), we add this as a new parameter into the model. Then after all parameters are loaded, we unflatten thestate_dictand reassign the model parameters.Testing
Modified unit tests to test all tensor subclasses
python tests/quantization/torchao_integration/test_torchao.py -k TorchAoSafeSerializationTest