[Feature] overlap LoRA weight loading with compute #15512
+108
−20
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Motivation
In #14190, I tried out a method to make LoRA weight loading asynchronous (see #8712). However, one issue is that this method required allocating more GPU memory for LoRA weight storage, which is not ideal. This PR instead makes LoRA weight loading truly free and asynchronous by pipelining the loading of LoRA weights. To illustrate:
In the current implementation, loading in a new batch of LoRA adapters blocks all forward computation:
With this PR, we pipeline this process so that forward compute can overlap with LoRA weight loading:
Modifications
ThreadPoolExecutorin the scheduler to handle CPU computation for loading LoRAs asynchronouslyload_stream) to ensure adapter loading happens on a separate stream fromforward_streamAccuracy Tests
I'm not too sure how to test this change, as it's mainly internal. Happy to hear suggestions!
Benchmarking and Profiling
The hardware I used was a single H200 GPU. I benchmarked with the following scripts:
Note that
mkopecki/chess-lora-adapter-llama-3.1-8bis a large adapter (>1GB). However, I used it mainly to see how well this implementation can perform in the best case and to establish it as a proof of concept.I ran the benchmark twice so that I could account for possible performance drops due to any cold start latencies:
Benchmark Comparison: Main vs PR
mainRun 1mainRun 2Overall, we see very large decreases in E2E latency and TTFT.
Checklist