Skip to content

torch.save w/ _use_new_zipfile_serialization=True corrupts state_dict #46020

@stas00

Description

@stas00

🐛 Bug

In one of the transformers tests torch.save w/ the new _use_new_zipfile_serialization=True corrupts the data in the state_dict resulting in load returning bogus data that wasn't there on save.

To Reproduce

Steps to reproduce the behavior:

It seems to happen with the specific data only, so at the moment requires a specific sub-test:

git clone https://github.com/huggingface/transformers
cd transformers
pip install -e .[dev]
CUDA_LAUNCH_BLOCKING=1 USE_CUDA=1 pytest tests/test_modeling_bert.py::BertModelTest::test_head_pruning_integration

Specifically, we have https://github.com/huggingface/transformers/blob/960faaaf28b198c0dd2fcb288fa336a846aed398/src/transformers/modeling_utils.py#L729

        torch.save(state_dict, output_model_file) #, _use_new_zipfile_serialization=False)

        k = "embeddings.position_ids"
        if k in state_dict: print("SAVE\n", state_dict[k])

        state_dict = torch.load(output_model_file)
        if k in state_dict: print("LOAD\n", state_dict[k])

(I added the debug prints and load code as part of debug as the problem wasn't manifesting under the debugger)

What goes in with saving the key embeddings.position_ids:

0, 1, 2,..., 511

What comes out after load:

512, 0, 1, 2, ..., 510

That 512 shouldn't be there. It leads to a CUDA assert and the whole test suite goes kaboom.

Note 1: One more important detail. if output_model_file is set manually directly inside save/load code above I replace output_model_file - everything works fine. If it is done via the test (regardless whether it does it via tempfile.TemporaryDirectory()` or a hardcoded path it fails.

Note 2: If I run the code under debugger, the problem doesn't manifest itself.

These 2 peculiarities are very odd. Is there a race condition here and somehow one or the other changes the timing?

Also, if I remove all keys but "embeddings.position_ids" the data doesn't get corrupted. So it has to do with a specific data this state_dict happens to have.

Many thanks to @ptrblck for doing most of the heavy lifting, debugging this on slack.

t/a @jamesr66a

Environment

The problem happens with any pytorch-1.6+ (since _use_new_zipfile_serialization=True has been enabled). all is good with pytorch-1.5. And a variety of nvidia drivers were experimented with with no difference in behavior.

PyTorch version: 1.8.0.dev20201007
Is debug build: True
CUDA used to build PyTorch: 11.0
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.10.2

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 10.0.130
GPU models and configuration:
GPU 0: GeForce GTX TITAN X
GPU 1: GeForce GTX TITAN X

Nvidia driver version: 455.23.05
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.1
[pip3] pytorch-lightning==0.9.1rc3
[pip3] torch==1.8.0.dev20201007
[pip3] torchtext==0.6.0
[pip3] torchvision==0.8.0.dev20201007
[conda] blas                      1.0                         mkl
[conda] cudatoolkit               11.0.221             h6bb024c_0
[conda] mkl                       2020.2                      256
[conda] mkl-service               2.3.0            py38he904b0f_0
[conda] mkl_fft                   1.2.0            py38h23d657b_0
[conda] mkl_random                1.1.1            py38h0573a6f_0
[conda] numpy                     1.18.5                   pypi_0    pypi
[conda] numpy-base                1.19.1           py38hfa32c7d_0
[conda] pytorch                   1.8.0.dev20201007 py3.8_cuda11.0.221_cudnn8.0.3_0    pytorch-nightly
[conda] pytorch-lightning         0.9.1rc3                  dev_0    <develop>
[conda] torchtext                 0.6.0                    pypi_0    pypi
[conda] torchvision               0.8.0.dev20201007      py38_cu110    pytorch-nightly

cc @ezyang @gchanan @zou3519 @mruberry

Metadata

Metadata

Assignees

No one assigned

    Labels

    high prioritymodule: serializationIssues related to serialization (e.g., via pickle, or otherwise) of PyTorch objectstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions