Skip to content

Load models much faster on accelerator devices!!#36380

Merged
ArthurZucker merged 4 commits intomainfrom
fast-loading
Feb 25, 2025
Merged

Load models much faster on accelerator devices!!#36380
ArthurZucker merged 4 commits intomainfrom
fast-loading

Conversation

@Cyrilvallez
Copy link
Member

@Cyrilvallez Cyrilvallez commented Feb 24, 2025

What does this PR do?

Cut model loading time by a large margin (7x factor for an 8B model, 6x for a 32B 🚀🚀)
CudaMalloc was in fact the bottleneck in our current loading pipeline. We can pre-allocate tensor of the expected size to warm-up the caching allocator, which makes much fewer calls to CudaMalloc and cuts loading time accordingly.

The following snippet:

import time
t_ini = time.time()
import torch
from transformers import AutoModelForCausalLM
print(f"Time for the imports: {time.time() - t_ini:.2f} s")

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
# model_id = "codellama/CodeLlama-34b-Instruct-hf"
device = torch.device(f"cuda:2")


t0 = time.time()
torch.cuda.synchronize(device)
print(f"Time for cuda warmup before loading model: {time.time() - t0:.2f} s")
t0 = time.time()
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map=device)
torch.cuda.synchronize(device)
dt = time.time() - t0
print(f"time to load the model: {dt:.2f}")

max_mem = torch.cuda.max_memory_allocated(device) / 1024**3
current_mem = torch.cuda.memory_allocated(device) / 1024**3
print(f"Max: {max_mem:.2f} GiB")
print(f"Current: {current_mem:.2f} GiB")
print(f"Full time: {time.time() - t_ini:.2f} s")

returns:

Time for the imports: 4.38 s
Time for cuda warmup before loading model: 16.18 s
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:04<00:00,  1.21s/it]
time to load the model: 5.92
Max: 14.96 GiB
Current: 14.96 GiB
Full time: 26.49 s

and before:

Time for the imports: 4.54 s
Time for cuda warmup before loading model: 15.23 s
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:42<00:00, 10.51s/it]
time to load the model: 42.54
Max: 14.96 GiB
Current: 14.96 GiB
Full time: 62.32 s

SO we basically cut loading time by 4x for an 8B model.

Note that in both cases, there is basically 15s overhead for cuda before loading the model.

@Cyrilvallez Cyrilvallez changed the title Load model much faster!! Load models much faster on accelerator devices!! Feb 24, 2025
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's fucking go 🚀

offload_index = None

if device_map is not None:
expanded_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

param_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix) let's make sure it's only used once, good to go for me otherwise!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merging!

@ArthurZucker ArthurZucker merged commit 4b5cf54 into main Feb 25, 2025
24 checks passed
@ArthurZucker ArthurZucker deleted the fast-loading branch February 25, 2025 08:41
Comment on lines +4822 to +4824
if device_map is not None:
expanded_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix)
caching_allocator_warmup(model, expanded_device_map, dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Cyrilvallez, a quick question: does it take into account quantization? just to make sure we will not catch OOM allocating such a big tensor at once

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 we should try to use the dtype of the param if possible

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I don't remember if the dtype variable is always correctly switched based on quantization or not, but if it stays torch.float16 we need to update for sure 😉 I'll check tomorrow if nobody did it in the meantime!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does seem to break an integration test on bitsandbytes (aside from allocating extra memory as dtype here remains the original torch_dtype=torch.bfloat16)

  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): Modul...6)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)
expanded_device_map = {'lm_head.weight': 0, 'model.embed_tokens.weight': 0, 'model.layers.0.input_layernorm.weight': 0, 'model.layers.0.mlp.down_proj.weight': 0, ...}
dtype = torch.bfloat16
    def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, dtype: torch.dtype) -> Dict:
        """This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
        device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
        the model, which is actually the loading speed botteneck.
        Calling this function allows to cut the model loading time by a very large margin.
        """
        # Remove disk and cpu devices, and cast to proper torch.device
        accelerator_device_map = {
            param: torch.device(device) for param, device in expanded_device_map.items() if device not in ["cpu", "disk"]
        }
        parameter_count = defaultdict(lambda: 0)
        for param_name, device in accelerator_device_map.items():
            try:
>               param = model.get_parameter(param_name)
../transformers/src/transformers/modeling_utils.py:5818: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../../.local/venv/bnbdev/lib/python3.12/site-packages/torch/nn/modules/module.py:812: in get_parameter
    mod: torch.nn.Module = self.get_submodule(module_path)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
self = LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): Modul...6)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)
target = 'model.layers.0.self_attn.rotary_emb'
    def get_submodule(self, target: str) -> "Module":
        """Return the submodule given by ``target`` if it exists, otherwise throw an error.
    
        For example, let's say you have an ``nn.Module`` ``A`` that
        looks like this:
    
        .. code-block:: text
    
            A(
                (net_b): Module(
                    (net_c): Module(
                        (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
                    )
                    (linear): Linear(in_features=100, out_features=200, bias=True)
                )
            )
    
        (The diagram shows an ``nn.Module`` ``A``. ``A`` which has a nested
        submodule ``net_b``, which itself has two submodules ``net_c``
        and ``linear``. ``net_c`` then has a submodule ``conv``.)
    
        To check whether or not we have the ``linear`` submodule, we
        would call ``get_submodule("net_b.linear")``. To check whether
        we have the ``conv`` submodule, we would call
        ``get_submodule("net_b.net_c.conv")``.
    
        The runtime of ``get_submodule`` is bounded by the degree
        of module nesting in ``target``. A query against
        ``named_modules`` achieves the same result, but it is O(N) in
        the number of transitive modules. So, for a simple check to see
        if some submodule exists, ``get_submodule`` should always be
        used.
    
        Args:
            target: The fully-qualified string name of the submodule
                to look for. (See above example for how to specify a
                fully-qualified string.)
    
        Returns:
            torch.nn.Module: The submodule referenced by ``target``
    
        Raises:
            AttributeError: If the target string references an invalid
                path or resolves to something that is not an
                ``nn.Module``
        """
        if target == "":
            return self
    
        atoms: List[str] = target.split(".")
        mod: torch.nn.Module = self
    
        for item in atoms:
            if not hasattr(mod, item):
>               raise AttributeError(
                    mod._get_name() + " has no " "attribute `" + item + "`"
                )
E               AttributeError: LlamaAttention has no attribute `rotary_emb`
../../.local/venv/bnbdev/lib/python3.12/site-packages/torch/nn/modules/module.py:720: AttributeError

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for now, we will just skip quantized model #36428

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will check if we can get faster speedup in the following days

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matthewdouglas I assume you tried to load an old checkpoint right? rotary_emb moved to base model, and should be skipped from keys if present. However present code really goes in all directions for renaming the keys, so apparently they are still present when entering the new cache allocating function. This will all stabilize in the near future as we keep re-building from_pretrained e.g. in #36033. We are gonna simplify the inner workings a lot, which should avoid this kind of key shenanigans

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Cyrilvallez That's correct; the checkpoints in the test suite are generally older, e.g. huggyllama/llama-7b and bigscience/bloom-1b7.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants