M

Mechanistic Interpretability Complete

Boost productivity using this provides, guidance, performing, causal. Includes structured workflows, validation checks, and reusable patterns for ai research.

SkillClipticsai researchv1.0.0MIT
0 views0 copies

Mechanistic Interpretability Complete

Overview

Mechanistic interpretability is the discipline of reverse-engineering neural networks to understand the computational algorithms they implement. Rather than treating models as black boxes, mechanistic interpretability researchers decompose model behavior into understandable circuits, features, and mechanisms. This complete guide covers the NNsight library for transparent access to neural network internals, providing a comprehensive workflow from basic activation extraction through advanced causal interventions, systematic patching experiments, and remote execution on massive models. NNsight wraps any PyTorch model and provides a deferred execution tracing system where operations are recorded as a computation graph, enabling the same experimental code to run locally on small models or remotely on 70B-405B parameter models via NDIF. This template provides the full interpretability researcher's toolkit: activation analysis, attention pattern visualization, activation patching at arbitrary granularity, cross-prompt transfer experiments, and gradient-based attribution.

When to Use

  • Full interpretability research pipeline: When you need a comprehensive toolkit covering extraction, patching, intervention, and analysis in a single framework.
  • Circuit discovery: Systematically identify which model components (layers, attention heads, MLPs) are causally responsible for specific behaviors.
  • Feature localization: Find where specific knowledge or capabilities are stored in the model's parameters and activations.
  • Intervention experiments at scale: Run patching sweeps, ablation studies, and causal scrubbing across many layers, heads, and positions.
  • Cross-architecture research: Apply the same experimental methodology across GPT-2, LLaMA, Mistral, Mamba, and custom architectures.
  • Scaling interpretability to large models: Use NDIF to run the same experiments on models too large for local hardware.

Quick Start

pip install nnsight torch transformers # For remote execution: register at login.ndif.us
from nnsight import LanguageModel import torch model = LanguageModel("openai-community/gpt2", device_map="auto") # Extract activations and analyze predictions with model.trace("The capital of France is") as tracer: hidden_states = {} for i in range(12): hidden_states[i] = model.transformer.h[i].output[0].save() logits = model.output.save() # Find where "Paris" prediction emerges paris_id = model.tokenizer.encode(" Paris")[0] for i in range(12): # Project layer output through unembedding layer_logits = hidden_states[i] @ model.transformer.wte.weight.T paris_prob = torch.softmax(layer_logits[0, -1], dim=-1)[paris_id] print(f"Layer {i}: P(Paris) = {paris_prob.item():.4f}")

Core Concepts

The Interpretability Research Loop

A typical mechanistic interpretability investigation follows this loop:

1. Observe behavior → What does the model do?
2. Hypothesize mechanism → Which components might cause it?
3. Extract activations → What are the internal representations?
4. Intervene (patch/ablate) → Does changing component X change behavior?
5. Analyze results → Confirm or refute hypothesis
6. Iterate → Refine understanding

Comprehensive Activation Extraction

from nnsight import LanguageModel import torch model = LanguageModel("gpt2", device_map="auto") def extract_all_activations(model, prompt, layers=None): """Extract residual stream, attention, and MLP outputs.""" n_layers = len(model._model.transformer.h) layers = layers or range(n_layers) with model.trace(prompt) as tracer: results = {"residual": {}, "attn": {}, "mlp": {}} for i in layers: block = model.transformer.h[i] results["residual"][i] = block.output[0].save() results["attn"][i] = block.attn.attn_dropout.input[0][0].save() results["logits"] = model.output.save() results["embeddings"] = model.transformer.wte.output.save() return results activations = extract_all_activations(model, "The Eiffel Tower is in")

Head-Level Activation Patching

def patch_attention_head(model, clean_prompt, corrupt_prompt, layer, head, n_heads=12): """Patch a single attention head from clean to corrupt run.""" d_head = model._model.config.n_embd // n_heads with model.trace(clean_prompt) as tracer: clean_attn_out = model.transformer.h[layer].attn.c_proj.input[0][0].save() with model.trace(corrupt_prompt) as tracer: corrupt_attn_out = model.transformer.h[layer].attn.c_proj.input[0][0] # Patch only the target head start = head * d_head end = (head + 1) * d_head corrupt_attn_out[:, :, start:end] = clean_attn_out[:, :, start:end] patched_logits = model.output.save() return patched_logits # Sweep over all heads n_layers, n_heads = 12, 12 head_results = torch.zeros(n_layers, n_heads) for layer in range(n_layers): for head in range(n_heads): logits = patch_attention_head( model, "The Eiffel Tower is in", "The Colosseum is in", layer, head ) paris_id = model.tokenizer.encode(" Paris")[0] probs = torch.softmax(logits[0, -1], dim=-1) head_results[layer, head] = probs[paris_id].item()

Logit Lens Analysis

def logit_lens(model, prompt, top_k=5): """Apply the logit lens: project each layer's residual stream through the unembedding matrix to see predictions at each layer.""" with model.trace(prompt) as tracer: layer_outputs = [] for i in range(12): layer_outputs.append( model.transformer.h[i].output[0].save() ) ln_f_weight = model.transformer.ln_f.weight.save() ln_f_bias = model.transformer.ln_f.bias.save() unembed = model._model.transformer.wte.weight # [vocab, d_model] print(f"Prompt: {prompt}") print(f"{'Layer':<8} {'Top Predictions':<60}") print("-" * 68) for i, hidden in enumerate(layer_outputs): # Apply final layer norm then project normalized = torch.nn.functional.layer_norm( hidden[0, -1], [hidden.shape[-1]], weight=ln_f_weight, bias=ln_f_bias ) logits = normalized @ unembed.T probs = torch.softmax(logits, dim=-1) top = probs.topk(top_k) tokens = [model.tokenizer.decode(t) for t in top.indices] probs_str = [f"{t}({p:.2f})" for t, p in zip(tokens, top.values)] print(f"Layer {i:<3} {', '.join(probs_str)}") logit_lens(model, "The capital of France is")

Causal Scrubbing Pattern

def ablation_study(model, prompt, target_token_id): """Zero-ablation study: zero out each layer and measure impact.""" # Baseline (no ablation) with model.trace(prompt) as tracer: baseline_logits = model.output.save() baseline_prob = torch.softmax( baseline_logits[0, -1], dim=-1 )[target_token_id].item() results = {} for layer in range(12): with model.trace(prompt) as tracer: # Zero out this layer's contribution model.transformer.h[layer].output[0][:] = 0 ablated_logits = model.output.save() ablated_prob = torch.softmax( ablated_logits[0, -1], dim=-1 )[target_token_id].item() results[layer] = { "original": baseline_prob, "ablated": ablated_prob, "impact": baseline_prob - ablated_prob } return results

Configuration Reference

SettingDescriptionRecommendation
device_mapPyTorch device placement"auto" for GPU, "cpu" for small models
remoteUse NDIF for large modelsTrue for 70B+ models
torch_dtypeModel precisiontorch.float16 for speed, float32 for gradients
NDIF_API_KEYRemote execution keySet via environment variable

Experiment Configuration

ParameterTypical RangeNotes
Layers to sweepAll (12 for GPT-2, 32 for LLaMA-8B)Start with every 4th layer
Positions to patchAll or last NLast position most important
Heads to analyzeAllFocus on heads with high impact
Number of examples20-100More for statistical significance

Best Practices

  1. Start with the logit lens: Before running expensive patching experiments, use the logit lens to see at which layer the model's prediction changes. This narrows your search space dramatically.

  2. Use mean ablation instead of zero ablation: Replacing activations with zeros introduces distributional shift. Replace with the mean activation across a batch of diverse inputs for more reliable causal conclusions.

  3. Control for position effects: When patching, be aware that different prompts may tokenize differently. Align positions carefully or use position-specific patching.

  4. Run multiple examples per hypothesis: A single prompt is anecdotal. Run your experiment across 50-100 examples and report aggregate statistics.

  5. Build from small to large: Debug your full experimental pipeline on GPT-2, then scale to LLaMA-8B, then to LLaMA-70B via NDIF. This catches bugs cheaply.

  6. Save intermediate results: Large sweeps take hours. Save results to disk after each layer or batch so you can resume if interrupted.

  7. Visualize with heatmaps: Layer-by-position or layer-by-head patching results are best understood as heatmaps. Use matplotlib or seaborn for publication-quality figures.

  8. Document model-specific module paths: Keep a reference table of module paths for each architecture you work with. The path from GPT-2 to LLaMA to Mistral changes at every level.

Troubleshooting

Different results between local and remote execution Model precision may differ. NDIF may use different quantization. Pin torch_dtype explicitly and verify model versions match.

Activation patching shows no effect The patched component may not be causally relevant to the behavior. Try patching at different granularities (full layer, attention only, MLP only, individual heads). Verify your target metric is correct.

Memory overflow during large sweeps Save results incrementally and clear GPU cache between iterations with torch.cuda.empty_cache(). Process one layer at a time rather than loading all activations simultaneously.

Tokenization mismatch between prompts When patching between clean and corrupted prompts, verify they have the same number of tokens. Use model.tokenizer(prompt, return_tensors="pt") to check token counts before running experiments.

Logit lens shows garbage predictions at early layers This is expected. Early layers have not yet computed enough to produce meaningful predictions. The logit lens is most informative at middle-to-late layers.

Community

Reviews

Write a review

No reviews yet. Be the first to review this template!

Similar Templates