LOCAL PREVIEW View on GitHub

05. Knowledge Distillation Pipeline — Compressing Large Models into Production-Ready Students

Problem Statement and MangaAssist Context

MangaAssist's inference pipeline has strict latency budgets: intent classification must complete within 15ms, and the reranker within 50ms. Our fine-tuned DistilBERT (66M params) achieves 92.1% accuracy at 15ms, but what if we want even faster inference? Or what if we want to transfer Claude 3.5 Sonnet's reasoning quality into a smaller model that can run on Lambda without Bedrock costs?

Knowledge distillation trains a small "student" model to mimic a large "teacher" model's behavior, achieving 85-97% of the teacher's quality at 2-10× faster inference and 5-50× fewer parameters.

MangaAssist Distillation Targets

Teacher → Student Teacher Quality Student Quality Latency Improvement Use Case
DistilBERT (66M) → TinyBERT (14.5M) 92.1% intent accuracy 89.3% 15ms → 5ms Edge/Lambda cold start
Claude 3.5 Sonnet → Llama 3 8B 93.7% manga QA 82.4% 500-1500ms → 100-200ms Self-hosted fallback
ms-marco-MiniLM (33M) → ONNX 4-layer (11M) 0.84 NDCG@3 0.79 NDCG@3 50ms → 15ms Inline reranking
Ensemble (3 models) → Single DistilBERT 94.2% composite 91.5% 80ms (total) → 15ms Unified classifier

Mathematical Foundations

Standard Cross-Entropy (Hard Labels)

In standard training, a student model $S$ learns from one-hot labels $\mathbf{y}$:

$$\mathcal{L}{\text{hard}} = -\sum{c=1}^{C} y_c \log p_S(c | \mathbf{x})$$

where $y_c \in {0, 1}$ and $p_S$ is the student's softmax output. This captures what the correct answer is, but not how confident the teacher is or which alternatives are plausible.

Knowledge Distillation Loss — The KL-Divergence Formulation

Hinton et al. (2015) proposed training the student to match the teacher's soft probability distribution over all classes, not just the correct label:

$$\mathcal{L}{\text{KD}} = T^2 \cdot D{\text{KL}}(p_T^{(\tau)} | p_S^{(\tau)})$$

where $p_T^{(\tau)}$ and $p_S^{(\tau)}$ are "softened" probability distributions computed with temperature $\tau = T$:

$$p_i^{(\tau)} = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)}$$

$z_i$ are the logits (pre-softmax scores). The temperature $T$ controls how "soft" the distribution is.

Full expansion of KL-divergence:

$$D_{\text{KL}}(p_T | p_S) = \sum_{c=1}^{C} p_T^{(\tau)}© \log \frac{p_T^{(\tau)}©}{p_S^{(\tau)}©}$$

$$= \sum_{c=1}^{C} p_T^{(\tau)}© \log p_T^{(\tau)}© - \sum_{c=1}^{C} p_T^{(\tau)}© \log p_S^{(\tau)}©$$

The first term is the teacher's entropy (constant w.r.t. student params), so the gradient only depends on the cross-entropy between teacher and student soft distributions.

Why multiply by $T^2$?

When we compute gradients of the softened logits, the temperature introduces a $1/T$ factor in each gradient. Since the KD loss involves both teacher and student softened probs, the net effect is a $1/T^2$ scaling. Multiplying by $T^2$ compensates, ensuring the gradient magnitude is comparable to the hard-label loss regardless of temperature choice.

Formally, for logit $z_i$:

$$\frac{\partial p_i^{(\tau)}}{\partial z_i} = \frac{1}{T} p_i^{(\tau)}(1 - p_i^{(\tau)})$$

The $1/T$ factor appears twice (once from teacher, once from student) in the KL gradient, hence $T^2$ correction.

Temperature's Effect on the Soft Distribution

Consider a teacher's logits for the intent "product_inquiry": $\mathbf{z} = [5.2, 2.1, 1.8, 0.3, -1.5, -2.0, -2.4, -3.1, -3.5, -4.0]$

Temperature Top-1 Prob Top-2 Prob Top-3 Prob Entropy Information Transfer
T = 1 0.89 0.04 0.03 0.58 Almost one-hot — hard labels
T = 2 0.62 0.11 0.09 1.42 Moderate — reveals second choices
T = 4 0.35 0.15 0.14 1.98 Soft — shows full distribution
T = 8 0.19 0.14 0.13 2.18 Very soft — nearly uniform
T = 20 0.12 0.11 0.11 2.28 Too soft — all classes look equal

Intuition: At T=1, the teacher says "this is product_inquiry, period." At T=4, the teacher says "this is mainly product_inquiry, but it has elements of order_status (11%) and recommendation (9%) — those are the closest alternatives." The student learns these inter-class relationships, which is impossible from hard labels alone.

Optimal T for MangaAssist: Our intent classifier has 10 classes with moderate confusion between similar intents (product_inquiry vs recommendation, return vs order_status). We found T=4 optimal — it reveals these confusable pairs without washing out the signal.

Combined Distillation Loss

The final training loss combines KD loss with the standard hard-label loss:

$$\mathcal{L} = (1 - \alpha) \cdot \mathcal{L}{\text{hard}} + \alpha \cdot T^2 \cdot D{\text{KL}}(p_T^{(\tau)} | p_S^{(\tau)})$$

where $\alpha$ balances the two objectives.

$\alpha$ Effect
0.0 Pure hard-label training (no distillation)
0.3 Emphasis on ground truth, distillation as regularizer
0.5 Equal weight (typical starting point)
0.7 Emphasis on teacher's soft knowledge
1.0 Pure distillation (ignore hard labels)

For our experiment, $\alpha = 0.7$ worked best. The teacher's soft distribution carries more information than the labels alone (teacher already achieves 92.1%, so its distributions are well-calibrated).

Feature-Based Distillation (TinyBERT/DistilBERT Layer Matching)

Beyond matching output distributions, we can match intermediate representations:

$$\mathcal{L}{\text{feature}} = \sum{l=1}^{L_S} \text{MSE}(\mathbf{h}_S^{(l)} \mathbf{W}_l, \mathbf{h}_T^{(m(l))})$$

where: - $\mathbf{h}_S^{(l)}$ is the student's $l$-th layer hidden state - $\mathbf{h}_T^{(m(l))}$ is the teacher's $m(l)$-th layer hidden state (layer mapping function $m$) - $\mathbf{W}_l$ is a learnable linear transformation (because dimensions may differ)

Layer mapping for DistilBERT → TinyBERT:

Student Layer (TinyBERT, 4 layers) Teacher Layer (DistilBERT, 6 layers) Rationale
Layer 0 (embeddings) Layer 0 (embeddings) Token representations
Layer 1 Layer 2 Low-level syntax
Layer 2 Layer 4 Mid-level semantics
Layer 3 Layer 6 Task-specific features

Attention transfer:

TinyBERT adds another term — matching teacher-student attention matrices:

$$\mathcal{L}{\text{attn}} = \sum{l=1}^{L_S} \frac{1}{h} \sum_{i=1}^{h} \text{MSE}(\mathbf{A}_S^{(l,i)}, \mathbf{A}_T^{(m(l),i)})$$

where $\mathbf{A}^{(l,i)} \in \mathbb{R}^{n \times n}$ is the attention weight matrix for head $i$ in layer $l$.

The total distillation loss becomes:

$$\mathcal{L}{\text{total}} = \mathcal{L}{\text{KD}} + \beta \cdot \mathcal{L}{\text{feature}} + \gamma \cdot \mathcal{L}{\text{attn}}$$

Gradient Analysis: Why Distillation Works Better Than Training From Scratch

In standard training, the gradient for class $c$ (when $c$ is not the correct class) is:

$$\frac{\partial \mathcal{L}_{\text{hard}}}{\partial z_c} = p_S© \quad \text{(for incorrect classes)}$$

For a well-trained student, $p_S© \approx 0$ for most wrong classes, so gradients are tiny — the student learns almost nothing about inter-class relationships.

In distillation, the gradient for class $c$ is:

$$\frac{\partial \mathcal{L}_{\text{KD}}}{\partial z_c} = p_S^{(\tau)}© - p_T^{(\tau)}©$$

Even for wrong classes, $p_T^{(\tau)}©$ can be non-trivially large (e.g., 0.11 for "order_status" when the true intent is "product_inquiry"). This non-zero target means the student receives meaningful gradients for every class at every step.

Effective gradient magnitude comparison (for 10-class MangaAssist intent):

Class Type Hard-Label Gradient Distillation Gradient (T=4)
Correct class $p_S - 1$ ≈ -0.08 $p_S^{(\tau)} - p_T^{(\tau)}$ ≈ -0.10
Top-2 confusable $p_S$ ≈ 0.04 $(p_S^{(\tau)} - p_T^{(\tau)})$ ≈ 0.08
Top-3 confusable $p_S$ ≈ 0.02 $(p_S^{(\tau)} - p_T^{(\tau)})$ ≈ 0.05
Irrelevant class $p_S$ ≈ 0.001 $(p_S^{(\tau)} - p_T^{(\tau)})$ ≈ 0.01

Distillation provides 2-10× stronger gradients for confusable classes, which is exactly where the student needs the most guidance.


Model Internals — Layer-by-Layer Diagrams

Distillation Architecture Overview

graph TB
    subgraph "Teacher Pipeline (Frozen — DistilBERT 66M)"
        T1["Input: 'When will my manga arrive?'"]
        T2["Embedding Layer (768d)"]
        T3["Transformer Layer 1"]
        T4["Transformer Layer 2"]
        T5["Transformer Layer 3"]
        T6["Transformer Layer 4"]
        T7["Transformer Layer 5"]
        T8["Transformer Layer 6"]
        T9["Classification Head"]
        T10["Teacher Logits z_T ∈ ℝ¹⁰"]
        T11["Softened Probs p_T^(τ)<br>T=4"]

        T1 --> T2 --> T3 --> T4 --> T5 --> T6 --> T7 --> T8 --> T9 --> T10 --> T11
    end

    subgraph "Student Pipeline (Training — TinyBERT 14.5M)"
        S1["Same Input"]
        S2["Embedding Layer (312d)"]
        S3["Transformer Layer 1"]
        S4["Transformer Layer 2"]
        S5["Transformer Layer 3"]
        S6["Transformer Layer 4"]
        S7["Classification Head"]
        S8["Student Logits z_S ∈ ℝ¹⁰"]
        S9["Softened Probs p_S^(τ)<br>T=4"]

        S1 --> S2 --> S3 --> S4 --> S5 --> S6 --> S7 --> S8 --> S9
    end

    T11 --> KD["KL-Divergence Loss<br>T²·D_KL(p_T ∥ p_S)"]
    S9 --> KD
    S8 --> HARD["Hard-Label Loss<br>CE(y, p_S)"]
    KD --> TOTAL["Total Loss = 0.3·L_hard + 0.7·L_KD"]
    HARD --> TOTAL
    TOTAL --> GRAD["Backprop → Student Only"]
    GRAD --> S2

    T2 -.->|"Feature Match (MSE)"| S2
    T4 -.->|"Feature Match (MSE)"| S3
    T6 -.->|"Feature Match (MSE)"| S4
    T8 -.->|"Feature Match (MSE)"| S6

    style T3 fill:#bbdefb
    style T4 fill:#bbdefb
    style T5 fill:#bbdefb
    style T6 fill:#bbdefb
    style T7 fill:#bbdefb
    style T8 fill:#bbdefb
    style S3 fill:#fff9c4
    style S4 fill:#fff9c4
    style S5 fill:#fff9c4
    style S6 fill:#fff9c4

Temperature's Effect on Softmax Distributions

graph TD
    subgraph "Teacher Logits for 'When will my manga arrive?'"
        L["z = [5.2, 2.1, 1.8, 0.3, -1.5, -2.0, -2.4, -3.1, -3.5, -4.0]<br>order_status=5.2, product_inquiry=2.1, shipping=1.8, ..."]
    end

    subgraph "T=1 (Standard Softmax)"
        T1A["order_status: 89%<br>product_inquiry: 4%<br>shipping: 3%<br>All others: ~0%<br><br>Looks like hard label"]
    end

    subgraph "T=4 (Sweet Spot)"
        T4A["order_status: 35%<br>product_inquiry: 15%<br>shipping: 14%<br>return: 8%<br>recommendation: 5%<br>...<br><br>Reveals inter-class structure"]
    end

    subgraph "T=20 (Too Soft)"
        T20A["order_status: 12%<br>product_inquiry: 11%<br>shipping: 11%<br>return: 10%<br>recommendation: 10%<br>...<br><br>Nearly uniform — no signal"]
    end

    L --> T1A
    L --> T4A
    L --> T20A

    style T4A fill:#c8e6c9

Gradient Flow: Distillation vs Standard Training

graph LR
    subgraph "Standard Training (Hard Labels)"
        HL["Loss: CE(one-hot, p_S)"]
        HL --> G1["∂L/∂z_correct ≈ -0.08<br>(strong signal)"]
        HL --> G2["∂L/∂z_confusable ≈ 0.04<br>(weak signal)"]
        HL --> G3["∂L/∂z_irrelevant ≈ 0.001<br>(near zero)"]
    end

    subgraph "Distillation Training (Soft Labels, T=4)"
        SL["Loss: KL(p_T^τ, p_S^τ)"]
        SL --> G4["∂L/∂z_correct ≈ -0.10<br>(strong signal)"]
        SL --> G5["∂L/∂z_confusable ≈ 0.08<br>(2× stronger)"]
        SL --> G6["∂L/∂z_irrelevant ≈ 0.01<br>(10× stronger)"]
    end

    G2 -.->|"2× gradient boost"| G5
    G3 -.->|"10× gradient boost"| G6

    style G5 fill:#c8e6c9
    style G6 fill:#c8e6c9

TinyBERT Layer Matching

graph TB
    subgraph "Teacher: DistilBERT (6 layers, 768d)"
        TE["Embedding (768d)"]
        TL1["Layer 1: Syntax patterns"]
        TL2["Layer 2: Phrase-level features"]
        TL3["Layer 3: Semantic roles"]
        TL4["Layer 4: Intent signals"]
        TL5["Layer 5: Context integration"]
        TL6["Layer 6: Task-specific"]
        TE --> TL1 --> TL2 --> TL3 --> TL4 --> TL5 --> TL6
    end

    subgraph "Student: TinyBERT (4 layers, 312d)"
        SE["Embedding (312d)"]
        SL1["Layer 1"]
        SL2["Layer 2"]
        SL3["Layer 3"]
        SL4["Layer 4"]
        SE --> SL1 --> SL2 --> SL3 --> SL4
    end

    TE -.->|"Embed MSE<br>W₀: 312×768"| SE
    TL2 -.->|"Hidden MSE<br>W₁: 312×768"| SL1
    TL4 -.->|"Hidden MSE<br>W₂: 312×768"| SL2
    TL5 -.->|"Attn MSE<br>student heads map<br>to teacher heads"| SL3
    TL6 -.->|"Hidden + Attn MSE<br>W₃: 312×768"| SL4

    style TE fill:#bbdefb
    style TL1 fill:#bbdefb
    style TL2 fill:#bbdefb
    style TL3 fill:#bbdefb
    style TL4 fill:#bbdefb
    style TL5 fill:#bbdefb
    style TL6 fill:#bbdefb
    style SE fill:#fff9c4
    style SL1 fill:#fff9c4
    style SL2 fill:#fff9c4
    style SL3 fill:#fff9c4
    style SL4 fill:#fff9c4

LLM-to-Small-Model Distillation Pipeline

sequenceDiagram
    participant DS as Manga QA Dataset<br>(3K examples)
    participant CL as Claude 3.5 Sonnet<br>(Teacher, via Bedrock)
    participant SA as SageMaker<br>Processing Job
    participant L3 as Llama 3 8B<br>(Student)
    participant EV as Evaluation<br>Suite

    DS->>CL: Batch inference on all 3K questions
    Note over CL: Generate responses + logprobs<br>Cost: ~$15 for 3K examples

    CL->>SA: Teacher responses + logprobs stored in S3

    Note over SA: Augmentation:<br>- Generate 5 paraphrases per example<br>- Teacher scores paraphrases<br>- Filter to keep top-quality pairs<br>Dataset grows: 3K → 12K

    SA->>L3: Train with combined loss:<br>0.3·CE(label, student) + 0.7·KL(teacher, student)

    Note over L3: Training: 4 epochs, lr=2e-5<br>SageMaker g5.2xlarge<br>~4 hours, ~$24

    L3->>EV: Evaluate on held-out set

    Note over EV: Manga QA: 82.4% (teacher: 93.7%)<br>Latency: 120ms (teacher: 800ms)<br>Quality ratio: 87.9%<br>Cost/query: $0.0001 vs $0.003

Implementation Deep-Dive

Output Distillation: DistilBERT → TinyBERT

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
)


class DistillationLoss(nn.Module):
    """
    Combined distillation + hard-label loss.

    L = (1-alpha) * CE(student, labels) + alpha * T^2 * KL(teacher_soft, student_soft)
    """

    def __init__(self, temperature: float = 4.0, alpha: float = 0.7):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha

    def forward(
        self,
        student_logits: torch.Tensor,   # (B, C)
        teacher_logits: torch.Tensor,    # (B, C)
        labels: torch.Tensor,            # (B,)
    ) -> torch.Tensor:
        # Hard-label loss (standard cross-entropy)
        hard_loss = F.cross_entropy(student_logits, labels)

        # Soft-label loss (KL-divergence with temperature)
        student_soft = F.log_softmax(student_logits / self.temperature, dim=-1)
        teacher_soft = F.softmax(teacher_logits / self.temperature, dim=-1)

        # KL(teacher || student) = sum(teacher * log(teacher/student))
        # F.kl_div expects log-probs as first arg, probs as second
        kd_loss = F.kl_div(
            student_soft,
            teacher_soft,
            reduction="batchmean",
        ) * (self.temperature ** 2)

        return (1 - self.alpha) * hard_loss + self.alpha * kd_loss


class DistillationTrainer(Trainer):
    """Custom trainer that runs teacher inference and computes distillation loss."""

    def __init__(self, teacher_model, temperature=4.0, alpha=0.7, **kwargs):
        super().__init__(**kwargs)
        self.teacher = teacher_model
        self.teacher.eval()
        for p in self.teacher.parameters():
            p.requires_grad = False
        self.distill_loss = DistillationLoss(temperature, alpha)

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")

        # Student forward pass
        student_outputs = model(**inputs)
        student_logits = student_outputs.logits

        # Teacher forward pass (no gradient)
        with torch.no_grad():
            teacher_outputs = self.teacher(**inputs)
            teacher_logits = teacher_outputs.logits

        loss = self.distill_loss(student_logits, teacher_logits, labels)

        return (loss, student_outputs) if return_outputs else loss


def distill_intent_classifier():
    """
    Distill DistilBERT (6-layer) teacher into TinyBERT (4-layer) student.
    """
    teacher = AutoModelForSequenceClassification.from_pretrained(
        "./manga_intent_teacher",  # Our fine-tuned DistilBERT
        num_labels=10,
    )
    student = AutoModelForSequenceClassification.from_pretrained(
        "huawei-noah/TinyBERT_General_4L_312D",
        num_labels=10,
    )
    tokenizer = AutoTokenizer.from_pretrained("huawei-noah/TinyBERT_General_4L_312D")

    # Training arguments
    args = TrainingArguments(
        output_dir="./tinybert_distilled",
        num_train_epochs=10,               # More epochs for distillation
        per_device_train_batch_size=64,     # TinyBERT is small — large batch OK
        learning_rate=5e-5,
        warmup_ratio=0.1,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="accuracy",
        bf16=True,
    )

    trainer = DistillationTrainer(
        teacher_model=teacher,
        temperature=4.0,
        alpha=0.7,
        model=student,
        args=args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
    )

    trainer.train()
    return student

Feature-Based Distillation (TinyBERT Two-Stage)

class TinyBERTFeatureLoss(nn.Module):
    """
    TinyBERT's two-stage distillation:
    Stage 1: Intermediate layer matching (embedding + hidden + attention)
    Stage 2: Output KD (prediction layer)
    """

    def __init__(self, teacher_dim: int = 768, student_dim: int = 312):
        super().__init__()
        # Layer mapping: student layer -> teacher layer
        self.layer_map = {0: 0, 1: 2, 2: 4, 3: 6}  # 4S → 6T

        # Linear transforms to project student hidden states to teacher dim
        self.hidden_transforms = nn.ModuleDict({
            str(s): nn.Linear(student_dim, teacher_dim, bias=False)
            for s in self.layer_map.keys()
        })

        # Embedding transform
        self.embed_transform = nn.Linear(student_dim, teacher_dim, bias=False)

    def forward(
        self,
        teacher_hidden_states,   # List of (B, seq_len, 768)
        student_hidden_states,   # List of (B, seq_len, 312)
        teacher_attentions,      # List of (B, heads, seq_len, seq_len)
        student_attentions,      # List of (B, heads, seq_len, seq_len)
    ):
        losses = {}

        # Embedding layer loss
        s_embed = self.embed_transform(student_hidden_states[0])
        losses["embed"] = F.mse_loss(s_embed, teacher_hidden_states[0])

        # Hidden state matching
        hidden_loss = 0
        for s_layer, t_layer in self.layer_map.items():
            s_hidden = self.hidden_transforms[str(s_layer)](
                student_hidden_states[s_layer + 1]  # +1 because [0] is embeddings
            )
            hidden_loss += F.mse_loss(s_hidden, teacher_hidden_states[t_layer + 1])
        losses["hidden"] = hidden_loss / len(self.layer_map)

        # Attention matrix matching
        attn_loss = 0
        for s_layer, t_layer in self.layer_map.items():
            s_attn = student_attentions[s_layer]
            t_attn = teacher_attentions[t_layer]
            # Student has fewer heads — average teacher heads to match
            if s_attn.shape[1] != t_attn.shape[1]:
                ratio = t_attn.shape[1] // s_attn.shape[1]
                t_attn = t_attn.view(t_attn.shape[0], s_attn.shape[1], ratio,
                                     *t_attn.shape[2:]).mean(dim=2)
            attn_loss += F.mse_loss(s_attn, t_attn)
        losses["attention"] = attn_loss / len(self.layer_map)

        total = losses["embed"] + losses["hidden"] + losses["attention"]
        return total, losses

LLM Distillation: Claude → Llama 3 8B

import json
import boto3


def generate_teacher_labels(dataset_path: str, output_path: str):
    """
    Generate teacher labels from Claude 3.5 Sonnet via Bedrock.
    Collect both responses and token-level logprobs for distillation.
    """
    bedrock = boto3.client("bedrock-runtime", region_name="us-east-1")

    with open(dataset_path) as f:
        examples = json.load(f)

    results = []
    for example in examples:
        response = bedrock.invoke_model(
            modelId="anthropic.claude-3-5-sonnet-20241022-v2:0",
            body=json.dumps({
                "anthropic_version": "bedrock-2023-05-31",
                "messages": [
                    {"role": "user", "content": example["question"]}
                ],
                "max_tokens": 512,
                "temperature": 0.0,  # Deterministic for consistent labels
            }),
        )

        result = json.loads(response["body"].read())
        teacher_response = result["content"][0]["text"]

        results.append({
            "question": example["question"],
            "teacher_response": teacher_response,
            "ground_truth": example.get("answer", ""),
        })

    with open(output_path, "w") as f:
        json.dump(results, f, indent=2)

    # Cost estimate: 3K examples × ~500 tokens/response × $0.003/1K tokens ≈ $4.50
    return results


def train_student_llm(teacher_data_path: str):
    """
    Train Llama 3 8B student from Claude teacher responses.
    Uses response-level distillation (since we can't get Claude's logits directly).
    """
    from transformers import AutoModelForCausalLM, AutoTokenizer
    from trl import SFTTrainer, SFTConfig

    model = AutoModelForCausalLM.from_pretrained(
        "meta-llama/Meta-Llama-3-8B",
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
    tokenizer.pad_token = tokenizer.eos_token

    # Load teacher-labeled data
    with open(teacher_data_path) as f:
        data = json.load(f)

    # Format as instruction-following
    formatted = []
    for d in data:
        text = (
            f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
            f"You are MangaAssist, an expert manga advisor.\n"
            f"<|eot_id|><|start_header_id|>user<|end_header_id|>\n"
            f"{d['question']}\n"
            f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
            f"{d['teacher_response']}\n"
            f"<|eot_id|>"
        )
        formatted.append({"text": text})

    sft_config = SFTConfig(
        output_dir="./llama8b_manga_student",
        num_train_epochs=4,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=8,  # Effective batch size: 16
        learning_rate=2e-5,
        warmup_ratio=0.1,
        bf16=True,
        gradient_checkpointing=True,
        max_seq_length=512,
    )

    trainer = SFTTrainer(
        model=model,
        args=sft_config,
        train_dataset=formatted,
        tokenizer=tokenizer,
    )

    trainer.train()
    return model

Group Discussion: Key Decision Points

Decision Point 1: Which Teacher-Student Pair First?

Priya (ML Engineer): We have four distillation opportunities. I recommend prioritizing based on production impact:

Priority Pair Quality Gain Latency Gain Cost Savings
1 DistilBERT → TinyBERT (intent) -2.8% acc 15ms → 5ms ~$0 (both on Lambda)
2 Claude → Llama 3 8B (fallback) -11.3% QA 800ms → 120ms $800/mo Bedrock savings
3 MiniLM → ONNX-4L (reranker) -0.05 NDCG 50ms → 15ms ~$0
4 Ensemble → Single model -2.7% composite 80ms → 15ms Operational simplicity

Sam (PM): The Claude → Llama 3 8B distillation saves $800/month. That's $9,600/year. At 82.4% manga QA vs 93.7%, the quality gap is large, but this is a fallback model for when Bedrock is unavailable. As a fallback at 82.4%, it is much better than showing an error page.

Marcus (Architect): Agreed. A self-hosted fallback that handles 87.9% of Claude's quality means our availability goes from 99.5% (Bedrock SLA) to 99.99% (self-hosted + Bedrock). That is a significant reliability improvement for customer experience.

Jordan (MLOps): The intent classifier distillation (DistilBERT → TinyBERT) is the fastest to validate — we have all the data, both models are small, training takes 20 minutes. I would use it as a proof-of-concept before committing to the LLM distillation (4 hours, $24).

Aiko (Data Scientist): Correct. The intent distillation also has the most controlled evaluation — 10-class accuracy is deterministic. LLM response quality requires human evaluation or an LLM-as-judge setup, which adds uncertainty.

Resolution: Priority order: (1) Intent classifier distillation as proof-of-concept (20 min, validates pipeline). (2) Claude → Llama 3 8B for production fallback ($24, saves $800/month). (3) Reranker post-quantization if latency becomes an issue. (4) Ensemble consolidation deferred to V3.

Decision Point 2: Temperature Selection

Priya (ML Engineer): I swept temperature from 1 to 20 on the intent distillation task:

Temperature Student Accuracy Training Loss Notes
T=1 87.2% 0.42 Minimal knowledge transfer — too confident
T=2 88.1% 0.38 Better — starts revealing alternatives
T=4 89.3% 0.31 Best — clear inter-class structure
T=8 88.7% 0.35 Slightly oversmoothed
T=20 86.4% 0.44 Too uniform — signal washed out

Aiko (Data Scientist): T=4 aligns with theory. Our intent classifier has moderate confusion — "product_inquiry" vs "recommendation" share vocabulary like "suggest," "similar," "like." At T=4, the teacher reveals these confusions without making all intents look equally likely.

The information-theoretic explanation: at T=4, the teacher's entropy is 1.98 bits (out of max $\log_2 10 = 3.32$ bits). This means each soft label carries 1.34 bits of information beyond what a uniform distribution provides. At T=1 the entropy is 0.58 bits (mostly one-hot), and at T=20 it is 2.28 bits (nearly maximal — meaning near-uniform, little to learn from).

Marcus (Architect): The difference between T=4 (89.3%) and T=2 (88.1%) is 1.2%. Between T=4 and T=8 (88.7%) it is 0.6%. T=4 is a robust choice — even if we are slightly off, we lose less than 1% performance.

Resolution: T=4 for all classification distillation tasks. For LLM-to-LLM response distillation (Claude → Llama), temperature is not directly applicable since we use response-level distillation (SFT on teacher outputs), not logit-level KD.

Decision Point 3: Two-Stage vs End-to-End Distillation

Priya (ML Engineer): TinyBERT's original paper proposes two stages: - Stage 1: Feature-based distillation (match embeddings, hidden states, attention matrices) without classification loss - Stage 2: Output-based distillation (match soft labels) with classification loss

My ablation:

Approach Student Accuracy Training Time
Output-only (KD loss) 87.8% 20 min
Feature-only (Stage 1) 82.1% 35 min
Stage 1 then Stage 2 (TinyBERT original) 89.3% 55 min
Combined all losses simultaneously 88.9% 40 min

Aiko (Data Scientist): The two-stage approach works better because feature matching is a pre-training objective — it gives the student internal representations similar to the teacher before the classification loss steers those representations toward the task. If we combine losses simultaneously, the classification gradient can pull representations away from teacher-aligned patterns before the student has time to absorb the teacher's intermediate knowledge.

Jordan (MLOps): The 55-minute training time is still trivial. If it gives us 89.3% vs 87.8%, I would take the two-stage approach. The extra 35 minutes is worth 1.5% quality.

Resolution: Two-stage TinyBERT distillation. Stage 1 (35 min): feature matching on general corpus. Stage 2 (20 min): output distillation on intent-labeled data. Combined quality: 89.3% (vs teacher's 92.1% — a 3.0% gap, acceptable for the 3× latency improvement).

Decision Point 4: Data Augmentation for Distillation

Priya (ML Engineer): Distillation benefits from more data than standard training because the student needs to learn the teacher's behavior across the full input distribution, not just labeled examples. I tested augmentation strategies:

Dataset Size Student Accuracy Teacher Accuracy on Same
Original labeled 3K 87.2% 92.1%
+ Backtranslation 9K 88.5% 91.8%
+ Paraphrasing (Claude) 12K 89.3% 92.0%
+ Unlabeled manga queries 50K 89.8%

Sam (PM): The 50K unlabeled approach is interesting — we do not need labels, just teacher soft predictions on real user queries. We have millions of historical queries. The cost is just the teacher inference: 50K × DistilBERT ≈ $0 (runs on Lambda).

Aiko (Data Scientist): But be careful — historical queries have distribution shift. Queries from 6 months ago may include trends (manga titles) that are no longer relevant. I would use only the last 30 days of queries and deduplicate.

Marcus (Architect): Paraphrasing with Claude costs $15 for the augmented set. That is negligible. I would use both: 12K labeled (original + paraphrased) for Stage 2 output KD, and 50K unlabeled (recent queries with teacher soft labels) for additional soft-label training.

Resolution: Use 12K (original + Claude paraphrasing) for Stage 2 output distillation. Optionally add 50K recent unlabeled queries with teacher soft labels for a Stage 3 self-training step. Total augmentation cost: ~$15 for Claude paraphrasing + ~$0 for teacher inference on 50K queries.


Research Paper References

1. Distilling the Knowledge in a Neural Network (Hinton, Vinyals, Dean, 2015)

Key contribution: Introduced the temperature-scaling softmax for knowledge distillation. Showed that the "dark knowledge" in soft probability distributions — the relative probabilities of incorrect classes — carries information that hard labels discard. A student trained on soft labels at T>1 outperforms one trained on hard labels by 1-5%.

Relevance to MangaAssist: Foundation of all our distillation work. The T=4 temperature reveals which intents are confusable (product_inquiry ↔ recommendation) — information that hard labels cannot convey. This reduces our student's confusion matrix errors on the top-3 confusable pairs by 40%.

2. TinyBERT: Distilling BERT for Natural Language Understanding (Jiao et al., 2019)

Key contribution: Proposed two-stage distillation matching intermediate transformer layers: embeddings, attention matrices, and hidden states. Achieved 96.8% of BERT-base's performance with a 7.5× smaller, 9.4× faster model.

Relevance to MangaAssist: Direct architecture for our DistilBERT → TinyBERT distillation. The layer mapping strategy (student layer $l$ → teacher layer $m(l)$) and attention transfer loss give us 89.3% intent accuracy with a 14.5M parameter model that runs in 5ms on Lambda — critical for cold start scenarios.

3. Patient Knowledge Distillation for BERT Model Compression (Sun et al., 2019)

Key contribution: Instead of matching only the last layer, "patient" distillation matches representations from every $k$-th teacher layer, allowing the student to learn the teacher's intermediate reasoning process. This consistently outperforms last-layer-only distillation by 0.5-1.5%.

Relevance to MangaAssist: Informs our layer mapping strategy. For our 6→4 layer distillation, we use every-2-layer matching ({0→0, 1→2, 2→4, 3→6}) following the patient KD principle.

4. DistilBERT, a distilled version of BERT (Sanh et al., 2019)

Key contribution: Demonstrated that pretraining-time distillation (distilling during the pretraining phase, not just fine-tuning) produces a student that retains 97% of BERT's language understanding while being 60% smaller and 60% faster.

Relevance to MangaAssist: Our teacher (DistilBERT) is itself a distilled model. This creates an interesting chain: BERT → DistilBERT (pre-training distillation) → TinyBERT (task-specific distillation). Each step trades quality for efficiency. BERT → DistilBERT keeps 97%; DistilBERT → TinyBERT keeps 96.9%. Cumulative: TinyBERT retains ~94% of BERT-base's quality at 4.5× compression and 6× speedup.

5. On the Efficacy of Knowledge Distillation (Cho & Hariharan, 2019)

Key contribution: Showed that a larger teacher is not always better — if the teacher-student capacity gap is too large, the student cannot mimic the teacher effectively. An intermediate-sized teacher ("assistant") that bridges the gap can improve distillation quality by 1-3%.

Relevance to MangaAssist: Validates our two-hop distillation chain. Rather than distilling BERT-large (340M) directly to TinyBERT (14.5M), we use DistilBERT (66M) as the teacher — a 4.5× capacity ratio. The paper found capacity ratios above 10× degrade distillation, and our 4.5× ratio is well within the effective range.


Production Deployment Results

DistilBERT → TinyBERT Intent Distillation

Metric DistilBERT (Teacher) TinyBERT (Student) Delta
Accuracy 92.1% 89.3% -2.8%
Parameters 66M 14.5M 4.6× smaller
Latency (Lambda cold) 35ms 12ms 2.9× faster
Latency (Lambda warm) 15ms 5ms 3× faster
Model size 264MB 58MB 4.6× smaller
Lambda memory 512MB 256MB

Claude → Llama 3 8B Fallback Distillation

Metric Claude (Teacher) Llama 3 8B (Student) Delta
Manga QA accuracy 93.7% 82.4% -11.3%
Response quality (human eval, 1-5) 4.6 3.9 -0.7
Latency (p50) 800ms 120ms 6.7× faster
Latency (p99) 2100ms 350ms 6× faster
Cost per query $0.003 $0.0001 30× cheaper
Monthly cost (500K queries) $1,500 $50 $1,450 savings
Availability 99.5% (Bedrock SLA) 99.99% (self-hosted) +0.49%

CPQ Analysis

Deploying the Llama 3 8B fallback costs $50/month (SageMaker g5.2xlarge spot) and provides: - Availability improvement: 99.5% → 99.99% (worth $25K/year in prevented customer drop-off) - Cost savings: $1,450/month × 12 = $17,400/year (if used as primary; $0 if purely fallback) - Quality cost: -11.3% accuracy × 500K queries = 56,500 lower-quality responses/month

As a fallback (activated only during Bedrock outages, ~0.5% of traffic), the quality impact is minimal: 56,500 × 0.005 = 283 lower-quality responses/month — acceptable for maintaining availability.