Skip to content

Commit ff0f7d6

Browse files
authored
More data in benchmarking (#41848)
* Reduce scope of cross-generate * Rm generate_sall configs * Workflow benchmarks more * Prevent crash when FA is not installed
1 parent 8030536 commit ff0f7d6

File tree

3 files changed

+30
-31
lines changed

3 files changed

+30
-31
lines changed

.github/workflows/benchmark.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ jobs:
5252
commit_id=$GITHUB_SHA
5353
fi
5454
commit_msg=$(git show -s --format=%s | cut -c1-70)
55-
python3 benchmark_v2/run_benchmarks.py -b 32 -s 128 -n 256 --branch-name "$BRANCH_NAME" --commit-id "$commit_id" --commit-message "$commit_msg" --model-id "$MODEL_ID" --log-level INFO --push-result-to-dataset "$DATASET_ID"
55+
python3 benchmark_v2/run_benchmarks.py -b 32 -s 128 -n 256 --cross-generate --branch-name "$BRANCH_NAME" --commit-id "$commit_id" --commit-message "$commit_msg" --model-id "$MODEL_ID" --log-level INFO --push-result-to-dataset "$DATASET_ID"
5656
env:
5757
HF_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
5858
PUSH_TO_HUB_TOKEN: ${{ secrets.PUSH_TO_HUB_TOKEN }}

benchmark_v2/framework/benchmark_config.py

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import logging
44
from typing import Any
55

6+
from transformers.utils.import_utils import is_flash_attn_2_available
7+
68

79
KERNELIZATION_AVAILABLE = False
810
try:
@@ -18,6 +20,16 @@
1820
class BenchmarkConfig:
1921
"""Configuration for a single benchmark scenario."""
2022

23+
all_attn_implementations = [
24+
("flash_attention_2", None),
25+
("eager", None),
26+
("sdpa", "math"),
27+
("sdpa", "flash_attention"),
28+
("flex_attention", None),
29+
]
30+
31+
all_compiled_modes = [None, "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"]
32+
2133
def __init__(
2234
self,
2335
warmup_iterations: int = 5,
@@ -59,6 +71,13 @@ def __init__(
5971
def check_validity(self, skip_validity_check: bool = False) -> None:
6072
if skip_validity_check:
6173
return
74+
# Check FA is installed
75+
if self.attn_implementation == "flash_attention_2" and not is_flash_attn_2_available():
76+
logger.warning(
77+
"Flash attention does not support compile mode. Defaulting to SDPA w/ flash attention backend."
78+
)
79+
self.attn_implementation = "sdpa"
80+
self.sdpa_backend = "flash_attention"
6281
# Flash attention does not support compile mode, so we turn it off # FIXME: it would be better to support it
6382
is_fa = self.attn_implementation == "flash_attention_2"
6483
is_fa |= self.attn_implementation == "sdpa" and self.sdpa_backend == "flash_attention"
@@ -163,34 +182,6 @@ def cross_generate_configs(
163182
return configs
164183

165184

166-
def generate_all_configs(
167-
warmup_iterations: int = 5,
168-
measurement_iterations: int = 20,
169-
batch_size: int = 1,
170-
sequence_length: int = 128,
171-
num_tokens_to_generate: int = 128,
172-
gpu_monitoring: bool = True,
173-
) -> list[BenchmarkConfig]:
174-
all_attn_implementations = [
175-
("flash_attention_2", None),
176-
("eager", None),
177-
("sdpa", "math"),
178-
("sdpa", "flash_attention"),
179-
("flex_attention", None),
180-
]
181-
return cross_generate_configs(
182-
attn_impl_and_sdpa_backend=all_attn_implementations,
183-
compiled_mode=[None, "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"],
184-
kernelized=[False, KERNELIZATION_AVAILABLE],
185-
warmup_iterations=warmup_iterations,
186-
measurement_iterations=measurement_iterations,
187-
batch_size=batch_size,
188-
sequence_length=sequence_length,
189-
num_tokens_to_generate=num_tokens_to_generate,
190-
gpu_monitoring=gpu_monitoring,
191-
)
192-
193-
194185
def generate_main_configs(
195186
warmup_iterations: int = 5,
196187
measurement_iterations: int = 20,

benchmark_v2/run_benchmarks.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,12 @@
2323
import sys
2424
import uuid
2525

26-
from framework.benchmark_config import BenchmarkConfig, generate_all_configs, generate_main_configs
26+
from framework.benchmark_config import (
27+
KERNELIZATION_AVAILABLE,
28+
BenchmarkConfig,
29+
cross_generate_configs,
30+
generate_main_configs,
31+
)
2732
from framework.benchmark_runner import BenchmarkRunner
2833

2934

@@ -82,7 +87,10 @@
8287
# If there is only one (batch_size, sequence_length, num_tokens_to_generate), we benchmark across configs
8388
elif len(args.batch_size) * len(args.sequence_length) * len(args.num_tokens_to_generate) == 1:
8489
if args.cross_generate:
85-
benchmark_configs = generate_all_configs(
90+
benchmark_configs = cross_generate_configs(
91+
attn_impl_and_sdpa_backend=BenchmarkConfig.all_attn_implementations,
92+
compiled_mode=[None, "default"], # usually there is not much to gain by compiling with other modes
93+
kernelized=[False, KERNELIZATION_AVAILABLE],
8694
warmup_iterations=args.warmup,
8795
measurement_iterations=args.iterations,
8896
batch_size=args.batch_size[0],

0 commit comments

Comments
 (0)