Load models much faster on accelerator devices!!#36380
Conversation
src/transformers/modeling_utils.py
Outdated
| offload_index = None | ||
|
|
||
| if device_map is not None: | ||
| expanded_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix) |
There was a problem hiding this comment.
param_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix) let's make sure it's only used once, good to go for me otherwise!
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| if device_map is not None: | ||
| expanded_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix) | ||
| caching_allocator_warmup(model, expanded_device_map, dtype) |
There was a problem hiding this comment.
@Cyrilvallez, a quick question: does it take into account quantization? just to make sure we will not catch OOM allocating such a big tensor at once
There was a problem hiding this comment.
+1 we should try to use the dtype of the param if possible
There was a problem hiding this comment.
Yes, I don't remember if the dtype variable is always correctly switched based on quantization or not, but if it stays torch.float16 we need to update for sure 😉 I'll check tomorrow if nobody did it in the meantime!
There was a problem hiding this comment.
This does seem to break an integration test on bitsandbytes (aside from allocating extra memory as dtype here remains the original torch_dtype=torch.bfloat16)
(model): LlamaModel(
(embed_tokens): Embedding(32000, 4096, padding_idx=0)
(layers): Modul...6)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)
expanded_device_map = {'lm_head.weight': 0, 'model.embed_tokens.weight': 0, 'model.layers.0.input_layernorm.weight': 0, 'model.layers.0.mlp.down_proj.weight': 0, ...}
dtype = torch.bfloat16
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, dtype: torch.dtype) -> Dict:
"""This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
the model, which is actually the loading speed botteneck.
Calling this function allows to cut the model loading time by a very large margin.
"""
# Remove disk and cpu devices, and cast to proper torch.device
accelerator_device_map = {
param: torch.device(device) for param, device in expanded_device_map.items() if device not in ["cpu", "disk"]
}
parameter_count = defaultdict(lambda: 0)
for param_name, device in accelerator_device_map.items():
try:
> param = model.get_parameter(param_name)
../transformers/src/transformers/modeling_utils.py:5818:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../.local/venv/bnbdev/lib/python3.12/site-packages/torch/nn/modules/module.py:812: in get_parameter
mod: torch.nn.Module = self.get_submodule(module_path)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(32000, 4096, padding_idx=0)
(layers): Modul...6)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)
target = 'model.layers.0.self_attn.rotary_emb'
def get_submodule(self, target: str) -> "Module":
"""Return the submodule given by ``target`` if it exists, otherwise throw an error.
For example, let's say you have an ``nn.Module`` ``A`` that
looks like this:
.. code-block:: text
A(
(net_b): Module(
(net_c): Module(
(conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
)
(linear): Linear(in_features=100, out_features=200, bias=True)
)
)
(The diagram shows an ``nn.Module`` ``A``. ``A`` which has a nested
submodule ``net_b``, which itself has two submodules ``net_c``
and ``linear``. ``net_c`` then has a submodule ``conv``.)
To check whether or not we have the ``linear`` submodule, we
would call ``get_submodule("net_b.linear")``. To check whether
we have the ``conv`` submodule, we would call
``get_submodule("net_b.net_c.conv")``.
The runtime of ``get_submodule`` is bounded by the degree
of module nesting in ``target``. A query against
``named_modules`` achieves the same result, but it is O(N) in
the number of transitive modules. So, for a simple check to see
if some submodule exists, ``get_submodule`` should always be
used.
Args:
target: The fully-qualified string name of the submodule
to look for. (See above example for how to specify a
fully-qualified string.)
Returns:
torch.nn.Module: The submodule referenced by ``target``
Raises:
AttributeError: If the target string references an invalid
path or resolves to something that is not an
``nn.Module``
"""
if target == "":
return self
atoms: List[str] = target.split(".")
mod: torch.nn.Module = self
for item in atoms:
if not hasattr(mod, item):
> raise AttributeError(
mod._get_name() + " has no " "attribute `" + item + "`"
)
E AttributeError: LlamaAttention has no attribute `rotary_emb`
../../.local/venv/bnbdev/lib/python3.12/site-packages/torch/nn/modules/module.py:720: AttributeError
There was a problem hiding this comment.
I will check if we can get faster speedup in the following days
There was a problem hiding this comment.
@matthewdouglas I assume you tried to load an old checkpoint right? rotary_emb moved to base model, and should be skipped from keys if present. However present code really goes in all directions for renaming the keys, so apparently they are still present when entering the new cache allocating function. This will all stabilize in the near future as we keep re-building from_pretrained e.g. in #36033. We are gonna simplify the inner workings a lot, which should avoid this kind of key shenanigans
There was a problem hiding this comment.
@Cyrilvallez That's correct; the checkpoints in the test suite are generally older, e.g. huggyllama/llama-7b and bigscience/bloom-1b7.
What does this PR do?
Cut model loading time by a large margin (7x factor for an 8B model, 6x for a 32B 🚀🚀)
CudaMalloc was in fact the bottleneck in our current loading pipeline. We can pre-allocate tensor of the expected size to warm-up the caching allocator, which makes much fewer calls to CudaMalloc and cuts loading time accordingly.
The following snippet:
returns:
and before:
SO we basically cut loading time by 4x for an 8B model.
Note that in both cases, there is basically 15s overhead for cuda before loading the model.