05. Knowledge Distillation Pipeline — Compressing Large Models into Production-Ready Students
Problem Statement and MangaAssist Context
MangaAssist's inference pipeline has strict latency budgets: intent classification must complete within 15ms, and the reranker within 50ms. Our fine-tuned DistilBERT (66M params) achieves 92.1% accuracy at 15ms, but what if we want even faster inference? Or what if we want to transfer Claude 3.5 Sonnet's reasoning quality into a smaller model that can run on Lambda without Bedrock costs?
Knowledge distillation trains a small "student" model to mimic a large "teacher" model's behavior, achieving 85-97% of the teacher's quality at 2-10× faster inference and 5-50× fewer parameters.
MangaAssist Distillation Targets
| Teacher → Student | Teacher Quality | Student Quality | Latency Improvement | Use Case |
|---|---|---|---|---|
| DistilBERT (66M) → TinyBERT (14.5M) | 92.1% intent accuracy | 89.3% | 15ms → 5ms | Edge/Lambda cold start |
| Claude 3.5 Sonnet → Llama 3 8B | 93.7% manga QA | 82.4% | 500-1500ms → 100-200ms | Self-hosted fallback |
| ms-marco-MiniLM (33M) → ONNX 4-layer (11M) | 0.84 NDCG@3 | 0.79 NDCG@3 | 50ms → 15ms | Inline reranking |
| Ensemble (3 models) → Single DistilBERT | 94.2% composite | 91.5% | 80ms (total) → 15ms | Unified classifier |
Mathematical Foundations
Standard Cross-Entropy (Hard Labels)
In standard training, a student model $S$ learns from one-hot labels $\mathbf{y}$:
$$\mathcal{L}{\text{hard}} = -\sum{c=1}^{C} y_c \log p_S(c | \mathbf{x})$$
where $y_c \in {0, 1}$ and $p_S$ is the student's softmax output. This captures what the correct answer is, but not how confident the teacher is or which alternatives are plausible.
Knowledge Distillation Loss — The KL-Divergence Formulation
Hinton et al. (2015) proposed training the student to match the teacher's soft probability distribution over all classes, not just the correct label:
$$\mathcal{L}{\text{KD}} = T^2 \cdot D{\text{KL}}(p_T^{(\tau)} | p_S^{(\tau)})$$
where $p_T^{(\tau)}$ and $p_S^{(\tau)}$ are "softened" probability distributions computed with temperature $\tau = T$:
$$p_i^{(\tau)} = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)}$$
$z_i$ are the logits (pre-softmax scores). The temperature $T$ controls how "soft" the distribution is.
Full expansion of KL-divergence:
$$D_{\text{KL}}(p_T | p_S) = \sum_{c=1}^{C} p_T^{(\tau)}© \log \frac{p_T^{(\tau)}©}{p_S^{(\tau)}©}$$
$$= \sum_{c=1}^{C} p_T^{(\tau)}© \log p_T^{(\tau)}© - \sum_{c=1}^{C} p_T^{(\tau)}© \log p_S^{(\tau)}©$$
The first term is the teacher's entropy (constant w.r.t. student params), so the gradient only depends on the cross-entropy between teacher and student soft distributions.
Why multiply by $T^2$?
When we compute gradients of the softened logits, the temperature introduces a $1/T$ factor in each gradient. Since the KD loss involves both teacher and student softened probs, the net effect is a $1/T^2$ scaling. Multiplying by $T^2$ compensates, ensuring the gradient magnitude is comparable to the hard-label loss regardless of temperature choice.
Formally, for logit $z_i$:
$$\frac{\partial p_i^{(\tau)}}{\partial z_i} = \frac{1}{T} p_i^{(\tau)}(1 - p_i^{(\tau)})$$
The $1/T$ factor appears twice (once from teacher, once from student) in the KL gradient, hence $T^2$ correction.
Temperature's Effect on the Soft Distribution
Consider a teacher's logits for the intent "product_inquiry": $\mathbf{z} = [5.2, 2.1, 1.8, 0.3, -1.5, -2.0, -2.4, -3.1, -3.5, -4.0]$
| Temperature | Top-1 Prob | Top-2 Prob | Top-3 Prob | Entropy | Information Transfer |
|---|---|---|---|---|---|
| T = 1 | 0.89 | 0.04 | 0.03 | 0.58 | Almost one-hot — hard labels |
| T = 2 | 0.62 | 0.11 | 0.09 | 1.42 | Moderate — reveals second choices |
| T = 4 | 0.35 | 0.15 | 0.14 | 1.98 | Soft — shows full distribution |
| T = 8 | 0.19 | 0.14 | 0.13 | 2.18 | Very soft — nearly uniform |
| T = 20 | 0.12 | 0.11 | 0.11 | 2.28 | Too soft — all classes look equal |
Intuition: At T=1, the teacher says "this is product_inquiry, period." At T=4, the teacher says "this is mainly product_inquiry, but it has elements of order_status (11%) and recommendation (9%) — those are the closest alternatives." The student learns these inter-class relationships, which is impossible from hard labels alone.
Optimal T for MangaAssist: Our intent classifier has 10 classes with moderate confusion between similar intents (product_inquiry vs recommendation, return vs order_status). We found T=4 optimal — it reveals these confusable pairs without washing out the signal.
Combined Distillation Loss
The final training loss combines KD loss with the standard hard-label loss:
$$\mathcal{L} = (1 - \alpha) \cdot \mathcal{L}{\text{hard}} + \alpha \cdot T^2 \cdot D{\text{KL}}(p_T^{(\tau)} | p_S^{(\tau)})$$
where $\alpha$ balances the two objectives.
| $\alpha$ | Effect |
|---|---|
| 0.0 | Pure hard-label training (no distillation) |
| 0.3 | Emphasis on ground truth, distillation as regularizer |
| 0.5 | Equal weight (typical starting point) |
| 0.7 | Emphasis on teacher's soft knowledge |
| 1.0 | Pure distillation (ignore hard labels) |
For our experiment, $\alpha = 0.7$ worked best. The teacher's soft distribution carries more information than the labels alone (teacher already achieves 92.1%, so its distributions are well-calibrated).
Feature-Based Distillation (TinyBERT/DistilBERT Layer Matching)
Beyond matching output distributions, we can match intermediate representations:
$$\mathcal{L}{\text{feature}} = \sum{l=1}^{L_S} \text{MSE}(\mathbf{h}_S^{(l)} \mathbf{W}_l, \mathbf{h}_T^{(m(l))})$$
where: - $\mathbf{h}_S^{(l)}$ is the student's $l$-th layer hidden state - $\mathbf{h}_T^{(m(l))}$ is the teacher's $m(l)$-th layer hidden state (layer mapping function $m$) - $\mathbf{W}_l$ is a learnable linear transformation (because dimensions may differ)
Layer mapping for DistilBERT → TinyBERT:
| Student Layer (TinyBERT, 4 layers) | Teacher Layer (DistilBERT, 6 layers) | Rationale |
|---|---|---|
| Layer 0 (embeddings) | Layer 0 (embeddings) | Token representations |
| Layer 1 | Layer 2 | Low-level syntax |
| Layer 2 | Layer 4 | Mid-level semantics |
| Layer 3 | Layer 6 | Task-specific features |
Attention transfer:
TinyBERT adds another term — matching teacher-student attention matrices:
$$\mathcal{L}{\text{attn}} = \sum{l=1}^{L_S} \frac{1}{h} \sum_{i=1}^{h} \text{MSE}(\mathbf{A}_S^{(l,i)}, \mathbf{A}_T^{(m(l),i)})$$
where $\mathbf{A}^{(l,i)} \in \mathbb{R}^{n \times n}$ is the attention weight matrix for head $i$ in layer $l$.
The total distillation loss becomes:
$$\mathcal{L}{\text{total}} = \mathcal{L}{\text{KD}} + \beta \cdot \mathcal{L}{\text{feature}} + \gamma \cdot \mathcal{L}{\text{attn}}$$
Gradient Analysis: Why Distillation Works Better Than Training From Scratch
In standard training, the gradient for class $c$ (when $c$ is not the correct class) is:
$$\frac{\partial \mathcal{L}_{\text{hard}}}{\partial z_c} = p_S© \quad \text{(for incorrect classes)}$$
For a well-trained student, $p_S© \approx 0$ for most wrong classes, so gradients are tiny — the student learns almost nothing about inter-class relationships.
In distillation, the gradient for class $c$ is:
$$\frac{\partial \mathcal{L}_{\text{KD}}}{\partial z_c} = p_S^{(\tau)}© - p_T^{(\tau)}©$$
Even for wrong classes, $p_T^{(\tau)}©$ can be non-trivially large (e.g., 0.11 for "order_status" when the true intent is "product_inquiry"). This non-zero target means the student receives meaningful gradients for every class at every step.
Effective gradient magnitude comparison (for 10-class MangaAssist intent):
| Class Type | Hard-Label Gradient | Distillation Gradient (T=4) |
|---|---|---|
| Correct class | $p_S - 1$ ≈ -0.08 | $p_S^{(\tau)} - p_T^{(\tau)}$ ≈ -0.10 |
| Top-2 confusable | $p_S$ ≈ 0.04 | $(p_S^{(\tau)} - p_T^{(\tau)})$ ≈ 0.08 |
| Top-3 confusable | $p_S$ ≈ 0.02 | $(p_S^{(\tau)} - p_T^{(\tau)})$ ≈ 0.05 |
| Irrelevant class | $p_S$ ≈ 0.001 | $(p_S^{(\tau)} - p_T^{(\tau)})$ ≈ 0.01 |
Distillation provides 2-10× stronger gradients for confusable classes, which is exactly where the student needs the most guidance.
Model Internals — Layer-by-Layer Diagrams
Distillation Architecture Overview
graph TB
subgraph "Teacher Pipeline (Frozen — DistilBERT 66M)"
T1["Input: 'When will my manga arrive?'"]
T2["Embedding Layer (768d)"]
T3["Transformer Layer 1"]
T4["Transformer Layer 2"]
T5["Transformer Layer 3"]
T6["Transformer Layer 4"]
T7["Transformer Layer 5"]
T8["Transformer Layer 6"]
T9["Classification Head"]
T10["Teacher Logits z_T ∈ ℝ¹⁰"]
T11["Softened Probs p_T^(τ)<br>T=4"]
T1 --> T2 --> T3 --> T4 --> T5 --> T6 --> T7 --> T8 --> T9 --> T10 --> T11
end
subgraph "Student Pipeline (Training — TinyBERT 14.5M)"
S1["Same Input"]
S2["Embedding Layer (312d)"]
S3["Transformer Layer 1"]
S4["Transformer Layer 2"]
S5["Transformer Layer 3"]
S6["Transformer Layer 4"]
S7["Classification Head"]
S8["Student Logits z_S ∈ ℝ¹⁰"]
S9["Softened Probs p_S^(τ)<br>T=4"]
S1 --> S2 --> S3 --> S4 --> S5 --> S6 --> S7 --> S8 --> S9
end
T11 --> KD["KL-Divergence Loss<br>T²·D_KL(p_T ∥ p_S)"]
S9 --> KD
S8 --> HARD["Hard-Label Loss<br>CE(y, p_S)"]
KD --> TOTAL["Total Loss = 0.3·L_hard + 0.7·L_KD"]
HARD --> TOTAL
TOTAL --> GRAD["Backprop → Student Only"]
GRAD --> S2
T2 -.->|"Feature Match (MSE)"| S2
T4 -.->|"Feature Match (MSE)"| S3
T6 -.->|"Feature Match (MSE)"| S4
T8 -.->|"Feature Match (MSE)"| S6
style T3 fill:#bbdefb
style T4 fill:#bbdefb
style T5 fill:#bbdefb
style T6 fill:#bbdefb
style T7 fill:#bbdefb
style T8 fill:#bbdefb
style S3 fill:#fff9c4
style S4 fill:#fff9c4
style S5 fill:#fff9c4
style S6 fill:#fff9c4
Temperature's Effect on Softmax Distributions
graph TD
subgraph "Teacher Logits for 'When will my manga arrive?'"
L["z = [5.2, 2.1, 1.8, 0.3, -1.5, -2.0, -2.4, -3.1, -3.5, -4.0]<br>order_status=5.2, product_inquiry=2.1, shipping=1.8, ..."]
end
subgraph "T=1 (Standard Softmax)"
T1A["order_status: 89%<br>product_inquiry: 4%<br>shipping: 3%<br>All others: ~0%<br><br>Looks like hard label"]
end
subgraph "T=4 (Sweet Spot)"
T4A["order_status: 35%<br>product_inquiry: 15%<br>shipping: 14%<br>return: 8%<br>recommendation: 5%<br>...<br><br>Reveals inter-class structure"]
end
subgraph "T=20 (Too Soft)"
T20A["order_status: 12%<br>product_inquiry: 11%<br>shipping: 11%<br>return: 10%<br>recommendation: 10%<br>...<br><br>Nearly uniform — no signal"]
end
L --> T1A
L --> T4A
L --> T20A
style T4A fill:#c8e6c9
Gradient Flow: Distillation vs Standard Training
graph LR
subgraph "Standard Training (Hard Labels)"
HL["Loss: CE(one-hot, p_S)"]
HL --> G1["∂L/∂z_correct ≈ -0.08<br>(strong signal)"]
HL --> G2["∂L/∂z_confusable ≈ 0.04<br>(weak signal)"]
HL --> G3["∂L/∂z_irrelevant ≈ 0.001<br>(near zero)"]
end
subgraph "Distillation Training (Soft Labels, T=4)"
SL["Loss: KL(p_T^τ, p_S^τ)"]
SL --> G4["∂L/∂z_correct ≈ -0.10<br>(strong signal)"]
SL --> G5["∂L/∂z_confusable ≈ 0.08<br>(2× stronger)"]
SL --> G6["∂L/∂z_irrelevant ≈ 0.01<br>(10× stronger)"]
end
G2 -.->|"2× gradient boost"| G5
G3 -.->|"10× gradient boost"| G6
style G5 fill:#c8e6c9
style G6 fill:#c8e6c9
TinyBERT Layer Matching
graph TB
subgraph "Teacher: DistilBERT (6 layers, 768d)"
TE["Embedding (768d)"]
TL1["Layer 1: Syntax patterns"]
TL2["Layer 2: Phrase-level features"]
TL3["Layer 3: Semantic roles"]
TL4["Layer 4: Intent signals"]
TL5["Layer 5: Context integration"]
TL6["Layer 6: Task-specific"]
TE --> TL1 --> TL2 --> TL3 --> TL4 --> TL5 --> TL6
end
subgraph "Student: TinyBERT (4 layers, 312d)"
SE["Embedding (312d)"]
SL1["Layer 1"]
SL2["Layer 2"]
SL3["Layer 3"]
SL4["Layer 4"]
SE --> SL1 --> SL2 --> SL3 --> SL4
end
TE -.->|"Embed MSE<br>W₀: 312×768"| SE
TL2 -.->|"Hidden MSE<br>W₁: 312×768"| SL1
TL4 -.->|"Hidden MSE<br>W₂: 312×768"| SL2
TL5 -.->|"Attn MSE<br>student heads map<br>to teacher heads"| SL3
TL6 -.->|"Hidden + Attn MSE<br>W₃: 312×768"| SL4
style TE fill:#bbdefb
style TL1 fill:#bbdefb
style TL2 fill:#bbdefb
style TL3 fill:#bbdefb
style TL4 fill:#bbdefb
style TL5 fill:#bbdefb
style TL6 fill:#bbdefb
style SE fill:#fff9c4
style SL1 fill:#fff9c4
style SL2 fill:#fff9c4
style SL3 fill:#fff9c4
style SL4 fill:#fff9c4
LLM-to-Small-Model Distillation Pipeline
sequenceDiagram
participant DS as Manga QA Dataset<br>(3K examples)
participant CL as Claude 3.5 Sonnet<br>(Teacher, via Bedrock)
participant SA as SageMaker<br>Processing Job
participant L3 as Llama 3 8B<br>(Student)
participant EV as Evaluation<br>Suite
DS->>CL: Batch inference on all 3K questions
Note over CL: Generate responses + logprobs<br>Cost: ~$15 for 3K examples
CL->>SA: Teacher responses + logprobs stored in S3
Note over SA: Augmentation:<br>- Generate 5 paraphrases per example<br>- Teacher scores paraphrases<br>- Filter to keep top-quality pairs<br>Dataset grows: 3K → 12K
SA->>L3: Train with combined loss:<br>0.3·CE(label, student) + 0.7·KL(teacher, student)
Note over L3: Training: 4 epochs, lr=2e-5<br>SageMaker g5.2xlarge<br>~4 hours, ~$24
L3->>EV: Evaluate on held-out set
Note over EV: Manga QA: 82.4% (teacher: 93.7%)<br>Latency: 120ms (teacher: 800ms)<br>Quality ratio: 87.9%<br>Cost/query: $0.0001 vs $0.003
Implementation Deep-Dive
Output Distillation: DistilBERT → TinyBERT
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
Trainer,
TrainingArguments,
)
class DistillationLoss(nn.Module):
"""
Combined distillation + hard-label loss.
L = (1-alpha) * CE(student, labels) + alpha * T^2 * KL(teacher_soft, student_soft)
"""
def __init__(self, temperature: float = 4.0, alpha: float = 0.7):
super().__init__()
self.temperature = temperature
self.alpha = alpha
def forward(
self,
student_logits: torch.Tensor, # (B, C)
teacher_logits: torch.Tensor, # (B, C)
labels: torch.Tensor, # (B,)
) -> torch.Tensor:
# Hard-label loss (standard cross-entropy)
hard_loss = F.cross_entropy(student_logits, labels)
# Soft-label loss (KL-divergence with temperature)
student_soft = F.log_softmax(student_logits / self.temperature, dim=-1)
teacher_soft = F.softmax(teacher_logits / self.temperature, dim=-1)
# KL(teacher || student) = sum(teacher * log(teacher/student))
# F.kl_div expects log-probs as first arg, probs as second
kd_loss = F.kl_div(
student_soft,
teacher_soft,
reduction="batchmean",
) * (self.temperature ** 2)
return (1 - self.alpha) * hard_loss + self.alpha * kd_loss
class DistillationTrainer(Trainer):
"""Custom trainer that runs teacher inference and computes distillation loss."""
def __init__(self, teacher_model, temperature=4.0, alpha=0.7, **kwargs):
super().__init__(**kwargs)
self.teacher = teacher_model
self.teacher.eval()
for p in self.teacher.parameters():
p.requires_grad = False
self.distill_loss = DistillationLoss(temperature, alpha)
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
labels = inputs.pop("labels")
# Student forward pass
student_outputs = model(**inputs)
student_logits = student_outputs.logits
# Teacher forward pass (no gradient)
with torch.no_grad():
teacher_outputs = self.teacher(**inputs)
teacher_logits = teacher_outputs.logits
loss = self.distill_loss(student_logits, teacher_logits, labels)
return (loss, student_outputs) if return_outputs else loss
def distill_intent_classifier():
"""
Distill DistilBERT (6-layer) teacher into TinyBERT (4-layer) student.
"""
teacher = AutoModelForSequenceClassification.from_pretrained(
"./manga_intent_teacher", # Our fine-tuned DistilBERT
num_labels=10,
)
student = AutoModelForSequenceClassification.from_pretrained(
"huawei-noah/TinyBERT_General_4L_312D",
num_labels=10,
)
tokenizer = AutoTokenizer.from_pretrained("huawei-noah/TinyBERT_General_4L_312D")
# Training arguments
args = TrainingArguments(
output_dir="./tinybert_distilled",
num_train_epochs=10, # More epochs for distillation
per_device_train_batch_size=64, # TinyBERT is small — large batch OK
learning_rate=5e-5,
warmup_ratio=0.1,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="accuracy",
bf16=True,
)
trainer = DistillationTrainer(
teacher_model=teacher,
temperature=4.0,
alpha=0.7,
model=student,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
)
trainer.train()
return student
Feature-Based Distillation (TinyBERT Two-Stage)
class TinyBERTFeatureLoss(nn.Module):
"""
TinyBERT's two-stage distillation:
Stage 1: Intermediate layer matching (embedding + hidden + attention)
Stage 2: Output KD (prediction layer)
"""
def __init__(self, teacher_dim: int = 768, student_dim: int = 312):
super().__init__()
# Layer mapping: student layer -> teacher layer
self.layer_map = {0: 0, 1: 2, 2: 4, 3: 6} # 4S → 6T
# Linear transforms to project student hidden states to teacher dim
self.hidden_transforms = nn.ModuleDict({
str(s): nn.Linear(student_dim, teacher_dim, bias=False)
for s in self.layer_map.keys()
})
# Embedding transform
self.embed_transform = nn.Linear(student_dim, teacher_dim, bias=False)
def forward(
self,
teacher_hidden_states, # List of (B, seq_len, 768)
student_hidden_states, # List of (B, seq_len, 312)
teacher_attentions, # List of (B, heads, seq_len, seq_len)
student_attentions, # List of (B, heads, seq_len, seq_len)
):
losses = {}
# Embedding layer loss
s_embed = self.embed_transform(student_hidden_states[0])
losses["embed"] = F.mse_loss(s_embed, teacher_hidden_states[0])
# Hidden state matching
hidden_loss = 0
for s_layer, t_layer in self.layer_map.items():
s_hidden = self.hidden_transforms[str(s_layer)](
student_hidden_states[s_layer + 1] # +1 because [0] is embeddings
)
hidden_loss += F.mse_loss(s_hidden, teacher_hidden_states[t_layer + 1])
losses["hidden"] = hidden_loss / len(self.layer_map)
# Attention matrix matching
attn_loss = 0
for s_layer, t_layer in self.layer_map.items():
s_attn = student_attentions[s_layer]
t_attn = teacher_attentions[t_layer]
# Student has fewer heads — average teacher heads to match
if s_attn.shape[1] != t_attn.shape[1]:
ratio = t_attn.shape[1] // s_attn.shape[1]
t_attn = t_attn.view(t_attn.shape[0], s_attn.shape[1], ratio,
*t_attn.shape[2:]).mean(dim=2)
attn_loss += F.mse_loss(s_attn, t_attn)
losses["attention"] = attn_loss / len(self.layer_map)
total = losses["embed"] + losses["hidden"] + losses["attention"]
return total, losses
LLM Distillation: Claude → Llama 3 8B
import json
import boto3
def generate_teacher_labels(dataset_path: str, output_path: str):
"""
Generate teacher labels from Claude 3.5 Sonnet via Bedrock.
Collect both responses and token-level logprobs for distillation.
"""
bedrock = boto3.client("bedrock-runtime", region_name="us-east-1")
with open(dataset_path) as f:
examples = json.load(f)
results = []
for example in examples:
response = bedrock.invoke_model(
modelId="anthropic.claude-3-5-sonnet-20241022-v2:0",
body=json.dumps({
"anthropic_version": "bedrock-2023-05-31",
"messages": [
{"role": "user", "content": example["question"]}
],
"max_tokens": 512,
"temperature": 0.0, # Deterministic for consistent labels
}),
)
result = json.loads(response["body"].read())
teacher_response = result["content"][0]["text"]
results.append({
"question": example["question"],
"teacher_response": teacher_response,
"ground_truth": example.get("answer", ""),
})
with open(output_path, "w") as f:
json.dump(results, f, indent=2)
# Cost estimate: 3K examples × ~500 tokens/response × $0.003/1K tokens ≈ $4.50
return results
def train_student_llm(teacher_data_path: str):
"""
Train Llama 3 8B student from Claude teacher responses.
Uses response-level distillation (since we can't get Claude's logits directly).
"""
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer, SFTConfig
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-8B",
torch_dtype=torch.bfloat16,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
tokenizer.pad_token = tokenizer.eos_token
# Load teacher-labeled data
with open(teacher_data_path) as f:
data = json.load(f)
# Format as instruction-following
formatted = []
for d in data:
text = (
f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
f"You are MangaAssist, an expert manga advisor.\n"
f"<|eot_id|><|start_header_id|>user<|end_header_id|>\n"
f"{d['question']}\n"
f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
f"{d['teacher_response']}\n"
f"<|eot_id|>"
)
formatted.append({"text": text})
sft_config = SFTConfig(
output_dir="./llama8b_manga_student",
num_train_epochs=4,
per_device_train_batch_size=2,
gradient_accumulation_steps=8, # Effective batch size: 16
learning_rate=2e-5,
warmup_ratio=0.1,
bf16=True,
gradient_checkpointing=True,
max_seq_length=512,
)
trainer = SFTTrainer(
model=model,
args=sft_config,
train_dataset=formatted,
tokenizer=tokenizer,
)
trainer.train()
return model
Group Discussion: Key Decision Points
Decision Point 1: Which Teacher-Student Pair First?
Priya (ML Engineer): We have four distillation opportunities. I recommend prioritizing based on production impact:
| Priority | Pair | Quality Gain | Latency Gain | Cost Savings |
|---|---|---|---|---|
| 1 | DistilBERT → TinyBERT (intent) | -2.8% acc | 15ms → 5ms | ~$0 (both on Lambda) |
| 2 | Claude → Llama 3 8B (fallback) | -11.3% QA | 800ms → 120ms | $800/mo Bedrock savings |
| 3 | MiniLM → ONNX-4L (reranker) | -0.05 NDCG | 50ms → 15ms | ~$0 |
| 4 | Ensemble → Single model | -2.7% composite | 80ms → 15ms | Operational simplicity |
Sam (PM): The Claude → Llama 3 8B distillation saves $800/month. That's $9,600/year. At 82.4% manga QA vs 93.7%, the quality gap is large, but this is a fallback model for when Bedrock is unavailable. As a fallback at 82.4%, it is much better than showing an error page.
Marcus (Architect): Agreed. A self-hosted fallback that handles 87.9% of Claude's quality means our availability goes from 99.5% (Bedrock SLA) to 99.99% (self-hosted + Bedrock). That is a significant reliability improvement for customer experience.
Jordan (MLOps): The intent classifier distillation (DistilBERT → TinyBERT) is the fastest to validate — we have all the data, both models are small, training takes 20 minutes. I would use it as a proof-of-concept before committing to the LLM distillation (4 hours, $24).
Aiko (Data Scientist): Correct. The intent distillation also has the most controlled evaluation — 10-class accuracy is deterministic. LLM response quality requires human evaluation or an LLM-as-judge setup, which adds uncertainty.
Resolution: Priority order: (1) Intent classifier distillation as proof-of-concept (20 min, validates pipeline). (2) Claude → Llama 3 8B for production fallback ($24, saves $800/month). (3) Reranker post-quantization if latency becomes an issue. (4) Ensemble consolidation deferred to V3.
Decision Point 2: Temperature Selection
Priya (ML Engineer): I swept temperature from 1 to 20 on the intent distillation task:
| Temperature | Student Accuracy | Training Loss | Notes |
|---|---|---|---|
| T=1 | 87.2% | 0.42 | Minimal knowledge transfer — too confident |
| T=2 | 88.1% | 0.38 | Better — starts revealing alternatives |
| T=4 | 89.3% | 0.31 | Best — clear inter-class structure |
| T=8 | 88.7% | 0.35 | Slightly oversmoothed |
| T=20 | 86.4% | 0.44 | Too uniform — signal washed out |
Aiko (Data Scientist): T=4 aligns with theory. Our intent classifier has moderate confusion — "product_inquiry" vs "recommendation" share vocabulary like "suggest," "similar," "like." At T=4, the teacher reveals these confusions without making all intents look equally likely.
The information-theoretic explanation: at T=4, the teacher's entropy is 1.98 bits (out of max $\log_2 10 = 3.32$ bits). This means each soft label carries 1.34 bits of information beyond what a uniform distribution provides. At T=1 the entropy is 0.58 bits (mostly one-hot), and at T=20 it is 2.28 bits (nearly maximal — meaning near-uniform, little to learn from).
Marcus (Architect): The difference between T=4 (89.3%) and T=2 (88.1%) is 1.2%. Between T=4 and T=8 (88.7%) it is 0.6%. T=4 is a robust choice — even if we are slightly off, we lose less than 1% performance.
Resolution: T=4 for all classification distillation tasks. For LLM-to-LLM response distillation (Claude → Llama), temperature is not directly applicable since we use response-level distillation (SFT on teacher outputs), not logit-level KD.
Decision Point 3: Two-Stage vs End-to-End Distillation
Priya (ML Engineer): TinyBERT's original paper proposes two stages: - Stage 1: Feature-based distillation (match embeddings, hidden states, attention matrices) without classification loss - Stage 2: Output-based distillation (match soft labels) with classification loss
My ablation:
| Approach | Student Accuracy | Training Time |
|---|---|---|
| Output-only (KD loss) | 87.8% | 20 min |
| Feature-only (Stage 1) | 82.1% | 35 min |
| Stage 1 then Stage 2 (TinyBERT original) | 89.3% | 55 min |
| Combined all losses simultaneously | 88.9% | 40 min |
Aiko (Data Scientist): The two-stage approach works better because feature matching is a pre-training objective — it gives the student internal representations similar to the teacher before the classification loss steers those representations toward the task. If we combine losses simultaneously, the classification gradient can pull representations away from teacher-aligned patterns before the student has time to absorb the teacher's intermediate knowledge.
Jordan (MLOps): The 55-minute training time is still trivial. If it gives us 89.3% vs 87.8%, I would take the two-stage approach. The extra 35 minutes is worth 1.5% quality.
Resolution: Two-stage TinyBERT distillation. Stage 1 (35 min): feature matching on general corpus. Stage 2 (20 min): output distillation on intent-labeled data. Combined quality: 89.3% (vs teacher's 92.1% — a 3.0% gap, acceptable for the 3× latency improvement).
Decision Point 4: Data Augmentation for Distillation
Priya (ML Engineer): Distillation benefits from more data than standard training because the student needs to learn the teacher's behavior across the full input distribution, not just labeled examples. I tested augmentation strategies:
| Dataset | Size | Student Accuracy | Teacher Accuracy on Same |
|---|---|---|---|
| Original labeled | 3K | 87.2% | 92.1% |
| + Backtranslation | 9K | 88.5% | 91.8% |
| + Paraphrasing (Claude) | 12K | 89.3% | 92.0% |
| + Unlabeled manga queries | 50K | 89.8% | — |
Sam (PM): The 50K unlabeled approach is interesting — we do not need labels, just teacher soft predictions on real user queries. We have millions of historical queries. The cost is just the teacher inference: 50K × DistilBERT ≈ $0 (runs on Lambda).
Aiko (Data Scientist): But be careful — historical queries have distribution shift. Queries from 6 months ago may include trends (manga titles) that are no longer relevant. I would use only the last 30 days of queries and deduplicate.
Marcus (Architect): Paraphrasing with Claude costs $15 for the augmented set. That is negligible. I would use both: 12K labeled (original + paraphrased) for Stage 2 output KD, and 50K unlabeled (recent queries with teacher soft labels) for additional soft-label training.
Resolution: Use 12K (original + Claude paraphrasing) for Stage 2 output distillation. Optionally add 50K recent unlabeled queries with teacher soft labels for a Stage 3 self-training step. Total augmentation cost: ~$15 for Claude paraphrasing + ~$0 for teacher inference on 50K queries.
Research Paper References
1. Distilling the Knowledge in a Neural Network (Hinton, Vinyals, Dean, 2015)
Key contribution: Introduced the temperature-scaling softmax for knowledge distillation. Showed that the "dark knowledge" in soft probability distributions — the relative probabilities of incorrect classes — carries information that hard labels discard. A student trained on soft labels at T>1 outperforms one trained on hard labels by 1-5%.
Relevance to MangaAssist: Foundation of all our distillation work. The T=4 temperature reveals which intents are confusable (product_inquiry ↔ recommendation) — information that hard labels cannot convey. This reduces our student's confusion matrix errors on the top-3 confusable pairs by 40%.
2. TinyBERT: Distilling BERT for Natural Language Understanding (Jiao et al., 2019)
Key contribution: Proposed two-stage distillation matching intermediate transformer layers: embeddings, attention matrices, and hidden states. Achieved 96.8% of BERT-base's performance with a 7.5× smaller, 9.4× faster model.
Relevance to MangaAssist: Direct architecture for our DistilBERT → TinyBERT distillation. The layer mapping strategy (student layer $l$ → teacher layer $m(l)$) and attention transfer loss give us 89.3% intent accuracy with a 14.5M parameter model that runs in 5ms on Lambda — critical for cold start scenarios.
3. Patient Knowledge Distillation for BERT Model Compression (Sun et al., 2019)
Key contribution: Instead of matching only the last layer, "patient" distillation matches representations from every $k$-th teacher layer, allowing the student to learn the teacher's intermediate reasoning process. This consistently outperforms last-layer-only distillation by 0.5-1.5%.
Relevance to MangaAssist: Informs our layer mapping strategy. For our 6→4 layer distillation, we use every-2-layer matching ({0→0, 1→2, 2→4, 3→6}) following the patient KD principle.
4. DistilBERT, a distilled version of BERT (Sanh et al., 2019)
Key contribution: Demonstrated that pretraining-time distillation (distilling during the pretraining phase, not just fine-tuning) produces a student that retains 97% of BERT's language understanding while being 60% smaller and 60% faster.
Relevance to MangaAssist: Our teacher (DistilBERT) is itself a distilled model. This creates an interesting chain: BERT → DistilBERT (pre-training distillation) → TinyBERT (task-specific distillation). Each step trades quality for efficiency. BERT → DistilBERT keeps 97%; DistilBERT → TinyBERT keeps 96.9%. Cumulative: TinyBERT retains ~94% of BERT-base's quality at 4.5× compression and 6× speedup.
5. On the Efficacy of Knowledge Distillation (Cho & Hariharan, 2019)
Key contribution: Showed that a larger teacher is not always better — if the teacher-student capacity gap is too large, the student cannot mimic the teacher effectively. An intermediate-sized teacher ("assistant") that bridges the gap can improve distillation quality by 1-3%.
Relevance to MangaAssist: Validates our two-hop distillation chain. Rather than distilling BERT-large (340M) directly to TinyBERT (14.5M), we use DistilBERT (66M) as the teacher — a 4.5× capacity ratio. The paper found capacity ratios above 10× degrade distillation, and our 4.5× ratio is well within the effective range.
Production Deployment Results
DistilBERT → TinyBERT Intent Distillation
| Metric | DistilBERT (Teacher) | TinyBERT (Student) | Delta |
|---|---|---|---|
| Accuracy | 92.1% | 89.3% | -2.8% |
| Parameters | 66M | 14.5M | 4.6× smaller |
| Latency (Lambda cold) | 35ms | 12ms | 2.9× faster |
| Latency (Lambda warm) | 15ms | 5ms | 3× faster |
| Model size | 264MB | 58MB | 4.6× smaller |
| Lambda memory | 512MB | 256MB | 2× |
Claude → Llama 3 8B Fallback Distillation
| Metric | Claude (Teacher) | Llama 3 8B (Student) | Delta |
|---|---|---|---|
| Manga QA accuracy | 93.7% | 82.4% | -11.3% |
| Response quality (human eval, 1-5) | 4.6 | 3.9 | -0.7 |
| Latency (p50) | 800ms | 120ms | 6.7× faster |
| Latency (p99) | 2100ms | 350ms | 6× faster |
| Cost per query | $0.003 | $0.0001 | 30× cheaper |
| Monthly cost (500K queries) | $1,500 | $50 | $1,450 savings |
| Availability | 99.5% (Bedrock SLA) | 99.99% (self-hosted) | +0.49% |
CPQ Analysis
Deploying the Llama 3 8B fallback costs $50/month (SageMaker g5.2xlarge spot) and provides: - Availability improvement: 99.5% → 99.99% (worth $25K/year in prevented customer drop-off) - Cost savings: $1,450/month × 12 = $17,400/year (if used as primary; $0 if purely fallback) - Quality cost: -11.3% accuracy × 500K queries = 56,500 lower-quality responses/month
As a fallback (activated only during Bedrock outages, ~0.5% of traffic), the quality impact is minimal: 56,500 × 0.005 = 283 lower-quality responses/month — acceptable for maintaining availability.