|
3 | 3 | import logging |
4 | 4 | from typing import Any |
5 | 5 |
|
| 6 | +from transformers.utils.import_utils import is_flash_attn_2_available |
| 7 | + |
6 | 8 |
|
7 | 9 | KERNELIZATION_AVAILABLE = False |
8 | 10 | try: |
|
18 | 20 | class BenchmarkConfig: |
19 | 21 | """Configuration for a single benchmark scenario.""" |
20 | 22 |
|
| 23 | + all_attn_implementations = [ |
| 24 | + ("flash_attention_2", None), |
| 25 | + ("eager", None), |
| 26 | + ("sdpa", "math"), |
| 27 | + ("sdpa", "flash_attention"), |
| 28 | + ("flex_attention", None), |
| 29 | + ] |
| 30 | + |
| 31 | + all_compiled_modes = [None, "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"] |
| 32 | + |
21 | 33 | def __init__( |
22 | 34 | self, |
23 | 35 | warmup_iterations: int = 5, |
@@ -59,6 +71,13 @@ def __init__( |
59 | 71 | def check_validity(self, skip_validity_check: bool = False) -> None: |
60 | 72 | if skip_validity_check: |
61 | 73 | return |
| 74 | + # Check FA is installed |
| 75 | + if self.attn_implementation == "flash_attention_2" and not is_flash_attn_2_available(): |
| 76 | + logger.warning( |
| 77 | + "Flash attention does not support compile mode. Defaulting to SDPA w/ flash attention backend." |
| 78 | + ) |
| 79 | + self.attn_implementation = "sdpa" |
| 80 | + self.sdpa_backend = "flash_attention" |
62 | 81 | # Flash attention does not support compile mode, so we turn it off # FIXME: it would be better to support it |
63 | 82 | is_fa = self.attn_implementation == "flash_attention_2" |
64 | 83 | is_fa |= self.attn_implementation == "sdpa" and self.sdpa_backend == "flash_attention" |
@@ -163,34 +182,6 @@ def cross_generate_configs( |
163 | 182 | return configs |
164 | 183 |
|
165 | 184 |
|
166 | | -def generate_all_configs( |
167 | | - warmup_iterations: int = 5, |
168 | | - measurement_iterations: int = 20, |
169 | | - batch_size: int = 1, |
170 | | - sequence_length: int = 128, |
171 | | - num_tokens_to_generate: int = 128, |
172 | | - gpu_monitoring: bool = True, |
173 | | -) -> list[BenchmarkConfig]: |
174 | | - all_attn_implementations = [ |
175 | | - ("flash_attention_2", None), |
176 | | - ("eager", None), |
177 | | - ("sdpa", "math"), |
178 | | - ("sdpa", "flash_attention"), |
179 | | - ("flex_attention", None), |
180 | | - ] |
181 | | - return cross_generate_configs( |
182 | | - attn_impl_and_sdpa_backend=all_attn_implementations, |
183 | | - compiled_mode=[None, "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"], |
184 | | - kernelized=[False, KERNELIZATION_AVAILABLE], |
185 | | - warmup_iterations=warmup_iterations, |
186 | | - measurement_iterations=measurement_iterations, |
187 | | - batch_size=batch_size, |
188 | | - sequence_length=sequence_length, |
189 | | - num_tokens_to_generate=num_tokens_to_generate, |
190 | | - gpu_monitoring=gpu_monitoring, |
191 | | - ) |
192 | | - |
193 | | - |
194 | 185 | def generate_main_configs( |
195 | 186 | warmup_iterations: int = 5, |
196 | 187 | measurement_iterations: int = 20, |
|
0 commit comments