Mechanistic Interpretability Dynamic
Production-ready skill that handles provides, guidance, interpreting, manipulating. Includes structured workflows, validation checks, and reusable patterns for ai research.
Mechanistic Interpretability Dynamic
Overview
NNsight is a Python library that enables researchers to interpret and manipulate the internals of any PyTorch model through a deferred execution paradigm. Its core value proposition is "write once, run anywhere": the same interpretability code works on GPT-2 running locally or Llama-3.1-405B running remotely via NDIF (National Deep Inference Fabric). NNsight uses a tracing context manager that records operations as a computation graph rather than executing them immediately, allowing efficient batching, remote execution, and complex multi-prompt interventions. The library supports activation extraction, activation patching, causal interventions, gradient-based analysis, and cross-prompt activation sharing for any PyTorch architecture including transformers, state space models, and vision models. This template covers dynamic interpretability workflows: experiments where you interactively probe, modify, and analyze neural network behavior to understand how models process information.
When to Use
- Activation extraction and analysis: Extract hidden states, attention patterns, and intermediate representations from any layer of any PyTorch model.
- Activation patching experiments: Swap activations between clean and corrupted runs to identify which components are causally responsible for model behavior.
- Remote execution on large models: Run interpretability experiments on 70B+ parameter models without local GPU access via NDIF.
- Cross-prompt interventions: Share and transplant activations between different input prompts in a single trace.
- Gradient-based feature analysis: Compute gradients of outputs with respect to intermediate activations to identify influential components.
- Architecture-agnostic research: Work with transformers, Mamba, ViT, or any custom PyTorch model using a unified API.
Quick Start
Installation
# Basic installation pip install nnsight # For vLLM support pip install "nnsight[vllm]" # For remote execution, get API key at login.ndif.us export NDIF_API_KEY="your_key"
First Trace
from nnsight import LanguageModel model = LanguageModel("openai-community/gpt2", device_map="auto") with model.trace("The capital of France is") as tracer: # Extract hidden states from layer 5 hidden = model.transformer.h[5].output[0].save() # Get final logits logits = model.output.save() # Access saved values outside the trace print(f"Hidden shape: {hidden.shape}") print(f"Logits shape: {logits.shape}")
Core Concepts
Tracing and Proxy Objects
Inside a trace context, module accesses return Proxy objects that record operations without executing them:
from nnsight import LanguageModel import torch model = LanguageModel("gpt2", device_map="auto") with model.trace("The Eiffel Tower is in") as tracer: # All operations are deferred (Proxy objects) h5_out = model.transformer.h[5].output[0] # Proxy h5_mean = h5_out.mean(dim=-1) # Proxy h5_saved = h5_mean.save() # Mark for retrieval # Modify activations in-place model.transformer.h[8].output[0][:] = 0 # Zero out layer 8 # Get final output logits = model.output.save() # After trace exits, access concrete values print(h5_saved) # Actual tensor print(logits.shape) # Actual shape
Multi-Layer Activation Analysis
from nnsight import LanguageModel import torch model = LanguageModel("gpt2", device_map="auto") prompt = "The capital of France is" with model.trace(prompt) as tracer: # Collect activations from all 12 layers layer_outputs = [] for i in range(12): layer_out = model.transformer.h[i].output[0].save() layer_outputs.append(layer_out) # Collect attention patterns attn_patterns = [] for i in range(12): attn = model.transformer.h[i].attn.attn_dropout.input[0][0].save() attn_patterns.append(attn) logits = model.output.save() # Analyze layer norms and top predictions for i, layer_out in enumerate(layer_outputs): print(f"Layer {i}: norm={layer_out.norm().item():.3f}") probs = torch.softmax(logits[0, -1], dim=-1) top_tokens = probs.topk(5) for token, prob in zip(top_tokens.indices, top_tokens.values): print(f" {model.tokenizer.decode(token)}: {prob.item():.3f}")
Activation Patching
from nnsight import LanguageModel import torch model = LanguageModel("gpt2", device_map="auto") clean_prompt = "The Eiffel Tower is in" corrupted_prompt = "The Colosseum is in" # Step 1: Get clean activations with model.trace(clean_prompt) as tracer: clean_hidden = model.transformer.h[8].output[0].save() # Step 2: Patch clean activations into corrupted run with model.trace(corrupted_prompt) as tracer: model.transformer.h[8].output[0][:] = clean_hidden patched_logits = model.output.save() # Step 3: Compare predictions paris_token = model.tokenizer.encode(" Paris")[0] rome_token = model.tokenizer.encode(" Rome")[0] patched_probs = torch.softmax(patched_logits[0, -1], dim=-1) print(f"Paris: {patched_probs[paris_token].item():.3f}") print(f"Rome: {patched_probs[rome_token].item():.3f}")
Systematic Patching Sweep
def patch_sweep(model, clean_prompt, corrupted_prompt, target_token, n_layers=12): """Sweep activation patching across all layers and positions.""" # Get clean activations with model.trace(clean_prompt) as tracer: clean_cache = {} for i in range(n_layers): clean_cache[i] = model.transformer.h[i].output[0].save() clean_logits = model.output.save() seq_len = clean_cache[0].shape[1] results = torch.zeros(n_layers, seq_len) for layer in range(n_layers): for pos in range(seq_len): with model.trace(corrupted_prompt) as tracer: current = model.transformer.h[layer].output[0] current[:, pos, :] = clean_cache[layer][:, pos, :] logits = model.output.save() probs = torch.softmax(logits[0, -1], dim=-1) results[layer, pos] = probs[target_token].item() return results
Remote Execution with NDIF
from nnsight import LanguageModel # Load a 70B model (runs on NDIF servers) model = LanguageModel("meta-llama/Llama-3.1-70B") # Same code, just add remote=True with model.trace("The meaning of life is", remote=True) as tracer: layer_40 = model.model.layers[40].output[0].save() logits = model.output.save() print(f"Layer 40 shape: {layer_40.shape}")
Cross-Prompt Activation Sharing
from nnsight import LanguageModel model = LanguageModel("gpt2", device_map="auto") with model.trace() as tracer: # First prompt with tracer.invoke("The cat sat on the"): cat_hidden = model.transformer.h[6].output[0].save() # Second prompt with injected activations from first with tracer.invoke("The dog ran through the"): model.transformer.h[6].output[0][:] = cat_hidden dog_with_cat = model.output.save()
Gradient-Based Analysis
from nnsight import LanguageModel import torch model = LanguageModel("gpt2", device_map="auto") with model.trace("The quick brown fox") as tracer: hidden = model.transformer.h[5].output[0].save() hidden.retain_grad() logits = model.output target_token = model.tokenizer.encode(" jumps")[0] loss = -logits[0, -1, target_token] loss.backward() grad = hidden.grad print(f"Gradient norm: {grad.norm().item():.3f}")
Configuration Reference
| Parameter | Description | Default |
|---|---|---|
device_map | Device placement strategy | "auto" |
remote | Execute on NDIF servers | False |
timeout | Remote execution timeout (seconds) | 120 |
NDIF_API_KEY | API key for remote execution | Environment variable |
Model Architecture Paths
| Model | Layer Access | Attention Access |
|---|---|---|
| GPT-2 | model.transformer.h[i].output[0] | model.transformer.h[i].attn |
| LLaMA | model.model.layers[i].output[0] | model.model.layers[i].self_attn |
| Mistral | model.model.layers[i].output[0] | model.model.layers[i].self_attn |
| GPT-NeoX | model.gpt_neox.layers[i].output[0] | model.gpt_neox.layers[i].attention |
Key API Methods
| Method | Purpose |
|---|---|
model.trace(prompt, remote=False) | Start tracing context |
proxy.save() | Save value for access after trace |
proxy[:] | Slice/index proxy (assignment patches) |
tracer.invoke(prompt) | Add prompt within multi-prompt trace |
model.generate(...) | Generate tokens with interventions |
model.output | Access final model output logits |
model._model | Access underlying HuggingFace model |
Best Practices
-
Always call
.save()on values you need: Values inside a trace context are Proxy objects. Without.save(), they are not accessible after the context exits. This is the most common beginner mistake. -
Check model structure before writing traces: Use
print(model._model)to see the actual module hierarchy. Layer access paths differ between model architectures (GPT-2 vs LLaMA vs Mistral). -
Save only what you need: Saving every layer's activations for a large model consumes significant memory. Be selective about which layers and positions you save.
-
Use remote execution for large models: Do not attempt to load 70B+ models locally. Use
remote=Truewith NDIF to run the same code on server-grade hardware. -
Start with small models for debugging: Develop and debug your experimental code on GPT-2 locally, then switch to larger models once the logic is verified.
-
Use cross-prompt traces for causal experiments: Instead of running separate traces for clean and corrupted prompts, use
tracer.invoke()to share activations within a single trace context. -
Add iteration limits for sweeps: Patching sweeps over all layers and positions can be computationally expensive. Start with a subset and expand once you know where to focus.
-
Retain gradients explicitly: Gradient access requires calling
.retain_grad()on the saved proxy. Gradients are not available for vLLM or remote execution.
Troubleshooting
Proxy value is not a tensor outside trace
You forgot to call .save() on the value inside the trace context. Only saved values are materialized as actual tensors after the context exits.
Module path does not exist
Different model architectures have different module hierarchies. Use print(model._model) to inspect the actual structure. GPT-2 uses model.transformer.h[i] while LLaMA uses model.model.layers[i].
Remote execution timeout
Increase the timeout parameter: model.trace(prompt, remote=True, timeout=300). NDIF servers may have queue delays during peak usage.
Memory error with many saved activations
Reduce the number of layers saved or save only specific token positions instead of full sequences. Use hidden[:, -1, :].save() to save only the last position.
Gradient is None after backward
Ensure you called hidden.retain_grad() inside the trace before the backward pass. Gradient access is not supported with vLLM or remote execution.
Reviews
No reviews yet. Be the first to review this template!
Similar Templates
Full-Stack Code Reviewer
Comprehensive code review skill that checks for security vulnerabilities, performance issues, accessibility, and best practices across frontend and backend code.
Test Suite Generator
Generates comprehensive test suites with unit tests, integration tests, and edge cases. Supports Jest, Vitest, Pytest, and Go testing.
Pro Architecture Workspace
Battle-tested skill for architectural, decision, making, framework. Includes structured workflows, validation checks, and reusable patterns for development.