06. Continual Learning and Catastrophic Forgetting — Retraining Without Losing Knowledge
Problem Statement and MangaAssist Context
MangaAssist's models must evolve continuously. Every month, new manga titles launch, seasonal trends shift (holiday gift guides, anime adaptation spikes), and user vocabulary drifts ("rizz" in manga reviews, new genre-fusions like "isekai-horror"). If we naively fine-tune on new data, the model "forgets" old knowledge — a phenomenon called catastrophic forgetting.
The Forgetting Problem in Production
Our intent classifier was fine-tuned in January 2024 on 5K labeled examples covering 10 intents. In March, we collected 1.2K new examples with a new pattern: "gift recommendation" queries surged 340% (White Day in Japan). Naively fine-tuning on these 1.2K new examples:
| Intent | Accuracy Before | After Naive Retrain | Delta |
|---|---|---|---|
| product_inquiry | 93.4% | 88.1% | -5.3% |
| order_status | 95.1% | 86.2% | -8.9% |
| recommendation | 88.7% | 94.2% | +5.5% |
| return | 91.2% | 82.4% | -8.8% |
| gift_recommendation (new) | — | 91.5% | New |
| Overall | 92.1% | 87.3% | -4.8% |
The model improved on recommendations and gifts but degraded severely on return and order_status — exactly the catastrophic forgetting pattern. We need techniques that learn new knowledge without destroying old knowledge.
The Stability-Plasticity Dilemma
Every learning system faces a fundamental tradeoff: - Stability: Preserve existing knowledge (good at old tasks) - Plasticity: Adapt to new information (good at new tasks)
Full plasticity (naive fine-tuning) → catastrophic forgetting. Full stability (freeze everything) → cannot learn new patterns. The solution lies in methods that find the optimal balance.
Mathematical Foundations
Why Catastrophic Forgetting Happens — The Geometric Perspective
Consider the model's parameters $\boldsymbol{\theta} \in \mathbb{R}^n$. After training on Task A (original 5K data), the model sits at $\boldsymbol{\theta}_A^*$ — a local minimum of the Task A loss surface.
When we train on Task B (new 1.2K data), gradient descent moves $\boldsymbol{\theta}$ downhill on the Task B loss surface:
$$\boldsymbol{\theta}{t+1} = \boldsymbol{\theta}_t - \eta \nabla{\boldsymbol{\theta}} \mathcal{L}_B(\boldsymbol{\theta}_t)$$
Problem: the gradient $\nabla \mathcal{L}_B$ has no knowledge of Task A. If the Task B gradient pushes parameters away from the region where Task A performs well, we forget Task A.
Why some parameters matter more: Not all parameters are equally important for Task A. Some parameters encode critical decision boundaries (e.g., the weight separating "order_status" from "return"), while others encode minor variations. Moving a critical parameter by even 0.01 can flip predictions; moving a non-critical parameter by 1.0 has no effect.
Elastic Weight Consolidation (EWC) — The Fisher Information Approach
EWC (Kirkpatrick et al., 2017) adds a regularization term that penalizes changes to parameters that are important for the old task:
$$\mathcal{L}{\text{EWC}} = \mathcal{L}_B(\boldsymbol{\theta}) + \frac{\lambda}{2} \sum{i=1}^{n} F_i (\theta_i - \theta_{A,i}^*)^2$$
where: - $\mathcal{L}B(\boldsymbol{\theta})$ is the new task loss (standard cross-entropy on new data) - $\theta{A,i}^*$ is the $i$-th parameter's value after training on old task - $F_i$ is the Fisher Information of parameter $i$ (how important it is for the old task) - $\lambda$ controls the strength of consolidation
The Fisher Information Matrix:
The Fisher Information for parameter $i$ measures the curvature of the loss surface at $\boldsymbol{\theta}_A^*$ along the $i$-th dimension:
$$F_i = \mathbb{E}_{\mathbf{x} \sim \mathcal{D}_A}\left[\left(\frac{\partial \log p(\mathbf{y} | \mathbf{x}, \boldsymbol{\theta})}{\partial \theta_i}\right)^2\right]$$
Intuition: $F_i$ measures how sensitive the model's predictions are to changes in $\theta_i$. - High $F_i$: The loss surface is sharply curved around $\theta_i$ — small changes cause large quality drops. EWC will strongly resist moving this parameter. - Low $F_i$: The loss surface is flat around $\theta_i$ — this parameter can be changed freely without hurting old task quality. EWC allows adaptation here.
Practical computation: Computing the exact Fisher requires iterating over all training data. We approximate it using a sample of old-task data (typically 200-500 examples):
$$\hat{F}i = \frac{1}{N} \sum{n=1}^{N} \left(\frac{\partial \log p(\mathbf{y}_n | \mathbf{x}_n, \boldsymbol{\theta}_A^*)}{\partial \theta_i}\right)^2$$
For DistilBERT with 66M parameters, computing Fisher takes ~5 minutes on a single GPU using 500 examples.
EWC Gradient Analysis
The gradient of the EWC loss with respect to parameter $\theta_i$:
$$\frac{\partial \mathcal{L}{\text{EWC}}}{\partial \theta_i} = \frac{\partial \mathcal{L}_B}{\partial \theta_i} + \lambda F_i (\theta_i - \theta{A,i}^*)$$
The second term acts as a "spring" that pulls parameters back toward their old-task-optimal values. The spring constant $\lambda F_i$ is proportional to the parameter's importance:
| Parameter Type | Typical $F_i$ | EWC Effect |
|---|---|---|
| Classification head (final layer) | $10^{-2}$ to $10^{-1}$ | Strong resistance to change |
| Last transformer layer attention | $10^{-3}$ to $10^{-2}$ | Moderate resistance |
| Middle transformer layers | $10^{-4}$ to $10^{-3}$ | Mild resistance |
| Embedding layer | $10^{-5}$ to $10^{-4}$ | Weak resistance — adapts freely |
This creates a natural layer-wise learning rate effect: lower layers (general features) adapt more freely, while upper layers (task-specific features) are more conserved.
Fisher Information as Loss Curvature
The connection between Fisher Information and loss curvature is precise. For the empirical Fisher:
$$\hat{F} = \frac{1}{N}\sum_{n=1}^{N} \nabla_{\boldsymbol{\theta}} \log p(\mathbf{y}n|\mathbf{x}_n) \nabla{\boldsymbol{\theta}} \log p(\mathbf{y}_n|\mathbf{x}_n)^T$$
This outer-product matrix approximates the Hessian of the negative log-likelihood:
$$\hat{F} \approx -\frac{1}{N}\sum_{n=1}^{N} \nabla_{\boldsymbol{\theta}}^2 \log p(\mathbf{y}_n|\mathbf{x}_n) = \hat{\mathbf{H}}$$
Geometric meaning: The Hessian describes the local curvature of the loss surface. High curvature in dimension $i$ means the loss function is a narrow valley along $\theta_i$ — indicating the current value is precisely tuned and should not be changed. Low curvature means a flat plain — the parameter can vary without affecting the loss.
Experience Replay — Data-Level Approach
Instead of regularizing weights, experience replay simply mixes old data with new data during training:
$$\mathcal{L}{\text{replay}} = \frac{1}{|\mathcal{B}_B|} \sum{(\mathbf{x},y) \in \mathcal{B}B} \ell(\mathbf{x}, y) + \frac{\mu}{|\mathcal{B}_A|} \sum{(\mathbf{x},y) \in \mathcal{B}_A} \ell(\mathbf{x}, y)$$
where $\mathcal{B}_A$ is a replay buffer of old examples and $\mu$ controls the replay weight.
Buffer selection strategies:
| Strategy | Buffer Size | Quality | Bias Risk |
|---|---|---|---|
| Random sampling | 10% of old data | Good | Low |
| Herding (exemplars per class) | $K$ per class | Better | Balanced |
| Uncertainty-based (highest loss) | 10% of old data | Best for boundaries | May overweight hard examples |
| Stratified random | Equal per class | Good | None |
For our 5K → 1.2K data scenario, a replay buffer of 500 examples (10% of old data, ~50 per intent) combined with all 1.2K new examples gives an effective training set of 1.7K with balanced representation.
Online EWC — Multiple Sequential Tasks
When we face Task A → Task B → Task C → ..., standard EWC accumulates terms:
$$\mathcal{L} = \mathcal{L}C + \frac{\lambda}{2}\left[\sum_i F_i^A (\theta_i - \theta{A,i}^)^2 + \sum_i F_i^B (\theta_i - \theta_{B,i}^)^2\right]$$
This grows linearly with the number of tasks. Online EWC consolidates all previous tasks into a single running estimate:
$$\hat{F}_i^{\text{online}} = \gamma \hat{F}_i^{\text{online}} + F_i^{\text{current}}$$
where $\gamma \in [0,1)$ is a decay factor that downweights older tasks. This keeps memory constant regardless of the number of tasks.
For MangaAssist's monthly retraining cycle, $\gamma = 0.9$ means the model retains 90% of the previous Fisher estimate (which already accounts for all older tasks) plus the new task's Fisher.
Model Internals — Layer-by-Layer Diagrams
Catastrophic Forgetting: Weight Space Trajectory
graph TB
subgraph "Parameter Space ℝⁿ"
A["θ_A*<br>Task A optimum<br>(92.1% on 10 intents)"]
B["θ_B*<br>Task B optimum<br>(91.5% on new data but<br>87.3% on old intents)"]
C["θ_EWC<br>EWC compromise<br>(91.8% old + 89.5% new<br>= 91.1% overall)"]
D["θ_replay<br>Replay compromise<br>(91.5% old + 90.2% new<br>= 91.2% overall)"]
A -->|"Naive fine-tune<br>(forgetting path)"| B
A -->|"EWC λ=1000<br>(constrained path)"| C
A -->|"Experience Replay<br>(data-balanced path)"| D
end
subgraph "Task A Performance Along Path"
E["Start: 92.1%"]
F["Naive: drops to 87.3%<br>(-4.8%)"]
G["EWC: drops to 91.8%<br>(-0.3%)"]
H["Replay: drops to 91.5%<br>(-0.6%)"]
end
style A fill:#c8e6c9
style B fill:#ffcdd2
style C fill:#c8e6c9
style D fill:#c8e6c9
Fisher Information per Layer
graph TB
subgraph "Fisher Information Magnitude by Layer (DistilBERT)"
EMB["Embedding Layer<br>F_avg = 2.3×10⁻⁵<br>→ LOW importance<br>General token representations<br>EWC: allows free movement"]
L1["Layer 1<br>F_avg = 8.1×10⁻⁵<br>→ LOW importance<br>Basic syntax patterns<br>EWC: mild constraint"]
L2["Layer 2<br>F_avg = 3.4×10⁻⁴<br>→ MEDIUM importance<br>Phrase-level features<br>EWC: moderate constraint"]
L3["Layer 3<br>F_avg = 1.2×10⁻³<br>→ MEDIUM-HIGH<br>Semantic grouping<br>EWC: significant constraint"]
L4["Layer 4<br>F_avg = 5.7×10⁻³<br>→ HIGH importance<br>Intent-discriminative features<br>EWC: strong constraint"]
L5["Layer 5<br>F_avg = 1.8×10⁻²<br>→ HIGH importance<br>Task-specific representations<br>EWC: strong constraint"]
L6["Layer 6<br>F_avg = 4.2×10⁻²<br>→ VERY HIGH<br>Classification patterns<br>EWC: very strong constraint"]
CLS["Classification Head<br>F_avg = 8.5×10⁻²<br>→ CRITICAL<br>Decision boundaries<br>EWC: maximum constraint"]
EMB --> L1 --> L2 --> L3 --> L4 --> L5 --> L6 --> CLS
end
style EMB fill:#e8f5e9
style L1 fill:#e8f5e9
style L2 fill:#fff9c4
style L3 fill:#fff9c4
style L4 fill:#ffe0b2
style L5 fill:#ffccbc
style L6 fill:#ffcdd2
style CLS fill:#ef9a9a
EWC vs Naive Training: Gradient Comparison
sequenceDiagram
participant Batch as New Data Batch<br>"Gift me a manga<br>recommendation"
participant Model as DistilBERT<br>Parameters θ
participant CE as CrossEntropy Loss<br>L_B(θ)
participant EWC as EWC Regularizer<br>λ/2 · Σ Fᵢ(θᵢ-θᵢ*)²
participant Update as Parameter Update
Batch->>Model: Forward pass
Model->>CE: Logits → CE Loss
Note over CE: ∂L_B/∂θ pushes toward<br>new task optimum<br>(may destroy old knowledge)
Model->>EWC: Current θ vs stored θ_A*
Note over EWC: ∂L_EWC/∂θᵢ = λ·Fᵢ·(θᵢ-θᵢ*)<br>Spring force pulling toward<br>old task optimum<br>Strength ∝ Fisher Fᵢ
CE->>Update: ∂L_B/∂θ (new task gradient)
EWC->>Update: λ·F·(θ-θ*) (consolidation gradient)
Note over Update: Total gradient:<br>∂L_B/∂θ + λ·F·(θ-θ*)<br><br>For high-F params: EWC dominates → minimal change<br>For low-F params: CE dominates → free to adapt
Update->>Model: θ ← θ - η · total_gradient
Experience Replay: Training Batch Construction
graph LR
subgraph "New Task Data (1.2K examples)"
N1["recommendation: 450"]
N2["gift_recommendation: 380"]
N3["product_inquiry: 210"]
N4["others: 160"]
end
subgraph "Replay Buffer (500 from old task)"
O1["order_status: 55"]
O2["return: 52"]
O3["product_inquiry: 50"]
O4["shipping: 48"]
O5["complaint: 50"]
O6["recommendation: 50"]
O7["account: 48"]
O8["cancel: 50"]
O9["exchange: 49"]
O10["greeting: 48"]
end
subgraph "Training Batch (size 32)"
B1["20 from new task (62.5%)"]
B2["12 from replay buffer (37.5%)"]
end
N1 --> B1
N2 --> B1
O1 --> B2
O2 --> B2
O3 --> B2
B1 --> LOSS["Combined CE Loss<br>= L_new + μ·L_replay<br>μ = 1.0 (equal weight)"]
B2 --> LOSS
style B1 fill:#fff9c4
style B2 fill:#bbdefb
Progressive Training Schedule
graph TB
subgraph "Month 1: Initial Training"
M1["Train on 5K labeled examples<br>10 intents, 92.1% accuracy<br>Save θ₁* and F₁"]
end
subgraph "Month 2: Seasonal Update"
M2A["New data: 1.2K (gift reco spike)"]
M2B["Replay buffer: 500 (from Month 1)"]
M2C["EWC: λ₁·F₁·(θ-θ₁*)²"]
M2D["Result: 91.2% overall + gift_reco<br>Save θ₂* and F₂ = 0.9·F₁ + F₂_new"]
M2A --> M2D
M2B --> M2D
M2C --> M2D
end
subgraph "Month 3: Trend Shift"
M3A["New data: 800 (anime adaptation queries)"]
M3B["Replay buffer: 500 (from Months 1+2)"]
M3C["Online EWC: λ·F₂·(θ-θ₂*)²"]
M3D["Result: 90.8% on old + new patterns<br>Save θ₃* and F₃ = 0.9·F₂ + F₃_new"]
M3A --> M3D
M3B --> M3D
M3C --> M3D
end
M1 --> M2A
M2D --> M3A
style M1 fill:#c8e6c9
style M2D fill:#c8e6c9
style M3D fill:#c8e6c9
Implementation Deep-Dive
EWC Implementation
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import copy
from collections import defaultdict
class EWCTrainer:
"""
Elastic Weight Consolidation for continual learning.
Computes Fisher Information from old-task data, then adds a
quadratic penalty when training on new data.
"""
def __init__(
self,
model: nn.Module,
fisher_sample_size: int = 500,
ewc_lambda: float = 1000.0,
online_gamma: float = 0.9,
):
self.model = model
self.fisher_sample_size = fisher_sample_size
self.ewc_lambda = ewc_lambda
self.online_gamma = online_gamma
# Store consolidated Fisher and optimal params
self.fisher: dict[str, torch.Tensor] = {}
self.optimal_params: dict[str, torch.Tensor] = {}
self.task_count = 0
def compute_fisher(self, dataloader: DataLoader):
"""
Compute empirical Fisher Information Matrix (diagonal approximation).
F_i = E[(∂ log p(y|x,θ) / ∂θ_i)²]
Uses squared gradients from negative log-likelihood over old data samples.
"""
self.model.eval()
fisher = {
name: torch.zeros_like(param)
for name, param in self.model.named_parameters()
if param.requires_grad
}
n_samples = 0
for batch in dataloader:
if n_samples >= self.fisher_sample_size:
break
input_ids = batch["input_ids"].to(self.model.device)
attention_mask = batch["attention_mask"].to(self.model.device)
labels = batch["labels"].to(self.model.device)
self.model.zero_grad()
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)
loss = outputs.loss
loss.backward()
for name, param in self.model.named_parameters():
if param.requires_grad and param.grad is not None:
fisher[name] += param.grad.data ** 2
n_samples += input_ids.shape[0]
# Average over samples
for name in fisher:
fisher[name] /= n_samples
return fisher
def consolidate(self, dataloader: DataLoader):
"""
After training on a task, consolidate knowledge:
1. Compute Fisher Information on current task
2. Store optimal parameters
3. If online EWC, merge with previous Fisher
"""
new_fisher = self.compute_fisher(dataloader)
if self.task_count == 0:
# First task — just store
self.fisher = new_fisher
else:
# Online EWC: merge with running estimate
for name in new_fisher:
self.fisher[name] = (
self.online_gamma * self.fisher[name] + new_fisher[name]
)
# Store current optimal parameters
self.optimal_params = {
name: param.data.clone()
for name, param in self.model.named_parameters()
if param.requires_grad
}
self.task_count += 1
def ewc_loss(self) -> torch.Tensor:
"""
Compute EWC penalty: (λ/2) Σ_i F_i (θ_i - θ_i*)²
Returns 0 if no task has been consolidated yet.
"""
if self.task_count == 0:
return torch.tensor(0.0, device=self.model.device)
loss = torch.tensor(0.0, device=self.model.device)
for name, param in self.model.named_parameters():
if name in self.fisher:
loss += (
self.fisher[name] * (param - self.optimal_params[name]) ** 2
).sum()
return (self.ewc_lambda / 2) * loss
def train_step(self, batch, optimizer):
"""Single training step with EWC regularization."""
self.model.train()
input_ids = batch["input_ids"].to(self.model.device)
attention_mask = batch["attention_mask"].to(self.model.device)
labels = batch["labels"].to(self.model.device)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)
task_loss = outputs.loss
ewc_penalty = self.ewc_loss()
total_loss = task_loss + ewc_penalty
optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
optimizer.step()
return {
"total_loss": total_loss.item(),
"task_loss": task_loss.item(),
"ewc_penalty": ewc_penalty.item(),
}
Experience Replay Implementation
import random
from collections import defaultdict
class ReplayBuffer:
"""
Stratified replay buffer that maintains balanced class representation.
Uses reservoir sampling for efficient memory-bounded updates.
"""
def __init__(self, max_per_class: int = 50):
self.max_per_class = max_per_class
self.buffer: dict[int, list] = defaultdict(list)
self.counts: dict[int, int] = defaultdict(int)
def add_examples(self, examples: list[dict]):
"""
Add examples using reservoir sampling.
Ensures each class has at most max_per_class examples.
"""
for example in examples:
label = example["label"]
self.counts[label] += 1
if len(self.buffer[label]) < self.max_per_class:
self.buffer[label].append(example)
else:
# Reservoir sampling: replace with probability max_per_class / count
idx = random.randint(0, self.counts[label] - 1)
if idx < self.max_per_class:
self.buffer[label][idx] = example
def sample(self, n: int) -> list[dict]:
"""Sample n examples with uniform class distribution."""
all_classes = list(self.buffer.keys())
per_class = max(1, n // len(all_classes))
samples = []
for cls in all_classes:
cls_samples = random.sample(
self.buffer[cls],
min(per_class, len(self.buffer[cls])),
)
samples.extend(cls_samples)
random.shuffle(samples)
return samples[:n]
@property
def total_size(self) -> int:
return sum(len(v) for v in self.buffer.values())
def stats(self) -> dict:
return {cls: len(examples) for cls, examples in self.buffer.items()}
class ContinualTrainer:
"""
Combines EWC regularization with experience replay.
This hybrid approach outperforms either technique alone.
"""
def __init__(
self,
model,
ewc_lambda: float = 500.0,
replay_ratio: float = 0.3,
replay_buffer_per_class: int = 50,
):
self.model = model
self.ewc = EWCTrainer(model, ewc_lambda=ewc_lambda)
self.replay_buffer = ReplayBuffer(max_per_class=replay_buffer_per_class)
self.replay_ratio = replay_ratio
def train_on_new_task(
self,
train_data,
epochs: int = 5,
lr: float = 2e-5,
):
"""
Train on new task data with EWC + replay.
Each batch is constructed as:
- (1 - replay_ratio) from new data
- replay_ratio from replay buffer
"""
optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr)
new_batch_size = int(32 * (1 - self.replay_ratio)) # 22
replay_batch_size = 32 - new_batch_size # 10
for epoch in range(epochs):
epoch_loss = {"total": 0, "task": 0, "ewc": 0}
n_steps = 0
new_loader = DataLoader(train_data, batch_size=new_batch_size, shuffle=True)
for new_batch in new_loader:
# Construct mixed batch
if self.replay_buffer.total_size > 0:
replay_samples = self.replay_buffer.sample(replay_batch_size)
batch = self._merge_batches(new_batch, replay_samples)
else:
batch = new_batch
metrics = self.ewc.train_step(batch, optimizer)
epoch_loss["total"] += metrics["total_loss"]
epoch_loss["task"] += metrics["task_loss"]
epoch_loss["ewc"] += metrics["ewc_penalty"]
n_steps += 1
for k in epoch_loss:
epoch_loss[k] /= max(n_steps, 1)
print(
f"Epoch {epoch+1}: total={epoch_loss['total']:.4f} "
f"task={epoch_loss['task']:.4f} ewc={epoch_loss['ewc']:.4f}"
)
# After training, consolidate and update replay buffer
self.ewc.consolidate(DataLoader(train_data, batch_size=32))
self.replay_buffer.add_examples(
[{"input_ids": x["input_ids"], "attention_mask": x["attention_mask"],
"label": x["labels"]} for x in train_data]
)
def _merge_batches(self, new_batch, replay_samples):
"""Merge new data batch with replay buffer samples."""
replay_ids = torch.stack([s["input_ids"] for s in replay_samples])
replay_masks = torch.stack([s["attention_mask"] for s in replay_samples])
replay_labels = torch.tensor([s["label"] for s in replay_samples])
return {
"input_ids": torch.cat([new_batch["input_ids"], replay_ids]),
"attention_mask": torch.cat([new_batch["attention_mask"], replay_masks]),
"labels": torch.cat([new_batch["labels"], replay_labels]),
}
Monthly Retraining Pipeline
import json
from datetime import datetime
import mlflow
def monthly_retrain_pipeline(
model_path: str,
new_data_path: str,
old_data_path: str,
fisher_path: str = None,
):
"""
MangaAssist monthly retraining pipeline with catastrophic forgetting prevention.
Steps:
1. Load current production model and Fisher (if exists)
2. Detect if new intents have emerged
3. Train with EWC + replay
4. Validate against old-task benchmarks
5. Deploy only if old-task regression < 1%
"""
model = AutoModelForSequenceClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Initialize continual trainer
trainer = ContinualTrainer(
model=model,
ewc_lambda=500.0,
replay_ratio=0.3,
replay_buffer_per_class=50,
)
# Load previous Fisher if available
if fisher_path:
fisher_data = torch.load(fisher_path)
trainer.ewc.fisher = fisher_data["fisher"]
trainer.ewc.optimal_params = fisher_data["optimal_params"]
trainer.ewc.task_count = fisher_data["task_count"]
# Load replay buffer from previous tasks
if old_data_path:
old_data = load_dataset(old_data_path)
trainer.replay_buffer.add_examples(old_data)
# Train on new data
new_data = load_dataset(new_data_path)
trainer.train_on_new_task(new_data, epochs=5, lr=2e-5)
# Validation gate: check old-task regression
old_eval = evaluate_model(model, tokenizer, old_data_path)
new_eval = evaluate_model(model, tokenizer, new_data_path)
with mlflow.start_run():
mlflow.log_metric("old_task_accuracy", old_eval["accuracy"])
mlflow.log_metric("new_task_accuracy", new_eval["accuracy"])
mlflow.log_metric("combined_accuracy",
0.8 * old_eval["accuracy"] + 0.2 * new_eval["accuracy"])
# Deployment gate
REGRESSION_THRESHOLD = 0.01 # Max 1% old-task accuracy drop
if old_eval["accuracy"] >= (0.921 - REGRESSION_THRESHOLD):
print(f"✓ Old-task accuracy {old_eval['accuracy']:.4f} within threshold")
model.save_pretrained(f"./models/intent_v{datetime.now():%Y%m}")
# Save Fisher for next cycle
torch.save({
"fisher": trainer.ewc.fisher,
"optimal_params": trainer.ewc.optimal_params,
"task_count": trainer.ewc.task_count,
}, f"./models/fisher_v{datetime.now():%Y%m}.pt")
else:
raise ValueError(
f"✗ Old-task accuracy {old_eval['accuracy']:.4f} below threshold "
f"{0.921 - REGRESSION_THRESHOLD:.4f}. Rejecting model update."
)
Group Discussion: Key Decision Points
Decision Point 1: EWC vs Experience Replay vs Combined
Priya (ML Engineer): I tested all three approaches on our January→March retraining scenario:
| Method | Old-Task Accuracy | New-Task Accuracy | Combined | Training Time |
|---|---|---|---|---|
| Naive fine-tune | 87.3% (-4.8%) | 93.1% | 88.5% | 15 min |
| EWC only (λ=1000) | 91.8% (-0.3%) | 89.5% | 91.3% | 20 min |
| Replay only (30%) | 91.5% (-0.6%) | 90.2% | 91.2% | 18 min |
| EWC + Replay (hybrid) | 91.9% (-0.2%) | 90.5% | 91.6% | 25 min |
| Full retrain (all data) | 92.0% (-0.1%) | 91.0% | 91.8% | 45 min |
Aiko (Data Scientist): The hybrid (EWC + replay) comes within 0.2% of the full retrain in 25 minutes vs 45 minutes. More importantly, full retrain requires storing and processing all historical data, which grows over time. The hybrid approach needs only a 500-example replay buffer (fixed) plus the Fisher matrix (66M floats = 264MB, fixed).
Marcus (Architect): From an infrastructure perspective, the hybrid is clearly superior at scale. After 12 months of monthly retraining, full retrain processes 12K+ examples. The hybrid processes only ~1.7K (new data + 500 replay) regardless of how many months have passed.
Jordan (MLOps): Full retrain also re-introduces risk of old data quality issues. If there was a labeling error in month 3, full retrain keeps propagating it. The hybrid naturally downweights old data through the Fisher's online decay ($\gamma = 0.9$).
Sam (PM): The 0.2% quality gap between hybrid (91.6%) and full retrain (91.8%) is negligible. But the operational simplicity of not needing all historical data is a significant advantage.
Resolution: Hybrid EWC + Replay for monthly retraining. EWC provides parameter-level protection (prevents overwriting critical weights). Replay provides data-level coverage (ensures all classes are represented in training). Combined, they achieve 99.8% of full retrain quality at 56% of the training time with constant memory.
Decision Point 2: EWC Lambda Selection
Priya (ML Engineer): Lambda controls the strength of consolidation. I swept it:
| Lambda | Old Accuracy | New Accuracy | Combined | Behavior |
|---|---|---|---|---|
| 0 | 87.3% | 93.1% | 88.5% | No consolidation (naive) |
| 100 | 89.5% | 91.2% | 89.8% | Mild — some forgetting persists |
| 500 | 91.2% | 90.8% | 91.1% | Moderate — good balance |
| 1000 | 91.8% | 89.5% | 91.3% | Strong — favors old task |
| 5000 | 92.0% | 83.4% | 90.3% | Too strong — blocks new learning |
| 10000 | 92.1% | 72.1% | 88.1% | Nearly frozen — cannot adapt |
Aiko (Data Scientist): The optimal lambda depends on the old/new data ratio. With 5K old and 1.2K new, the new data is 19% of the total. If the ratio were 50/50, we would need a lower lambda (around 200) because the new task gradient is stronger and needs less damping.
A useful heuristic: $\lambda \approx \frac{N_{\text{old}}}{N_{\text{new}}} \times 200$. For us: $\frac{5000}{1200} \times 200 \approx 833$. We round to $\lambda = 1000$ for the combined metric optimum.
Jordan (MLOps): Can we auto-tune lambda? In production, we cannot afford a sweep every month.
Priya (ML Engineer): Yes — we can use a validation set from the old task (200 held-out examples from the replay buffer). Train for 1 epoch at lambdas [100, 500, 1000, 2000], evaluate on old-task validation, and pick the lambda with <0.5% degradation. This adds 10 minutes to the pipeline.
Resolution: Default $\lambda = 1000$ with auto-tuning via 1-epoch validation sweep. The sweep adds minimal overhead and adapts lambda to the specific new/old data ratio each month.
Decision Point 3: When to Add New Intents vs Expand Existing
Sam (PM): The "gift_recommendation" pattern: do we add an 11th intent class, or expand the "recommendation" intent to cover gifts?
Aiko (Data Scientist): This is a taxonomy decision with ML implications. Adding a new class is cleaner if the new intent has a distinct distribution — different vocabulary, different expected response. Expanding an existing class is safer if the new pattern is a sub-type.
I analyzed the "gift_recommendation" queries against "recommendation":
| Feature | recommendation | gift_recommendation |
|---|---|---|
| Mean query length | 14.2 tokens | 18.7 tokens |
| Unique vocabulary | 412 terms | 287 terms (83% overlap with reco) |
| Typical structure | "Recommend manga like X" | "Gift for someone who likes X" |
| Expected response | Ranked list | List + age-appropriateness + gift-wrapping |
Marcus (Architect): The 83% vocabulary overlap suggests it is a sub-type of recommendation, but the response format differs significantly. The downstream system needs to know if it should add gift-wrapping info or not.
Priya (ML Engineer): From an ML perspective, adding an 11th class is lower risk for continual learning — the new class occupies its own region of the embedding space. Expanding "recommendation" risks confusing the boundary between regular and gift recommendations, which could degrade both.
Jordan (MLOps): A new class also means updating the classification head: 10→11 outputs. This requires re-initializing the last layer's 11th row (random) and fine-tuning it while keeping the other 10 stable. EWC handles this naturally — the Fisher for the 11th class starts at zero, allowing free adaptation.
Resolution: Add as 11th intent class. The response format difference justifies a separate category. EWC naturally handles the expanding head: high Fisher on existing 10 dimensions (preserve), zero Fisher on the 11th (adapt freely). Re-evaluate quarterly whether to further split or merge intents.
Decision Point 4: Validation Gate Threshold
Jordan (MLOps): Our deployment gate rejects model updates if old-task accuracy drops by more than X%. What should X be?
Sam (PM): From a product perspective, our SLA guarantees 90% routing accuracy. Current production is 92.1%, giving us 2.1% headroom before SLA breach.
Priya (ML Engineer): But accuracy is not the only metric. A 1% overall accuracy drop could hide a 10% drop on one specific intent (like "return" in our naive retrain example). I would add per-intent thresholds in addition to the overall threshold.
Aiko (Data Scientist): Proposed gate:
| Metric | Threshold | Rationale |
|---|---|---|
| Overall accuracy vs previous | ≤ 1% drop | Headroom preservation |
| Any single intent accuracy | ≤ 3% drop | Prevents catastrophic single-class degradation |
| New intent accuracy | ≥ 85% | Ensures new pattern is actually learned |
| Confidence calibration (ECE) | ≤ 0.05 | Prevents overconfident wrong predictions |
Marcus (Architect): I agree with Aiko's multi-metric gate. A model that passes on overall accuracy but fails on one intent would be caught by the per-intent check. This is defense in depth.
Resolution: Multi-metric validation gate. Reject deployment if any metric exceeds threshold. Log all metrics to MLflow with automated alerting. If rejected, fall back to the previous month's model (which is always kept as rollback candidate).
Research Paper References
1. Overcoming Catastrophic Forgetting in Neural Networks — EWC (Kirkpatrick et al., 2017)
Key contribution: Introduced Elastic Weight Consolidation using the diagonal Fisher Information Matrix as an importance measure for each parameter. Demonstrated that protecting high-Fisher parameters while allowing low-Fisher parameters to adapt preserves old-task knowledge while learning new tasks.
Relevance to MangaAssist: Core of our continual learning strategy. The Fisher reveals that our DistilBERT's upper layers (classification head, layer 6) have 1000× higher importance than lower layers (embeddings), creating a natural per-layer protection gradient.
2. Progressive Neural Networks (Rusu et al., 2016)
Key contribution: Proposed adding new lateral columns for each new task while freezing old columns. This completely prevents forgetting (zero interference) at the cost of linearly growing model size. Introduced lateral connections for knowledge transfer between columns.
Relevance to MangaAssist: We considered progressive nets for our intent classifier but rejected them due to latency concerns — each new column adds ~5ms inference time. However, the lateral connection concept informs our adapter-based approach (docs 04, 11) where task-specific adapters are lightweight "columns" that do not affect base model latency.
3. Avalanche: An End-to-End Library for Continual Learning (Lomonaco et al., 2021)
Key contribution: Open-source framework implementing 20+ continual learning strategies (EWC, SI, LwF, replay, progressive nets) with standardized benchmarks and evaluation protocols. Provides plug-and-play strategies that can be combined.
Relevance to MangaAssist: We use Avalanche for benchmarking and ablation studies. Its standardized evaluation protocol (Average Accuracy, Forward Transfer, Backward Transfer metrics) gives us comparable metrics across strategies.
4. Continual Lifelong Learning with Neural Networks (Parisi et al., 2019)
Key contribution: Comprehensive survey of catastrophic forgetting in neural networks. Categories: regularization-based (EWC, SI), replay-based (experience replay, generative replay), and architecture-based (progressive nets, PackNet). Key finding: hybrid approaches outperform pure strategies.
Relevance to MangaAssist: Informed our decision to combine EWC + replay. The survey's analysis shows that regularization alone can be brittle (sensitive to lambda), and replay alone suffers from buffer bias. The hybrid addresses both weaknesses.
Production Deployment Results
Monthly Retrain Results (6-Month Trajectory)
| Month | New Data | Method | Old Accuracy | New Accuracy | Combined | EWC λ |
|---|---|---|---|---|---|---|
| Jan (baseline) | 5K (initial) | Full train | — | 92.1% | 92.1% | — |
| Feb | 400 (minor) | EWC + Replay | 91.9% | 88.4% | 91.2% | 1000 |
| Mar | 1.2K (gift spike) | EWC + Replay | 91.6% | 90.5% | 91.4% | 833 |
| Apr | 600 (anime season) | EWC + Replay | 91.4% | 89.8% | 91.1% | 1000 |
| May | 300 (slow month) | EWC + Replay | 91.3% | 87.2% | 90.5% | 1500 |
| Jun | 900 (summer sale) | EWC + Replay | 91.1% | 90.1% | 90.9% | 800 |
Key observations: - Old-task accuracy degrades only 1.0% over 6 months (92.1% → 91.1%) — 5× better than naive fine-tuning which would degrade ~4% per cycle - New-task accuracy consistently above 87% — model successfully learns new patterns - Combined accuracy stable around 91% — well above the 90% SLA threshold - Online EWC Fisher consolidation keeps memory constant (264MB) regardless of task count
Per-Intent Stability
| Intent | Jan Acc | Jun Acc | Max Monthly Drop |
|---|---|---|---|
| product_inquiry | 93.4% | 92.8% | -0.4% |
| order_status | 95.1% | 94.2% | -0.5% |
| recommendation | 88.7% | 91.2% | +2.5% (improved) |
| return | 91.2% | 90.4% | -0.8% |
| shipping | 90.8% | 90.1% | -0.6% |
| gift_recommendation | — | 89.5% | N/A (new) |
| No intent exceeded 1% monthly drop | ✓ |