Mechanistic Interpretability Saelens Kit
All-in-one skill covering provides, guidance, training, analyzing. Includes structured workflows, validation checks, and reusable patterns for ai research.
Mechanistic Interpretability SAELens Kit
Overview
SAELens is the primary Python library for training and analyzing Sparse Autoencoders (SAEs) for mechanistic interpretability research. SAEs address the fundamental challenge of polysemanticity in neural networks: individual neurons activate in multiple, semantically distinct contexts because models use superposition to represent more features than they have dimensions. SAELens decomposes dense model activations into sparse, interpretable features where each feature corresponds to a human-understandable concept. Built on Anthropic's groundbreaking "Towards Monosemanticity" research, SAELens provides tools for loading pre-trained SAEs, training custom SAEs on any TransformerLens-compatible model, analyzing individual features, performing feature-based model steering, and computing feature attribution for specific predictions. The library integrates with Neuronpedia for visual feature exploration and supports multiple SAE architectures including standard ReLU, gated, and TopK variants. With over 1,100 GitHub stars, SAELens is the community standard for SAE-based interpretability research.
When to Use
- Discovering interpretable features: Decompose model activations into sparse, monosemantic features that correspond to human-understandable concepts.
- Understanding superposition: Study how models represent more features than they have dimensions, and how SAEs recover those features.
- Feature-based model steering: Use SAE decoder directions to steer model behavior toward or away from specific concepts.
- Safety-relevant feature analysis: Identify features related to deception, harmful content, bias, or other safety-relevant behaviors.
- Training custom SAEs: Train SAEs on specific layers of specific models with custom hyperparameters for your research needs.
- Feature attribution: Determine which features are most responsible for specific model predictions.
Quick Start
Installation
pip install sae-lens # Requires: Python 3.10+, transformer-lens >= 2.0.0
Loading and Using a Pre-trained SAE
from transformer_lens import HookedTransformer from sae_lens import SAE # Load model and pre-trained SAE model = HookedTransformer.from_pretrained("gpt2-small", device="cuda") sae, cfg_dict, sparsity = SAE.from_pretrained( release="gpt2-small-res-jb", sae_id="blocks.8.hook_resid_pre", device="cuda" ) # Encode model activations to SAE features tokens = model.to_tokens("The capital of France is Paris") _, cache = model.run_with_cache(tokens) activations = cache["resid_pre", 8] # [batch, pos, d_model] sae_features = sae.encode(activations) # [batch, pos, d_sae] print(f"Active features: {(sae_features > 0).sum()}") # Reconstruct and measure quality reconstructed = sae.decode(sae_features) error = (activations - reconstructed).norm() print(f"Reconstruction error: {error.item():.3f}")
Core Concepts
How SAEs Work
SAEs are trained to reconstruct model activations through a sparse bottleneck:
Input Activation --> Encoder --> Sparse Features --> Decoder --> Reconstructed
(d_model) (d_sae >> d_model) (d_model)
Loss Function: MSE(original, reconstructed) + L1_coefficient * L1(features)
The L1 penalty enforces sparsity: only a small number of features activate for any given input, making each feature more interpretable.
Analyzing Features Per Token
from transformer_lens import HookedTransformer from sae_lens import SAE model = HookedTransformer.from_pretrained("gpt2-small", device="cuda") sae, _, _ = SAE.from_pretrained( release="gpt2-small-res-jb", sae_id="blocks.8.hook_resid_pre", device="cuda" ) text = "The capital of France is Paris" tokens = model.to_tokens(text) _, cache = model.run_with_cache(tokens) features = sae.encode(cache["resid_pre", 8]) # Find top features for each token for pos in range(tokens.shape[1]): top_features = features[0, pos].topk(5) token = model.to_str_tokens(tokens[0, pos:pos+1])[0] feature_ids = top_features.indices.tolist() activations = [f"{v:.2f}" for v in top_features.values.tolist()] print(f"Token '{token}': features {feature_ids} ({activations})")
Training a Custom SAE
from sae_lens import SAE, LanguageModelSAERunnerConfig, SAETrainingRunner cfg = LanguageModelSAERunnerConfig( # Model configuration model_name="gpt2-small", hook_name="blocks.8.hook_resid_pre", hook_layer=8, d_in=768, # SAE architecture architecture="standard", # "standard", "gated", or "topk" d_sae=768 * 8, # Expansion factor of 8 activation_fn="relu", # Training hyperparameters lr=4e-4, l1_coefficient=8e-5, # Sparsity penalty l1_warm_up_steps=1000, # Prevent early feature death train_batch_size_tokens=4096, training_tokens=100_000_000, # Data dataset_path="monology/pile-uncopyrighted", context_size=128, # Logging log_to_wandb=True, wandb_project="sae-training", # Checkpointing checkpoint_path="checkpoints", n_checkpoints=5, ) trainer = SAETrainingRunner(cfg) sae = trainer.run() # Evaluate print(f"L0 (avg active features): {trainer.metrics['l0']}") print(f"CE Loss Recovered: {trainer.metrics['ce_loss_score']}")
Feature Steering
def steer_with_feature(model, sae, prompt, feature_idx, strength=5.0): """Add SAE feature direction to residual stream for steering.""" tokens = model.to_tokens(prompt) # Get feature direction from the SAE decoder feature_direction = sae.W_dec[feature_idx] # [d_model] def steering_hook(activation, hook): activation += strength * feature_direction return activation output = model.generate( tokens, max_new_tokens=50, fwd_hooks=[("blocks.8.hook_resid_pre", steering_hook)] ) return model.to_string(output[0]) # Example: steer with a specific feature result = steer_with_feature(model, sae, "I think the best", feature_idx=1234) print(result)
Feature Attribution for Predictions
import torch text = "The capital of France is" tokens = model.to_tokens(text) _, cache = model.run_with_cache(tokens) # Get features at final token position features = sae.encode(cache["resid_pre", 8])[0, -1] # [d_sae] # Compute each feature's contribution to the "Paris" prediction W_dec = sae.W_dec # [d_sae, d_model] W_U = model.W_U # [d_model, vocab] paris_token = model.to_single_token(" Paris") feature_contributions = features * (W_dec @ W_U[:, paris_token]) top_features = feature_contributions.topk(10) print("Top features contributing to 'Paris' prediction:") for idx, val in zip(top_features.indices, top_features.values): print(f" Feature {idx.item()}: contribution={val.item():.3f}, " f"activation={features[idx].item():.3f}")
Using Different SAE Architectures
# Standard SAE (ReLU + L1) cfg_standard = LanguageModelSAERunnerConfig( architecture="standard", activation_fn="relu", l1_coefficient=8e-5, # ... ) # Gated SAE (learned gating mechanism) cfg_gated = LanguageModelSAERunnerConfig( architecture="gated", # L1 applied to gate, not activations l1_coefficient=5e-5, # ... ) # TopK SAE (exactly K features active) cfg_topk = LanguageModelSAERunnerConfig( architecture="topk", activation_fn="topk", activation_fn_kwargs={"k": 50}, # Exactly 50 features per input # No L1 needed - sparsity is structural # ... )
Configuration Reference
Training Hyperparameters
| Parameter | Typical Value | Effect |
|---|---|---|
d_sae | 4-16x d_model | More features, higher capacity |
l1_coefficient | 5e-5 to 1e-4 | Higher = sparser, less reconstruction |
lr | 1e-4 to 1e-3 | Standard optimizer learning rate |
l1_warm_up_steps | 500-2000 | Prevents early feature death |
training_tokens | 50M-500M | More = better features |
train_batch_size_tokens | 2048-8192 | Larger batches = more stable |
Evaluation Metrics
| Metric | Target | Meaning |
|---|---|---|
| L0 | 50-200 | Average active features per token |
| CE Loss Score | 80-95% | Cross-entropy loss recovered vs original model |
| Dead Features | < 5% | Features that never activate |
| Explained Variance | > 90% | Reconstruction quality |
Available Pre-trained SAEs
| Release | Model | Layers |
|---|---|---|
gpt2-small-res-jb | GPT-2 Small | Multiple residual streams |
gemma-2b-res | Gemma 2B | Residual streams |
| Various on HuggingFace | Search tag saelens | Various |
SAE Architecture Comparison
| Architecture | Description | Best For |
|---|---|---|
| Standard (ReLU + L1) | Classic approach | General purpose |
| Gated | Learned gating mechanism | Better sparsity control |
| TopK | Fixed K active features | Consistent, predictable sparsity |
Best Practices
-
Start with pre-trained SAEs: Before training your own, load and analyze existing SAEs to understand what good features look like. Use
SAE.from_pretrained()with thegpt2-small-res-jbrelease as a starting point. -
Always use L1 warm-up: Set
l1_warm_up_stepsto 500-2000. Without warm-up, the L1 penalty kills features before they have a chance to specialize, leading to high dead feature ratios. -
Enable ghost gradients for dead features: Set
use_ghost_grads=Truein the training config to revive features that have stopped activating during training. -
Monitor L0 and CE loss score together: High L0 with high CE recovery means features are informative but not sparse enough. Low L0 with low CE recovery means you are too sparse. Aim for L0 of 50-200 with CE recovery above 85%.
-
Use TopK architecture for consistent sparsity: If you want exact control over how many features activate per input, use the TopK architecture with
activation_fn_kwargs={"k": 50}. This eliminates the L1 coefficient tuning problem. -
Validate features with Neuronpedia: After training, upload your SAE to Neuronpedia to browse features visually and verify interpretability. Features should activate consistently for semantically coherent inputs.
-
Check reconstruction quality before using features: Always verify that
sae.decode(sae.encode(activations))closely matches the original activations. Poor reconstruction means the features are unreliable. -
Use feature attribution, not just activation: A feature that activates strongly is not necessarily important for the prediction. Compute attribution scores (activation times logit contribution) to find causally relevant features.
-
Train on diverse data: Use broad training corpora like The Pile. Training on narrow domains produces features that only cover that domain, missing general model capabilities.
-
Save checkpoints and log to W&B: SAE training runs can take hours to days. Save multiple checkpoints and log metrics to Weights & Biases so you can monitor progress and recover from failures.
Troubleshooting
High dead feature ratio (>10%)
Increase l1_warm_up_steps to 2000+. Enable use_ghost_grads=True. Reduce l1_coefficient. Dead features mean the sparsity penalty is too aggressive too early.
Poor reconstruction (CE loss score < 80%)
Reduce l1_coefficient to allow less sparsity. Increase d_sae for more capacity (try 16x expansion). Train for more tokens. Check that the hook point matches your intended layer.
Features not interpretable
Increase l1_coefficient for more sparsity. Try TopK architecture with lower K (e.g., 30). Features need to be sparse to be interpretable; over-dense features blend multiple concepts.
Memory errors during training
Reduce train_batch_size_tokens to 2048. Lower n_batches_in_buffer. Reduce store_batch_size_prompts. Use gradient checkpointing if available.
SAE features do not improve with more training Check learning rate (may be too low). Verify the hook point captures meaningful activations (earlier layers have simpler features). Ensure the training data is sufficiently diverse.
Reviews
No reviews yet. Be the first to review this template!
Similar Templates
Full-Stack Code Reviewer
Comprehensive code review skill that checks for security vulnerabilities, performance issues, accessibility, and best practices across frontend and backend code.
Test Suite Generator
Generates comprehensive test suites with unit tests, integration tests, and edge cases. Supports Jest, Vitest, Pytest, and Go testing.
Pro Architecture Workspace
Battle-tested skill for architectural, decision, making, framework. Includes structured workflows, validation checks, and reusable patterns for development.