15. Mixture of Experts (MoE) — Specialized Sub-Model Routing for MangaAssist
Problem Statement and MangaAssist Context
MangaAssist serves diverse query types — product lookup, recommendation, order tracking, complaint handling, manga knowledge Q&A — through a single monolithic LLM. The problem: a single model must be expert in everything, leading to:
- Diluted expertise (8% quality drop on tail queries): the model is mediocre at everything rather than excellent at specific domains
- Wasted compute (70% of parameters activate for every query): a product price lookup shouldn't require the full 8B parameter model
- No specialization path: improving manga recommendations may degrade order-tracking quality
Mixture of Experts (MoE) routes each query to specialized sub-networks, activating only the relevant parameters. This gives us the capacity of a large model with the inference cost of a small one.
Before vs After MoE
| Metric | Dense 8B | MoE (8 experts, top-2) | Improvement |
|---|---|---|---|
| Total parameters | 8B | 32B | 4× capacity |
| Active parameters per query | 8B | 8B | Same cost |
| Product query accuracy | 88% | 94% | +6% |
| Recommendation quality (NDCG) | 0.76 | 0.84 | +10.5% |
| Complaint handling CSAT | 3.9/5 | 4.⅘ | +12.8% |
| Inference latency | 480ms | 510ms | +6% (routing overhead) |
| Training cost | $200 | $340 | +70% (one-time) |
Mathematical Foundations
Gating Network (Shazeer et al., 2017)
At the core of MoE is the gating network $G(x)$, which decides which experts to activate for each input token $x$:
$$G(x) = \text{TopK}\left(\text{softmax}(W_g \cdot x + \epsilon)\right)$$
where: - $W_g \in \mathbb{R}^{E \times d}$ is the gating weight matrix ($E$ = number of experts, $d$ = hidden dimension) - $\epsilon \sim \mathcal{N}(0, \frac{1}{E^2})$ is noise for exploration during training - $\text{TopK}(\cdot)$ keeps only the top-$K$ values and sets the rest to $-\infty$ before re-normalizing
The output of the MoE layer:
$$\text{MoE}(x) = \sum_{i=1}^{E} G(x)_i \cdot E_i(x)$$
where $E_i(x)$ is expert $i$'s output and $G(x)_i$ is the routing weight (zero for non-selected experts).
For MangaAssist: $E = 8$ experts (product, recommendation, order, complaint, knowledge, chitchat, multi-turn, fallback), $K = 2$ active per token.
Expert Architecture
Each expert is a standard FFN sub-layer:
$$E_i(x) = W_i^{\text{up}} \cdot \text{SiLU}(W_i^{\text{gate}} \cdot x) + b_i^{\text{up}}$$
where $W_i^{\text{gate}} \in \mathbb{R}^{d_{\text{ff}} \times d}$ and $W_i^{\text{up}} \in \mathbb{R}^{d \times d_{\text{ff}}}$.
In a standard Llama 3 8B, the FFN has $d = 4096$, $d_{\text{ff}} = 14336$. In MoE, each expert has the same structure, but we have 8 of them:
| Component | Dense | MoE (8 experts) |
|---|---|---|
| Attention layers | 4096 × 4096 × 4 = 67M per layer | Same (shared) |
| FFN per expert | 4096 × 14336 × 2 = 117M | 117M × 8 = 939M |
| Active FFN per token | 117M | 117M × 2 = 234M (top-2) |
| Total FFN params | 117M × 32 layers = 3.75B | 939M × 32 = 30B |
Load Balancing Loss
Without regularization, the gating network collapses: it routes all tokens to 1-2 experts, leaving the rest unused. The load balancing loss prevents this:
$$\mathcal{L}{\text{balance}} = \alpha \cdot E \cdot \sum{i=1}^{E} f_i \cdot p_i$$
where: - $f_i = \frac{\text{number of tokens routed to expert } i}{\text{total tokens in batch}}$ (fraction of tokens) - $p_i = \frac{1}{T}\sum_{t=1}^{T} G(x_t)_i$ (average routing probability for expert $i$) - $\alpha$ is the balancing coefficient (typically 0.01)
Intuition: The product $f_i \cdot p_i$ is minimized when tokens are evenly distributed. If expert $i$ receives too many tokens ($f_i$ high) AND has high routing probability ($p_i$ high), the loss heavily penalizes it. This pushes the gate to spread tokens more evenly.
Ideal state: $f_i = \frac{1}{E}$ for all $i$ (uniform distribution). The loss achieves its minimum when $f_i = p_i = \frac{1}{E}$.
Capacity Factor and Token Dropping
Each expert has a capacity factor $C$ that limits how many tokens it can process:
$$\text{capacity}_i = C \cdot \frac{T}{E}$$
where $T$ is the total tokens in the batch. Tokens exceeding capacity are either: 1. Dropped (Switch Transformer): skipped entirely, passed through a residual connection 2. Rerouted (GShard): sent to the next-best expert
Typical settings: - $C = 1.0$: each expert handles exactly its fair share (risk of dropping) - $C = 1.25$: 25% buffer (recommended for training) - $C = 2.0$: generous buffer (fewer drops, more compute)
Router Z-Loss (Zoph et al., 2022)
Large routing logits cause numerical instability. The router z-loss penalizes large logits:
$$\mathcal{L}{z} = \frac{1}{T} \sum{t=1}^{T} \left(\log \sum_{i=1}^{E} e^{z_i^{(t)}}\right)^2$$
where $z_i^{(t)} = W_g \cdot x_t$ are the raw logits before softmax. This keeps logits in a numerically stable range, improving training convergence.
Expert Specialization Analysis
To verify experts specialize, we measure the expert utilization matrix $U \in \mathbb{R}^{E \times Q}$ where $Q$ is the number of query types:
$$U_{i,j} = \frac{\text{tokens from query type } j \text{ routed to expert } i}{\text{total tokens from query type } j}$$
Ideal specialization: each row of $U$ shows peaks for specific query types, not uniform activation. We measure this with the specialization score:
$$\text{Spec}_i = 1 - \frac{H(U_i)}{\log Q}$$
where $H(U_i) = -\sum_j U_{i,j} \log U_{i,j}$ is the entropy of expert $i$'s utilization. Spec = 1 means perfect specialization (only one query type), Spec = 0 means uniform (no specialization).
Model Internals — Layer-by-Layer Diagrams
Token Routing Through Experts
graph TB
subgraph "MoE Layer Processing"
INPUT["Input tokens from attention layer<br>x ∈ ℝ^{T×d}, T=512, d=4096"]
GATE["Gating Network G(x)<br>W_g ∈ ℝ^{8×4096}<br>Computes: softmax(W_g·x + ε)<br>Selects top-2 experts per token"]
subgraph "8 Expert FFNs (only 2 active per token)"
E1["Expert 1: Product<br>W_gate ∈ ℝ^{14336×4096}<br>W_up ∈ ℝ^{4096×14336}<br>117M params"]
E2["Expert 2: Recommend<br>117M params"]
E3["Expert 3: Order<br>117M params"]
E4["Expert 4: Complaint<br>117M params"]
E5["Expert 5: Knowledge<br>117M params"]
E6["Expert 6: Chitchat<br>117M params"]
E7["Expert 7: Multi-turn<br>117M params"]
E8["Expert 8: Fallback<br>117M params"]
end
COMBINE["Weighted combination:<br>MoE(x) = G(x)₁·E₁(x) + G(x)₅·E₅(x)<br>(example: product + knowledge)"]
OUTPUT["Output: same shape ℝ^{T×d}<br>Only 234M params activated<br>(2× expert = 2×117M)"]
INPUT --> GATE
GATE -->|"w=0.65"| E1
GATE -->|"w=0"| E2
GATE -->|"w=0"| E3
GATE -->|"w=0"| E4
GATE -->|"w=0.35"| E5
GATE -->|"w=0"| E6
GATE -->|"w=0"| E7
GATE -->|"w=0"| E8
E1 --> COMBINE
E5 --> COMBINE
COMBINE --> OUTPUT
end
style E1 fill:#c8e6c9
style E5 fill:#c8e6c9
style E2 fill:#eeeeee
style E3 fill:#eeeeee
style E4 fill:#eeeeee
style E6 fill:#eeeeee
style E7 fill:#eeeeee
style E8 fill:#eeeeee
Expert Specialization Heatmap
graph TB
subgraph "Expert Utilization Matrix U (after training)"
direction TB
HEADER["Query Types →<br>Product | Recommend | Order | Complaint | Knowledge | Chitchat"]
E1_ROW["Expert 1: 0.42 | 0.12 | 0.08 | 0.05 | 0.28 | 0.05<br>Specialization: Product + Knowledge lookups"]
E2_ROW["Expert 2: 0.10 | 0.45 | 0.05 | 0.08 | 0.22 | 0.10<br>Specialization: Recommendations"]
E3_ROW["Expert 3: 0.05 | 0.05 | 0.48 | 0.15 | 0.07 | 0.20<br>Specialization: Order tracking + general"]
E4_ROW["Expert 4: 0.08 | 0.06 | 0.12 | 0.52 | 0.10 | 0.12<br>Specialization: Complaint handling"]
E5_ROW["Expert 5: 0.15 | 0.18 | 0.05 | 0.05 | 0.47 | 0.10<br>Specialization: Manga knowledge"]
E6_ROW["Expert 6: 0.07 | 0.10 | 0.08 | 0.12 | 0.08 | 0.55<br>Specialization: Chitchat + conversational"]
E7_ROW["Expert 7: 0.15 | 0.20 | 0.18 | 0.15 | 0.17 | 0.15<br>Specialization: Multi-turn context (generalist)"]
E8_ROW["Expert 8: 0.12 | 0.08 | 0.15 | 0.08 | 0.10 | 0.47<br>Specialization: Fallback / edge cases"]
HEADER --> E1_ROW --> E2_ROW --> E3_ROW --> E4_ROW --> E5_ROW --> E6_ROW --> E7_ROW --> E8_ROW
end
style E1_ROW fill:#c8e6c9
style E2_ROW fill:#c8e6c9
style E3_ROW fill:#c8e6c9
style E4_ROW fill:#bbdefb
style E5_ROW fill:#c8e6c9
style E6_ROW fill:#fff9c4
style E7_ROW fill:#eeeeee
style E8_ROW fill:#fff9c4
Load Balancing Dynamics
graph LR
subgraph "Without Load Balancing Loss"
WO_INIT["Epoch 0: Uniform<br>12.5% per expert"]
WO_MID["Epoch 5: Collapse begins<br>Expert 1: 35%<br>Expert 3: 28%<br>Others: 5-8%"]
WO_END["Epoch 20: Collapsed<br>Expert 1: 62%<br>Expert 3: 31%<br>Others: ~1% (dead)"]
WO_INIT --> WO_MID --> WO_END
end
subgraph "With Load Balancing Loss (α=0.01)"
W_INIT["Epoch 0: Uniform<br>12.5% per expert"]
W_MID["Epoch 5: Slight specialization<br>Range: 9-18%"]
W_END["Epoch 20: Balanced<br>Range: 10-16%<br>All experts active"]
W_INIT --> W_MID --> W_END
end
style WO_END fill:#ffcdd2
style W_END fill:#c8e6c9
MoE Layer Position in Transformer
graph TB
subgraph "Llama 3 8B MoE Architecture (32 layers)"
INPUT["Input embeddings"]
L1["Layer 1-8: Dense FFN<br>(shared low-level features)<br>No MoE — pattern extraction"]
L2["Layer 9-24: MoE FFN<br>8 experts per layer, top-2 routing<br>16 MoE layers × 939M = 15B params<br>Active: 16 × 234M = 3.7B"]
L3["Layer 25-32: Dense FFN<br>(shared output projection)<br>No MoE — output generation"]
HEAD["LM Head → vocabulary logits"]
INPUT --> L1 --> L2 --> L3 --> HEAD
end
subgraph "Why This Split?"
R1["Early layers learn universal<br>features (tokenization, syntax)<br>→ No need for specialization"]
R2["Middle layers learn<br>task-specific representations<br>→ Maximum benefit from routing"]
R3["Late layers unify<br>representations for generation<br>→ Shared output space"]
end
style L2 fill:#c8e6c9
style L1 fill:#e3f2fd
style L3 fill:#e3f2fd
Routing Decision Flow
sequenceDiagram
participant T as Token "One Piece volume 107"
participant G as Gating Network
participant LB as Load Balancer
participant E1 as Expert 1 (Product)
participant E5 as Expert 5 (Knowledge)
participant C as Combiner
T->>G: x = hidden state (d=4096)
G->>G: z = W_g · x + ε (8 logits)
G->>G: p = softmax(z) = [0.38, 0.05, 0.02, 0.01, 0.42, 0.03, 0.06, 0.03]
G->>G: top-2 = {Expert 5: 0.42, Expert 1: 0.38}
G->>LB: Check capacity for E1, E5
LB->>LB: E1: 58/64 slots used → OK
LB->>LB: E5: 61/64 slots used → OK
G->>E1: Route with weight 0.38/(0.38+0.42) = 0.475
G->>E5: Route with weight 0.42/(0.38+0.42) = 0.525
E1->>E1: FFN: SiLU(W_gate · x) ⊙ (W_up · x)
E5->>E5: FFN: SiLU(W_gate · x) ⊙ (W_up · x)
E1->>C: Product expert output
E5->>C: Knowledge expert output
C->>C: 0.475 × E1(x) + 0.525 × E5(x)
C->>T: Combined MoE output
Implementation Deep-Dive
MoE Layer Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
class ExpertFFN(nn.Module):
"""Single expert feed-forward network (SwiGLU architecture)."""
def __init__(self, d_model: int = 4096, d_ff: int = 14336):
super().__init__()
self.gate_proj = nn.Linear(d_model, d_ff, bias=False)
self.up_proj = nn.Linear(d_model, d_ff, bias=False)
self.down_proj = nn.Linear(d_ff, d_model, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class TopKGating(nn.Module):
"""Top-K gating network with load balancing loss."""
def __init__(
self, d_model: int, num_experts: int, top_k: int = 2,
noise_std: float = 0.1,
):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.noise_std = noise_std
self.gate = nn.Linear(d_model, num_experts, bias=False)
def forward(
self, x: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x: (batch, seq_len, d_model)
Returns:
gates: (batch, seq_len, num_experts) — routing weights
indices: (batch, seq_len, top_k) — selected expert indices
balance_loss: scalar load balancing loss
"""
# Compute logits
logits = self.gate(x) # (B, T, E)
# Add noise during training for exploration
if self.training:
noise = torch.randn_like(logits) * self.noise_std
logits = logits + noise
# Top-K selection
top_k_logits, indices = torch.topk(logits, self.top_k, dim=-1)
gates = F.softmax(top_k_logits, dim=-1) # (B, T, K)
# Load balancing loss
# f_i: fraction of tokens routed to expert i
# p_i: average routing probability for expert i
routing_probs = F.softmax(logits, dim=-1) # (B, T, E)
p = routing_probs.mean(dim=[0, 1]) # (E,)
# One-hot for selected experts
mask = torch.zeros_like(routing_probs)
mask.scatter_(-1, indices, 1.0)
f = mask.mean(dim=[0, 1]) # (E,)
balance_loss = self.num_experts * (f * p).sum()
# Router z-loss
z_loss = (torch.logsumexp(logits, dim=-1) ** 2).mean()
return gates, indices, balance_loss + 0.001 * z_loss
class MoELayer(nn.Module):
"""Mixture of Experts layer replacing standard FFN."""
def __init__(
self,
d_model: int = 4096,
d_ff: int = 14336,
num_experts: int = 8,
top_k: int = 2,
capacity_factor: float = 1.25,
):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.capacity_factor = capacity_factor
self.gating = TopKGating(d_model, num_experts, top_k)
self.experts = nn.ModuleList([
ExpertFFN(d_model, d_ff) for _ in range(num_experts)
])
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: (batch, seq_len, d_model)
Returns:
output: (batch, seq_len, d_model)
aux_loss: load balancing loss
"""
B, T, D = x.shape
gates, indices, aux_loss = self.gating(x)
# Compute capacity
capacity = int(self.capacity_factor * T / self.num_experts)
# Gather expert outputs
output = torch.zeros_like(x)
expert_counts = torch.zeros(self.num_experts, device=x.device)
for k in range(self.top_k):
expert_idx = indices[:, :, k] # (B, T)
gate_weight = gates[:, :, k] # (B, T)
for e_idx in range(self.num_experts):
mask = expert_idx == e_idx # (B, T)
if not mask.any():
continue
# Capacity check
count = mask.sum().item()
expert_counts[e_idx] += count
if count > capacity * B:
# Drop excess tokens (keep first `capacity * B`)
flat_mask = mask.reshape(-1)
indices_true = flat_mask.nonzero().squeeze(-1)
drop = indices_true[capacity * B:]
flat_mask[drop] = False
mask = flat_mask.reshape(B, T)
# Route to expert
expert_input = x[mask] # (num_tokens, D)
if expert_input.numel() > 0:
expert_output = self.experts[e_idx](expert_input)
output[mask] += gate_weight[mask].unsqueeze(-1) * expert_output
return output, aux_loss
MoE Trainer for MangaAssist
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.utils.data import DataLoader
class MoETrainer:
"""
Train a Mixture of Experts model for MangaAssist.
Uses LoRA on the gating network + expert selection.
"""
def __init__(self, base_model: str = "meta-llama/Llama-3-8b-hf"):
self.tokenizer = AutoTokenizer.from_pretrained(base_model)
self.tokenizer.pad_token = self.tokenizer.eos_token
# Load base model and replace FFN layers with MoE
self.model = AutoModelForCausalLM.from_pretrained(
base_model,
torch_dtype=torch.bfloat16,
device_map="auto",
)
# Replace middle FFN layers (9-24) with MoE
self._replace_ffn_with_moe(layers=range(8, 24))
def _replace_ffn_with_moe(self, layers: range):
"""Replace selected FFN layers with MoE layers."""
for idx in layers:
layer = self.model.model.layers[idx]
d_model = layer.mlp.gate_proj.in_features
d_ff = layer.mlp.gate_proj.out_features
# Initialize first expert from the original FFN weights
moe = MoELayer(d_model=d_model, d_ff=d_ff, num_experts=8, top_k=2)
# Copy original FFN weights to expert 0
with torch.no_grad():
moe.experts[0].gate_proj.weight.copy_(layer.mlp.gate_proj.weight)
moe.experts[0].up_proj.weight.copy_(layer.mlp.up_proj.weight)
moe.experts[0].down_proj.weight.copy_(layer.mlp.down_proj.weight)
# Initialize other experts with noise around expert 0
for e in range(1, 8):
for param_name in ["gate_proj", "up_proj", "down_proj"]:
src = getattr(moe.experts[0], param_name).weight
tgt = getattr(moe.experts[e], param_name).weight
tgt.copy_(src + 0.01 * torch.randn_like(src))
layer.mlp = moe
def train(
self,
train_data: list[dict],
epochs: int = 5,
lr: float = 1e-5,
balance_coeff: float = 0.01,
):
"""Train with combined language modeling and load balancing loss."""
loader = DataLoader(train_data, batch_size=2, shuffle=True)
optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr)
for epoch in range(epochs):
total_lm_loss = 0
total_aux_loss = 0
for batch in loader:
inputs = self.tokenizer(
batch["text"], return_tensors="pt", padding=True,
truncation=True, max_length=2048,
).to(self.model.device)
outputs = self.model(**inputs, labels=inputs["input_ids"])
lm_loss = outputs.loss
# Collect auxiliary losses from all MoE layers
aux_loss = torch.tensor(0.0, device=self.model.device)
for layer in self.model.model.layers:
if hasattr(layer.mlp, 'gating'):
# MoE layer stores aux_loss from last forward
if hasattr(layer.mlp, '_last_aux_loss'):
aux_loss += layer.mlp._last_aux_loss
total_loss = lm_loss + balance_coeff * aux_loss
total_loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
total_lm_loss += lm_loss.item()
total_aux_loss += aux_loss.item()
avg_lm = total_lm_loss / len(loader)
avg_aux = total_aux_loss / len(loader)
print(
f"Epoch {epoch+1}: LM loss={avg_lm:.4f}, "
f"Aux loss={avg_aux:.4f}"
)
def analyze_specialization(self, test_data: list[dict]) -> dict:
"""Analyze which experts specialize in which query types."""
self.model.eval()
utilization = {} # {query_type: {expert_id: count}}
hooks = []
routing_log = []
def make_hook(layer_idx):
def hook_fn(module, input, output):
if isinstance(output, tuple) and len(output) == 2:
# MoE layer returns (output, aux_loss)
pass
# Log routing decisions
_, indices, _ = module.gating(input[0])
routing_log.append({
"layer": layer_idx,
"indices": indices.detach().cpu(),
})
return hook_fn
for idx, layer in enumerate(self.model.model.layers):
if hasattr(layer.mlp, 'gating'):
hooks.append(
layer.mlp.register_forward_hook(make_hook(idx))
)
for ex in test_data:
routing_log.clear()
query_type = ex["query_type"]
inputs = self.tokenizer(
ex["text"], return_tensors="pt",
).to(self.model.device)
with torch.no_grad():
self.model(**inputs)
# Aggregate routing decisions
if query_type not in utilization:
utilization[query_type] = {i: 0 for i in range(8)}
for entry in routing_log:
for expert_id in entry["indices"].flatten().tolist():
utilization[query_type][expert_id] += 1
for hook in hooks:
hook.remove()
# Normalize to percentages
for qt in utilization:
total = sum(utilization[qt].values())
if total > 0:
utilization[qt] = {
k: v / total for k, v in utilization[qt].items()
}
return utilization
SageMaker Deployment for MoE
import sagemaker
from sagemaker.huggingface import HuggingFaceModel
def deploy_moe_model(model_path: str, instance_type: str = "ml.g5.2xlarge"):
"""Deploy MoE model to SageMaker endpoint."""
# MoE requires more memory but same inference FLOPS
# g5.2xlarge: 24GB GPU, sufficient for 32B total / 8B active
hub = {
"HF_MODEL_ID": model_path,
"SM_NUM_GPUS": "1",
"MAX_INPUT_LENGTH": "2048",
"MAX_TOTAL_TOKENS": "4096",
}
model = HuggingFaceModel(
env=hub,
role=sagemaker.get_execution_role(),
image_uri=sagemaker.image_uris.retrieve(
framework="huggingface-llm",
region="us-west-2",
version="2.3.0",
instance_type=instance_type,
),
)
predictor = model.deploy(
initial_instance_count=1,
instance_type=instance_type,
endpoint_name="mangaassist-moe-endpoint",
)
return predictor
Group Discussion: Key Decision Points
Decision Point 1: Dense vs MoE — When It's Worth It
Marcus (Architect): Is MoE warranted for our scale?
| Factor | Dense 8B | MoE 32B (top-2) | Verdict |
|---|---|---|---|
| Active compute per query | 8B params | 8B params | Tie |
| Total model memory | 16GB (FP16) | 64GB (FP16) → 16GB (INT4) | MoE needs quantization |
| Quality on specialized tasks | 88% avg | 93% avg | MoE +5% |
| Training complexity | Low | High (gating, balancing) | Dense easier |
| Serving complexity | Low | Medium (expert routing) | Dense easier |
Jordan (MLOps): MoE is over-engineered for MangaAssist V1. Our 6 query types can be handled by a well-fine-tuned dense model. MoE becomes valuable when we scale beyond 15+ distinct specializations.
Sam (PM): What about the Mixtral approach — using a pre-trained MoE model instead of building our own?
Priya (ML Engineer): Mixtral 8x7B is pre-trained with 8 experts. We could fine-tune it with LoRA instead of building MoE from scratch. The experts are already specialized in different language patterns; fine-tuning adapts them to our domain.
| Approach | Training Cost | Quality | Complexity |
|---|---|---|---|
| Build MoE from Llama 3 8B | $340 | 93% | Very High |
| Fine-tune Mixtral 8x7B with LoRA | $85 | 91% | Low |
| Fine-tune dense Llama 3 8B with LoRA | $48 | 88% | Very Low |
Resolution: For MangaAssist V1, continue with dense Llama 3 8B + LoRA (doc 04). For V2 (15+ query types, multi-language), evaluate Mixtral 8x7B + LoRA as a drop-in upgrade. Building custom MoE is only justified at V3 scale (50+ specializations, 100K+ queries/day).
Decision Point 2: Number of Experts and Top-K
Aiko (Data Scientist): Ablation study on expert count:
| Config | Quality | Balance | Memory | Specialization Score |
|---|---|---|---|---|
| 4 experts, top-1 | 89% | 0.92 | 28GB | 0.41 |
| 4 experts, top-2 | 91% | 0.95 | 28GB | 0.38 |
| 8 experts, top-1 | 90% | 0.87 | 48GB | 0.52 |
| 8 experts, top-2 | 93% | 0.91 | 48GB | 0.48 |
| 16 experts, top-2 | 93.5% | 0.82 | 80GB | 0.56 |
| 16 experts, top-4 | 94% | 0.88 | 80GB | 0.45 |
Priya (ML Engineer): 8 experts, top-2 is optimal: matches our query type count (6 types + 2 generalist), maintains good balance (0.91), and fits in 48GB (16GB with INT4 quantization). Going to 16 experts gives diminishing returns (+0.5%) at 67% more memory.
Marcus (Architect): Top-K=2 ensures redundancy: if expert 1 (product) makes a factual error, expert 5 (knowledge) can correct it through the weighted combination.
Resolution: 8 experts, top-2. This matches our 6 query types with 2 generalist experts. One expert per major domain + buffer.
Decision Point 3: Load Balancing Coefficient α
Aiko (Data Scientist): The balance coefficient $\alpha$ controls specialization vs uniformity:
| α | Expert Utilization Range | Max Quality | Dead Experts | Convergence |
|---|---|---|---|---|
| 0 (no balancing) | 1-62% | 94% | 4/8 dead | Unstable |
| 0.001 | 3-35% | 93.5% | ⅛ dead | Stable |
| 0.01 | 8-18% | 93% | 0 dead | Stable |
| 0.1 | 11-14% | 90% | 0 dead | Very stable |
| 1.0 | 12.4-12.6% | 86% | 0 dead | Uniform (defeats MoE) |
Jordan (MLOps): $\alpha = 0.01$ is the standard. No dead experts, reasonable specialization range (8-18%), and 93% quality. Lower values risk dead experts, higher values force uniformity.
Resolution: $\alpha = 0.01$ as default. Monitor expert utilization during training; if any expert drops below 5%, increase to 0.02.
Research Paper References
1. Switch Transformers: Scaling to Trillion Parameter Models (Fedus et al., 2021)
Key contribution: Simplified MoE by using top-1 routing (only one expert per token) instead of top-2, reducing communication costs. Introduced the concept of capacity factor and the simplified load balancing loss. Showed that MoE models can scale to 1.6 trillion parameters while maintaining training efficiency. The key insight: top-1 routing with proper capacity factor works just as well as top-2 for many tasks, while being 2× faster.
Relevance to MangaAssist: Switch Transformer's load balancing loss and capacity factor directly apply to our gating network design. While we use top-2 for quality reasons (MangaAssist needs redundancy), the monitoring and balancing techniques from this paper are essential.
2. Mixtral of Experts (Jiang et al., 2024)
Key contribution: Released a pre-trained MoE model (8 experts, top-2 routing) that achieves GPT-3.5-level performance at 5× less inference compute. Mixtral 8x7B has 46.7B total parameters but activates only 12.9B per token. The paper showed that pre-trained MoE models can be effective for downstream fine-tuning, with experts naturally specializing in different language patterns (coding, reasoning, multilingual).
Relevance to MangaAssist: Mixtral is our prime candidate for the V2 upgrade. Instead of building custom MoE, we can fine-tune Mixtral with LoRA, benefiting from pre-trained expert specialization. The 12.9B active parameters provide better quality than our current 8B dense model at comparable inference cost.
3. GShard: Scaling Giant Models with Conditional Computation (Lepikhin et al., 2020)
Key contribution: Introduced expert parallelism for distributed training of MoE models across multiple devices. Key innovations: (1) auxiliary loss for load balancing, (2) random routing for the second-choice expert to improve exploration, (3) group-level top-2 routing that ensures each group of tokens uses at most $C$ capacity per expert. Scaled to 600B parameters on 2048 TPU chips.
Relevance to MangaAssist: GShard's capacity management techniques matter for production deployment. When we exceed single-GPU memory (V3 scale), expert parallelism from GShard enables splitting experts across devices, keeping inference latency constant while scaling model capacity.
Production Results
MoE Evaluation on MangaAssist Test Set (V2 Projection)
| Metric | Dense Llama 3 8B + LoRA | Mixtral 8x7B + LoRA | Custom MoE (8 experts) |
|---|---|---|---|
| Product accuracy | 88% | 92% | 94% |
| Recommendation NDCG | 0.76 | 0.82 | 0.84 |
| Order tracking accuracy | 91% | 93% | 95% |
| Complaint CSAT | 3.9/5 | 4.⅖ | 4.⅘ |
| Knowledge accuracy | 85% | 90% | 92% |
| Overall weighted | 88% | 91% | 93% |
| Inference latency (p50) | 480ms | 520ms | 510ms |
| Memory (INT4) | 4.1GB | 12GB | 16GB |
| Monthly cost (SageMaker) | $170 | $340 | $510 |
Cost
| Item | Dense LoRA | Mixtral LoRA | Custom MoE |
|---|---|---|---|
| Training cost | $48 | $85 | $340 |
| Monthly inference | $170 | $340 | $510 |
| Quality improvement | baseline | +3% | +5% |
| CPQ (cost per quality point) | - | $56/point | $68/point |
Verdict: Mixtral + LoRA gives the best CPQ ($56/point) and is the V2 recommendation. Custom MoE is only justified if quality requirements exceed what Mixtral can achieve.