Skip to content

Conversation

@david6666666
Copy link
Contributor

@david6666666 david6666666 commented Aug 29, 2025

Purpose

issue #23780

Test Plan

prefill instance:

CUDA_VISIBLE_DEVICES=0 VLLM_NIXL_SIDE_CHANNEL_PORT=5567 vllm serve /workspace/models/qwen2.5_7B \
  --port 20001 \
  --tensor-parallel-size 1 \
  --enforce-eager \
  --block-size 16 \
  --enable-log-requests \
  --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'

decode instance:

CUDA_VISIBLE_DEVICES=1 VLLM_NIXL_SIDE_CHANNEL_PORT=5667 viztracer -o decode.json vllm serve /workspace/models/qwen2.5_7B \
  --port 20002 \
  --tensor-parallel-size 1 \
  --enforce-eager \
  --block-size 16 \
  --enable-log-requests \
  --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'

proxy:

python ./tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \
      --port 40000 \
      --prefiller-port 20001 \
      --decoder-port 20002
python benchmarks/benchmark_serving.py \
    --backend vllm \
    --endpoint /v1/completions \
    --model /workspace/models/qwen2.5_7B \
    --dataset-name random \
    --random-input 800 \
    --random-output 100 \
    --num-prompts 100 \
    --request-rate 5 \
    --host localhost \
    --port 40000 \

Test Result

origin:
image

after this pr:
image

_get_block_descs_ids time reduced from ~2ms to ~0.05ms


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance optimization in the NixlConnector by vectorizing the computation in the _get_block_descs_ids method. The change replaces nested Python loops with NumPy broadcasting operations, which significantly reduces the execution time for generating block descriptor IDs, as demonstrated by the performance results in the description. The implementation is correct and effectively leverages NumPy for better performance. The addition of the numpy import is necessary and appropriate for this change.

@robertgshaw2-redhat
Copy link
Collaborator

nice job

Signed-off-by: ycyaw66 <497410282@qq.com>
Signed-off-by: ycyaw66 <497410282@qq.com>
Signed-off-by: ycyaw66 <497410282@qq.com>
@david6666666
Copy link
Contributor Author

hi @robertgshaw2-redhat, PTAL thanks.

@robertgshaw2-redhat
Copy link
Collaborator

really nice work

@robertgshaw2-redhat robertgshaw2-redhat enabled auto-merge (squash) September 3, 2025 17:19
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 3, 2025
@robertgshaw2-redhat robertgshaw2-redhat merged commit 6adaed4 into vllm-project:main Sep 3, 2025
52 checks passed
eicherseiji pushed a commit to eicherseiji/vllm that referenced this pull request Sep 9, 2025
Signed-off-by: ycyaw66 <497410282@qq.com>
Co-authored-by: ycyaw66 <497410282@qq.com>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
Signed-off-by: ycyaw66 <497410282@qq.com>
Co-authored-by: ycyaw66 <497410282@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants