Skip to content

Conversation

@liangel-02
Copy link
Contributor

Context

This PR is a followup to #40735 and #41138. Previously, we enabled safetensors in torchao for one shard file. This PR fixes some errors introduced in #41138 and handles the case when checkpoints are sharded onto more than one file, including the edge case where a single quantized tensor (ie Float8Tensor) is sharded onto two different files (ie qdata on one and scale on another).

Summary

If we are loading in a component of a tensor subclass in create_quantized_param() called by _load_state_dict_into_meta_model(), we add this as a new parameter into the model. Then after all parameters are loaded, we unflatten the state_dict and reassign the model parameters.

Testing

Modified unit tests to test all tensor subclasses
python tests/quantization/torchao_integration/test_torchao.py -k TorchAoSafeSerializationTest

@liangel-02 liangel-02 marked this pull request as draft November 3, 2025 17:43
@liangel-02 liangel-02 force-pushed the torchao-safetensors-sharding branch from 8b6b802 to eeb8451 Compare November 3, 2025 17:54
@liangel-02 liangel-02 marked this pull request as ready for review November 3, 2025 18:32
@github-actions github-actions bot requested review from MekkCyber and SunMarc November 3, 2025 18:33
@liangel-02 liangel-02 force-pushed the torchao-safetensors-sharding branch 2 times, most recently from a431b9a to 5a62843 Compare November 3, 2025 21:12
Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, looks good mostly, had one more inline comment

@liangel-02 liangel-02 force-pushed the torchao-safetensors-sharding branch from 5a62843 to 1a020ed Compare November 3, 2025 21:19
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your work ! Left a couple of comments. Btw, we will soon refactor how quantization is applied as we move to dynamic weights loading like vllm. This should help getting support for features like TP

Comment on lines 245 to 268
if TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(self.metadata):
updated_state_dict = unflatten_tensor_state_dict(model.state_dict(), metadata)

weights_to_register = set(updated_state_dict.keys())

for name, param in list(model.named_parameters()):
module_fqn, weight_name = name.rsplit(".", 1)
module = model.get_submodule(module_fqn)
weight = getattr(module, weight_name)

device = weight.device
requires_grad = weight.requires_grad

if "_weight_" in weight_name:
delattr(module, weight_name)

if name in weights_to_register:
new_param_value = updated_state_dict[name]
new_param = torch.nn.Parameter(new_param_value.to(device), requires_grad=requires_grad)
module.register_parameter(weight_name, new_param)

weights_to_register.remove(name)

model.load_state_dict(updated_state_dict, strict=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so instead of performing unflatten_tensor_state_dict in create_quantized_param, we do it here at the very end and we just store the flattened weights in the module?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, we don't want to do it in create_quantized_param since at most, we'd only have access to one shard file, and we want to consider the case where tensor subclass attributes are split up over multiple files

we call unflatten_tensor_state_dict at the very end to get the recovered state dict, and then iterate through the model and replace the weights that represent the tensor attributes with the entire tensor subclass.

@liangel-02 liangel-02 force-pushed the torchao-safetensors-sharding branch from 1a020ed to 7cdb0c6 Compare November 4, 2025 15:58
@liangel-02 liangel-02 requested a review from SunMarc November 4, 2025 15:59
@liangel-02 liangel-02 force-pushed the torchao-safetensors-sharding branch from 7cdb0c6 to e4773a5 Compare November 4, 2025 18:15
@github-actions
Copy link
Contributor

github-actions bot commented Nov 4, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: torchao_integration

@liangel-02 liangel-02 force-pushed the torchao-safetensors-sharding branch 2 times, most recently from 0e3c0d5 to a2df2ec Compare November 4, 2025 18:22
@liangel-02 liangel-02 force-pushed the torchao-safetensors-sharding branch from a2df2ec to 3b60f7c Compare November 4, 2025 18:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants