Skip to content

Pupking/05_flash_attention_lite

Repository files navigation

05_flash_attention_lite - single-head attention on Ampere

A three-step walk from a naive multi-launch attention pipeline to a fused-softmax + occupancy-tuned variant, with cuDNN's SDPA kernel as the reference. Each step composes primitives from earlier layers: WMMA + cp.async (Layer 1), warp_reduce_sum (Layer 2), online softmax (Layer 3). The audit's "the online softmax bug is triply present" (§L4.1.2) is fixed here — the branchless merge with an all--INF guard lives in attention_common.h and is referenced by every variant's softmax pass.

Results

naive fused_sdp fused_v2 cuDNN SDPA
ms @ N=1024, d=64 0.354 0.249 0.122 0.104
GFLOPS (4·N²·d) 757 1,078 2,199 2,581
% cuDNN 29 42 85 100
step what changed counter that moved step gain
00_naive.cu 7 kernel launches, N² scores materialised in GMEM baseline
01_fused_sdp.cu fuse scale + softmax + f32→f16 (one kernel instead of three) 2 DRAM passes × N² dropped 1.42×
02_fused_v2.cu block 1024→256 (6 blocks/SM), float4 loads, __expf, half2 stores Theoretical occupancy 66.7%→100% 2.04×

cuDNN baseline dispatches to cudnn_generated_fort_native_sdpa_sm80_ flash_fprop_wmma_f16_knob_6_128x64x64_4x1x1_cga1x1x1_kernel0_0 — cuDNN's single-kernel SDPA implementation. No intermediate materialisation, no N²-sized workspace.

Experimental Setup

Click for more details cudaGetDeviceProperties / cudaDeviceGetAttribute
  • GPU: NVIDIA GeForce RTX 3050 Laptop GPU (GA107), sm_86, 16 SMs
  • Per-SM: 65,536 registers, 1,536 threads, 100 KB shared memory, 128 KB unified L1/TEX
  • Tensor Cores: HMMA.16816 throughput 512 FP16×FP16+FP32 MAC / SM / cycle
  • Toolkit / driver: CUDA 13.0.88, driver 580.82.09, compiled -O3 --gpu-architecture=sm_86
  • cuDNN: 9.x (vendored at ../cuda-kernel-portfolio/cudnn)
  • Shape: N = 1024 (sequence length), d = 64 (head dimension). One head, no mask. 4·N²·d = 262 M ops per forward pass.
  • Workspace: 4·N²+d·N·2+N²·2 ≈ 6 MB per call (fits in L2 comfortably).

Summary

Row 0 → Row 1 — remove unnecessary GMEM round-trips. naive launches 7 kernels: transpose K, WMMA QK^T, scale N², online softmax (reads + writes N²), f32→f16 N², WMMA SV, f32→f16 N·d. The scale, softmax, and f32→f16 operate on the same N² working set three times through GMEM, costing 3 passes × 4 MB = 12 MB of traffic plus 3 launch latencies. fused_sdp fuses those three passes into one kernel that reads S_float once, computes max+sum+normalise+cast, writes S_half once — total traffic 8 MB, total launches 5. Step gain 1.42×.

Row 1 → Row 2 — unblock the scheduler. fused_sdp runs at 1024 threads per block, which on sm_86's 1,536-threads-per-SM budget forces exactly 1 block resident. The scheduler sees 32 warps/SM but many of those warps stall waiting on the same __expf pipe during the softmax pass. fused_v2 drops to 256 threads/block (8 warps/block), which fits 6 blocks/SM = 48 warps/SM — same warp count as v1 on occupancy terms, but now 6 independent rows per SM each with their own softmax state. The scheduler has 6× more work to choose from when one row stalls on its __expf chain. Float4 loads compress the S_float read into 1/4 the load instructions; __expf (fast math) trades 2-3 bits of mantissa for a ~2× pipe throughput; half2 stores similarly pack the f32→f16 conversion output into 4-byte transactions. Step gain 2.04×.

Row 2 → cuDNN. cuDNN's flash_fprop_wmma_f16 kernel fuses the entire attention into one kernel — no N² materialisation in HBM, no intermediate workspace. Softmax runs row-by-row as the QK^T tile streams through SMEM; the SV multiply consumes the softmaxed row before it leaves SMEM. The saved N²-traffic is exactly what FlashAttention does; cuDNN's kernel is a polished implementation of that pattern. Our fused_v2 still carries N² scores through the workspace, which at N=1024 is 4 MB per call = 2× the working set cuDNN uses. The 1.17× gap comes from that.

Verification

  • Cross-checked element-wise against a Kahan-summed, FP64-accumulated CPU attention reference (audit §L2.1.2 pattern extended to attention). For d=64, N=1024 and inputs in [0, 0.5], forward-sum on the d=64 dot product has error ~d·eps ≈ 8e-6 — below the verify tolerance.
  • verify_close<half> with atol = 2e-2, rtol = 1e-2. FP16 rounding at typical output magnitudes (~1/N ≈ 1e-3) is ~O(1e-3); the tolerance accommodates one matmul's accumulation noise plus softmax's fp16 down-cast.
  • Output poisoned before each launch (§0.1). If a kernel silently skips a write the poisoned NaN makes it into the verify and fails.
  • Audit §L4.1.2 branchless online softmax applied in 3 places: attention_common.h::online_softmax_rows_kernel, 01_fused_sdp.cu::fused_scale_softmax_f16_kernel, 02_fused_v2.cu::fused_scale_softmax_f16_v2_kernel. All three use attn_merge which short-circuits both-operands--INF to (−INF, 0) instead of exp(NaN).
  • DEDUP note: the WMMA cp.async GEMM kernel is duplicated in attention_common.h from 02_wmma_cp_async.cu. Kept local to avoid pulling an anonymous-namespace symbol across TUs. Any correctness fix to the Layer-1 kernel must be mirrored here — audit §0.2 flag, accepted trade-off.

Reproducing

Build:

rm -rf build && mkdir build && cd build
cmake .. && cmake --build . --parallel
cd ..

Run the Layer-4 sweep:

export LD_LIBRARY_PATH=PATH_TO_CUDNN/lib:$LD_LIBRARY_PATH
./build/bin/attention_bench --N 1024 --d 64 --iters 20 --runs 5 --warmup 3

Capture profiles:

./scripts/profile_layer5.sh

Scope

  • Single head, unmasked, N = 1024, d = 64. No batch dimension, no attention mask, no rotary / ALiBi / sliding-window. All three custom variants can be extended to masked by poisoning masked rows with -INF — the branchless online merge already handles it.
  • FP16 only, FP32 accumulators. cuDNN SDPA supports the same dtype mix; BF16 and FP8 paths are cuDNN-only in this repo.
  • No gradient path. Only forward.
  • Larger shapes not tested. N = 2048, d = 128 should fit but may hit workspace alloc limits; cuDNN's fused kernel scales linearly, ours quadratically in memory.

About

Single-head CUDA attention kernel: naive SDPA --> fused softmax --> occupancy-tuned variants, benchmarked against cuDNN SDPA with Nsight Compute profiling.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors