LOCAL PREVIEW View on GitHub

06. Continual Learning and Catastrophic Forgetting — Retraining Without Losing Knowledge

Problem Statement and MangaAssist Context

MangaAssist's models must evolve continuously. Every month, new manga titles launch, seasonal trends shift (holiday gift guides, anime adaptation spikes), and user vocabulary drifts ("rizz" in manga reviews, new genre-fusions like "isekai-horror"). If we naively fine-tune on new data, the model "forgets" old knowledge — a phenomenon called catastrophic forgetting.

The Forgetting Problem in Production

Our intent classifier was fine-tuned in January 2024 on 5K labeled examples covering 10 intents. In March, we collected 1.2K new examples with a new pattern: "gift recommendation" queries surged 340% (White Day in Japan). Naively fine-tuning on these 1.2K new examples:

Intent Accuracy Before After Naive Retrain Delta
product_inquiry 93.4% 88.1% -5.3%
order_status 95.1% 86.2% -8.9%
recommendation 88.7% 94.2% +5.5%
return 91.2% 82.4% -8.8%
gift_recommendation (new) 91.5% New
Overall 92.1% 87.3% -4.8%

The model improved on recommendations and gifts but degraded severely on return and order_status — exactly the catastrophic forgetting pattern. We need techniques that learn new knowledge without destroying old knowledge.

The Stability-Plasticity Dilemma

Every learning system faces a fundamental tradeoff: - Stability: Preserve existing knowledge (good at old tasks) - Plasticity: Adapt to new information (good at new tasks)

Full plasticity (naive fine-tuning) → catastrophic forgetting. Full stability (freeze everything) → cannot learn new patterns. The solution lies in methods that find the optimal balance.


Mathematical Foundations

Why Catastrophic Forgetting Happens — The Geometric Perspective

Consider the model's parameters $\boldsymbol{\theta} \in \mathbb{R}^n$. After training on Task A (original 5K data), the model sits at $\boldsymbol{\theta}_A^*$ — a local minimum of the Task A loss surface.

When we train on Task B (new 1.2K data), gradient descent moves $\boldsymbol{\theta}$ downhill on the Task B loss surface:

$$\boldsymbol{\theta}{t+1} = \boldsymbol{\theta}_t - \eta \nabla{\boldsymbol{\theta}} \mathcal{L}_B(\boldsymbol{\theta}_t)$$

Problem: the gradient $\nabla \mathcal{L}_B$ has no knowledge of Task A. If the Task B gradient pushes parameters away from the region where Task A performs well, we forget Task A.

Why some parameters matter more: Not all parameters are equally important for Task A. Some parameters encode critical decision boundaries (e.g., the weight separating "order_status" from "return"), while others encode minor variations. Moving a critical parameter by even 0.01 can flip predictions; moving a non-critical parameter by 1.0 has no effect.

Elastic Weight Consolidation (EWC) — The Fisher Information Approach

EWC (Kirkpatrick et al., 2017) adds a regularization term that penalizes changes to parameters that are important for the old task:

$$\mathcal{L}{\text{EWC}} = \mathcal{L}_B(\boldsymbol{\theta}) + \frac{\lambda}{2} \sum{i=1}^{n} F_i (\theta_i - \theta_{A,i}^*)^2$$

where: - $\mathcal{L}B(\boldsymbol{\theta})$ is the new task loss (standard cross-entropy on new data) - $\theta{A,i}^*$ is the $i$-th parameter's value after training on old task - $F_i$ is the Fisher Information of parameter $i$ (how important it is for the old task) - $\lambda$ controls the strength of consolidation

The Fisher Information Matrix:

The Fisher Information for parameter $i$ measures the curvature of the loss surface at $\boldsymbol{\theta}_A^*$ along the $i$-th dimension:

$$F_i = \mathbb{E}_{\mathbf{x} \sim \mathcal{D}_A}\left[\left(\frac{\partial \log p(\mathbf{y} | \mathbf{x}, \boldsymbol{\theta})}{\partial \theta_i}\right)^2\right]$$

Intuition: $F_i$ measures how sensitive the model's predictions are to changes in $\theta_i$. - High $F_i$: The loss surface is sharply curved around $\theta_i$ — small changes cause large quality drops. EWC will strongly resist moving this parameter. - Low $F_i$: The loss surface is flat around $\theta_i$ — this parameter can be changed freely without hurting old task quality. EWC allows adaptation here.

Practical computation: Computing the exact Fisher requires iterating over all training data. We approximate it using a sample of old-task data (typically 200-500 examples):

$$\hat{F}i = \frac{1}{N} \sum{n=1}^{N} \left(\frac{\partial \log p(\mathbf{y}_n | \mathbf{x}_n, \boldsymbol{\theta}_A^*)}{\partial \theta_i}\right)^2$$

For DistilBERT with 66M parameters, computing Fisher takes ~5 minutes on a single GPU using 500 examples.

EWC Gradient Analysis

The gradient of the EWC loss with respect to parameter $\theta_i$:

$$\frac{\partial \mathcal{L}{\text{EWC}}}{\partial \theta_i} = \frac{\partial \mathcal{L}_B}{\partial \theta_i} + \lambda F_i (\theta_i - \theta{A,i}^*)$$

The second term acts as a "spring" that pulls parameters back toward their old-task-optimal values. The spring constant $\lambda F_i$ is proportional to the parameter's importance:

Parameter Type Typical $F_i$ EWC Effect
Classification head (final layer) $10^{-2}$ to $10^{-1}$ Strong resistance to change
Last transformer layer attention $10^{-3}$ to $10^{-2}$ Moderate resistance
Middle transformer layers $10^{-4}$ to $10^{-3}$ Mild resistance
Embedding layer $10^{-5}$ to $10^{-4}$ Weak resistance — adapts freely

This creates a natural layer-wise learning rate effect: lower layers (general features) adapt more freely, while upper layers (task-specific features) are more conserved.

Fisher Information as Loss Curvature

The connection between Fisher Information and loss curvature is precise. For the empirical Fisher:

$$\hat{F} = \frac{1}{N}\sum_{n=1}^{N} \nabla_{\boldsymbol{\theta}} \log p(\mathbf{y}n|\mathbf{x}_n) \nabla{\boldsymbol{\theta}} \log p(\mathbf{y}_n|\mathbf{x}_n)^T$$

This outer-product matrix approximates the Hessian of the negative log-likelihood:

$$\hat{F} \approx -\frac{1}{N}\sum_{n=1}^{N} \nabla_{\boldsymbol{\theta}}^2 \log p(\mathbf{y}_n|\mathbf{x}_n) = \hat{\mathbf{H}}$$

Geometric meaning: The Hessian describes the local curvature of the loss surface. High curvature in dimension $i$ means the loss function is a narrow valley along $\theta_i$ — indicating the current value is precisely tuned and should not be changed. Low curvature means a flat plain — the parameter can vary without affecting the loss.

Experience Replay — Data-Level Approach

Instead of regularizing weights, experience replay simply mixes old data with new data during training:

$$\mathcal{L}{\text{replay}} = \frac{1}{|\mathcal{B}_B|} \sum{(\mathbf{x},y) \in \mathcal{B}B} \ell(\mathbf{x}, y) + \frac{\mu}{|\mathcal{B}_A|} \sum{(\mathbf{x},y) \in \mathcal{B}_A} \ell(\mathbf{x}, y)$$

where $\mathcal{B}_A$ is a replay buffer of old examples and $\mu$ controls the replay weight.

Buffer selection strategies:

Strategy Buffer Size Quality Bias Risk
Random sampling 10% of old data Good Low
Herding (exemplars per class) $K$ per class Better Balanced
Uncertainty-based (highest loss) 10% of old data Best for boundaries May overweight hard examples
Stratified random Equal per class Good None

For our 5K → 1.2K data scenario, a replay buffer of 500 examples (10% of old data, ~50 per intent) combined with all 1.2K new examples gives an effective training set of 1.7K with balanced representation.

Online EWC — Multiple Sequential Tasks

When we face Task A → Task B → Task C → ..., standard EWC accumulates terms:

$$\mathcal{L} = \mathcal{L}C + \frac{\lambda}{2}\left[\sum_i F_i^A (\theta_i - \theta{A,i}^)^2 + \sum_i F_i^B (\theta_i - \theta_{B,i}^)^2\right]$$

This grows linearly with the number of tasks. Online EWC consolidates all previous tasks into a single running estimate:

$$\hat{F}_i^{\text{online}} = \gamma \hat{F}_i^{\text{online}} + F_i^{\text{current}}$$

where $\gamma \in [0,1)$ is a decay factor that downweights older tasks. This keeps memory constant regardless of the number of tasks.

For MangaAssist's monthly retraining cycle, $\gamma = 0.9$ means the model retains 90% of the previous Fisher estimate (which already accounts for all older tasks) plus the new task's Fisher.


Model Internals — Layer-by-Layer Diagrams

Catastrophic Forgetting: Weight Space Trajectory

graph TB
    subgraph "Parameter Space ℝⁿ"
        A["θ_A*<br>Task A optimum<br>(92.1% on 10 intents)"]
        B["θ_B*<br>Task B optimum<br>(91.5% on new data but<br>87.3% on old intents)"]
        C["θ_EWC<br>EWC compromise<br>(91.8% old + 89.5% new<br>= 91.1% overall)"]
        D["θ_replay<br>Replay compromise<br>(91.5% old + 90.2% new<br>= 91.2% overall)"]

        A -->|"Naive fine-tune<br>(forgetting path)"| B
        A -->|"EWC λ=1000<br>(constrained path)"| C
        A -->|"Experience Replay<br>(data-balanced path)"| D
    end

    subgraph "Task A Performance Along Path"
        E["Start: 92.1%"]
        F["Naive: drops to 87.3%<br>(-4.8%)"]
        G["EWC: drops to 91.8%<br>(-0.3%)"]
        H["Replay: drops to 91.5%<br>(-0.6%)"]
    end

    style A fill:#c8e6c9
    style B fill:#ffcdd2
    style C fill:#c8e6c9
    style D fill:#c8e6c9

Fisher Information per Layer

graph TB
    subgraph "Fisher Information Magnitude by Layer (DistilBERT)"
        EMB["Embedding Layer<br>F_avg = 2.3×10⁻⁵<br>→ LOW importance<br>General token representations<br>EWC: allows free movement"]
        L1["Layer 1<br>F_avg = 8.1×10⁻⁵<br>→ LOW importance<br>Basic syntax patterns<br>EWC: mild constraint"]
        L2["Layer 2<br>F_avg = 3.4×10⁻⁴<br>→ MEDIUM importance<br>Phrase-level features<br>EWC: moderate constraint"]
        L3["Layer 3<br>F_avg = 1.2×10⁻³<br>→ MEDIUM-HIGH<br>Semantic grouping<br>EWC: significant constraint"]
        L4["Layer 4<br>F_avg = 5.7×10⁻³<br>→ HIGH importance<br>Intent-discriminative features<br>EWC: strong constraint"]
        L5["Layer 5<br>F_avg = 1.8×10⁻²<br>→ HIGH importance<br>Task-specific representations<br>EWC: strong constraint"]
        L6["Layer 6<br>F_avg = 4.2×10⁻²<br>→ VERY HIGH<br>Classification patterns<br>EWC: very strong constraint"]
        CLS["Classification Head<br>F_avg = 8.5×10⁻²<br>→ CRITICAL<br>Decision boundaries<br>EWC: maximum constraint"]

        EMB --> L1 --> L2 --> L3 --> L4 --> L5 --> L6 --> CLS
    end

    style EMB fill:#e8f5e9
    style L1 fill:#e8f5e9
    style L2 fill:#fff9c4
    style L3 fill:#fff9c4
    style L4 fill:#ffe0b2
    style L5 fill:#ffccbc
    style L6 fill:#ffcdd2
    style CLS fill:#ef9a9a

EWC vs Naive Training: Gradient Comparison

sequenceDiagram
    participant Batch as New Data Batch<br>"Gift me a manga<br>recommendation"
    participant Model as DistilBERT<br>Parameters θ
    participant CE as CrossEntropy Loss<br>L_B(θ)
    participant EWC as EWC Regularizer<br>λ/2 · Σ Fᵢ(θᵢ-θᵢ*)²
    participant Update as Parameter Update

    Batch->>Model: Forward pass
    Model->>CE: Logits → CE Loss
    Note over CE: ∂L_B/∂θ pushes toward<br>new task optimum<br>(may destroy old knowledge)

    Model->>EWC: Current θ vs stored θ_A*
    Note over EWC: ∂L_EWC/∂θᵢ = λ·Fᵢ·(θᵢ-θᵢ*)<br>Spring force pulling toward<br>old task optimum<br>Strength ∝ Fisher Fᵢ

    CE->>Update: ∂L_B/∂θ (new task gradient)
    EWC->>Update: λ·F·(θ-θ*) (consolidation gradient)

    Note over Update: Total gradient:<br>∂L_B/∂θ + λ·F·(θ-θ*)<br><br>For high-F params: EWC dominates → minimal change<br>For low-F params: CE dominates → free to adapt

    Update->>Model: θ ← θ - η · total_gradient

Experience Replay: Training Batch Construction

graph LR
    subgraph "New Task Data (1.2K examples)"
        N1["recommendation: 450"]
        N2["gift_recommendation: 380"]
        N3["product_inquiry: 210"]
        N4["others: 160"]
    end

    subgraph "Replay Buffer (500 from old task)"
        O1["order_status: 55"]
        O2["return: 52"]
        O3["product_inquiry: 50"]
        O4["shipping: 48"]
        O5["complaint: 50"]
        O6["recommendation: 50"]
        O7["account: 48"]
        O8["cancel: 50"]
        O9["exchange: 49"]
        O10["greeting: 48"]
    end

    subgraph "Training Batch (size 32)"
        B1["20 from new task (62.5%)"]
        B2["12 from replay buffer (37.5%)"]
    end

    N1 --> B1
    N2 --> B1
    O1 --> B2
    O2 --> B2
    O3 --> B2

    B1 --> LOSS["Combined CE Loss<br>= L_new + μ·L_replay<br>μ = 1.0 (equal weight)"]
    B2 --> LOSS

    style B1 fill:#fff9c4
    style B2 fill:#bbdefb

Progressive Training Schedule

graph TB
    subgraph "Month 1: Initial Training"
        M1["Train on 5K labeled examples<br>10 intents, 92.1% accuracy<br>Save θ₁* and F₁"]
    end

    subgraph "Month 2: Seasonal Update"
        M2A["New data: 1.2K (gift reco spike)"]
        M2B["Replay buffer: 500 (from Month 1)"]
        M2C["EWC: λ₁·F₁·(θ-θ₁*)²"]
        M2D["Result: 91.2% overall + gift_reco<br>Save θ₂* and F₂ = 0.9·F₁ + F₂_new"]
        M2A --> M2D
        M2B --> M2D
        M2C --> M2D
    end

    subgraph "Month 3: Trend Shift"
        M3A["New data: 800 (anime adaptation queries)"]
        M3B["Replay buffer: 500 (from Months 1+2)"]
        M3C["Online EWC: λ·F₂·(θ-θ₂*)²"]
        M3D["Result: 90.8% on old + new patterns<br>Save θ₃* and F₃ = 0.9·F₂ + F₃_new"]
        M3A --> M3D
        M3B --> M3D
        M3C --> M3D
    end

    M1 --> M2A
    M2D --> M3A

    style M1 fill:#c8e6c9
    style M2D fill:#c8e6c9
    style M3D fill:#c8e6c9

Implementation Deep-Dive

EWC Implementation

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import copy
from collections import defaultdict


class EWCTrainer:
    """
    Elastic Weight Consolidation for continual learning.

    Computes Fisher Information from old-task data, then adds a
    quadratic penalty when training on new data.
    """

    def __init__(
        self,
        model: nn.Module,
        fisher_sample_size: int = 500,
        ewc_lambda: float = 1000.0,
        online_gamma: float = 0.9,
    ):
        self.model = model
        self.fisher_sample_size = fisher_sample_size
        self.ewc_lambda = ewc_lambda
        self.online_gamma = online_gamma

        # Store consolidated Fisher and optimal params
        self.fisher: dict[str, torch.Tensor] = {}
        self.optimal_params: dict[str, torch.Tensor] = {}
        self.task_count = 0

    def compute_fisher(self, dataloader: DataLoader):
        """
        Compute empirical Fisher Information Matrix (diagonal approximation).

        F_i = E[(∂ log p(y|x,θ) / ∂θ_i)²]

        Uses squared gradients from negative log-likelihood over old data samples.
        """
        self.model.eval()
        fisher = {
            name: torch.zeros_like(param)
            for name, param in self.model.named_parameters()
            if param.requires_grad
        }

        n_samples = 0
        for batch in dataloader:
            if n_samples >= self.fisher_sample_size:
                break

            input_ids = batch["input_ids"].to(self.model.device)
            attention_mask = batch["attention_mask"].to(self.model.device)
            labels = batch["labels"].to(self.model.device)

            self.model.zero_grad()
            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
            )
            loss = outputs.loss
            loss.backward()

            for name, param in self.model.named_parameters():
                if param.requires_grad and param.grad is not None:
                    fisher[name] += param.grad.data ** 2

            n_samples += input_ids.shape[0]

        # Average over samples
        for name in fisher:
            fisher[name] /= n_samples

        return fisher

    def consolidate(self, dataloader: DataLoader):
        """
        After training on a task, consolidate knowledge:
        1. Compute Fisher Information on current task
        2. Store optimal parameters
        3. If online EWC, merge with previous Fisher
        """
        new_fisher = self.compute_fisher(dataloader)

        if self.task_count == 0:
            # First task — just store
            self.fisher = new_fisher
        else:
            # Online EWC: merge with running estimate
            for name in new_fisher:
                self.fisher[name] = (
                    self.online_gamma * self.fisher[name] + new_fisher[name]
                )

        # Store current optimal parameters
        self.optimal_params = {
            name: param.data.clone()
            for name, param in self.model.named_parameters()
            if param.requires_grad
        }

        self.task_count += 1

    def ewc_loss(self) -> torch.Tensor:
        """
        Compute EWC penalty: (λ/2) Σ_i F_i (θ_i - θ_i*)²

        Returns 0 if no task has been consolidated yet.
        """
        if self.task_count == 0:
            return torch.tensor(0.0, device=self.model.device)

        loss = torch.tensor(0.0, device=self.model.device)
        for name, param in self.model.named_parameters():
            if name in self.fisher:
                loss += (
                    self.fisher[name] * (param - self.optimal_params[name]) ** 2
                ).sum()

        return (self.ewc_lambda / 2) * loss

    def train_step(self, batch, optimizer):
        """Single training step with EWC regularization."""
        self.model.train()

        input_ids = batch["input_ids"].to(self.model.device)
        attention_mask = batch["attention_mask"].to(self.model.device)
        labels = batch["labels"].to(self.model.device)

        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )

        task_loss = outputs.loss
        ewc_penalty = self.ewc_loss()
        total_loss = task_loss + ewc_penalty

        optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        optimizer.step()

        return {
            "total_loss": total_loss.item(),
            "task_loss": task_loss.item(),
            "ewc_penalty": ewc_penalty.item(),
        }

Experience Replay Implementation

import random
from collections import defaultdict


class ReplayBuffer:
    """
    Stratified replay buffer that maintains balanced class representation.
    Uses reservoir sampling for efficient memory-bounded updates.
    """

    def __init__(self, max_per_class: int = 50):
        self.max_per_class = max_per_class
        self.buffer: dict[int, list] = defaultdict(list)
        self.counts: dict[int, int] = defaultdict(int)

    def add_examples(self, examples: list[dict]):
        """
        Add examples using reservoir sampling.
        Ensures each class has at most max_per_class examples.
        """
        for example in examples:
            label = example["label"]
            self.counts[label] += 1

            if len(self.buffer[label]) < self.max_per_class:
                self.buffer[label].append(example)
            else:
                # Reservoir sampling: replace with probability max_per_class / count
                idx = random.randint(0, self.counts[label] - 1)
                if idx < self.max_per_class:
                    self.buffer[label][idx] = example

    def sample(self, n: int) -> list[dict]:
        """Sample n examples with uniform class distribution."""
        all_classes = list(self.buffer.keys())
        per_class = max(1, n // len(all_classes))
        samples = []
        for cls in all_classes:
            cls_samples = random.sample(
                self.buffer[cls],
                min(per_class, len(self.buffer[cls])),
            )
            samples.extend(cls_samples)
        random.shuffle(samples)
        return samples[:n]

    @property
    def total_size(self) -> int:
        return sum(len(v) for v in self.buffer.values())

    def stats(self) -> dict:
        return {cls: len(examples) for cls, examples in self.buffer.items()}


class ContinualTrainer:
    """
    Combines EWC regularization with experience replay.
    This hybrid approach outperforms either technique alone.
    """

    def __init__(
        self,
        model,
        ewc_lambda: float = 500.0,
        replay_ratio: float = 0.3,
        replay_buffer_per_class: int = 50,
    ):
        self.model = model
        self.ewc = EWCTrainer(model, ewc_lambda=ewc_lambda)
        self.replay_buffer = ReplayBuffer(max_per_class=replay_buffer_per_class)
        self.replay_ratio = replay_ratio

    def train_on_new_task(
        self,
        train_data,
        epochs: int = 5,
        lr: float = 2e-5,
    ):
        """
        Train on new task data with EWC + replay.

        Each batch is constructed as:
        - (1 - replay_ratio) from new data
        - replay_ratio from replay buffer
        """
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr)

        new_batch_size = int(32 * (1 - self.replay_ratio))  # 22
        replay_batch_size = 32 - new_batch_size              # 10

        for epoch in range(epochs):
            epoch_loss = {"total": 0, "task": 0, "ewc": 0}
            n_steps = 0

            new_loader = DataLoader(train_data, batch_size=new_batch_size, shuffle=True)

            for new_batch in new_loader:
                # Construct mixed batch
                if self.replay_buffer.total_size > 0:
                    replay_samples = self.replay_buffer.sample(replay_batch_size)
                    batch = self._merge_batches(new_batch, replay_samples)
                else:
                    batch = new_batch

                metrics = self.ewc.train_step(batch, optimizer)
                epoch_loss["total"] += metrics["total_loss"]
                epoch_loss["task"] += metrics["task_loss"]
                epoch_loss["ewc"] += metrics["ewc_penalty"]
                n_steps += 1

            for k in epoch_loss:
                epoch_loss[k] /= max(n_steps, 1)

            print(
                f"Epoch {epoch+1}: total={epoch_loss['total']:.4f} "
                f"task={epoch_loss['task']:.4f} ewc={epoch_loss['ewc']:.4f}"
            )

        # After training, consolidate and update replay buffer
        self.ewc.consolidate(DataLoader(train_data, batch_size=32))
        self.replay_buffer.add_examples(
            [{"input_ids": x["input_ids"], "attention_mask": x["attention_mask"],
              "label": x["labels"]} for x in train_data]
        )

    def _merge_batches(self, new_batch, replay_samples):
        """Merge new data batch with replay buffer samples."""
        replay_ids = torch.stack([s["input_ids"] for s in replay_samples])
        replay_masks = torch.stack([s["attention_mask"] for s in replay_samples])
        replay_labels = torch.tensor([s["label"] for s in replay_samples])

        return {
            "input_ids": torch.cat([new_batch["input_ids"], replay_ids]),
            "attention_mask": torch.cat([new_batch["attention_mask"], replay_masks]),
            "labels": torch.cat([new_batch["labels"], replay_labels]),
        }

Monthly Retraining Pipeline

import json
from datetime import datetime
import mlflow


def monthly_retrain_pipeline(
    model_path: str,
    new_data_path: str,
    old_data_path: str,
    fisher_path: str = None,
):
    """
    MangaAssist monthly retraining pipeline with catastrophic forgetting prevention.

    Steps:
    1. Load current production model and Fisher (if exists)
    2. Detect if new intents have emerged
    3. Train with EWC + replay
    4. Validate against old-task benchmarks
    5. Deploy only if old-task regression < 1%
    """
    model = AutoModelForSequenceClassification.from_pretrained(model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path)

    # Initialize continual trainer
    trainer = ContinualTrainer(
        model=model,
        ewc_lambda=500.0,
        replay_ratio=0.3,
        replay_buffer_per_class=50,
    )

    # Load previous Fisher if available
    if fisher_path:
        fisher_data = torch.load(fisher_path)
        trainer.ewc.fisher = fisher_data["fisher"]
        trainer.ewc.optimal_params = fisher_data["optimal_params"]
        trainer.ewc.task_count = fisher_data["task_count"]

    # Load replay buffer from previous tasks
    if old_data_path:
        old_data = load_dataset(old_data_path)
        trainer.replay_buffer.add_examples(old_data)

    # Train on new data
    new_data = load_dataset(new_data_path)
    trainer.train_on_new_task(new_data, epochs=5, lr=2e-5)

    # Validation gate: check old-task regression
    old_eval = evaluate_model(model, tokenizer, old_data_path)
    new_eval = evaluate_model(model, tokenizer, new_data_path)

    with mlflow.start_run():
        mlflow.log_metric("old_task_accuracy", old_eval["accuracy"])
        mlflow.log_metric("new_task_accuracy", new_eval["accuracy"])
        mlflow.log_metric("combined_accuracy",
                         0.8 * old_eval["accuracy"] + 0.2 * new_eval["accuracy"])

    # Deployment gate
    REGRESSION_THRESHOLD = 0.01  # Max 1% old-task accuracy drop
    if old_eval["accuracy"] >= (0.921 - REGRESSION_THRESHOLD):
        print(f"✓ Old-task accuracy {old_eval['accuracy']:.4f} within threshold")
        model.save_pretrained(f"./models/intent_v{datetime.now():%Y%m}")
        # Save Fisher for next cycle
        torch.save({
            "fisher": trainer.ewc.fisher,
            "optimal_params": trainer.ewc.optimal_params,
            "task_count": trainer.ewc.task_count,
        }, f"./models/fisher_v{datetime.now():%Y%m}.pt")
    else:
        raise ValueError(
            f"✗ Old-task accuracy {old_eval['accuracy']:.4f} below threshold "
            f"{0.921 - REGRESSION_THRESHOLD:.4f}. Rejecting model update."
        )

Group Discussion: Key Decision Points

Decision Point 1: EWC vs Experience Replay vs Combined

Priya (ML Engineer): I tested all three approaches on our January→March retraining scenario:

Method Old-Task Accuracy New-Task Accuracy Combined Training Time
Naive fine-tune 87.3% (-4.8%) 93.1% 88.5% 15 min
EWC only (λ=1000) 91.8% (-0.3%) 89.5% 91.3% 20 min
Replay only (30%) 91.5% (-0.6%) 90.2% 91.2% 18 min
EWC + Replay (hybrid) 91.9% (-0.2%) 90.5% 91.6% 25 min
Full retrain (all data) 92.0% (-0.1%) 91.0% 91.8% 45 min

Aiko (Data Scientist): The hybrid (EWC + replay) comes within 0.2% of the full retrain in 25 minutes vs 45 minutes. More importantly, full retrain requires storing and processing all historical data, which grows over time. The hybrid approach needs only a 500-example replay buffer (fixed) plus the Fisher matrix (66M floats = 264MB, fixed).

Marcus (Architect): From an infrastructure perspective, the hybrid is clearly superior at scale. After 12 months of monthly retraining, full retrain processes 12K+ examples. The hybrid processes only ~1.7K (new data + 500 replay) regardless of how many months have passed.

Jordan (MLOps): Full retrain also re-introduces risk of old data quality issues. If there was a labeling error in month 3, full retrain keeps propagating it. The hybrid naturally downweights old data through the Fisher's online decay ($\gamma = 0.9$).

Sam (PM): The 0.2% quality gap between hybrid (91.6%) and full retrain (91.8%) is negligible. But the operational simplicity of not needing all historical data is a significant advantage.

Resolution: Hybrid EWC + Replay for monthly retraining. EWC provides parameter-level protection (prevents overwriting critical weights). Replay provides data-level coverage (ensures all classes are represented in training). Combined, they achieve 99.8% of full retrain quality at 56% of the training time with constant memory.

Decision Point 2: EWC Lambda Selection

Priya (ML Engineer): Lambda controls the strength of consolidation. I swept it:

Lambda Old Accuracy New Accuracy Combined Behavior
0 87.3% 93.1% 88.5% No consolidation (naive)
100 89.5% 91.2% 89.8% Mild — some forgetting persists
500 91.2% 90.8% 91.1% Moderate — good balance
1000 91.8% 89.5% 91.3% Strong — favors old task
5000 92.0% 83.4% 90.3% Too strong — blocks new learning
10000 92.1% 72.1% 88.1% Nearly frozen — cannot adapt

Aiko (Data Scientist): The optimal lambda depends on the old/new data ratio. With 5K old and 1.2K new, the new data is 19% of the total. If the ratio were 50/50, we would need a lower lambda (around 200) because the new task gradient is stronger and needs less damping.

A useful heuristic: $\lambda \approx \frac{N_{\text{old}}}{N_{\text{new}}} \times 200$. For us: $\frac{5000}{1200} \times 200 \approx 833$. We round to $\lambda = 1000$ for the combined metric optimum.

Jordan (MLOps): Can we auto-tune lambda? In production, we cannot afford a sweep every month.

Priya (ML Engineer): Yes — we can use a validation set from the old task (200 held-out examples from the replay buffer). Train for 1 epoch at lambdas [100, 500, 1000, 2000], evaluate on old-task validation, and pick the lambda with <0.5% degradation. This adds 10 minutes to the pipeline.

Resolution: Default $\lambda = 1000$ with auto-tuning via 1-epoch validation sweep. The sweep adds minimal overhead and adapts lambda to the specific new/old data ratio each month.

Decision Point 3: When to Add New Intents vs Expand Existing

Sam (PM): The "gift_recommendation" pattern: do we add an 11th intent class, or expand the "recommendation" intent to cover gifts?

Aiko (Data Scientist): This is a taxonomy decision with ML implications. Adding a new class is cleaner if the new intent has a distinct distribution — different vocabulary, different expected response. Expanding an existing class is safer if the new pattern is a sub-type.

I analyzed the "gift_recommendation" queries against "recommendation":

Feature recommendation gift_recommendation
Mean query length 14.2 tokens 18.7 tokens
Unique vocabulary 412 terms 287 terms (83% overlap with reco)
Typical structure "Recommend manga like X" "Gift for someone who likes X"
Expected response Ranked list List + age-appropriateness + gift-wrapping

Marcus (Architect): The 83% vocabulary overlap suggests it is a sub-type of recommendation, but the response format differs significantly. The downstream system needs to know if it should add gift-wrapping info or not.

Priya (ML Engineer): From an ML perspective, adding an 11th class is lower risk for continual learning — the new class occupies its own region of the embedding space. Expanding "recommendation" risks confusing the boundary between regular and gift recommendations, which could degrade both.

Jordan (MLOps): A new class also means updating the classification head: 10→11 outputs. This requires re-initializing the last layer's 11th row (random) and fine-tuning it while keeping the other 10 stable. EWC handles this naturally — the Fisher for the 11th class starts at zero, allowing free adaptation.

Resolution: Add as 11th intent class. The response format difference justifies a separate category. EWC naturally handles the expanding head: high Fisher on existing 10 dimensions (preserve), zero Fisher on the 11th (adapt freely). Re-evaluate quarterly whether to further split or merge intents.

Decision Point 4: Validation Gate Threshold

Jordan (MLOps): Our deployment gate rejects model updates if old-task accuracy drops by more than X%. What should X be?

Sam (PM): From a product perspective, our SLA guarantees 90% routing accuracy. Current production is 92.1%, giving us 2.1% headroom before SLA breach.

Priya (ML Engineer): But accuracy is not the only metric. A 1% overall accuracy drop could hide a 10% drop on one specific intent (like "return" in our naive retrain example). I would add per-intent thresholds in addition to the overall threshold.

Aiko (Data Scientist): Proposed gate:

Metric Threshold Rationale
Overall accuracy vs previous ≤ 1% drop Headroom preservation
Any single intent accuracy ≤ 3% drop Prevents catastrophic single-class degradation
New intent accuracy ≥ 85% Ensures new pattern is actually learned
Confidence calibration (ECE) ≤ 0.05 Prevents overconfident wrong predictions

Marcus (Architect): I agree with Aiko's multi-metric gate. A model that passes on overall accuracy but fails on one intent would be caught by the per-intent check. This is defense in depth.

Resolution: Multi-metric validation gate. Reject deployment if any metric exceeds threshold. Log all metrics to MLflow with automated alerting. If rejected, fall back to the previous month's model (which is always kept as rollback candidate).


Research Paper References

1. Overcoming Catastrophic Forgetting in Neural Networks — EWC (Kirkpatrick et al., 2017)

Key contribution: Introduced Elastic Weight Consolidation using the diagonal Fisher Information Matrix as an importance measure for each parameter. Demonstrated that protecting high-Fisher parameters while allowing low-Fisher parameters to adapt preserves old-task knowledge while learning new tasks.

Relevance to MangaAssist: Core of our continual learning strategy. The Fisher reveals that our DistilBERT's upper layers (classification head, layer 6) have 1000× higher importance than lower layers (embeddings), creating a natural per-layer protection gradient.

2. Progressive Neural Networks (Rusu et al., 2016)

Key contribution: Proposed adding new lateral columns for each new task while freezing old columns. This completely prevents forgetting (zero interference) at the cost of linearly growing model size. Introduced lateral connections for knowledge transfer between columns.

Relevance to MangaAssist: We considered progressive nets for our intent classifier but rejected them due to latency concerns — each new column adds ~5ms inference time. However, the lateral connection concept informs our adapter-based approach (docs 04, 11) where task-specific adapters are lightweight "columns" that do not affect base model latency.

3. Avalanche: An End-to-End Library for Continual Learning (Lomonaco et al., 2021)

Key contribution: Open-source framework implementing 20+ continual learning strategies (EWC, SI, LwF, replay, progressive nets) with standardized benchmarks and evaluation protocols. Provides plug-and-play strategies that can be combined.

Relevance to MangaAssist: We use Avalanche for benchmarking and ablation studies. Its standardized evaluation protocol (Average Accuracy, Forward Transfer, Backward Transfer metrics) gives us comparable metrics across strategies.

4. Continual Lifelong Learning with Neural Networks (Parisi et al., 2019)

Key contribution: Comprehensive survey of catastrophic forgetting in neural networks. Categories: regularization-based (EWC, SI), replay-based (experience replay, generative replay), and architecture-based (progressive nets, PackNet). Key finding: hybrid approaches outperform pure strategies.

Relevance to MangaAssist: Informed our decision to combine EWC + replay. The survey's analysis shows that regularization alone can be brittle (sensitive to lambda), and replay alone suffers from buffer bias. The hybrid addresses both weaknesses.


Production Deployment Results

Monthly Retrain Results (6-Month Trajectory)

Month New Data Method Old Accuracy New Accuracy Combined EWC λ
Jan (baseline) 5K (initial) Full train 92.1% 92.1%
Feb 400 (minor) EWC + Replay 91.9% 88.4% 91.2% 1000
Mar 1.2K (gift spike) EWC + Replay 91.6% 90.5% 91.4% 833
Apr 600 (anime season) EWC + Replay 91.4% 89.8% 91.1% 1000
May 300 (slow month) EWC + Replay 91.3% 87.2% 90.5% 1500
Jun 900 (summer sale) EWC + Replay 91.1% 90.1% 90.9% 800

Key observations: - Old-task accuracy degrades only 1.0% over 6 months (92.1% → 91.1%) — 5× better than naive fine-tuning which would degrade ~4% per cycle - New-task accuracy consistently above 87% — model successfully learns new patterns - Combined accuracy stable around 91% — well above the 90% SLA threshold - Online EWC Fisher consolidation keeps memory constant (264MB) regardless of task count

Per-Intent Stability

Intent Jan Acc Jun Acc Max Monthly Drop
product_inquiry 93.4% 92.8% -0.4%
order_status 95.1% 94.2% -0.5%
recommendation 88.7% 91.2% +2.5% (improved)
return 91.2% 90.4% -0.8%
shipping 90.8% 90.1% -0.6%
gift_recommendation 89.5% N/A (new)
No intent exceeded 1% monthly drop