E

Emerging Techniques Speculative Kit

Powerful skill for accelerate, inference, using, speculative. Includes structured workflows, validation checks, and reusable patterns for ai research.

SkillClipticsai researchv1.0.0MIT
0 views0 copies

Speculative Decoding for Fast LLM Inference

Overview

A comprehensive skill for implementing speculative decoding — the technique that accelerates LLM inference by 2-3x without any quality loss. Uses a small, fast "draft" model to propose candidate tokens that a larger "target" model verifies in a single forward pass, achieving mathematically guaranteed identical output distribution to standard autoregressive decoding.

When to Use

  • Need faster LLM inference without quality degradation
  • Serving latency-sensitive applications (chatbots, real-time)
  • Have access to both a large and small model of the same family
  • Want to reduce time-to-first-token and token generation latency
  • Optimizing cost by reducing GPU-seconds per request
  • Batch inference where throughput matters

Quick Start

# Using HuggingFace Transformers (built-in support) pip install transformers accelerate torch
from transformers import AutoModelForCausalLM, AutoTokenizer # Load target (large) and draft (small) models target = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-3-70B-Instruct", torch_dtype=torch.bfloat16, device_map="auto", ) draft = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-3-8B-Instruct", torch_dtype=torch.bfloat16, device_map="auto", ) tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3-70B-Instruct") # Generate with speculative decoding — identical quality, 2-3x faster inputs = tokenizer("Explain quantum computing", return_tensors="pt") outputs = target.generate( **inputs, max_new_tokens=200, assistant_model=draft, # Enable speculative decoding )

How Speculative Decoding Works

Standard Decoding (slow):
  Token 1 → [Large Model Forward] → Token 2 → [Large Model Forward] → Token 3 → ...
  Each token requires a full forward pass of the large model

Speculative Decoding (fast):
  Step 1: Draft model proposes K tokens quickly
    [Small Model] → t1, t2, t3, t4, t5  (5 draft tokens, very fast)

  Step 2: Target model verifies ALL K tokens in ONE forward pass
    [Large Model](t1, t2, t3, t4, t5) → verify each

  Step 3: Accept tokens where target agrees, reject where it disagrees
    t1 ✓, t2 ✓, t3 ✓, t4 ✗ → Accept 3 tokens + sample correction for t4

  Result: 4 tokens generated in ~1 large model forward pass (vs 4 passes normally)

Core Algorithm

import torch import torch.nn.functional as F def speculative_decode(target_model, draft_model, input_ids, gamma=5, max_tokens=100): """ Speculative decoding with rejection sampling. Generates tokens from target_model distribution using draft_model for speedup. """ generated = input_ids.clone() total_accepted = 0 total_proposed = 0 while generated.shape[-1] - input_ids.shape[-1] < max_tokens: # Phase 1: Draft model proposes gamma tokens autoregressively draft_tokens = [] draft_probs_list = [] current_seq = generated.clone() for _ in range(gamma): with torch.no_grad(): draft_logits = draft_model(current_seq).logits[:, -1, :] draft_probs = F.softmax(draft_logits, dim=-1) sampled_token = torch.multinomial(draft_probs, 1) draft_tokens.append(sampled_token) draft_probs_list.append(draft_probs) current_seq = torch.cat([current_seq, sampled_token], dim=-1) # Phase 2: Target model scores all draft tokens in ONE pass with torch.no_grad(): target_logits = target_model(current_seq).logits # Phase 3: Rejection sampling — accept/reject each draft token n_accepted = 0 for i in range(gamma): pos = generated.shape[-1] + i target_probs = F.softmax(target_logits[:, pos - 1, :], dim=-1) draft_prob = draft_probs_list[i] token = draft_tokens[i] # Acceptance probability: min(1, p_target / p_draft) p_target = target_probs.gather(-1, token) p_draft = draft_prob.gather(-1, token) accept_prob = (p_target / p_draft).clamp(max=1.0) if torch.rand(1, device=accept_prob.device) < accept_prob: generated = torch.cat([generated, token], dim=-1) n_accepted += 1 else: # Rejection: sample from adjusted distribution adjusted = F.relu(target_probs - draft_prob) adjusted = adjusted / adjusted.sum(dim=-1, keepdim=True) corrected_token = torch.multinomial(adjusted, 1) generated = torch.cat([generated, corrected_token], dim=-1) break # If all accepted, sample one bonus token from target if n_accepted == gamma: bonus_probs = F.softmax(target_logits[:, -1, :], dim=-1) bonus_token = torch.multinomial(bonus_probs, 1) generated = torch.cat([generated, bonus_token], dim=-1) total_accepted += n_accepted total_proposed += gamma acceptance_rate = total_accepted / max(total_proposed, 1) return generated, acceptance_rate

Draft Model Selection Guide

Target ModelRecommended DraftAcceptance RateSpeedup
Llama-3 70BLlama-3 8B70-80%2.5-3x
Llama-3 8BLlama-3 1B60-75%1.8-2.5x
Mixtral 8x7BMistral 7B65-75%2-2.5x
GPT-4GPT-3.5 (API)60-70%2-3x
CodeLlama 34BCodeLlama 7B75-85%2.5-3.5x

Configuration Reference

ParameterRangeImpact
gamma (draft length)3-10Higher = more draft tokens per step
temperature0-2Higher temp reduces acceptance rate
top_k1-100Smaller top_k improves acceptance
top_p0-1Lower top_p improves acceptance
draft_model_size-Smaller draft = faster but lower acceptance

Best Practices

  1. Choose draft models from the same family — Same tokenizer is mandatory; similar architecture helps acceptance rate
  2. Tune gamma based on acceptance rate — Higher acceptance rate → increase gamma; lower → decrease
  3. Use greedy decoding for highest speedup — Acceptance rates are highest with temperature=0
  4. Profile both models — Draft model should be 5-10x faster than target for good speedup
  5. Use on GPU, not CPU — Speculative decoding benefits from fast forward passes
  6. Batch draft and target — Run draft on separate GPU stream for overlap
  7. Monitor acceptance rate — Below 50% means your draft model is too different
  8. Consider self-speculative decoding — Use early layers of the target as the draft
  9. Use KV cache — Both models should reuse KV caches for maximum efficiency
  10. Test with representative prompts — Acceptance rate varies by domain

Troubleshooting

Low acceptance rate (<50%)

# Draft model too different from target # Option 1: Fine-tune draft to match target distribution # Option 2: Use a larger draft model # Option 3: Reduce gamma to minimize wasted computation gamma = 3 # Instead of 5

No speedup despite high acceptance

# Draft model overhead too large # Check: draft_time * gamma < target_time * (gamma - accepted) # Solution: Use smaller draft or run on separate GPU import time t0 = time.time() draft_model(input_ids) draft_time = time.time() - t0 t0 = time.time() target_model(input_ids) target_time = time.time() - t0 print(f"Draft: {draft_time*1000:.1f}ms, Target: {target_time*1000:.1f}ms") print(f"Ratio: {target_time/draft_time:.1f}x (need >5x for good speedup)")

Memory issues loading both models

# Use quantization for one or both models draft = AutoModelForCausalLM.from_pretrained( "Llama-3-8B", load_in_4bit=True, # 4-bit quantized draft ) target = AutoModelForCausalLM.from_pretrained( "Llama-3-70B", torch_dtype=torch.bfloat16, device_map="auto", )
Community

Reviews

Write a review

No reviews yet. Be the first to review this template!

Similar Templates