Skip to content

OOM when unwrap_model_for_generation #2250

@hlnchen

Description

@hlnchen

System Info

torch==2.4.0
transformers==4.43.4
trl==0.9.6
tokenizers==0.19.1
accelerate==0.32.0
peft==0.12.0
datasets==2.20.0
deepspeed==0.15.0
bitsandbytes==0.43.3
sentencepiece==0.2.0
flash-attn==2.6.3

gcc version 11.4.0 (Ubuntu 11.4.0-1ubuntu1~22.04)

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

Hi TRL team,

I am hitting OOM errors when fine-tuning a Llama-3.1-70B model on my modified RL trainer.
It looks like the error happens on unwrapping the model for generation (I have an on policy algorithm and each training step I will generate some sequences)

My machine has 8 H100 80GB GPUs and I used lora. But it looks like unwrap_model_for_generation will load the entire model into memory and causing OOM. Any suggestions?

[rank7]: Traceback (most recent call last):
[rank7]:   File "/export/scripts/training.py", line 243, in <module>
[rank7]:     trainer.train(resume_from_checkpoint=config.checkpoint_path)
[rank7]:   File "/export/trainer/trainer_simple_rloo.py", line 195, in train
[rank7]:     with torch.no_grad(), unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
[rank7]:   File "/opt/conda/lib/python3.12/contextlib.py", line 137, in __enter__
[rank7]:     return next(self.gen)
[rank7]:            ^^^^^^^^^^^^^^
[rank7]:   File "/export/venv/lib/python3.12/site-packages/trl/models/utils.py", line 162, in unwrap_model_for_generation
[rank7]:     with deepspeed.zero.GatheredParameters(model.parameters()):
[rank7]:   File "/export/venv/lib/python3.12/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 2224, in __enter__
[rank7]:     self.params[0].all_gather(param_list=self.params)
[rank7]:   File "/export/venv/lib/python3.12/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 1143, in all_gather
[rank7]:     return self._all_gather(param_list, async_op=async_op, hierarchy=hierarchy)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/export/venv/lib/python3.12/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
[rank7]:     ret_val = func(*args, **kwargs)
[rank7]:               ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/export/venv/lib/python3.12/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 1511, in _all_gather
[rank7]:     self._allgather_params_coalesced(all_gather_nonquantize_list, hierarchy, quantize=False)
[rank7]:   File "/export/venv/lib/python3.12/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 1799, in _allgather_params_coalesced
[rank7]:     flat_tensor = torch.empty(tensor_size, dtype=param_list[0].ds_tensor.dtype,
[rank7]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 448.00 MiB. GPU 7 has a total capacity of 79.11 GiB of which 256.56 MiB is free. Including non-PyTorch memory, this process has 0 bytes memory in use. Of the allocated memory 74.68 GiB is allocated by PyTorch, and 7.22 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Expected behavior

OOM issue resolved.

Metadata

Metadata

Assignees

Labels

🐛 bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions