LOCAL PREVIEW View on GitHub

15. Mixture of Experts (MoE) — Specialized Sub-Model Routing for MangaAssist

Problem Statement and MangaAssist Context

MangaAssist serves diverse query types — product lookup, recommendation, order tracking, complaint handling, manga knowledge Q&A — through a single monolithic LLM. The problem: a single model must be expert in everything, leading to:

  1. Diluted expertise (8% quality drop on tail queries): the model is mediocre at everything rather than excellent at specific domains
  2. Wasted compute (70% of parameters activate for every query): a product price lookup shouldn't require the full 8B parameter model
  3. No specialization path: improving manga recommendations may degrade order-tracking quality

Mixture of Experts (MoE) routes each query to specialized sub-networks, activating only the relevant parameters. This gives us the capacity of a large model with the inference cost of a small one.

Before vs After MoE

Metric Dense 8B MoE (8 experts, top-2) Improvement
Total parameters 8B 32B 4× capacity
Active parameters per query 8B 8B Same cost
Product query accuracy 88% 94% +6%
Recommendation quality (NDCG) 0.76 0.84 +10.5%
Complaint handling CSAT 3.9/5 4.⅘ +12.8%
Inference latency 480ms 510ms +6% (routing overhead)
Training cost $200 $340 +70% (one-time)

Mathematical Foundations

Gating Network (Shazeer et al., 2017)

At the core of MoE is the gating network $G(x)$, which decides which experts to activate for each input token $x$:

$$G(x) = \text{TopK}\left(\text{softmax}(W_g \cdot x + \epsilon)\right)$$

where: - $W_g \in \mathbb{R}^{E \times d}$ is the gating weight matrix ($E$ = number of experts, $d$ = hidden dimension) - $\epsilon \sim \mathcal{N}(0, \frac{1}{E^2})$ is noise for exploration during training - $\text{TopK}(\cdot)$ keeps only the top-$K$ values and sets the rest to $-\infty$ before re-normalizing

The output of the MoE layer:

$$\text{MoE}(x) = \sum_{i=1}^{E} G(x)_i \cdot E_i(x)$$

where $E_i(x)$ is expert $i$'s output and $G(x)_i$ is the routing weight (zero for non-selected experts).

For MangaAssist: $E = 8$ experts (product, recommendation, order, complaint, knowledge, chitchat, multi-turn, fallback), $K = 2$ active per token.

Expert Architecture

Each expert is a standard FFN sub-layer:

$$E_i(x) = W_i^{\text{up}} \cdot \text{SiLU}(W_i^{\text{gate}} \cdot x) + b_i^{\text{up}}$$

where $W_i^{\text{gate}} \in \mathbb{R}^{d_{\text{ff}} \times d}$ and $W_i^{\text{up}} \in \mathbb{R}^{d \times d_{\text{ff}}}$.

In a standard Llama 3 8B, the FFN has $d = 4096$, $d_{\text{ff}} = 14336$. In MoE, each expert has the same structure, but we have 8 of them:

Component Dense MoE (8 experts)
Attention layers 4096 × 4096 × 4 = 67M per layer Same (shared)
FFN per expert 4096 × 14336 × 2 = 117M 117M × 8 = 939M
Active FFN per token 117M 117M × 2 = 234M (top-2)
Total FFN params 117M × 32 layers = 3.75B 939M × 32 = 30B

Load Balancing Loss

Without regularization, the gating network collapses: it routes all tokens to 1-2 experts, leaving the rest unused. The load balancing loss prevents this:

$$\mathcal{L}{\text{balance}} = \alpha \cdot E \cdot \sum{i=1}^{E} f_i \cdot p_i$$

where: - $f_i = \frac{\text{number of tokens routed to expert } i}{\text{total tokens in batch}}$ (fraction of tokens) - $p_i = \frac{1}{T}\sum_{t=1}^{T} G(x_t)_i$ (average routing probability for expert $i$) - $\alpha$ is the balancing coefficient (typically 0.01)

Intuition: The product $f_i \cdot p_i$ is minimized when tokens are evenly distributed. If expert $i$ receives too many tokens ($f_i$ high) AND has high routing probability ($p_i$ high), the loss heavily penalizes it. This pushes the gate to spread tokens more evenly.

Ideal state: $f_i = \frac{1}{E}$ for all $i$ (uniform distribution). The loss achieves its minimum when $f_i = p_i = \frac{1}{E}$.

Capacity Factor and Token Dropping

Each expert has a capacity factor $C$ that limits how many tokens it can process:

$$\text{capacity}_i = C \cdot \frac{T}{E}$$

where $T$ is the total tokens in the batch. Tokens exceeding capacity are either: 1. Dropped (Switch Transformer): skipped entirely, passed through a residual connection 2. Rerouted (GShard): sent to the next-best expert

Typical settings: - $C = 1.0$: each expert handles exactly its fair share (risk of dropping) - $C = 1.25$: 25% buffer (recommended for training) - $C = 2.0$: generous buffer (fewer drops, more compute)

Router Z-Loss (Zoph et al., 2022)

Large routing logits cause numerical instability. The router z-loss penalizes large logits:

$$\mathcal{L}{z} = \frac{1}{T} \sum{t=1}^{T} \left(\log \sum_{i=1}^{E} e^{z_i^{(t)}}\right)^2$$

where $z_i^{(t)} = W_g \cdot x_t$ are the raw logits before softmax. This keeps logits in a numerically stable range, improving training convergence.

Expert Specialization Analysis

To verify experts specialize, we measure the expert utilization matrix $U \in \mathbb{R}^{E \times Q}$ where $Q$ is the number of query types:

$$U_{i,j} = \frac{\text{tokens from query type } j \text{ routed to expert } i}{\text{total tokens from query type } j}$$

Ideal specialization: each row of $U$ shows peaks for specific query types, not uniform activation. We measure this with the specialization score:

$$\text{Spec}_i = 1 - \frac{H(U_i)}{\log Q}$$

where $H(U_i) = -\sum_j U_{i,j} \log U_{i,j}$ is the entropy of expert $i$'s utilization. Spec = 1 means perfect specialization (only one query type), Spec = 0 means uniform (no specialization).


Model Internals — Layer-by-Layer Diagrams

Token Routing Through Experts

graph TB
    subgraph "MoE Layer Processing"
        INPUT["Input tokens from attention layer<br>x ∈ ℝ^{T×d}, T=512, d=4096"]

        GATE["Gating Network G(x)<br>W_g ∈ ℝ^{8×4096}<br>Computes: softmax(W_g·x + ε)<br>Selects top-2 experts per token"]

        subgraph "8 Expert FFNs (only 2 active per token)"
            E1["Expert 1: Product<br>W_gate ∈ ℝ^{14336×4096}<br>W_up ∈ ℝ^{4096×14336}<br>117M params"]
            E2["Expert 2: Recommend<br>117M params"]
            E3["Expert 3: Order<br>117M params"]
            E4["Expert 4: Complaint<br>117M params"]
            E5["Expert 5: Knowledge<br>117M params"]
            E6["Expert 6: Chitchat<br>117M params"]
            E7["Expert 7: Multi-turn<br>117M params"]
            E8["Expert 8: Fallback<br>117M params"]
        end

        COMBINE["Weighted combination:<br>MoE(x) = G(x)₁·E₁(x) + G(x)₅·E₅(x)<br>(example: product + knowledge)"]

        OUTPUT["Output: same shape ℝ^{T×d}<br>Only 234M params activated<br>(2× expert = 2×117M)"]

        INPUT --> GATE
        GATE -->|"w=0.65"| E1
        GATE -->|"w=0"| E2
        GATE -->|"w=0"| E3
        GATE -->|"w=0"| E4
        GATE -->|"w=0.35"| E5
        GATE -->|"w=0"| E6
        GATE -->|"w=0"| E7
        GATE -->|"w=0"| E8
        E1 --> COMBINE
        E5 --> COMBINE
        COMBINE --> OUTPUT
    end

    style E1 fill:#c8e6c9
    style E5 fill:#c8e6c9
    style E2 fill:#eeeeee
    style E3 fill:#eeeeee
    style E4 fill:#eeeeee
    style E6 fill:#eeeeee
    style E7 fill:#eeeeee
    style E8 fill:#eeeeee

Expert Specialization Heatmap

graph TB
    subgraph "Expert Utilization Matrix U (after training)"
        direction TB
        HEADER["Query Types →<br>Product | Recommend | Order | Complaint | Knowledge | Chitchat"]

        E1_ROW["Expert 1: 0.42 | 0.12 | 0.08 | 0.05 | 0.28 | 0.05<br>Specialization: Product + Knowledge lookups"]
        E2_ROW["Expert 2: 0.10 | 0.45 | 0.05 | 0.08 | 0.22 | 0.10<br>Specialization: Recommendations"]
        E3_ROW["Expert 3: 0.05 | 0.05 | 0.48 | 0.15 | 0.07 | 0.20<br>Specialization: Order tracking + general"]
        E4_ROW["Expert 4: 0.08 | 0.06 | 0.12 | 0.52 | 0.10 | 0.12<br>Specialization: Complaint handling"]
        E5_ROW["Expert 5: 0.15 | 0.18 | 0.05 | 0.05 | 0.47 | 0.10<br>Specialization: Manga knowledge"]
        E6_ROW["Expert 6: 0.07 | 0.10 | 0.08 | 0.12 | 0.08 | 0.55<br>Specialization: Chitchat + conversational"]
        E7_ROW["Expert 7: 0.15 | 0.20 | 0.18 | 0.15 | 0.17 | 0.15<br>Specialization: Multi-turn context (generalist)"]
        E8_ROW["Expert 8: 0.12 | 0.08 | 0.15 | 0.08 | 0.10 | 0.47<br>Specialization: Fallback / edge cases"]

        HEADER --> E1_ROW --> E2_ROW --> E3_ROW --> E4_ROW --> E5_ROW --> E6_ROW --> E7_ROW --> E8_ROW
    end

    style E1_ROW fill:#c8e6c9
    style E2_ROW fill:#c8e6c9
    style E3_ROW fill:#c8e6c9
    style E4_ROW fill:#bbdefb
    style E5_ROW fill:#c8e6c9
    style E6_ROW fill:#fff9c4
    style E7_ROW fill:#eeeeee
    style E8_ROW fill:#fff9c4

Load Balancing Dynamics

graph LR
    subgraph "Without Load Balancing Loss"
        WO_INIT["Epoch 0: Uniform<br>12.5% per expert"]
        WO_MID["Epoch 5: Collapse begins<br>Expert 1: 35%<br>Expert 3: 28%<br>Others: 5-8%"]
        WO_END["Epoch 20: Collapsed<br>Expert 1: 62%<br>Expert 3: 31%<br>Others: ~1% (dead)"]

        WO_INIT --> WO_MID --> WO_END
    end

    subgraph "With Load Balancing Loss (α=0.01)"
        W_INIT["Epoch 0: Uniform<br>12.5% per expert"]
        W_MID["Epoch 5: Slight specialization<br>Range: 9-18%"]
        W_END["Epoch 20: Balanced<br>Range: 10-16%<br>All experts active"]

        W_INIT --> W_MID --> W_END
    end

    style WO_END fill:#ffcdd2
    style W_END fill:#c8e6c9

MoE Layer Position in Transformer

graph TB
    subgraph "Llama 3 8B MoE Architecture (32 layers)"
        INPUT["Input embeddings"]

        L1["Layer 1-8: Dense FFN<br>(shared low-level features)<br>No MoE — pattern extraction"]

        L2["Layer 9-24: MoE FFN<br>8 experts per layer, top-2 routing<br>16 MoE layers × 939M = 15B params<br>Active: 16 × 234M = 3.7B"]

        L3["Layer 25-32: Dense FFN<br>(shared output projection)<br>No MoE — output generation"]

        HEAD["LM Head → vocabulary logits"]

        INPUT --> L1 --> L2 --> L3 --> HEAD
    end

    subgraph "Why This Split?"
        R1["Early layers learn universal<br>features (tokenization, syntax)<br>→ No need for specialization"]
        R2["Middle layers learn<br>task-specific representations<br>→ Maximum benefit from routing"]
        R3["Late layers unify<br>representations for generation<br>→ Shared output space"]
    end

    style L2 fill:#c8e6c9
    style L1 fill:#e3f2fd
    style L3 fill:#e3f2fd

Routing Decision Flow

sequenceDiagram
    participant T as Token "One Piece volume 107"
    participant G as Gating Network
    participant LB as Load Balancer
    participant E1 as Expert 1 (Product)
    participant E5 as Expert 5 (Knowledge)
    participant C as Combiner

    T->>G: x = hidden state (d=4096)
    G->>G: z = W_g · x + ε (8 logits)
    G->>G: p = softmax(z) = [0.38, 0.05, 0.02, 0.01, 0.42, 0.03, 0.06, 0.03]
    G->>G: top-2 = {Expert 5: 0.42, Expert 1: 0.38}

    G->>LB: Check capacity for E1, E5
    LB->>LB: E1: 58/64 slots used → OK
    LB->>LB: E5: 61/64 slots used → OK

    G->>E1: Route with weight 0.38/(0.38+0.42) = 0.475
    G->>E5: Route with weight 0.42/(0.38+0.42) = 0.525

    E1->>E1: FFN: SiLU(W_gate · x) ⊙ (W_up · x)
    E5->>E5: FFN: SiLU(W_gate · x) ⊙ (W_up · x)

    E1->>C: Product expert output
    E5->>C: Knowledge expert output
    C->>C: 0.475 × E1(x) + 0.525 × E5(x)
    C->>T: Combined MoE output

Implementation Deep-Dive

MoE Layer Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional


class ExpertFFN(nn.Module):
    """Single expert feed-forward network (SwiGLU architecture)."""

    def __init__(self, d_model: int = 4096, d_ff: int = 14336):
        super().__init__()
        self.gate_proj = nn.Linear(d_model, d_ff, bias=False)
        self.up_proj = nn.Linear(d_model, d_ff, bias=False)
        self.down_proj = nn.Linear(d_ff, d_model, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))


class TopKGating(nn.Module):
    """Top-K gating network with load balancing loss."""

    def __init__(
        self, d_model: int, num_experts: int, top_k: int = 2,
        noise_std: float = 0.1,
    ):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.noise_std = noise_std
        self.gate = nn.Linear(d_model, num_experts, bias=False)

    def forward(
        self, x: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Args:
            x: (batch, seq_len, d_model)
        Returns:
            gates: (batch, seq_len, num_experts) — routing weights
            indices: (batch, seq_len, top_k) — selected expert indices
            balance_loss: scalar load balancing loss
        """
        # Compute logits
        logits = self.gate(x)  # (B, T, E)

        # Add noise during training for exploration
        if self.training:
            noise = torch.randn_like(logits) * self.noise_std
            logits = logits + noise

        # Top-K selection
        top_k_logits, indices = torch.topk(logits, self.top_k, dim=-1)
        gates = F.softmax(top_k_logits, dim=-1)  # (B, T, K)

        # Load balancing loss
        # f_i: fraction of tokens routed to expert i
        # p_i: average routing probability for expert i
        routing_probs = F.softmax(logits, dim=-1)  # (B, T, E)
        p = routing_probs.mean(dim=[0, 1])  # (E,)

        # One-hot for selected experts
        mask = torch.zeros_like(routing_probs)
        mask.scatter_(-1, indices, 1.0)
        f = mask.mean(dim=[0, 1])  # (E,)

        balance_loss = self.num_experts * (f * p).sum()

        # Router z-loss
        z_loss = (torch.logsumexp(logits, dim=-1) ** 2).mean()

        return gates, indices, balance_loss + 0.001 * z_loss


class MoELayer(nn.Module):
    """Mixture of Experts layer replacing standard FFN."""

    def __init__(
        self,
        d_model: int = 4096,
        d_ff: int = 14336,
        num_experts: int = 8,
        top_k: int = 2,
        capacity_factor: float = 1.25,
    ):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.capacity_factor = capacity_factor

        self.gating = TopKGating(d_model, num_experts, top_k)
        self.experts = nn.ModuleList([
            ExpertFFN(d_model, d_ff) for _ in range(num_experts)
        ])

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            x: (batch, seq_len, d_model)
        Returns:
            output: (batch, seq_len, d_model)
            aux_loss: load balancing loss
        """
        B, T, D = x.shape
        gates, indices, aux_loss = self.gating(x)

        # Compute capacity
        capacity = int(self.capacity_factor * T / self.num_experts)

        # Gather expert outputs
        output = torch.zeros_like(x)
        expert_counts = torch.zeros(self.num_experts, device=x.device)

        for k in range(self.top_k):
            expert_idx = indices[:, :, k]  # (B, T)
            gate_weight = gates[:, :, k]   # (B, T)

            for e_idx in range(self.num_experts):
                mask = expert_idx == e_idx  # (B, T)
                if not mask.any():
                    continue

                # Capacity check
                count = mask.sum().item()
                expert_counts[e_idx] += count

                if count > capacity * B:
                    # Drop excess tokens (keep first `capacity * B`)
                    flat_mask = mask.reshape(-1)
                    indices_true = flat_mask.nonzero().squeeze(-1)
                    drop = indices_true[capacity * B:]
                    flat_mask[drop] = False
                    mask = flat_mask.reshape(B, T)

                # Route to expert
                expert_input = x[mask]  # (num_tokens, D)
                if expert_input.numel() > 0:
                    expert_output = self.experts[e_idx](expert_input)
                    output[mask] += gate_weight[mask].unsqueeze(-1) * expert_output

        return output, aux_loss

MoE Trainer for MangaAssist

from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.utils.data import DataLoader


class MoETrainer:
    """
    Train a Mixture of Experts model for MangaAssist.
    Uses LoRA on the gating network + expert selection.
    """

    def __init__(self, base_model: str = "meta-llama/Llama-3-8b-hf"):
        self.tokenizer = AutoTokenizer.from_pretrained(base_model)
        self.tokenizer.pad_token = self.tokenizer.eos_token

        # Load base model and replace FFN layers with MoE
        self.model = AutoModelForCausalLM.from_pretrained(
            base_model,
            torch_dtype=torch.bfloat16,
            device_map="auto",
        )

        # Replace middle FFN layers (9-24) with MoE
        self._replace_ffn_with_moe(layers=range(8, 24))

    def _replace_ffn_with_moe(self, layers: range):
        """Replace selected FFN layers with MoE layers."""
        for idx in layers:
            layer = self.model.model.layers[idx]
            d_model = layer.mlp.gate_proj.in_features
            d_ff = layer.mlp.gate_proj.out_features

            # Initialize first expert from the original FFN weights
            moe = MoELayer(d_model=d_model, d_ff=d_ff, num_experts=8, top_k=2)

            # Copy original FFN weights to expert 0
            with torch.no_grad():
                moe.experts[0].gate_proj.weight.copy_(layer.mlp.gate_proj.weight)
                moe.experts[0].up_proj.weight.copy_(layer.mlp.up_proj.weight)
                moe.experts[0].down_proj.weight.copy_(layer.mlp.down_proj.weight)

                # Initialize other experts with noise around expert 0
                for e in range(1, 8):
                    for param_name in ["gate_proj", "up_proj", "down_proj"]:
                        src = getattr(moe.experts[0], param_name).weight
                        tgt = getattr(moe.experts[e], param_name).weight
                        tgt.copy_(src + 0.01 * torch.randn_like(src))

            layer.mlp = moe

    def train(
        self,
        train_data: list[dict],
        epochs: int = 5,
        lr: float = 1e-5,
        balance_coeff: float = 0.01,
    ):
        """Train with combined language modeling and load balancing loss."""
        loader = DataLoader(train_data, batch_size=2, shuffle=True)
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr)

        for epoch in range(epochs):
            total_lm_loss = 0
            total_aux_loss = 0

            for batch in loader:
                inputs = self.tokenizer(
                    batch["text"], return_tensors="pt", padding=True,
                    truncation=True, max_length=2048,
                ).to(self.model.device)

                outputs = self.model(**inputs, labels=inputs["input_ids"])
                lm_loss = outputs.loss

                # Collect auxiliary losses from all MoE layers
                aux_loss = torch.tensor(0.0, device=self.model.device)
                for layer in self.model.model.layers:
                    if hasattr(layer.mlp, 'gating'):
                        # MoE layer stores aux_loss from last forward
                        if hasattr(layer.mlp, '_last_aux_loss'):
                            aux_loss += layer.mlp._last_aux_loss

                total_loss = lm_loss + balance_coeff * aux_loss

                total_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                optimizer.step()
                optimizer.zero_grad()

                total_lm_loss += lm_loss.item()
                total_aux_loss += aux_loss.item()

            avg_lm = total_lm_loss / len(loader)
            avg_aux = total_aux_loss / len(loader)
            print(
                f"Epoch {epoch+1}: LM loss={avg_lm:.4f}, "
                f"Aux loss={avg_aux:.4f}"
            )

    def analyze_specialization(self, test_data: list[dict]) -> dict:
        """Analyze which experts specialize in which query types."""
        self.model.eval()
        utilization = {}  # {query_type: {expert_id: count}}

        hooks = []
        routing_log = []

        def make_hook(layer_idx):
            def hook_fn(module, input, output):
                if isinstance(output, tuple) and len(output) == 2:
                    # MoE layer returns (output, aux_loss)
                    pass
                # Log routing decisions
                _, indices, _ = module.gating(input[0])
                routing_log.append({
                    "layer": layer_idx,
                    "indices": indices.detach().cpu(),
                })
            return hook_fn

        for idx, layer in enumerate(self.model.model.layers):
            if hasattr(layer.mlp, 'gating'):
                hooks.append(
                    layer.mlp.register_forward_hook(make_hook(idx))
                )

        for ex in test_data:
            routing_log.clear()
            query_type = ex["query_type"]

            inputs = self.tokenizer(
                ex["text"], return_tensors="pt",
            ).to(self.model.device)

            with torch.no_grad():
                self.model(**inputs)

            # Aggregate routing decisions
            if query_type not in utilization:
                utilization[query_type] = {i: 0 for i in range(8)}

            for entry in routing_log:
                for expert_id in entry["indices"].flatten().tolist():
                    utilization[query_type][expert_id] += 1

        for hook in hooks:
            hook.remove()

        # Normalize to percentages
        for qt in utilization:
            total = sum(utilization[qt].values())
            if total > 0:
                utilization[qt] = {
                    k: v / total for k, v in utilization[qt].items()
                }

        return utilization

SageMaker Deployment for MoE

import sagemaker
from sagemaker.huggingface import HuggingFaceModel


def deploy_moe_model(model_path: str, instance_type: str = "ml.g5.2xlarge"):
    """Deploy MoE model to SageMaker endpoint."""

    # MoE requires more memory but same inference FLOPS
    # g5.2xlarge: 24GB GPU, sufficient for 32B total / 8B active
    hub = {
        "HF_MODEL_ID": model_path,
        "SM_NUM_GPUS": "1",
        "MAX_INPUT_LENGTH": "2048",
        "MAX_TOTAL_TOKENS": "4096",
    }

    model = HuggingFaceModel(
        env=hub,
        role=sagemaker.get_execution_role(),
        image_uri=sagemaker.image_uris.retrieve(
            framework="huggingface-llm",
            region="us-west-2",
            version="2.3.0",
            instance_type=instance_type,
        ),
    )

    predictor = model.deploy(
        initial_instance_count=1,
        instance_type=instance_type,
        endpoint_name="mangaassist-moe-endpoint",
    )

    return predictor

Group Discussion: Key Decision Points

Decision Point 1: Dense vs MoE — When It's Worth It

Marcus (Architect): Is MoE warranted for our scale?

Factor Dense 8B MoE 32B (top-2) Verdict
Active compute per query 8B params 8B params Tie
Total model memory 16GB (FP16) 64GB (FP16) → 16GB (INT4) MoE needs quantization
Quality on specialized tasks 88% avg 93% avg MoE +5%
Training complexity Low High (gating, balancing) Dense easier
Serving complexity Low Medium (expert routing) Dense easier

Jordan (MLOps): MoE is over-engineered for MangaAssist V1. Our 6 query types can be handled by a well-fine-tuned dense model. MoE becomes valuable when we scale beyond 15+ distinct specializations.

Sam (PM): What about the Mixtral approach — using a pre-trained MoE model instead of building our own?

Priya (ML Engineer): Mixtral 8x7B is pre-trained with 8 experts. We could fine-tune it with LoRA instead of building MoE from scratch. The experts are already specialized in different language patterns; fine-tuning adapts them to our domain.

Approach Training Cost Quality Complexity
Build MoE from Llama 3 8B $340 93% Very High
Fine-tune Mixtral 8x7B with LoRA $85 91% Low
Fine-tune dense Llama 3 8B with LoRA $48 88% Very Low

Resolution: For MangaAssist V1, continue with dense Llama 3 8B + LoRA (doc 04). For V2 (15+ query types, multi-language), evaluate Mixtral 8x7B + LoRA as a drop-in upgrade. Building custom MoE is only justified at V3 scale (50+ specializations, 100K+ queries/day).

Decision Point 2: Number of Experts and Top-K

Aiko (Data Scientist): Ablation study on expert count:

Config Quality Balance Memory Specialization Score
4 experts, top-1 89% 0.92 28GB 0.41
4 experts, top-2 91% 0.95 28GB 0.38
8 experts, top-1 90% 0.87 48GB 0.52
8 experts, top-2 93% 0.91 48GB 0.48
16 experts, top-2 93.5% 0.82 80GB 0.56
16 experts, top-4 94% 0.88 80GB 0.45

Priya (ML Engineer): 8 experts, top-2 is optimal: matches our query type count (6 types + 2 generalist), maintains good balance (0.91), and fits in 48GB (16GB with INT4 quantization). Going to 16 experts gives diminishing returns (+0.5%) at 67% more memory.

Marcus (Architect): Top-K=2 ensures redundancy: if expert 1 (product) makes a factual error, expert 5 (knowledge) can correct it through the weighted combination.

Resolution: 8 experts, top-2. This matches our 6 query types with 2 generalist experts. One expert per major domain + buffer.

Decision Point 3: Load Balancing Coefficient α

Aiko (Data Scientist): The balance coefficient $\alpha$ controls specialization vs uniformity:

α Expert Utilization Range Max Quality Dead Experts Convergence
0 (no balancing) 1-62% 94% 4/8 dead Unstable
0.001 3-35% 93.5% ⅛ dead Stable
0.01 8-18% 93% 0 dead Stable
0.1 11-14% 90% 0 dead Very stable
1.0 12.4-12.6% 86% 0 dead Uniform (defeats MoE)

Jordan (MLOps): $\alpha = 0.01$ is the standard. No dead experts, reasonable specialization range (8-18%), and 93% quality. Lower values risk dead experts, higher values force uniformity.

Resolution: $\alpha = 0.01$ as default. Monitor expert utilization during training; if any expert drops below 5%, increase to 0.02.


Research Paper References

1. Switch Transformers: Scaling to Trillion Parameter Models (Fedus et al., 2021)

Key contribution: Simplified MoE by using top-1 routing (only one expert per token) instead of top-2, reducing communication costs. Introduced the concept of capacity factor and the simplified load balancing loss. Showed that MoE models can scale to 1.6 trillion parameters while maintaining training efficiency. The key insight: top-1 routing with proper capacity factor works just as well as top-2 for many tasks, while being 2× faster.

Relevance to MangaAssist: Switch Transformer's load balancing loss and capacity factor directly apply to our gating network design. While we use top-2 for quality reasons (MangaAssist needs redundancy), the monitoring and balancing techniques from this paper are essential.

2. Mixtral of Experts (Jiang et al., 2024)

Key contribution: Released a pre-trained MoE model (8 experts, top-2 routing) that achieves GPT-3.5-level performance at 5× less inference compute. Mixtral 8x7B has 46.7B total parameters but activates only 12.9B per token. The paper showed that pre-trained MoE models can be effective for downstream fine-tuning, with experts naturally specializing in different language patterns (coding, reasoning, multilingual).

Relevance to MangaAssist: Mixtral is our prime candidate for the V2 upgrade. Instead of building custom MoE, we can fine-tune Mixtral with LoRA, benefiting from pre-trained expert specialization. The 12.9B active parameters provide better quality than our current 8B dense model at comparable inference cost.

3. GShard: Scaling Giant Models with Conditional Computation (Lepikhin et al., 2020)

Key contribution: Introduced expert parallelism for distributed training of MoE models across multiple devices. Key innovations: (1) auxiliary loss for load balancing, (2) random routing for the second-choice expert to improve exploration, (3) group-level top-2 routing that ensures each group of tokens uses at most $C$ capacity per expert. Scaled to 600B parameters on 2048 TPU chips.

Relevance to MangaAssist: GShard's capacity management techniques matter for production deployment. When we exceed single-GPU memory (V3 scale), expert parallelism from GShard enables splitting experts across devices, keeping inference latency constant while scaling model capacity.


Production Results

MoE Evaluation on MangaAssist Test Set (V2 Projection)

Metric Dense Llama 3 8B + LoRA Mixtral 8x7B + LoRA Custom MoE (8 experts)
Product accuracy 88% 92% 94%
Recommendation NDCG 0.76 0.82 0.84
Order tracking accuracy 91% 93% 95%
Complaint CSAT 3.9/5 4.⅖ 4.⅘
Knowledge accuracy 85% 90% 92%
Overall weighted 88% 91% 93%
Inference latency (p50) 480ms 520ms 510ms
Memory (INT4) 4.1GB 12GB 16GB
Monthly cost (SageMaker) $170 $340 $510

Cost

Item Dense LoRA Mixtral LoRA Custom MoE
Training cost $48 $85 $340
Monthly inference $170 $340 $510
Quality improvement baseline +3% +5%
CPQ (cost per quality point) - $56/point $68/point

Verdict: Mixtral + LoRA gives the best CPQ ($56/point) and is the V2 recommendation. Custom MoE is only justified if quality requirements exceed what Mixtral can achieve.