-
-
Notifications
You must be signed in to change notification settings - Fork 12.1k
Updates to Flex + VLLm integration #21416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces significant performance improvements to FlexAttention by implementing a more efficient method for building the block mask. The changes are well-structured and include new helper functions for tensor manipulation. However, I've identified a critical correctness issue where __post_init__ incorrectly returns a value, and a high-severity issue regarding a hardcoded block size that limits the applicability of this backend. Addressing these points will improve the robustness and correctness of the implementation.
8896590 to
1bcd0a9
Compare
20ab73c to
7fe7ae2
Compare
|
Running Flex Test: > assert flex_text == default_text, (
f"FlexAttention output doesn't match default for: {prompt!r}\n"
f"FlexAttention: {flex_text!r}\n"
f"Default: {default_text!r}")
E AssertionError: FlexAttention output doesn't match default for: 'Hello, my name is'
E FlexAttention: ' John. I am a 16 year old boy. I am a student at a high school. I am a bit of a loner. I have'
E Default: ' John. I am a 20-year-old student at the University of California, Berkeley. I am a senior in my major of Computer Science. I am'
E assert ' John. I am ...loner. I have' == ' John. I am ...Science. I am'
E
E - John. I am a 20-year-old student at the University of California, Berkeley. I am a senior in my major of Computer Science. I am
E + John. I am a 16 year old boy. I am a student at a high school. I am a bit of a loner. I haveWould be curious if people have better ideas on more robust testing here |
dbeff4d to
6d84dc5
Compare
|
@LucasWilkinson Are the failures related? |
I dont think so; but we're holding off force merges till we can get the CI green (hopefully today) so id just wait and rebase after that |
6d84dc5 to
2037f77
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: drisspg <drisspguessous@gmail.com>
2037f77 to
d67e708
Compare
|
@drisspg can you share your env info? When running the below with Python 3.10.18 on H100s I get the error at the bottom. Interestingly it works fine if doing Edit: Reduced the script to just the below: # git clone https://github.com/vllm-project/vllm.git
# cd vllm
# pip install uv
# VLLM_USE_PRECOMPILED=1 uv pip install --editable .
import os
os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION"
from vllm import LLM, SamplingParams
model = LLM("Qwen/Qwen2-7B-Instruct")
output = model.generate(["Hi"] * 4)
print(output)LogProcessed prompts: 0%| | 0/4 [00:00", line 32, in __init__ (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] File "/opt/hpcaas/.mounts/fs-0301404b74c8d22fd/home/muennighoff/vllm/vllm/v1/attention/backends/flex_attention .py", line 511, in __post_init__ (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] self.block_mask = self.build_block_mask() (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] File "/opt/hpcaas/.mounts/fs-0301404b74c8d22fd/home/muennighoff/vllm/vllm/v1/attention/backends/flex_attention .py", line 481, in build_block_mask (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] return create_block_mask_compiled( (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] File "/home/muennighoff/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 6 55, in _fn (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] return fn(*args, **kwargs) (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] File "/home/muennighoff/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/nn/attention/flex_attention.py ", line 824, in create_block_mask (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] def create_block_mask( (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] File "/home/muennighoff/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 8 38, in _fn (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] return fn(*args, **kwargs) (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] File "/home/muennighoff/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", l ine 1209, in forward (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] return compiled_fn(full_args) (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] File "/home/muennighoff/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runti me_wrappers.py", line 328, in runtime_wrapper (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] all_outs = call_func_at_runtime_with_args( (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] File "/home/muennighoff/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils .py", line 126, in call_func_at_runtime_with_args (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] out = normalize_as_list(f(args)) (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] File "/home/muennighoff/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runti me_wrappers.py", line 689, in inner_fn (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] outs = compiled_fn(args) (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] File "/home/muennighoff/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runti me_wrappers.py", line 495, in wrapper (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] return compiled_fn(runtime_args) (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] File "/home/muennighoff/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/_inductor/output_code.py", lin e 460, in __call__ (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] return self.current_callable(inputs) (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] File "/home/muennighoff/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1372, in run (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] return compiled_fn(new_inputs) (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] File "/home/muennighoff/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 371, in deferred_cudagraphify (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] return fn(inputs) (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] File "/home/muennighoff/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/_inductor/utils.py", line 2404 , in run (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] return model(new_inputs) (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] File "/home/muennighoff/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 1997, in run (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] out = self._run(new_inputs, function_id) (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] File "/home/muennighoff/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 2175, in _run (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] out = self.record_function(new_inputs, function_id) (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] File "/home/muennighoff/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/_inductor/cudagraph_trees.py", line 2230, in record_function (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] torch.cuda.synchronize() (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] File "/home/muennighoff/miniconda3/envs/vllm/lib/python3.10/site-packages/torch/cuda/__init__.py", line 1040, in synchronize (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] return torch._C._cuda_synchronize() (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] RuntimeError: CUDA error: an illegal memory access was encountered (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be inc orrect. (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] For debugging consider passing CUDA_LAUNCH_BLOCKING=1 (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions. (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] Traceback (most recent call last): File "/home/muennighoff/s2/generate_simple.py", line 26, in output = model.generate( File "/opt/hpcaas/.mounts/fs-0301404b74c8d22fd/home/muennighoff/vllm/vllm/entrypoints/llm.py", line 388, in generate outputs = self._run_engine(use_tqdm=use_tqdm) File "/opt/hpcaas/.mounts/fs-0301404b74c8d22fd/home/muennighoff/vllm/vllm/entrypoints/llm.py", line 1448, in _run_engine step_outputs = self.llm_engine.step() File "/opt/hpcaas/.mounts/fs-0301404b74c8d22fd/home/muennighoff/vllm/vllm/v1/engine/llm_engine.py", line 241, in step outputs = self.engine_core.get_output() File "/opt/hpcaas/.mounts/fs-0301404b74c8d22fd/home/muennighoff/vllm/vllm/v1/engine/core_client.py", line 668, in get_output raise self._format_exception(outputs) from None vllm.v1.engine.exceptions.EngineDeadError: EngineCore encountered an issue. See stack trace (above) for the root cause. (EngineCore_0 pid=542941) Process EngineCore_0: ... (EngineCore_0 pid=542941) ERROR 08-27 00:35:20 [core.py:710] return compiled_fn(new_inputs) |
Signed-off-by: drisspg <drisspguessous@gmail.com>
Signed-off-by: drisspg <drisspguessous@gmail.com> Signed-off-by: Xiao Yu <xiao.yu@amd.com>
Signed-off-by: drisspg <drisspguessous@gmail.com>
Signed-off-by: drisspg <drisspguessous@gmail.com>
Signed-off-by: drisspg <drisspguessous@gmail.com> Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
|
Hi, @drisspg , I am wondering if Flex can be used with pipeline parallel ? It seems that setting --pipeline-parallel-size to more than one would incur errors. |
Signed-off-by: drisspg <drisspguessous@gmail.com>
Purpose
Improve flex attention performance by adding a custom blockmask metadata builder for common case.
Also updates to newer metadata passing APIs.
Co-authored by Horace
Test Plan
pytest tests/kernels/test_flex_attention.py
Here is my perf numbers on a sweep of vllm, using this script:
https://gist.github.com/drisspg/c983e853ba8e9d999ae429783cde3c2f
Two sources of slow down:
FBURL for trace: https://fburl.com/3bcgf2d9
FBURL Flash Trace: https://fburl.com/0k1vpaor
Flex is performing pretty miserably here (flex decode disabled):
420 us compared to 20 - 25 us for Flash
We are launching some GPU kernels but they are all CPU bound, cuda-graphs kind of a mixed bag since the nonzero call inbuild direct breaks the thangs.
Comparing to Flash -> 3.68 ms between decode steps vs 2.3 ms

Note
I tried to run the same sweep before this PR but kept getting IMAs
Online Bench
Flash:
Flex:
Lm Eval
HF_HUB_DISABLE_XET=1 VLLM_ATTENTION_BACKEND=FLEX_ATTENTION lm_eval
--model vllm
--model_args '{
"pretrained": "meta-llama/Meta-Llama-3-8B-Instruct",
"gpu_memory_utilization": 0.8
}'
--tasks gsm8k --batch_size auto
w/ Flash backend
limit: None, num_fewshot: None, batch_size: auto
cc @LucasWilkinson