LOCAL PREVIEW View on GitHub

08. Sentiment Classifier Fine-Tuning — Frustration Detection for Escalation

Problem Statement and MangaAssist Context

MangaAssist needs a sentiment classifier that detects frustrated customers in real time so the system can escalate to a human agent before the customer abandons. The default DistilBERT sentiment model (fine-tuned on SST-2) classifies positive/negative but misses manga-domain frustration signals: "I've been waiting 3 weeks for volume 14 and still nothing" reads as neutral to SST-2, but any manga buyer recognizes the frustration. This document covers fine-tuning DistilBERT for multi-label sentiment on manga customer queries, including gradual unfreezing, multi-label classification with BCE loss, threshold optimization, and the unique challenge of domain-specific emotional language.

Current Sentiment Performance

Metric SST-2 Pretrained Target (Manga Fine-Tuned)
Frustration recall 62.4% ≥ 90%
Satisfaction precision 78.1% ≥ 85%
Multi-label accuracy (exact match) 41.2% ≥ 75%
Latency (P50) 8ms ≤ 12ms

Why Multi-Label?

A single message can express multiple sentiments: - "The manga arrived damaged — but the story is amazing" → frustration + satisfaction - "I love this series but when will volume 5 restock?" → satisfaction + urgency - "Never buying from here again" → frustration + churn_risk

Labels: frustration, satisfaction, urgency, confusion, churn_risk, neutral.


Mathematical Foundations

Binary Cross-Entropy for Multi-Label Classification

In multi-class, we use softmax + cross-entropy (labels are mutually exclusive). In multi-label, each label is an independent binary prediction — we use sigmoid + Binary Cross-Entropy (BCE):

$$\mathcal{L}{\text{BCE}} = -\frac{1}{L} \sum{l=1}^{L} \left[ y_l \log(\sigma(z_l)) + (1 - y_l) \log(1 - \sigma(z_l)) \right]$$

where: - $L = 6$ (number of sentiment labels) - $y_l \in {0, 1}$ is the ground truth for label $l$ - $z_l$ is the logit (raw model output) for label $l$ - $\sigma(z) = \frac{1}{1 + e^{-z}}$ is the sigmoid function

Why BCE, not cross-entropy? Cross-entropy forces a probability distribution (sums to 1 via softmax). BCE treats each label independently — sigmoid outputs are independent probabilities in $[0, 1]$ that need not sum to 1. A message can be 90% frustration AND 85% urgency simultaneously.

Gradient Analysis of BCE

The gradient of BCE with respect to logit $z_l$:

$$\frac{\partial \mathcal{L}}{\partial z_l} = \frac{1}{L} (\sigma(z_l) - y_l)$$

This is clean and intuitive: - If $y_l = 1$ (label is active): gradient = $\sigma(z_l) - 1 < 0$ → pushes logit up - If $y_l = 0$ (label inactive): gradient = $\sigma(z_l) - 0 > 0$ → pushes logit down - Gradient magnitude = prediction error for that label

Class imbalance problem: In our dataset, "neutral" appears in 45% of messages, "frustration" in 18%, "churn_risk" in 3%. BCE weights false positives and false negatives equally. A model that predicts "never churn_risk" achieves 97% label accuracy for that label — but misses every actual churn risk.

Focal Loss for Imbalanced Multi-Label

To handle class imbalance, we use focal loss (Lin et al., 2017) adapted for BCE:

$$\mathcal{L}{\text{focal}} = -\frac{1}{L} \sum{l=1}^{L} \alpha_l \left[ y_l (1-p_l)^\gamma \log(p_l) + (1-y_l) p_l^\gamma \log(1-p_l) \right]$$

where $p_l = \sigma(z_l)$, $\gamma$ is the focusing parameter, and $\alpha_l$ is the class weight.

How $\gamma$ works: - When $\gamma = 0$: standard BCE - When $\gamma = 2$ (common): if $p_l = 0.9$ (confident correct), the $(1-0.9)^2 = 0.01$ scaling reduces the loss by 100×. Rare but important errors get amplified.

$\gamma$ Easy example ($p=0.95$) Hard example ($p=0.3$) Focus ratio
0 0.051 1.204 23.4×
1 0.003 0.843 337×
2 0.0001 0.590 4,712×

At $\gamma = 2$, the model focuses 4,712× more on hard examples relative to easy ones.

Per-label weights $\alpha_l$:

$$\alpha_l = \frac{\text{median frequency}}{\text{frequency of label } l}$$

Label Frequency $\alpha_l$
neutral 45% 0.40
frustration 18% 1.0 (median)
satisfaction 15% 1.20
urgency 12% 1.50
confusion 7% 2.57
churn_risk 3% 6.00

churn_risk gets 6× the gradient weight, compensating for its 15× lower frequency relative to neutral.

Gradual Unfreezing — ULMFiT Strategy

Howard & Ruder (2018) showed that unfreezing transformer layers gradually — from top to bottom — prevents catastrophic forgetting of pretrained features while allowing domain adaptation.

DistilBERT has 6 transformer layers:

Layer Captures When to Unfreeze Learning Rate
Embedding + Layer 0-1 Token/positional, basic syntax Epoch 4+ $\eta_{\text{base}} / 2.6^5$
Layer 2-3 Phrase-level patterns, negation Epoch 3 $\eta_{\text{base}} / 2.6^3$
Layer 4-5 Sentence-level semantics Epoch 2 $\eta_{\text{base}} / 2.6$
Classification head Sentiment-specific features Epoch 1 $\eta_{\text{base}}$

Discriminative learning rates: Each layer group gets a progressively lower learning rate. With $\eta_{\text{base}} = 3 \times 10^{-5}$ and decay factor 2.6:

Layer Group Learning Rate Relative
Classification head $3.0 \times 10^{-5}$ 1.0×
Layer 4-5 $1.15 \times 10^{-5}$ 0.38×
Layer 2-3 $4.44 \times 10^{-6}$ 0.15×
Layer 0-1 $1.71 \times 10^{-6}$ 0.057×
Embeddings $6.57 \times 10^{-7}$ 0.022×

Why 2.6? In ULMFiT, Howard & Ruder found 2.6 empirically optimal across tasks. The intuition: lower layers capture universal features (syntax, morphology) that should change minimally. Higher layers capture task-specific features that must adapt significantly.

Threshold Optimization

Unlike multi-class (argmax), multi-label classification requires choosing a threshold per label. The default 0.5 is rarely optimal.

Per-label threshold search:

For each label $l$, sweep thresholds $t \in [0.1, 0.9]$ and optimize a metric:

$$t_l^* = \arg\max_t F_\beta(l, t)$$

where $F_\beta$ balances precision and recall:

$$F_\beta = (1 + \beta^2) \frac{\text{precision} \cdot \text{recall}}{\beta^2 \cdot \text{precision} + \text{recall}}$$

For frustration and churn_risk, we use $\beta = 2$ (recall-focused — better to false-positive than miss a frustrated customer). For satisfaction, $\beta = 0.5$ (precision-focused — avoid false attribution of happiness).

Optimized thresholds (from our validation set):

Label Default (0.5) F1 Optimal $t^*$ Optimized F1 Change
frustration 0.72 0.35 0.81 +12.5%
satisfaction 0.74 0.55 0.77 +4.1%
urgency 0.68 0.40 0.76 +11.8%
confusion 0.61 0.32 0.73 +19.7%
churn_risk 0.53 0.28 0.69 +30.2%
neutral 0.82 0.48 0.83 +1.2%

Lower thresholds for rare labels (churn_risk: 0.28) dramatically improve recall at the cost of some precision — exactly the tradeoff we want for escalation-critical signals.


Model Internals — Layer-by-Layer Diagrams

Gradual Unfreezing Schedule

graph TB
    subgraph "Epoch 1: Only head trainable"
        E1_EMB["Embeddings ❄️"]
        E1_L01["Layers 0-1 ❄️"]
        E1_L23["Layers 2-3 ❄️"]
        E1_L45["Layers 4-5 ❄️"]
        E1_HEAD["Classification Head 🔥<br>lr = 3e-5"]
        E1_EMB --> E1_L01 --> E1_L23 --> E1_L45 --> E1_HEAD
    end

    subgraph "Epoch 2: Top layers unfrozen"
        E2_EMB["Embeddings ❄️"]
        E2_L01["Layers 0-1 ❄️"]
        E2_L23["Layers 2-3 ❄️"]
        E2_L45["Layers 4-5 🔥<br>lr = 1.15e-5"]
        E2_HEAD["Classification Head 🔥<br>lr = 3e-5"]
        E2_EMB --> E2_L01 --> E2_L23 --> E2_L45 --> E2_HEAD
    end

    subgraph "Epoch 3: Middle layers unfrozen"
        E3_EMB["Embeddings ❄️"]
        E3_L01["Layers 0-1 ❄️"]
        E3_L23["Layers 2-3 🔥<br>lr = 4.4e-6"]
        E3_L45["Layers 4-5 🔥<br>lr = 1.15e-5"]
        E3_HEAD["Classification Head 🔥<br>lr = 3e-5"]
        E3_EMB --> E3_L01 --> E3_L23 --> E3_L45 --> E3_HEAD
    end

    subgraph "Epoch 4+: All layers unfrozen"
        E4_EMB["Embeddings 🔥<br>lr = 6.6e-7"]
        E4_L01["Layers 0-1 🔥<br>lr = 1.7e-6"]
        E4_L23["Layers 2-3 🔥<br>lr = 4.4e-6"]
        E4_L45["Layers 4-5 🔥<br>lr = 1.15e-5"]
        E4_HEAD["Classification Head 🔥<br>lr = 3e-5"]
        E4_EMB --> E4_L01 --> E4_L23 --> E4_L45 --> E4_HEAD
    end

    style E1_HEAD fill:#ffcdd2
    style E2_L45 fill:#fff9c4
    style E2_HEAD fill:#ffcdd2
    style E3_L23 fill:#c8e6c9
    style E3_L45 fill:#fff9c4
    style E3_HEAD fill:#ffcdd2
    style E4_EMB fill:#e1bee7
    style E4_L01 fill:#e1bee7
    style E4_L23 fill:#c8e6c9
    style E4_L45 fill:#fff9c4
    style E4_HEAD fill:#ffcdd2

Gradient Magnitude Heatmap Across Layers

graph LR
    subgraph "Gradient Magnitude During Training"
        subgraph "Epoch 1 (head only)"
            G1_E["Embed: 0.000"]
            G1_L0["L0: 0.000"]
            G1_L2["L2: 0.000"]
            G1_L4["L4: 0.000"]
            G1_H["Head: 0.042"]
        end

        subgraph "Epoch 2 (+ top layers)"
            G2_E["Embed: 0.000"]
            G2_L0["L0: 0.000"]
            G2_L2["L2: 0.000"]
            G2_L4["L4: 0.008"]
            G2_H["Head: 0.031"]
        end

        subgraph "Epoch 4 (all layers)"
            G4_E["Embed: 0.001"]
            G4_L0["L0: 0.002"]
            G4_L2["L2: 0.005"]
            G4_L4["L4: 0.012"]
            G4_H["Head: 0.018"]
        end
    end

    style G1_H fill:#ff5252
    style G2_L4 fill:#ffab40
    style G2_H fill:#ff5252
    style G4_E fill:#e8eaf6
    style G4_L0 fill:#c5cae9
    style G4_L2 fill:#fff9c4
    style G4_L4 fill:#ffab40
    style G4_H fill:#ff5252

Multi-Label Classification Architecture

graph TB
    INPUT["Input: 'The manga arrived damaged<br>but the story is amazing'"]

    subgraph "DistilBERT Encoder (66M params)"
        TOK["Tokenizer → [CLS] The manga arrived damaged ..."]
        EMB["Embedding Layer (23M)"]
        L0["Transformer Layer 0"]
        L1["Transformer Layer 1"]
        L2["Transformer Layer 2"]
        L3["Transformer Layer 3"]
        L4["Transformer Layer 4"]
        L5["Transformer Layer 5"]
        CLS["[CLS] representation ∈ ℝ⁷⁶⁸"]
    end

    subgraph "Classification Head (multi-label)"
        DROP["Dropout (0.2)"]
        FC1["Linear: 768 → 256 + ReLU"]
        FC2["Linear: 256 → 6 (one per label)"]
        SIG["Sigmoid (independent per label)"]
    end

    subgraph "Output Probabilities"
        O1["frustration: 0.89 ⬆️"]
        O2["satisfaction: 0.74 ⬆️"]
        O3["urgency: 0.12"]
        O4["confusion: 0.08"]
        O5["churn_risk: 0.31"]
        O6["neutral: 0.05"]
    end

    subgraph "Threshold Gate"
        T1["frustration: 0.89 > 0.35 ✅"]
        T2["satisfaction: 0.74 > 0.55 ✅"]
        T3["urgency: 0.12 < 0.40 ❌"]
        T4["confusion: 0.08 < 0.32 ❌"]
        T5["churn_risk: 0.31 > 0.28 ✅"]
        T6["neutral: 0.05 < 0.48 ❌"]
    end

    INPUT --> TOK --> EMB --> L0 --> L1 --> L2 --> L3 --> L4 --> L5 --> CLS
    CLS --> DROP --> FC1 --> FC2 --> SIG
    SIG --> O1 & O2 & O3 & O4 & O5 & O6
    O1 --> T1
    O2 --> T2
    O3 --> T3
    O4 --> T4
    O5 --> T5
    O6 --> T6

    style O1 fill:#ffcdd2
    style O2 fill:#c8e6c9
    style O5 fill:#fff9c4
    style T1 fill:#ffcdd2
    style T2 fill:#c8e6c9
    style T5 fill:#fff9c4

Focal Loss Effect on Gradient Distribution

graph TB
    subgraph "Standard BCE: All examples weighted equally"
        BCE_EASY["Easy examples (p > 0.8)<br>70% of data<br>Total gradient share: 70%"]
        BCE_HARD["Hard examples (p < 0.5)<br>15% of data<br>Total gradient share: 15%"]
        BCE_RARE["Rare class hard examples<br>3% of data (churn_risk)<br>Total gradient share: 3%"]
    end

    subgraph "Focal Loss (γ=2): Hard examples amplified"
        FL_EASY["Easy examples (p > 0.8)<br>70% of data<br>Gradient share: 12% ⬇️"]
        FL_HARD["Hard examples (p < 0.5)<br>15% of data<br>Gradient share: 48% ⬆️"]
        FL_RARE["Rare class hard examples<br>3% of data (churn_risk)<br>Gradient share: 40% ⬆️"]
    end

    BCE_EASY -->|"Focal<br>reweighting"| FL_EASY
    BCE_HARD -->|"γ=2"| FL_HARD
    BCE_RARE -->|"γ=2 + α=6.0"| FL_RARE

    style BCE_RARE fill:#ffcdd2
    style FL_RARE fill:#c8e6c9
    style FL_HARD fill:#fff9c4

Threshold Optimization Visualization

graph TD
    subgraph "Per-Label Threshold Search"
        S["sweep t ∈ [0.1, 0.9]<br>step = 0.05"]

        subgraph "frustration (β=2, recall-focused)"
            F1["t=0.50 → P=0.82 R=0.64 F2=0.67"]
            F2["t=0.35 → P=0.71 R=0.86 F2=0.82 ◀ OPTIMAL"]
            F3["t=0.20 → P=0.58 R=0.93 F2=0.83"]
        end

        subgraph "churn_risk (β=2, recall-focused)"
            C1["t=0.50 → P=0.78 R=0.42 F2=0.46"]
            C2["t=0.28 → P=0.55 R=0.81 F2=0.74 ◀ OPTIMAL"]
            C3["t=0.15 → P=0.32 R=0.91 F2=0.67"]
        end

        subgraph "satisfaction (β=0.5, precision-focused)"
            SA1["t=0.40 → P=0.69 R=0.85 F0.5=0.71"]
            SA2["t=0.55 → P=0.83 R=0.72 F0.5=0.81 ◀ OPTIMAL"]
            SA3["t=0.70 → P=0.91 R=0.54 F0.5=0.82"]
        end
    end

    S --> F1 & F2 & F3
    S --> C1 & C2 & C3
    S --> SA1 & SA2 & SA3

    style F2 fill:#c8e6c9
    style C2 fill:#c8e6c9
    style SA2 fill:#c8e6c9

Implementation Deep-Dive

Multi-Label DistilBERT with Gradual Unfreezing

import torch
import torch.nn as nn
from transformers import (
    DistilBertModel,
    DistilBertTokenizer,
    get_linear_schedule_with_warmup,
)
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.metrics import f1_score, precision_recall_fscore_support


LABELS = ["frustration", "satisfaction", "urgency", "confusion", "churn_risk", "neutral"]
NUM_LABELS = len(LABELS)


class MultiLabelSentimentModel(nn.Module):
    """Multi-label sentiment classifier with per-label sigmoid outputs."""

    def __init__(self, model_name: str = "distilbert-base-uncased", hidden_dim: int = 256):
        super().__init__()
        self.encoder = DistilBertModel.from_pretrained(model_name)
        self.head = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(768, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, NUM_LABELS),
        )

    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        cls_emb = outputs.last_hidden_state[:, 0, :]
        logits = self.head(cls_emb)
        return logits  # (batch, NUM_LABELS) — raw logits, apply sigmoid externally


class FocalBCELoss(nn.Module):
    """Focal Binary Cross-Entropy for multi-label with class imbalance."""

    def __init__(self, gamma: float = 2.0, alpha: list[float] = None):
        super().__init__()
        self.gamma = gamma
        if alpha is not None:
            self.alpha = torch.tensor(alpha, dtype=torch.float32)
        else:
            self.alpha = None

    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        probs = torch.sigmoid(logits)

        # Standard BCE components
        pos_loss = targets * torch.log(probs + 1e-8)
        neg_loss = (1 - targets) * torch.log(1 - probs + 1e-8)

        # Focal modulation
        pos_focal = (1 - probs) ** self.gamma
        neg_focal = probs ** self.gamma

        loss = -(pos_focal * pos_loss + neg_focal * neg_loss)

        # Apply per-label weights
        if self.alpha is not None:
            alpha = self.alpha.to(logits.device)
            loss = loss * alpha.unsqueeze(0)

        return loss.mean()


class SentimentDataset(Dataset):
    def __init__(self, texts: list[str], labels: list[list[int]], tokenizer, max_length: int = 128):
        self.encodings = tokenizer(
            texts, padding=True, truncation=True,
            max_length=max_length, return_tensors="pt",
        )
        self.labels = torch.tensor(labels, dtype=torch.float32)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            "input_ids": self.encodings["input_ids"][idx],
            "attention_mask": self.encodings["attention_mask"][idx],
            "labels": self.labels[idx],
        }

Gradual Unfreezing Trainer

class GradualUnfreezingTrainer:
    """
    Implements ULMFiT-style gradual unfreezing with discriminative LRs.
    """

    def __init__(
        self,
        model: MultiLabelSentimentModel,
        train_loader: DataLoader,
        val_loader: DataLoader,
        base_lr: float = 3e-5,
        decay_factor: float = 2.6,
        num_epochs: int = 8,
        gamma: float = 2.0,
        label_weights: list[float] = None,
    ):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.base_lr = base_lr
        self.decay_factor = decay_factor
        self.num_epochs = num_epochs
        self.criterion = FocalBCELoss(gamma=gamma, alpha=label_weights)

        # Layer groups for unfreezing schedule
        self.layer_groups = self._get_layer_groups()

    def _get_layer_groups(self):
        """Group model parameters by layer depth."""
        encoder = self.model.encoder
        groups = {
            "embeddings": list(encoder.embeddings.parameters()),
            "layer_0_1": (
                list(encoder.transformer.layer[0].parameters())
                + list(encoder.transformer.layer[1].parameters())
            ),
            "layer_2_3": (
                list(encoder.transformer.layer[2].parameters())
                + list(encoder.transformer.layer[3].parameters())
            ),
            "layer_4_5": (
                list(encoder.transformer.layer[4].parameters())
                + list(encoder.transformer.layer[5].parameters())
            ),
            "head": list(self.model.head.parameters()),
        }
        return groups

    def _freeze_all_except(self, active_groups: list[str]):
        """Freeze all parameters except those in active groups."""
        for name, params in self.layer_groups.items():
            requires_grad = name in active_groups
            for p in params:
                p.requires_grad = requires_grad

    def _build_optimizer(self, active_groups: list[str]):
        """Build optimizer with discriminative learning rates."""
        lr_multipliers = {
            "head": 1.0,
            "layer_4_5": 1.0 / self.decay_factor,
            "layer_2_3": 1.0 / (self.decay_factor ** 2),
            "layer_0_1": 1.0 / (self.decay_factor ** 3),
            "embeddings": 1.0 / (self.decay_factor ** 4),
        }

        param_groups = []
        for name in active_groups:
            param_groups.append({
                "params": self.layer_groups[name],
                "lr": self.base_lr * lr_multipliers[name],
            })

        return torch.optim.AdamW(param_groups, weight_decay=0.01)

    def _get_unfreeze_schedule(self) -> dict[int, list[str]]:
        """Define which layer groups are active at each epoch."""
        return {
            0: ["head"],                                           # Epoch 1
            1: ["head", "layer_4_5"],                             # Epoch 2
            2: ["head", "layer_4_5", "layer_2_3"],                # Epoch 3
            3: ["head", "layer_4_5", "layer_2_3", "layer_0_1", "embeddings"],  # Epoch 4+
        }

    def train(self):
        schedule = self._get_unfreeze_schedule()
        best_f1 = 0
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(device)

        for epoch in range(self.num_epochs):
            # Determine active groups
            active_key = min(epoch, max(schedule.keys()))
            active_groups = schedule[active_key]
            self._freeze_all_except(active_groups)

            # Rebuild optimizer with current unfreezing
            optimizer = self._build_optimizer(active_groups)
            scheduler = get_linear_schedule_with_warmup(
                optimizer,
                num_warmup_steps=len(self.train_loader) // 10,
                num_training_steps=len(self.train_loader),
            )

            # Training loop
            self.model.train()
            total_loss = 0
            for batch in self.train_loader:
                batch = {k: v.to(device) for k, v in batch.items()}
                logits = self.model(batch["input_ids"], batch["attention_mask"])
                loss = self.criterion(logits, batch["labels"])

                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                total_loss += loss.item()

            # Validation
            val_metrics = self._evaluate(device)
            f1_macro = val_metrics["f1_macro"]

            active_str = ", ".join(active_groups)
            print(
                f"Epoch {epoch+1}/{self.num_epochs} | "
                f"Active: [{active_str}] | "
                f"Loss: {total_loss/len(self.train_loader):.4f} | "
                f"F1: {f1_macro:.4f}"
            )

            if f1_macro > best_f1:
                best_f1 = f1_macro
                torch.save(self.model.state_dict(), "best_sentiment_model.pt")

        return best_f1

    def _evaluate(self, device) -> dict:
        """Evaluate with optimized per-label thresholds."""
        self.model.eval()
        all_logits, all_labels = [], []

        with torch.no_grad():
            for batch in self.val_loader:
                batch = {k: v.to(device) for k, v in batch.items()}
                logits = self.model(batch["input_ids"], batch["attention_mask"])
                all_logits.append(logits.cpu())
                all_labels.append(batch["labels"].cpu())

        all_logits = torch.cat(all_logits)
        all_labels = torch.cat(all_labels)
        probs = torch.sigmoid(all_logits)

        # Apply default threshold
        preds = (probs > 0.5).int().numpy()
        labels = all_labels.int().numpy()

        p, r, f1, _ = precision_recall_fscore_support(labels, preds, average="macro")

        per_label = {}
        for i, label in enumerate(LABELS):
            lp, lr, lf, _ = precision_recall_fscore_support(
                labels[:, i], preds[:, i], average="binary"
            )
            per_label[label] = {"precision": lp, "recall": lr, "f1": lf}

        return {"f1_macro": f1, "precision": p, "recall": r, "per_label": per_label}

Threshold Optimization

from sklearn.metrics import fbeta_score


def optimize_thresholds(
    probs: np.ndarray,    # (N, L)
    labels: np.ndarray,   # (N, L)
    betas: dict[str, float] = None,
) -> dict[str, float]:
    """
    Find optimal per-label thresholds using F-beta search.

    betas: {label_name: beta_value}. Higher beta = more recall-focused.
    """
    if betas is None:
        betas = {
            "frustration": 2.0,    # Recall-focused: catch frustrated users
            "satisfaction": 0.5,    # Precision-focused: avoid false positives
            "urgency": 2.0,        # Recall-focused
            "confusion": 1.5,      # Slightly recall-focused
            "churn_risk": 2.0,     # Recall-focused: never miss churn risk
            "neutral": 1.0,        # Balanced
        }

    optimal_thresholds = {}

    for i, label_name in enumerate(LABELS):
        beta = betas[label_name]
        best_score = 0
        best_threshold = 0.5

        for threshold in np.arange(0.1, 0.91, 0.05):
            preds = (probs[:, i] > threshold).astype(int)
            score = fbeta_score(labels[:, i], preds, beta=beta)

            if score > best_score:
                best_score = score
                best_threshold = threshold

        optimal_thresholds[label_name] = {
            "threshold": round(best_threshold, 2),
            "fbeta_score": round(best_score, 4),
            "beta": beta,
        }

    return optimal_thresholds


# After training
thresholds = optimize_thresholds(val_probs, val_labels)
# {
#   "frustration": {"threshold": 0.35, "fbeta_score": 0.82, "beta": 2.0},
#   "churn_risk":  {"threshold": 0.28, "fbeta_score": 0.74, "beta": 2.0},
#   ...
# }

SageMaker Training Pipeline

import sagemaker
from sagemaker.huggingface import HuggingFace


def launch_sentiment_training(
    train_s3: str,
    val_s3: str,
    instance_type: str = "ml.g4dn.xlarge",
):
    """Launch multi-label sentiment fine-tuning on SageMaker."""
    hyperparameters = {
        "model_name": "distilbert-base-uncased",
        "num_epochs": 8,
        "batch_size": 32,
        "base_lr": 3e-5,
        "decay_factor": 2.6,
        "gamma": 2.0,
        "label_weights": "0.4,1.0,1.2,1.5,2.57,6.0",
        "gradual_unfreezing": True,
    }

    estimator = HuggingFace(
        entry_point="train_sentiment.py",
        source_dir="./src/sentiment/",
        instance_type=instance_type,
        instance_count=1,
        role=sagemaker.get_execution_role(),
        transformers_version="4.37",
        pytorch_version="2.1",
        py_version="py310",
        hyperparameters=hyperparameters,
    )

    estimator.fit({
        "train": train_s3,
        "validation": val_s3,
    })

    return estimator

Group Discussion: Key Decision Points

Decision Point 1: Multi-Label vs Multi-Class

Priya (ML Engineer): The fundamental question: should we use multi-class (one label per message) or multi-label (any combination)?

I tested both on our 5,000-message labeled dataset:

Approach Exact Match Hamming Loss Frustration Recall
Multi-class (6 exclusive) 68.4% 74.2%
Multi-class (21 combinations) 52.1% 71.8%
Multi-label BCE 76.2% 0.089 86.1%
Multi-label Focal BCE 74.8% 0.082 89.7%

Aiko (Data Scientist): 23% of our messages have 2+ labels. Multi-class with combinations creates a sparse classification problem (21 classes, some with <50 examples). Multi-label with independent per-label predictions is both simpler and more accurate.

Marcus (Architect): Multi-label also allows independent threshold tuning per label. We want aggressive frustration detection (low threshold, high recall) but conservative satisfaction attribution (high threshold, high precision). Multi-class makes this impossible.

Sam (PM): One concern: the escalation trigger uses frustration + churn_risk. With multi-label, do we OR them (either triggers escalation) or AND them (both needed)?

Jordan (MLOps): We use a priority-weighted OR. If frustration > 0.35 OR churn_risk > 0.28, escalate. But the urgency of escalation increases if both are active (different SLA for the human agent response).

Resolution: Multi-label with focal BCE. Independent thresholds per label, priority-weighted OR for escalation triggers. This achieves 89.7% frustration recall vs 74.2% with multi-class.

Decision Point 2: Gradual Unfreezing vs Full Fine-Tune

Priya (ML Engineer): I compared unfreezing strategies:

Strategy Val F1 (macro) Frustration F1 Old-task retention Training time
Freeze all, head only 0.71 0.68 100% 12 min
Full fine-tune, uniform LR 0.79 0.81 89% 25 min
Full fine-tune, discriminative LR 0.82 0.84 94% 25 min
Gradual unfreezing + disc. LR 0.84 0.87 97% 35 min

Aiko (Data Scientist): Gradual unfreezing wins on all metrics. The 97% old-task retention means we preserve the encoder's general language understanding while adding manga-specific sentiment features.

The mathematical intuition: lower layers capture universal syntactic features (negation, intensifiers like "never" or "absolutely"). These features are already useful for sentiment — they just need slight domain adaptation, hence the very low learning rate. Upper layers need significant restructuring to map manga-specific frustration signals (e.g., "volume 14 still not available") to sentiment labels.

Marcus (Architect): The 10 extra minutes of training is worth 97% vs 89% old-task retention. If we lose general sentiment understanding, the model fails on non-manga queries from users who also buy other products.

Resolution: Gradual unfreezing with discriminative LRs. The 35 min training time is acceptable for weekly retrains. For emergency updates (e.g., major product issue), we fall back to head-only training (12 min, 0.71 F1) and schedule a full retrain.

Decision Point 3: Focal Loss Hyperparameters

Aiko (Data Scientist): $\gamma$ and $\alpha$ interact in non-obvious ways:

$\gamma$ $\alpha$ weighting Frustration F1 churn_risk F1 Macro F1
0 (standard BCE) None 0.72 0.53 0.71
0 Frequency-inverse 0.76 0.62 0.75
2 None 0.79 0.61 0.78
2 Frequency-inverse 0.87 0.69 0.84
3 Frequency-inverse 0.85 0.72 0.83
5 Frequency-inverse 0.78 0.73 0.79

Priya (ML Engineer): $\gamma=2$ with $\alpha$ weighting gives the best macro F1. Increasing $\gamma$ to 5 helps churn_risk (+4%) but hurts frustration (-9%) — the model over-focuses on the hardest rare examples and underfits common patterns.

Sam (PM): churn_risk F1 of 0.69 concerns me. Can we use a separate model for churn detection?

Aiko (Data Scientist): The 0.69 F1 is actually good for a 3% base rate label. With the optimized threshold (0.28), recall is 81% — we catch 4 out of 5 churn-risk messages. Perfect recall is impossible because some churn signals require conversation history, not just the current message.

Resolution: $\gamma=2$, $\alpha$ = frequency-inverse weights. For churn_risk specifically, we augment the model signal with a rule-based detector (keywords: "cancel", "never again", "done with") to achieve combined recall of 91%.

Decision Point 4: Threshold Strategy

Jordan (MLOps): Per-label thresholds add complexity. Are they truly worth it?

Priya (ML Engineer): Absolutely. The threshold optimization table shows churn_risk improves from 0.53 to 0.69 F2 by lowering the threshold from 0.50 to 0.28. This is a 30% jump from a simple post-processing step that costs nothing in inference latency.

Marcus (Architect): Thresholds should be stored in DynamoDB, not hardcoded. This way our escalation team can adjust sensitivity without model retraining. Lowering the frustration threshold from 0.35 to 0.30 during a known outage (lots of frustrated users) makes operational sense.

Sam (PM): Agreed. Threshold tuning should be an operational control, not an ML control. The ML team delivers the model; the operations team tunes the thresholds based on false-positive rates they observe.

Resolution: Per-label thresholds stored in DynamoDB with a 10-second ElastiCache TTL. Operations team can adjust thresholds via an admin dashboard. ML team provides recommended thresholds; operations team overrides based on business context.


Research Paper References

1. Universal Language Model Fine-tuning for Text Classification — ULMFiT (Howard & Ruder, 2018)

Key contribution: Introduced three techniques that now underpin all transformer fine-tuning: (1) discriminative fine-tuning (different LRs per layer), (2) slanted triangular learning rate schedule, (3) gradual unfreezing. Showed that these techniques together prevent catastrophic forgetting during fine-tuning and achieve state-of-the-art on 6 text classification benchmarks.

Relevance to MangaAssist: Our gradual unfreezing schedule directly implements ULMFiT's approach adapted for DistilBERT's 6-layer architecture. The discriminative LR with decay factor 2.6 is from ULMFiT's optimized hyperparameters. Gradual unfreezing improved our old-task retention from 89% to 97%.

2. Focal Loss for Dense Object Detection (Lin et al., 2017)

Key contribution: Originally designed for object detection (addressing the extreme foreground-background imbalance), focal loss down-weights easy examples and focuses training on hard negatives. The $(1-p_t)^\gamma$ modulation is elegant: it is a smooth, differentiable, and parameter-efficient way to reshape the loss landscape.

Relevance to MangaAssist: Adapted for multi-label BCE to handle our 15:1 imbalance between neutral (45%) and churn_risk (3%). Focal loss with $\gamma=2$ improved churn_risk F1 from 0.53 to 0.69, which directly impacts our ability to detect and escalate at-risk customers.

3. Multi-Task Learning Using Uncertainty to Weigh Losses (Kendall et al., 2018)

Key contribution: Used homoscedastic uncertainty to learn relative task weights automatically. The key insight: the optimal weight for a task's loss is inversely proportional to the task's observation noise. Noisier tasks get lower weights, preventing them from dominating the gradient.

Relevance to MangaAssist: Although we use fixed per-label weights in our current implementation, Kendall's uncertainty weighting is the natural extension. If we observe that certain labels have inherently noisy annotations (human annotators disagree on "confusion" vs "neutral" 30% of the time), learned uncertainty weights would down-weight the noisy labels automatically. This is planned for V2.

4. Asymmetric Loss For Multi-Label Classification (Ridnik et al., 2021)

Key contribution: Extended focal loss with asymmetric focusing parameters: $\gamma^+$ for positive samples (typically 0) and $\gamma^-$ for negative samples (typically 4). This addresses the specific challenge of multi-label: most labels are negative for most examples, creating massive negative:positive imbalance per label.

Relevance to MangaAssist: our frustration label is positive in only 18% of messages. Asymmetric loss with $\gamma^-=4, \gamma^+=1$ further improved frustration recall by 2.3% in our ablation. Planned for production after the next retrain cycle.


Production Evaluation and Deployment

Ablation Study Results

Configuration Macro F1 Frustration Recall churn_risk F1 Latency
SST-2 pretrained (baseline) 0.54 62.4% 0.18 8ms
+ Head-only training 0.71 74.2% 0.48 8ms
+ Full fine-tune, uniform LR 0.79 81.3% 0.58 8ms
+ Discriminative LR 0.82 84.1% 0.62 8ms
+ Gradual unfreezing 0.84 87.3% 0.64 8ms
+ Focal loss (γ=2) 0.84 89.7% 0.69 8ms
+ Threshold optimization 0.86 91.2% 0.74 8ms
+ Rule boost for churn_risk 91.2% 0.79 9ms

Each technique contributes meaningfully. The full stack achieves 91.2% frustration recall (from 62.4% baseline) with no latency increase.

Escalation Impact

Metric Before (SST-2) After (Manga-Tuned) Change
Frustrated users escalated 62.4% 91.2% +28.8%
False escalations per day 12 28 +16
Average resolution time 8.2 min 5.1 min -37.8%
Customer satisfaction (escalated) 3.⅖ 4.⅕ +28.1%
Monthly churn rate 4.8% 3.1% -35.4%

The 16 extra false escalations per day cost ~$32 in human agent time. The 1.7% churn reduction saves ~$12K/month in customer lifetime value. ROI: 375:1.