M

Mechanistic Interpretability Dynamic

Production-ready skill that handles provides, guidance, interpreting, manipulating. Includes structured workflows, validation checks, and reusable patterns for ai research.

SkillClipticsai researchv1.0.0MIT
0 views0 copies

Mechanistic Interpretability Dynamic

Overview

NNsight is a Python library that enables researchers to interpret and manipulate the internals of any PyTorch model through a deferred execution paradigm. Its core value proposition is "write once, run anywhere": the same interpretability code works on GPT-2 running locally or Llama-3.1-405B running remotely via NDIF (National Deep Inference Fabric). NNsight uses a tracing context manager that records operations as a computation graph rather than executing them immediately, allowing efficient batching, remote execution, and complex multi-prompt interventions. The library supports activation extraction, activation patching, causal interventions, gradient-based analysis, and cross-prompt activation sharing for any PyTorch architecture including transformers, state space models, and vision models. This template covers dynamic interpretability workflows: experiments where you interactively probe, modify, and analyze neural network behavior to understand how models process information.

When to Use

  • Activation extraction and analysis: Extract hidden states, attention patterns, and intermediate representations from any layer of any PyTorch model.
  • Activation patching experiments: Swap activations between clean and corrupted runs to identify which components are causally responsible for model behavior.
  • Remote execution on large models: Run interpretability experiments on 70B+ parameter models without local GPU access via NDIF.
  • Cross-prompt interventions: Share and transplant activations between different input prompts in a single trace.
  • Gradient-based feature analysis: Compute gradients of outputs with respect to intermediate activations to identify influential components.
  • Architecture-agnostic research: Work with transformers, Mamba, ViT, or any custom PyTorch model using a unified API.

Quick Start

Installation

# Basic installation pip install nnsight # For vLLM support pip install "nnsight[vllm]" # For remote execution, get API key at login.ndif.us export NDIF_API_KEY="your_key"

First Trace

from nnsight import LanguageModel model = LanguageModel("openai-community/gpt2", device_map="auto") with model.trace("The capital of France is") as tracer: # Extract hidden states from layer 5 hidden = model.transformer.h[5].output[0].save() # Get final logits logits = model.output.save() # Access saved values outside the trace print(f"Hidden shape: {hidden.shape}") print(f"Logits shape: {logits.shape}")

Core Concepts

Tracing and Proxy Objects

Inside a trace context, module accesses return Proxy objects that record operations without executing them:

from nnsight import LanguageModel import torch model = LanguageModel("gpt2", device_map="auto") with model.trace("The Eiffel Tower is in") as tracer: # All operations are deferred (Proxy objects) h5_out = model.transformer.h[5].output[0] # Proxy h5_mean = h5_out.mean(dim=-1) # Proxy h5_saved = h5_mean.save() # Mark for retrieval # Modify activations in-place model.transformer.h[8].output[0][:] = 0 # Zero out layer 8 # Get final output logits = model.output.save() # After trace exits, access concrete values print(h5_saved) # Actual tensor print(logits.shape) # Actual shape

Multi-Layer Activation Analysis

from nnsight import LanguageModel import torch model = LanguageModel("gpt2", device_map="auto") prompt = "The capital of France is" with model.trace(prompt) as tracer: # Collect activations from all 12 layers layer_outputs = [] for i in range(12): layer_out = model.transformer.h[i].output[0].save() layer_outputs.append(layer_out) # Collect attention patterns attn_patterns = [] for i in range(12): attn = model.transformer.h[i].attn.attn_dropout.input[0][0].save() attn_patterns.append(attn) logits = model.output.save() # Analyze layer norms and top predictions for i, layer_out in enumerate(layer_outputs): print(f"Layer {i}: norm={layer_out.norm().item():.3f}") probs = torch.softmax(logits[0, -1], dim=-1) top_tokens = probs.topk(5) for token, prob in zip(top_tokens.indices, top_tokens.values): print(f" {model.tokenizer.decode(token)}: {prob.item():.3f}")

Activation Patching

from nnsight import LanguageModel import torch model = LanguageModel("gpt2", device_map="auto") clean_prompt = "The Eiffel Tower is in" corrupted_prompt = "The Colosseum is in" # Step 1: Get clean activations with model.trace(clean_prompt) as tracer: clean_hidden = model.transformer.h[8].output[0].save() # Step 2: Patch clean activations into corrupted run with model.trace(corrupted_prompt) as tracer: model.transformer.h[8].output[0][:] = clean_hidden patched_logits = model.output.save() # Step 3: Compare predictions paris_token = model.tokenizer.encode(" Paris")[0] rome_token = model.tokenizer.encode(" Rome")[0] patched_probs = torch.softmax(patched_logits[0, -1], dim=-1) print(f"Paris: {patched_probs[paris_token].item():.3f}") print(f"Rome: {patched_probs[rome_token].item():.3f}")

Systematic Patching Sweep

def patch_sweep(model, clean_prompt, corrupted_prompt, target_token, n_layers=12): """Sweep activation patching across all layers and positions.""" # Get clean activations with model.trace(clean_prompt) as tracer: clean_cache = {} for i in range(n_layers): clean_cache[i] = model.transformer.h[i].output[0].save() clean_logits = model.output.save() seq_len = clean_cache[0].shape[1] results = torch.zeros(n_layers, seq_len) for layer in range(n_layers): for pos in range(seq_len): with model.trace(corrupted_prompt) as tracer: current = model.transformer.h[layer].output[0] current[:, pos, :] = clean_cache[layer][:, pos, :] logits = model.output.save() probs = torch.softmax(logits[0, -1], dim=-1) results[layer, pos] = probs[target_token].item() return results

Remote Execution with NDIF

from nnsight import LanguageModel # Load a 70B model (runs on NDIF servers) model = LanguageModel("meta-llama/Llama-3.1-70B") # Same code, just add remote=True with model.trace("The meaning of life is", remote=True) as tracer: layer_40 = model.model.layers[40].output[0].save() logits = model.output.save() print(f"Layer 40 shape: {layer_40.shape}")

Cross-Prompt Activation Sharing

from nnsight import LanguageModel model = LanguageModel("gpt2", device_map="auto") with model.trace() as tracer: # First prompt with tracer.invoke("The cat sat on the"): cat_hidden = model.transformer.h[6].output[0].save() # Second prompt with injected activations from first with tracer.invoke("The dog ran through the"): model.transformer.h[6].output[0][:] = cat_hidden dog_with_cat = model.output.save()

Gradient-Based Analysis

from nnsight import LanguageModel import torch model = LanguageModel("gpt2", device_map="auto") with model.trace("The quick brown fox") as tracer: hidden = model.transformer.h[5].output[0].save() hidden.retain_grad() logits = model.output target_token = model.tokenizer.encode(" jumps")[0] loss = -logits[0, -1, target_token] loss.backward() grad = hidden.grad print(f"Gradient norm: {grad.norm().item():.3f}")

Configuration Reference

ParameterDescriptionDefault
device_mapDevice placement strategy"auto"
remoteExecute on NDIF serversFalse
timeoutRemote execution timeout (seconds)120
NDIF_API_KEYAPI key for remote executionEnvironment variable

Model Architecture Paths

ModelLayer AccessAttention Access
GPT-2model.transformer.h[i].output[0]model.transformer.h[i].attn
LLaMAmodel.model.layers[i].output[0]model.model.layers[i].self_attn
Mistralmodel.model.layers[i].output[0]model.model.layers[i].self_attn
GPT-NeoXmodel.gpt_neox.layers[i].output[0]model.gpt_neox.layers[i].attention

Key API Methods

MethodPurpose
model.trace(prompt, remote=False)Start tracing context
proxy.save()Save value for access after trace
proxy[:]Slice/index proxy (assignment patches)
tracer.invoke(prompt)Add prompt within multi-prompt trace
model.generate(...)Generate tokens with interventions
model.outputAccess final model output logits
model._modelAccess underlying HuggingFace model

Best Practices

  1. Always call .save() on values you need: Values inside a trace context are Proxy objects. Without .save(), they are not accessible after the context exits. This is the most common beginner mistake.

  2. Check model structure before writing traces: Use print(model._model) to see the actual module hierarchy. Layer access paths differ between model architectures (GPT-2 vs LLaMA vs Mistral).

  3. Save only what you need: Saving every layer's activations for a large model consumes significant memory. Be selective about which layers and positions you save.

  4. Use remote execution for large models: Do not attempt to load 70B+ models locally. Use remote=True with NDIF to run the same code on server-grade hardware.

  5. Start with small models for debugging: Develop and debug your experimental code on GPT-2 locally, then switch to larger models once the logic is verified.

  6. Use cross-prompt traces for causal experiments: Instead of running separate traces for clean and corrupted prompts, use tracer.invoke() to share activations within a single trace context.

  7. Add iteration limits for sweeps: Patching sweeps over all layers and positions can be computationally expensive. Start with a subset and expand once you know where to focus.

  8. Retain gradients explicitly: Gradient access requires calling .retain_grad() on the saved proxy. Gradients are not available for vLLM or remote execution.

Troubleshooting

Proxy value is not a tensor outside trace You forgot to call .save() on the value inside the trace context. Only saved values are materialized as actual tensors after the context exits.

Module path does not exist Different model architectures have different module hierarchies. Use print(model._model) to inspect the actual structure. GPT-2 uses model.transformer.h[i] while LLaMA uses model.model.layers[i].

Remote execution timeout Increase the timeout parameter: model.trace(prompt, remote=True, timeout=300). NDIF servers may have queue delays during peak usage.

Memory error with many saved activations Reduce the number of layers saved or save only specific token positions instead of full sequences. Use hidden[:, -1, :].save() to save only the last position.

Gradient is None after backward Ensure you called hidden.retain_grad() inside the trace before the backward pass. Gradient access is not supported with vLLM or remote execution.

Community

Reviews

Write a review

No reviews yet. Be the first to review this template!

Similar Templates