C

Comprehensive Emerging Techniques Moe

Enterprise-grade skill for train, mixture, experts, models. Includes structured workflows, validation checks, and reusable patterns for ai research.

SkillClipticsai researchv1.0.0MIT
0 views0 copies

Mixture of Experts (MoE) Training

Overview

A comprehensive skill for implementing Mixture of Experts architectures β€” the technique behind models like Mixtral, Switch Transformer, and GPT-4. MoE scales model capacity without proportionally scaling compute by routing each input to a subset of specialized "expert" sub-networks, enabling trillion-parameter models that run at the cost of much smaller dense models.

When to Use

  • Building models that need large capacity without proportional compute
  • Training domain-specific expert networks
  • Scaling beyond what dense models can achieve on your hardware
  • Need sparse activation for efficient inference
  • Implementing Switch Transformer or Mixtral-style architectures
  • Research on routing strategies and expert specialization

Quick Start

# Using HuggingFace Transformers pip install transformers accelerate # Using Mixtral pip install vllm # Best for MoE inference
# Load and use Mixtral (MoE model) from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained( "mistralai/Mixtral-8x7B-Instruct-v0.1", torch_dtype=torch.bfloat16, device_map="auto", ) tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1") # Despite 46.7B total params, only uses 12.9B per forward pass inputs = tokenizer("Explain MoE architectures", return_tensors="pt") output = model.generate(**inputs, max_new_tokens=200)

MoE Architecture

Input Token
    β”‚
    β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  Router  β”‚ β†’ Computes expert scores for this token
β”‚ (Gating) β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    β”‚ Top-K selection (usually K=2)
    β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
    β–Ό              β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚Expert 1β”‚   β”‚Expert 5β”‚   ← Only 2 of 8 experts activated
β”‚  FFN   β”‚   β”‚  FFN   β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”˜   β””β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    β”‚              β”‚
    β–Ό              β–Ό
  w₁·output₁ + wβ‚…Β·outputβ‚…  ← Weighted combination
    β”‚
    β–Ό
  Output

Custom MoE Implementation

import torch import torch.nn as nn import torch.nn.functional as F class TopKRouter(nn.Module): def __init__(self, hidden_size, num_experts, top_k=2): super().__init__() self.top_k = top_k self.gate = nn.Linear(hidden_size, num_experts, bias=False) def forward(self, x): # x: (batch, seq, hidden) logits = self.gate(x) # (batch, seq, num_experts) # Top-K gating top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=-1) top_k_weights = F.softmax(top_k_logits, dim=-1) return top_k_weights, top_k_indices, logits class MoEBlock(nn.Module): def __init__(self, hidden_size, ffn_size, num_experts=8, top_k=2): super().__init__() self.router = TopKRouter(hidden_size, num_experts, top_k) self.experts = nn.ModuleList([ nn.Sequential( nn.Linear(hidden_size, ffn_size), nn.SiLU(), nn.Linear(ffn_size, hidden_size), ) for _ in range(num_experts) ]) self.num_experts = num_experts def forward(self, x): weights, indices, gate_logits = self.router(x) # Efficient batched expert computation batch, seq, hidden = x.shape flat_x = x.reshape(-1, hidden) flat_indices = indices.reshape(-1, indices.shape[-1]) flat_weights = weights.reshape(-1, weights.shape[-1]) output = torch.zeros_like(flat_x) for i, expert in enumerate(self.experts): # Find tokens assigned to this expert mask = (flat_indices == i).any(dim=-1) if mask.any(): expert_input = flat_x[mask] expert_output = expert(expert_input) # Weight by router probability token_weights = flat_weights[mask] expert_idx = (flat_indices[mask] == i).float() weight = (token_weights * expert_idx).sum(dim=-1, keepdim=True) output[mask] += expert_output * weight return output.reshape(batch, seq, hidden), gate_logits

Load Balancing Loss

def load_balancing_loss(gate_logits, num_experts, top_k=2): """Auxiliary loss to ensure balanced expert utilization""" # gate_logits: (batch, seq, num_experts) routing_probs = F.softmax(gate_logits, dim=-1) # Fraction of tokens assigned to each expert _, indices = torch.topk(gate_logits, top_k, dim=-1) mask = F.one_hot(indices, num_experts).sum(dim=-2).float() # (batch, seq, num_experts) tokens_per_expert = mask.mean(dim=[0, 1]) # fraction per expert # Average routing probability per expert avg_prob = routing_probs.mean(dim=[0, 1]) # Balanced loss: minimize dot product of fraction and probability return num_experts * (tokens_per_expert * avg_prob).sum()

MoE Model Comparison

ModelExpertsActiveTotal ParamsActive ParamsPerformance
Mixtral 8x7B8246.7B12.9Bβ‰ˆ Llama-2 70B
Mixtral 8x22B82141B39B> GPT-3.5
Switch-Base12817.4B0.2BResearch baseline
DeepSeek-V21606236B21BCompetitive
Grok-182314B~86BOpen source

Configuration Reference

ParameterTypical RangeDescription
num_experts4-128Total number of expert networks
top_k1-4Experts activated per token
expert_capacity1.0-1.5Max tokens per expert (capacity factor)
aux_loss_weight0.01-0.1Weight of load balancing loss
expert_ffn_sizehiddenΓ—4FFN hidden dimension per expert
shared_expert0-2Number of always-active shared experts

Best Practices

  1. Use top-2 routing β€” Best balance of quality and efficiency for most architectures
  2. Always include load balancing loss β€” Without it, experts collapse to routing all tokens to 1-2 experts
  3. Set capacity factor > 1.0 β€” Prevents token dropping; 1.25 is a good default
  4. Use expert parallelism β€” Distribute experts across GPUs for large MoE models
  5. Shared experts improve stability β€” 1-2 always-active experts help with common patterns
  6. Monitor expert utilization β€” Log per-expert token counts to detect routing imbalances
  7. Start with fewer experts β€” 8 experts is the sweet spot; more experts need more data
  8. Use larger training datasets β€” MoE models need more data than dense models of equivalent quality
  9. Enable flash attention β€” MoE models have same attention bottleneck as dense models
  10. Serve with vLLM β€” vLLM has optimized MoE kernels for efficient inference

Troubleshooting

Expert collapse β€” all tokens to one expert

# Increase aux_loss_weight aux_loss_weight = 0.1 # Default 0.01 may be too small # Add noise to routing for exploration gate_logits += torch.randn_like(gate_logits) * 0.1 # During training only

High memory usage despite sparse activation

# Use expert parallelism β€” split experts across GPUs # With 8 experts on 4 GPUs: 2 experts per GPU # All-to-all communication routes tokens to correct GPU from torch.distributed import all_to_all

Slow inference due to expert routing overhead

# Use vLLM with tensor parallelism python -m vllm.entrypoints.openai.api_server \ --model mistralai/Mixtral-8x7B-Instruct-v0.1 \ --tensor-parallel-size 2
Community

Reviews

Write a review

No reviews yet. Be the first to review this template!

Similar Templates