Master Optimization Flash Attention
Production-ready skill that handles optimizes, transformer, attention, flash. Includes structured workflows, validation checks, and reusable patterns for ai research.
Flash Attention -- Fast Memory-Efficient Transformer Attention
Overview
A comprehensive skill for optimizing transformer attention computation using Flash Attention. Flash Attention is an IO-aware attention algorithm that reduces memory reads and writes between GPU HBM (high-bandwidth memory) and SRAM (on-chip cache) through tiling and recomputation, achieving 2-4x speedup and 5-20x memory reduction for the attention layer. Available natively in PyTorch 2.2+ via torch.nn.functional.scaled_dot_product_attention and through the standalone flash-attn library for advanced features like sliding window attention, multi-query attention (MQA/GQA), and FP8 support on H100 GPUs. This skill covers integration into existing models, benchmarking, and advanced configuration for production workloads.
When to Use
- Training or running inference with transformers on sequences longer than 512 tokens
- Encountering GPU out-of-memory errors during attention computation
- Need 2-4x speedup in transformer training or inference without accuracy loss
- Working with long-context models (4K-128K+ tokens) where attention is the bottleneck
- Using NVIDIA GPUs (Ampere A100, Ada RTX 4090, Hopper H100 or newer)
- Implementing multi-query or grouped-query attention (MQA/GQA)
- Deploying on H100 GPUs and want FP8 attention for maximum throughput
Quick Start
PyTorch Native (Recommended -- Zero Dependencies)
import torch import torch.nn.functional as F # PyTorch 2.2+ automatically uses Flash Attention when available q = torch.randn(2, 8, 2048, 64, device="cuda", dtype=torch.float16) k = torch.randn(2, 8, 2048, 64, device="cuda", dtype=torch.float16) v = torch.randn(2, 8, 2048, 64, device="cuda", dtype=torch.float16) # This dispatches to Flash Attention on compatible hardware out = F.scaled_dot_product_attention(q, k, v) # With causal mask (autoregressive models) out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
flash-attn Library (Advanced Features)
pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_func # Input shape: [batch, seqlen, nheads, headdim] q = torch.randn(2, 2048, 8, 64, device="cuda", dtype=torch.float16) k = torch.randn(2, 2048, 8, 64, device="cuda", dtype=torch.float16) v = torch.randn(2, 2048, 8, 64, device="cuda", dtype=torch.float16) out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)
Core Concepts
Why Flash Attention Is Faster
Standard attention computes the full NxN attention matrix, reading and writing it to GPU HBM. Flash Attention avoids this by:
Standard Attention (memory-bound):
Q, K, V in HBM ββreadββ> Compute QK^T ββwriteββ> S in HBM
S in HBM ββreadββ> Softmax ββwriteββ> P in HBM
P in HBM ββreadββ> Compute PV ββwriteββ> O in HBM
Memory: O(N^2) for attention matrix
Flash Attention (IO-aware tiling):
Load Q, K, V tiles into SRAM ββ> Compute attention in tiles
ββ> Accumulate output in SRAM ββ> Write final O to HBM
Memory: O(N) -- no materialized attention matrix
Fewer HBM reads/writes = faster execution
Performance Characteristics
| Sequence Length | Speedup vs Standard | Memory Reduction | Notes |
|---|---|---|---|
| 256 | 1.0-1.2x | 2-3x | Minimal benefit |
| 512 | 1.3-1.5x | 3-5x | Noticeable improvement |
| 1024 | 1.5-2.0x | 5-10x | Clear advantage |
| 2048 | 2.0-3.0x | 10-15x | Strong advantage |
| 4096+ | 2.5-4.0x | 15-20x | Dominant advantage |
| 16K+ | 3.0-5.0x | 20x+ | Essential for long context |
Hardware Requirements
| GPU Architecture | Flash Attention Support | FP8 Support |
|---|---|---|
| Volta (V100) | No | No |
| Turing (RTX 2080) | No | No |
| Ampere (A100, RTX 3090) | Yes | No |
| Ada Lovelace (RTX 4090) | Yes | No |
| Hopper (H100, H200) | Yes | Yes (FlashAttention-3) |
PyTorch SDPA Integration
Replacing Standard Attention
import math import torch import torch.nn.functional as F # BEFORE: Standard attention (O(N^2) memory) def standard_attention(q, k, v, mask=None): d_k = q.size(-1) scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) if mask is not None: scores = scores.masked_fill(mask == 0, float("-inf")) attn_weights = torch.softmax(scores, dim=-1) return torch.matmul(attn_weights, v) # AFTER: Flash Attention via SDPA (O(N) memory) def flash_attention(q, k, v, attn_mask=None, is_causal=False): return F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, is_causal=is_causal, dropout_p=0.0, )
Forcing Specific Backends
import torch # Check available backends print(f"Flash Attention: {torch.backends.cuda.flash_sdp_enabled()}") print(f"Memory Efficient: {torch.backends.cuda.mem_efficient_sdp_enabled()}") print(f"Math: {torch.backends.cuda.math_sdp_enabled()}") # Force Flash Attention only with torch.nn.attention.sdpa_kernel( backends=[torch.nn.attention.SDPBackend.FLASH_ATTENTION] ): out = F.scaled_dot_product_attention(q, k, v, is_causal=True) # PyTorch 2.2+ context manager with torch.backends.cuda.sdp_kernel( enable_flash=True, enable_math=False, enable_mem_efficient=False, ): out = F.scaled_dot_product_attention(q, k, v)
HuggingFace Transformers Integration
from transformers import AutoModelForCausalLM # Load model with Flash Attention 2 backend model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-3.1-8B-Instruct", torch_dtype=torch.float16, attn_implementation="flash_attention_2", # Enable Flash Attention device_map="auto", ) # Or use SDPA (default in recent versions) model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-3.1-8B-Instruct", torch_dtype=torch.float16, attn_implementation="sdpa", # PyTorch SDPA (auto-selects best backend) device_map="auto", )
flash-attn Library Advanced Features
Multi-Query Attention (MQA) and Grouped-Query Attention (GQA)
from flash_attn import flash_attn_func # GQA: fewer KV heads than query heads # q: [batch, seqlen, num_q_heads, head_dim] e.g., 32 heads # k: [batch, seqlen, num_kv_heads, head_dim] e.g., 8 heads (GQA ratio 4:1) # v: [batch, seqlen, num_kv_heads, head_dim] q = torch.randn(2, 2048, 32, 64, device="cuda", dtype=torch.float16) k = torch.randn(2, 2048, 8, 64, device="cuda", dtype=torch.float16) v = torch.randn(2, 2048, 8, 64, device="cuda", dtype=torch.float16) # flash_attn_func handles GQA automatically out = flash_attn_func(q, k, v, causal=True) # Output: [batch, seqlen, num_q_heads, head_dim]
Sliding Window Attention
from flash_attn import flash_attn_func # Attend only within a local window (e.g., Mistral-style) out = flash_attn_func( q, k, v, causal=True, window_size=(256, 0), # (left_window, right_window) # Each token attends to 256 tokens before it (plus itself) )
Cross-Attention
from flash_attn import flash_attn_func # Encoder-decoder cross-attention # q from decoder: [batch, decoder_len, heads, dim] # k, v from encoder: [batch, encoder_len, heads, dim] q_dec = torch.randn(2, 512, 8, 64, device="cuda", dtype=torch.float16) k_enc = torch.randn(2, 1024, 8, 64, device="cuda", dtype=torch.float16) v_enc = torch.randn(2, 1024, 8, 64, device="cuda", dtype=torch.float16) out = flash_attn_func(q_dec, k_enc, v_enc, causal=False)
Variable-Length Sequences (Packed Attention)
from flash_attn import flash_attn_varlen_func # Pack multiple variable-length sequences into one tensor # Eliminates padding overhead q_packed = torch.randn(total_tokens, num_heads, head_dim, device="cuda", dtype=torch.float16) k_packed = torch.randn(total_tokens, num_heads, head_dim, device="cuda", dtype=torch.float16) v_packed = torch.randn(total_tokens, num_heads, head_dim, device="cuda", dtype=torch.float16) # Cumulative sequence lengths: [0, len_seq1, len_seq1+len_seq2, ...] cu_seqlens_q = torch.tensor([0, 512, 1024, 1280], device="cuda", dtype=torch.int32) cu_seqlens_k = torch.tensor([0, 512, 1024, 1280], device="cuda", dtype=torch.int32) out = flash_attn_varlen_func( q_packed, k_packed, v_packed, cu_seqlens_q, cu_seqlens_k, max_seqlen_q=512, max_seqlen_k=512, causal=True, )
Benchmarking
import torch import torch.nn.functional as F import time def benchmark_attention(seq_len, num_heads=32, head_dim=64, batch=4, iterations=100): q = torch.randn(batch, num_heads, seq_len, head_dim, device="cuda", dtype=torch.float16) k = torch.randn(batch, num_heads, seq_len, head_dim, device="cuda", dtype=torch.float16) v = torch.randn(batch, num_heads, seq_len, head_dim, device="cuda", dtype=torch.float16) # Warmup for _ in range(10): _ = F.scaled_dot_product_attention(q, k, v, is_causal=True) torch.cuda.synchronize() # Benchmark start = time.time() for _ in range(iterations): _ = F.scaled_dot_product_attention(q, k, v, is_causal=True) torch.cuda.synchronize() elapsed = (time.time() - start) / iterations * 1000 mem = torch.cuda.max_memory_allocated() / 1e9 print(f"SeqLen={seq_len:>5}: {elapsed:.2f} ms, Peak mem: {mem:.2f} GB") torch.cuda.reset_peak_memory_stats() for seq_len in [512, 1024, 2048, 4096, 8192]: benchmark_attention(seq_len)
Configuration Reference
flash_attn_func Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
q | Tensor | Required | Query tensor [batch, seqlen, heads, dim] |
k | Tensor | Required | Key tensor [batch, seqlen, heads, dim] |
v | Tensor | Required | Value tensor [batch, seqlen, heads, dim] |
dropout_p | float | 0.0 | Dropout probability (0.0 for inference) |
softmax_scale | float | None | Scale factor (default: 1/sqrt(head_dim)) |
causal | bool | False | Apply causal mask |
window_size | tuple | (-1, -1) | Sliding window (left, right); -1 = unlimited |
alibi_slopes | Tensor | None | ALiBi position bias slopes |
deterministic | bool | False | Deterministic backward pass |
return_attn_probs | bool | False | Return attention probabilities |
PyTorch SDPA Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
query | Tensor | Required | [batch, heads, seq, dim] |
key | Tensor | Required | [batch, heads, seq, dim] |
value | Tensor | Required | [batch, heads, seq, dim] |
attn_mask | Tensor | None | Additive attention mask |
dropout_p | float | 0.0 | Dropout probability |
is_causal | bool | False | Apply causal mask |
scale | float | None | Scale factor |
enable_gqa | bool | False | Enable grouped-query attention |
Best Practices
- Use PyTorch SDPA as the default --
F.scaled_dot_product_attentionautomatically selects the best backend (Flash, memory-efficient, or math) based on hardware and input characteristics. - Install
flash-attnonly when you need advanced features -- Sliding window, variable-length packing, and FP8 require the standalone library. For basic attention, PyTorch SDPA is sufficient. - Use FP16 or BF16 inputs -- Flash Attention requires half-precision inputs. FP32 inputs silently fall back to the slower math backend in SDPA.
- Set
is_causal=Truefor autoregressive models -- This is more efficient than passing an explicit triangular mask, as the kernel generates the causal mask internally. - Enable Flash Attention in HuggingFace models -- Set
attn_implementation="flash_attention_2"when loading models for automatic integration with the flash-attn library. - Use variable-length packing for batch training --
flash_attn_varlen_funceliminates padding waste when sequences have different lengths, improving training throughput by 10-30%. - Benchmark at your actual sequence length -- Flash Attention benefits scale with sequence length. If your sequences are consistently short (<256 tokens), the overhead may not justify the integration.
- Monitor memory with
torch.cuda.max_memory_allocated-- Compare peak memory before and after enabling Flash Attention to verify the expected 5-20x reduction in attention memory. - Do not mix
is_causal=Truewithattn_mask-- These are mutually exclusive in PyTorch SDPA. Using both raises an error or produces incorrect results. - Keep head dimension at 64 or 128 -- Flash Attention kernels are optimized for these head dimensions. Non-standard dimensions may fall back to slower implementations.
Troubleshooting
Flash Attention not being used (falling back to math backend):
Check inputs are FP16/BF16 (not FP32). Verify GPU is Ampere or newer. Ensure PyTorch >= 2.2. Use torch.backends.cuda.flash_sdp_enabled() to check availability.
ImportError when installing flash-attn:
Use pip install flash-attn --no-build-isolation. Ensure CUDA toolkit matches your PyTorch CUDA version. The library requires compilation from source.
No speedup observed for short sequences: Flash Attention overhead exceeds savings for sequences below ~256 tokens. The benefit increases superlinearly with sequence length. Benchmark at your production sequence length.
Different results compared to standard attention: Flash Attention uses online softmax with FP16 accumulation, which introduces small numerical differences (~1e-3). This is expected and does not affect model quality.
Sliding window attention not working:
The window_size parameter requires the flash-attn library, not PyTorch SDPA. Install with pip install flash-attn --no-build-isolation and use flash_attn_func directly.
Reviews
No reviews yet. Be the first to review this template!
Similar Templates
Full-Stack Code Reviewer
Comprehensive code review skill that checks for security vulnerabilities, performance issues, accessibility, and best practices across frontend and backend code.
Test Suite Generator
Generates comprehensive test suites with unit tests, integration tests, and edge cases. Supports Jest, Vitest, Pytest, and Go testing.
Pro Architecture Workspace
Battle-tested skill for architectural, decision, making, framework. Includes structured workflows, validation checks, and reusable patterns for development.