M

Master Optimization Flash Attention

Production-ready skill that handles optimizes, transformer, attention, flash. Includes structured workflows, validation checks, and reusable patterns for ai research.

SkillClipticsai researchv1.0.0MIT
0 views0 copies

Flash Attention -- Fast Memory-Efficient Transformer Attention

Overview

A comprehensive skill for optimizing transformer attention computation using Flash Attention. Flash Attention is an IO-aware attention algorithm that reduces memory reads and writes between GPU HBM (high-bandwidth memory) and SRAM (on-chip cache) through tiling and recomputation, achieving 2-4x speedup and 5-20x memory reduction for the attention layer. Available natively in PyTorch 2.2+ via torch.nn.functional.scaled_dot_product_attention and through the standalone flash-attn library for advanced features like sliding window attention, multi-query attention (MQA/GQA), and FP8 support on H100 GPUs. This skill covers integration into existing models, benchmarking, and advanced configuration for production workloads.

When to Use

  • Training or running inference with transformers on sequences longer than 512 tokens
  • Encountering GPU out-of-memory errors during attention computation
  • Need 2-4x speedup in transformer training or inference without accuracy loss
  • Working with long-context models (4K-128K+ tokens) where attention is the bottleneck
  • Using NVIDIA GPUs (Ampere A100, Ada RTX 4090, Hopper H100 or newer)
  • Implementing multi-query or grouped-query attention (MQA/GQA)
  • Deploying on H100 GPUs and want FP8 attention for maximum throughput

Quick Start

import torch import torch.nn.functional as F # PyTorch 2.2+ automatically uses Flash Attention when available q = torch.randn(2, 8, 2048, 64, device="cuda", dtype=torch.float16) k = torch.randn(2, 8, 2048, 64, device="cuda", dtype=torch.float16) v = torch.randn(2, 8, 2048, 64, device="cuda", dtype=torch.float16) # This dispatches to Flash Attention on compatible hardware out = F.scaled_dot_product_attention(q, k, v) # With causal mask (autoregressive models) out = F.scaled_dot_product_attention(q, k, v, is_causal=True)

flash-attn Library (Advanced Features)

pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_func # Input shape: [batch, seqlen, nheads, headdim] q = torch.randn(2, 2048, 8, 64, device="cuda", dtype=torch.float16) k = torch.randn(2, 2048, 8, 64, device="cuda", dtype=torch.float16) v = torch.randn(2, 2048, 8, 64, device="cuda", dtype=torch.float16) out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)

Core Concepts

Why Flash Attention Is Faster

Standard attention computes the full NxN attention matrix, reading and writing it to GPU HBM. Flash Attention avoids this by:

Standard Attention (memory-bound):
  Q, K, V in HBM  ──read──>  Compute QK^T  ──write──>  S in HBM
  S in HBM         ──read──>  Softmax       ──write──>  P in HBM
  P in HBM         ──read──>  Compute PV    ──write──>  O in HBM
  Memory: O(N^2) for attention matrix

Flash Attention (IO-aware tiling):
  Load Q, K, V tiles into SRAM ──> Compute attention in tiles
  ──> Accumulate output in SRAM ──> Write final O to HBM
  Memory: O(N) -- no materialized attention matrix
  Fewer HBM reads/writes = faster execution

Performance Characteristics

Sequence LengthSpeedup vs StandardMemory ReductionNotes
2561.0-1.2x2-3xMinimal benefit
5121.3-1.5x3-5xNoticeable improvement
10241.5-2.0x5-10xClear advantage
20482.0-3.0x10-15xStrong advantage
4096+2.5-4.0x15-20xDominant advantage
16K+3.0-5.0x20x+Essential for long context

Hardware Requirements

GPU ArchitectureFlash Attention SupportFP8 Support
Volta (V100)NoNo
Turing (RTX 2080)NoNo
Ampere (A100, RTX 3090)YesNo
Ada Lovelace (RTX 4090)YesNo
Hopper (H100, H200)YesYes (FlashAttention-3)

PyTorch SDPA Integration

Replacing Standard Attention

import math import torch import torch.nn.functional as F # BEFORE: Standard attention (O(N^2) memory) def standard_attention(q, k, v, mask=None): d_k = q.size(-1) scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) if mask is not None: scores = scores.masked_fill(mask == 0, float("-inf")) attn_weights = torch.softmax(scores, dim=-1) return torch.matmul(attn_weights, v) # AFTER: Flash Attention via SDPA (O(N) memory) def flash_attention(q, k, v, attn_mask=None, is_causal=False): return F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, is_causal=is_causal, dropout_p=0.0, )

Forcing Specific Backends

import torch # Check available backends print(f"Flash Attention: {torch.backends.cuda.flash_sdp_enabled()}") print(f"Memory Efficient: {torch.backends.cuda.mem_efficient_sdp_enabled()}") print(f"Math: {torch.backends.cuda.math_sdp_enabled()}") # Force Flash Attention only with torch.nn.attention.sdpa_kernel( backends=[torch.nn.attention.SDPBackend.FLASH_ATTENTION] ): out = F.scaled_dot_product_attention(q, k, v, is_causal=True) # PyTorch 2.2+ context manager with torch.backends.cuda.sdp_kernel( enable_flash=True, enable_math=False, enable_mem_efficient=False, ): out = F.scaled_dot_product_attention(q, k, v)

HuggingFace Transformers Integration

from transformers import AutoModelForCausalLM # Load model with Flash Attention 2 backend model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-3.1-8B-Instruct", torch_dtype=torch.float16, attn_implementation="flash_attention_2", # Enable Flash Attention device_map="auto", ) # Or use SDPA (default in recent versions) model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-3.1-8B-Instruct", torch_dtype=torch.float16, attn_implementation="sdpa", # PyTorch SDPA (auto-selects best backend) device_map="auto", )

flash-attn Library Advanced Features

Multi-Query Attention (MQA) and Grouped-Query Attention (GQA)

from flash_attn import flash_attn_func # GQA: fewer KV heads than query heads # q: [batch, seqlen, num_q_heads, head_dim] e.g., 32 heads # k: [batch, seqlen, num_kv_heads, head_dim] e.g., 8 heads (GQA ratio 4:1) # v: [batch, seqlen, num_kv_heads, head_dim] q = torch.randn(2, 2048, 32, 64, device="cuda", dtype=torch.float16) k = torch.randn(2, 2048, 8, 64, device="cuda", dtype=torch.float16) v = torch.randn(2, 2048, 8, 64, device="cuda", dtype=torch.float16) # flash_attn_func handles GQA automatically out = flash_attn_func(q, k, v, causal=True) # Output: [batch, seqlen, num_q_heads, head_dim]

Sliding Window Attention

from flash_attn import flash_attn_func # Attend only within a local window (e.g., Mistral-style) out = flash_attn_func( q, k, v, causal=True, window_size=(256, 0), # (left_window, right_window) # Each token attends to 256 tokens before it (plus itself) )

Cross-Attention

from flash_attn import flash_attn_func # Encoder-decoder cross-attention # q from decoder: [batch, decoder_len, heads, dim] # k, v from encoder: [batch, encoder_len, heads, dim] q_dec = torch.randn(2, 512, 8, 64, device="cuda", dtype=torch.float16) k_enc = torch.randn(2, 1024, 8, 64, device="cuda", dtype=torch.float16) v_enc = torch.randn(2, 1024, 8, 64, device="cuda", dtype=torch.float16) out = flash_attn_func(q_dec, k_enc, v_enc, causal=False)

Variable-Length Sequences (Packed Attention)

from flash_attn import flash_attn_varlen_func # Pack multiple variable-length sequences into one tensor # Eliminates padding overhead q_packed = torch.randn(total_tokens, num_heads, head_dim, device="cuda", dtype=torch.float16) k_packed = torch.randn(total_tokens, num_heads, head_dim, device="cuda", dtype=torch.float16) v_packed = torch.randn(total_tokens, num_heads, head_dim, device="cuda", dtype=torch.float16) # Cumulative sequence lengths: [0, len_seq1, len_seq1+len_seq2, ...] cu_seqlens_q = torch.tensor([0, 512, 1024, 1280], device="cuda", dtype=torch.int32) cu_seqlens_k = torch.tensor([0, 512, 1024, 1280], device="cuda", dtype=torch.int32) out = flash_attn_varlen_func( q_packed, k_packed, v_packed, cu_seqlens_q, cu_seqlens_k, max_seqlen_q=512, max_seqlen_k=512, causal=True, )

Benchmarking

import torch import torch.nn.functional as F import time def benchmark_attention(seq_len, num_heads=32, head_dim=64, batch=4, iterations=100): q = torch.randn(batch, num_heads, seq_len, head_dim, device="cuda", dtype=torch.float16) k = torch.randn(batch, num_heads, seq_len, head_dim, device="cuda", dtype=torch.float16) v = torch.randn(batch, num_heads, seq_len, head_dim, device="cuda", dtype=torch.float16) # Warmup for _ in range(10): _ = F.scaled_dot_product_attention(q, k, v, is_causal=True) torch.cuda.synchronize() # Benchmark start = time.time() for _ in range(iterations): _ = F.scaled_dot_product_attention(q, k, v, is_causal=True) torch.cuda.synchronize() elapsed = (time.time() - start) / iterations * 1000 mem = torch.cuda.max_memory_allocated() / 1e9 print(f"SeqLen={seq_len:>5}: {elapsed:.2f} ms, Peak mem: {mem:.2f} GB") torch.cuda.reset_peak_memory_stats() for seq_len in [512, 1024, 2048, 4096, 8192]: benchmark_attention(seq_len)

Configuration Reference

flash_attn_func Parameters

ParameterTypeDefaultDescription
qTensorRequiredQuery tensor [batch, seqlen, heads, dim]
kTensorRequiredKey tensor [batch, seqlen, heads, dim]
vTensorRequiredValue tensor [batch, seqlen, heads, dim]
dropout_pfloat0.0Dropout probability (0.0 for inference)
softmax_scalefloatNoneScale factor (default: 1/sqrt(head_dim))
causalboolFalseApply causal mask
window_sizetuple(-1, -1)Sliding window (left, right); -1 = unlimited
alibi_slopesTensorNoneALiBi position bias slopes
deterministicboolFalseDeterministic backward pass
return_attn_probsboolFalseReturn attention probabilities

PyTorch SDPA Parameters

ParameterTypeDefaultDescription
queryTensorRequired[batch, heads, seq, dim]
keyTensorRequired[batch, heads, seq, dim]
valueTensorRequired[batch, heads, seq, dim]
attn_maskTensorNoneAdditive attention mask
dropout_pfloat0.0Dropout probability
is_causalboolFalseApply causal mask
scalefloatNoneScale factor
enable_gqaboolFalseEnable grouped-query attention

Best Practices

  1. Use PyTorch SDPA as the default -- F.scaled_dot_product_attention automatically selects the best backend (Flash, memory-efficient, or math) based on hardware and input characteristics.
  2. Install flash-attn only when you need advanced features -- Sliding window, variable-length packing, and FP8 require the standalone library. For basic attention, PyTorch SDPA is sufficient.
  3. Use FP16 or BF16 inputs -- Flash Attention requires half-precision inputs. FP32 inputs silently fall back to the slower math backend in SDPA.
  4. Set is_causal=True for autoregressive models -- This is more efficient than passing an explicit triangular mask, as the kernel generates the causal mask internally.
  5. Enable Flash Attention in HuggingFace models -- Set attn_implementation="flash_attention_2" when loading models for automatic integration with the flash-attn library.
  6. Use variable-length packing for batch training -- flash_attn_varlen_func eliminates padding waste when sequences have different lengths, improving training throughput by 10-30%.
  7. Benchmark at your actual sequence length -- Flash Attention benefits scale with sequence length. If your sequences are consistently short (<256 tokens), the overhead may not justify the integration.
  8. Monitor memory with torch.cuda.max_memory_allocated -- Compare peak memory before and after enabling Flash Attention to verify the expected 5-20x reduction in attention memory.
  9. Do not mix is_causal=True with attn_mask -- These are mutually exclusive in PyTorch SDPA. Using both raises an error or produces incorrect results.
  10. Keep head dimension at 64 or 128 -- Flash Attention kernels are optimized for these head dimensions. Non-standard dimensions may fall back to slower implementations.

Troubleshooting

Flash Attention not being used (falling back to math backend): Check inputs are FP16/BF16 (not FP32). Verify GPU is Ampere or newer. Ensure PyTorch >= 2.2. Use torch.backends.cuda.flash_sdp_enabled() to check availability.

ImportError when installing flash-attn: Use pip install flash-attn --no-build-isolation. Ensure CUDA toolkit matches your PyTorch CUDA version. The library requires compilation from source.

No speedup observed for short sequences: Flash Attention overhead exceeds savings for sequences below ~256 tokens. The benefit increases superlinearly with sequence length. Benchmark at your production sequence length.

Different results compared to standard attention: Flash Attention uses online softmax with FP16 accumulation, which introduces small numerical differences (~1e-3). This is expected and does not affect model quality.

Sliding window attention not working: The window_size parameter requires the flash-attn library, not PyTorch SDPA. Install with pip install flash-attn --no-build-isolation and use flash_attn_func directly.

Community

Reviews

Write a review

No reviews yet. Be the first to review this template!

Similar Templates