Skip to content

Conversation

@glenliu21
Copy link
Contributor

Motivation

In #14190, I tried out a method to make LoRA weight loading asynchronous (see #8712). However, one issue is that this method required allocating more GPU memory for LoRA weight storage, which is not ideal. This PR instead makes LoRA weight loading truly free and asynchronous by pipelining the loading of LoRA weights. To illustrate:

In the current implementation, loading in a new batch of LoRA adapters blocks all forward computation:

GPU Compute
R1:  │            idle                                                       │████████ RUN R1..R4 ████████████│
R2:  │            idle                                                       │████████ RUN R1..R4 ████████████│
R3:  │            idle                                                       │████████ RUN R1..R4 ████████████│
R4:  │            idle                                                       │████████ RUN R1..R4 ████████████│

PCIe / LoRA Load
     │[ LOAD R1, R2, R3, R4 (blocking before run) ]                          │
     │─████████████████████──────────────────────────────────────────────────|

With this PR, we pipeline this process so that forward compute can overlap with LoRA weight loading:

GPU Compute
R1:  │      idle       │█████ RUN R1 ████│███ RUN R1+R2 ███│█ RUN R1+R2+R3 ██│RUN R1..R4 █████│
R2:  │      idle       │      idle       │███ RUN R1+R2 ███│█ RUN R1+R2+R3 ██│RUN R1..R4 █████│
R3:  │      idle       │      idle       │      idle       │█ RUN R1+R2+R3 ██│RUN R1..R4 █████│
R4:  │      idle       │      idle       │      idle       │       idle      │RUN R1..R4 █████│

PCIe / LoRA Load
     │ [ LOAD R1 ]     │ [ LOAD R2 ]     │ [ LOAD R3 ]     │ [ LOAD R4 ]     │
     │█████████────────│█████████────────│█████████────────│█████████────────│

Modifications

  • Introduce a ThreadPoolExecutor in the scheduler to handle CPU computation for loading LoRAs asynchronously
  • Introduce a separate CUDA stream (load_stream) to ensure adapter loading happens on a separate stream from forward_stream
  • Modify existing batch selection logic and add required bookkeeping logic

Accuracy Tests

I'm not too sure how to test this change, as it's mainly internal. Happy to hear suggestions!

Benchmarking and Profiling

The hardware I used was a single H200 GPU. I benchmarked with the following scripts:

python3 -m sglang.launch_server \
    --model-path meta-llama/Llama-3.1-8B-Instruct \
    --max-loaded-loras 16 \
    --max-loras-per-batch 8 \
    --lora-paths \
        adapter0=mkopecki/chess-lora-adapter-llama-3.1-8b \
        adapter1=mkopecki/chess-lora-adapter-llama-3.1-8b \
        adapter2=mkopecki/chess-lora-adapter-llama-3.1-8b \
        adapter3=mkopecki/chess-lora-adapter-llama-3.1-8b \
        adapter4=mkopecki/chess-lora-adapter-llama-3.1-8b \
        adapter5=mkopecki/chess-lora-adapter-llama-3.1-8b \
        adapter6=mkopecki/chess-lora-adapter-llama-3.1-8b \
        adapter7=mkopecki/chess-lora-adapter-llama-3.1-8b \
        adapter8=mkopecki/chess-lora-adapter-llama-3.1-8b \
        adapter9=mkopecki/chess-lora-adapter-llama-3.1-8b \
        adapter10=mkopecki/chess-lora-adapter-llama-3.1-8b \
        adapter11=mkopecki/chess-lora-adapter-llama-3.1-8b \
        adapter12=mkopecki/chess-lora-adapter-llama-3.1-8b \
        adapter13=mkopecki/chess-lora-adapter-llama-3.1-8b \
        adapter14=mkopecki/chess-lora-adapter-llama-3.1-8b \
        adapter15=mkopecki/chess-lora-adapter-llama-3.1-8b

Note that mkopecki/chess-lora-adapter-llama-3.1-8b is a large adapter (>1GB). However, I used it mainly to see how well this implementation can perform in the best case and to establish it as a proof of concept.

python3 -m sglang.bench_serving \
  --backend sglang \
  --base-url http://localhost:30000 \
  --dataset-name random \
  --num-prompts 100 \
  --request-rate 4 \
  --random-input-len 2048 \
  --random-output-len 1024 \
  --lora-name \
    adapter0 \
    adapter1 \
    adapter2 \
    adapter3 \
    adapter4 \
    adapter5 \
    adapter6 \
    adapter7 \
    adapter8 \
    adapter9 \
    adapter10 \
    adapter11 \
    adapter12 \
    adapter13 \
    adapter14 \
    adapter15

I ran the benchmark twice so that I could account for possible performance drops due to any cold start latencies:

Benchmark Comparison: Main vs PR

Metric main Run 1 PR Run 1 % decrease
Mean E2E Latency 12955.76 5929.72 54.4%
Median E2E Latency 11294.58 5813.74 48.5%
Mean TTFT 5721.71 1853.25 67.6%
Median TTFT 495.27 213.42 56.9%
P99 TTFT 21236.25 9827.03 53.7%
Median TPOT 13.63 7.76 43%
Metric main Run 2 PR Run 2 % decrease
Mean E2E Latency 10505.25 5409.31 48.5%
Median E2E Latency 9045.03 5005.13 44.7%
Mean TTFT 4135.48 1618.01 60.9%
Median TTFT 193.50 208.14 -7.7%
P99 TTFT 15808.89 9436.19 40.3%
Median TPOT 11.85 7.28 38.6%

Overall, we see very large decreases in E2E latency and TTFT.

Checklist

@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant