Use torch.device instead of current device index for BnB quantizer#10069
Use torch.device instead of current device index for BnB quantizer#10069
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
sayakpaul
left a comment
There was a problem hiding this comment.
Could you run some slow tests and maybe the
diffusers/tests/quantization/bnb/test_4bit.py
Line 489 in c96bfa5
suite?
If not, it will need to wait till tomorrow when I can find time to run it. LMK.
| # TODO (sayakpaul, SunMarc): remove this after model loading refactor | ||
| elif is_quant_method_bnb: | ||
| param_device = torch.cuda.current_device() | ||
| param_device = torch.device(torch.cuda.current_device()) |
There was a problem hiding this comment.
let's throw an error in load_model_dict_into_meta when device is passed as index??
There was a problem hiding this comment.
Throws a value error now. @yiyixuxu
@sayakpaul, the integration tests pass:
(nightly-venv) (nightly-venv) aryan@hf-dgx-01:~/work/diffusers$ RUN_SLOW=1 CUDA_VISIBLE_DEVICES="3" pytest -s tests/quantization/bnb/test_4bit.py::SlowBnb4BitFluxTests
========================================================================================================================================= test session starts ==========================================================================================================================================
platform linux -- Python 3.10.14, pytest-8.3.2, pluggy-1.5.0
rootdir: /home/aryan/work/diffusers
configfile: pyproject.toml
plugins: timeout-2.3.1, requests-mock-1.10.0, xdist-3.6.1, anyio-4.6.2.post1
collected 1 item
tests/quantization/bnb/test_4bit.py Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.
`low_cpu_mem_usage` was None, now default to True since model is quantized.
Loading pipeline components...: 14%|████████████████████████████████▋ | 1/7 [00:00<00:00, 9.00it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading pipeline components...: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 9.26it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:14<00:00, 1.50s/it]
.
========================================================================================================================================== 1 passed in 52.27s ===================================================================
sayakpaul
left a comment
There was a problem hiding this comment.
Merge away! Thank you very much.
Maybe we could also update the type hint of device in load_model_dict_into_meta()?
I think it's already correctly set to str | torch.device |
…10069) * update * apply review suggestion --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
As discussed in #10009 (comment)
cc @sayakpaul