08. Sentiment Classifier Fine-Tuning — Frustration Detection for Escalation
Problem Statement and MangaAssist Context
MangaAssist needs a sentiment classifier that detects frustrated customers in real time so the system can escalate to a human agent before the customer abandons. The default DistilBERT sentiment model (fine-tuned on SST-2) classifies positive/negative but misses manga-domain frustration signals: "I've been waiting 3 weeks for volume 14 and still nothing" reads as neutral to SST-2, but any manga buyer recognizes the frustration. This document covers fine-tuning DistilBERT for multi-label sentiment on manga customer queries, including gradual unfreezing, multi-label classification with BCE loss, threshold optimization, and the unique challenge of domain-specific emotional language.
Current Sentiment Performance
| Metric | SST-2 Pretrained | Target (Manga Fine-Tuned) |
|---|---|---|
| Frustration recall | 62.4% | ≥ 90% |
| Satisfaction precision | 78.1% | ≥ 85% |
| Multi-label accuracy (exact match) | 41.2% | ≥ 75% |
| Latency (P50) | 8ms | ≤ 12ms |
Why Multi-Label?
A single message can express multiple sentiments: - "The manga arrived damaged — but the story is amazing" → frustration + satisfaction - "I love this series but when will volume 5 restock?" → satisfaction + urgency - "Never buying from here again" → frustration + churn_risk
Labels: frustration, satisfaction, urgency, confusion, churn_risk, neutral.
Mathematical Foundations
Binary Cross-Entropy for Multi-Label Classification
In multi-class, we use softmax + cross-entropy (labels are mutually exclusive). In multi-label, each label is an independent binary prediction — we use sigmoid + Binary Cross-Entropy (BCE):
$$\mathcal{L}{\text{BCE}} = -\frac{1}{L} \sum{l=1}^{L} \left[ y_l \log(\sigma(z_l)) + (1 - y_l) \log(1 - \sigma(z_l)) \right]$$
where: - $L = 6$ (number of sentiment labels) - $y_l \in {0, 1}$ is the ground truth for label $l$ - $z_l$ is the logit (raw model output) for label $l$ - $\sigma(z) = \frac{1}{1 + e^{-z}}$ is the sigmoid function
Why BCE, not cross-entropy? Cross-entropy forces a probability distribution (sums to 1 via softmax). BCE treats each label independently — sigmoid outputs are independent probabilities in $[0, 1]$ that need not sum to 1. A message can be 90% frustration AND 85% urgency simultaneously.
Gradient Analysis of BCE
The gradient of BCE with respect to logit $z_l$:
$$\frac{\partial \mathcal{L}}{\partial z_l} = \frac{1}{L} (\sigma(z_l) - y_l)$$
This is clean and intuitive: - If $y_l = 1$ (label is active): gradient = $\sigma(z_l) - 1 < 0$ → pushes logit up - If $y_l = 0$ (label inactive): gradient = $\sigma(z_l) - 0 > 0$ → pushes logit down - Gradient magnitude = prediction error for that label
Class imbalance problem: In our dataset, "neutral" appears in 45% of messages, "frustration" in 18%, "churn_risk" in 3%. BCE weights false positives and false negatives equally. A model that predicts "never churn_risk" achieves 97% label accuracy for that label — but misses every actual churn risk.
Focal Loss for Imbalanced Multi-Label
To handle class imbalance, we use focal loss (Lin et al., 2017) adapted for BCE:
$$\mathcal{L}{\text{focal}} = -\frac{1}{L} \sum{l=1}^{L} \alpha_l \left[ y_l (1-p_l)^\gamma \log(p_l) + (1-y_l) p_l^\gamma \log(1-p_l) \right]$$
where $p_l = \sigma(z_l)$, $\gamma$ is the focusing parameter, and $\alpha_l$ is the class weight.
How $\gamma$ works: - When $\gamma = 0$: standard BCE - When $\gamma = 2$ (common): if $p_l = 0.9$ (confident correct), the $(1-0.9)^2 = 0.01$ scaling reduces the loss by 100×. Rare but important errors get amplified.
| $\gamma$ | Easy example ($p=0.95$) | Hard example ($p=0.3$) | Focus ratio |
|---|---|---|---|
| 0 | 0.051 | 1.204 | 23.4× |
| 1 | 0.003 | 0.843 | 337× |
| 2 | 0.0001 | 0.590 | 4,712× |
At $\gamma = 2$, the model focuses 4,712× more on hard examples relative to easy ones.
Per-label weights $\alpha_l$:
$$\alpha_l = \frac{\text{median frequency}}{\text{frequency of label } l}$$
| Label | Frequency | $\alpha_l$ |
|---|---|---|
| neutral | 45% | 0.40 |
| frustration | 18% | 1.0 (median) |
| satisfaction | 15% | 1.20 |
| urgency | 12% | 1.50 |
| confusion | 7% | 2.57 |
| churn_risk | 3% | 6.00 |
churn_risk gets 6× the gradient weight, compensating for its 15× lower frequency relative to neutral.
Gradual Unfreezing — ULMFiT Strategy
Howard & Ruder (2018) showed that unfreezing transformer layers gradually — from top to bottom — prevents catastrophic forgetting of pretrained features while allowing domain adaptation.
DistilBERT has 6 transformer layers:
| Layer | Captures | When to Unfreeze | Learning Rate |
|---|---|---|---|
| Embedding + Layer 0-1 | Token/positional, basic syntax | Epoch 4+ | $\eta_{\text{base}} / 2.6^5$ |
| Layer 2-3 | Phrase-level patterns, negation | Epoch 3 | $\eta_{\text{base}} / 2.6^3$ |
| Layer 4-5 | Sentence-level semantics | Epoch 2 | $\eta_{\text{base}} / 2.6$ |
| Classification head | Sentiment-specific features | Epoch 1 | $\eta_{\text{base}}$ |
Discriminative learning rates: Each layer group gets a progressively lower learning rate. With $\eta_{\text{base}} = 3 \times 10^{-5}$ and decay factor 2.6:
| Layer Group | Learning Rate | Relative |
|---|---|---|
| Classification head | $3.0 \times 10^{-5}$ | 1.0× |
| Layer 4-5 | $1.15 \times 10^{-5}$ | 0.38× |
| Layer 2-3 | $4.44 \times 10^{-6}$ | 0.15× |
| Layer 0-1 | $1.71 \times 10^{-6}$ | 0.057× |
| Embeddings | $6.57 \times 10^{-7}$ | 0.022× |
Why 2.6? In ULMFiT, Howard & Ruder found 2.6 empirically optimal across tasks. The intuition: lower layers capture universal features (syntax, morphology) that should change minimally. Higher layers capture task-specific features that must adapt significantly.
Threshold Optimization
Unlike multi-class (argmax), multi-label classification requires choosing a threshold per label. The default 0.5 is rarely optimal.
Per-label threshold search:
For each label $l$, sweep thresholds $t \in [0.1, 0.9]$ and optimize a metric:
$$t_l^* = \arg\max_t F_\beta(l, t)$$
where $F_\beta$ balances precision and recall:
$$F_\beta = (1 + \beta^2) \frac{\text{precision} \cdot \text{recall}}{\beta^2 \cdot \text{precision} + \text{recall}}$$
For frustration and churn_risk, we use $\beta = 2$ (recall-focused — better to false-positive than miss a frustrated customer). For satisfaction, $\beta = 0.5$ (precision-focused — avoid false attribution of happiness).
Optimized thresholds (from our validation set):
| Label | Default (0.5) F1 | Optimal $t^*$ | Optimized F1 | Change |
|---|---|---|---|---|
| frustration | 0.72 | 0.35 | 0.81 | +12.5% |
| satisfaction | 0.74 | 0.55 | 0.77 | +4.1% |
| urgency | 0.68 | 0.40 | 0.76 | +11.8% |
| confusion | 0.61 | 0.32 | 0.73 | +19.7% |
| churn_risk | 0.53 | 0.28 | 0.69 | +30.2% |
| neutral | 0.82 | 0.48 | 0.83 | +1.2% |
Lower thresholds for rare labels (churn_risk: 0.28) dramatically improve recall at the cost of some precision — exactly the tradeoff we want for escalation-critical signals.
Model Internals — Layer-by-Layer Diagrams
Gradual Unfreezing Schedule
graph TB
subgraph "Epoch 1: Only head trainable"
E1_EMB["Embeddings ❄️"]
E1_L01["Layers 0-1 ❄️"]
E1_L23["Layers 2-3 ❄️"]
E1_L45["Layers 4-5 ❄️"]
E1_HEAD["Classification Head 🔥<br>lr = 3e-5"]
E1_EMB --> E1_L01 --> E1_L23 --> E1_L45 --> E1_HEAD
end
subgraph "Epoch 2: Top layers unfrozen"
E2_EMB["Embeddings ❄️"]
E2_L01["Layers 0-1 ❄️"]
E2_L23["Layers 2-3 ❄️"]
E2_L45["Layers 4-5 🔥<br>lr = 1.15e-5"]
E2_HEAD["Classification Head 🔥<br>lr = 3e-5"]
E2_EMB --> E2_L01 --> E2_L23 --> E2_L45 --> E2_HEAD
end
subgraph "Epoch 3: Middle layers unfrozen"
E3_EMB["Embeddings ❄️"]
E3_L01["Layers 0-1 ❄️"]
E3_L23["Layers 2-3 🔥<br>lr = 4.4e-6"]
E3_L45["Layers 4-5 🔥<br>lr = 1.15e-5"]
E3_HEAD["Classification Head 🔥<br>lr = 3e-5"]
E3_EMB --> E3_L01 --> E3_L23 --> E3_L45 --> E3_HEAD
end
subgraph "Epoch 4+: All layers unfrozen"
E4_EMB["Embeddings 🔥<br>lr = 6.6e-7"]
E4_L01["Layers 0-1 🔥<br>lr = 1.7e-6"]
E4_L23["Layers 2-3 🔥<br>lr = 4.4e-6"]
E4_L45["Layers 4-5 🔥<br>lr = 1.15e-5"]
E4_HEAD["Classification Head 🔥<br>lr = 3e-5"]
E4_EMB --> E4_L01 --> E4_L23 --> E4_L45 --> E4_HEAD
end
style E1_HEAD fill:#ffcdd2
style E2_L45 fill:#fff9c4
style E2_HEAD fill:#ffcdd2
style E3_L23 fill:#c8e6c9
style E3_L45 fill:#fff9c4
style E3_HEAD fill:#ffcdd2
style E4_EMB fill:#e1bee7
style E4_L01 fill:#e1bee7
style E4_L23 fill:#c8e6c9
style E4_L45 fill:#fff9c4
style E4_HEAD fill:#ffcdd2
Gradient Magnitude Heatmap Across Layers
graph LR
subgraph "Gradient Magnitude During Training"
subgraph "Epoch 1 (head only)"
G1_E["Embed: 0.000"]
G1_L0["L0: 0.000"]
G1_L2["L2: 0.000"]
G1_L4["L4: 0.000"]
G1_H["Head: 0.042"]
end
subgraph "Epoch 2 (+ top layers)"
G2_E["Embed: 0.000"]
G2_L0["L0: 0.000"]
G2_L2["L2: 0.000"]
G2_L4["L4: 0.008"]
G2_H["Head: 0.031"]
end
subgraph "Epoch 4 (all layers)"
G4_E["Embed: 0.001"]
G4_L0["L0: 0.002"]
G4_L2["L2: 0.005"]
G4_L4["L4: 0.012"]
G4_H["Head: 0.018"]
end
end
style G1_H fill:#ff5252
style G2_L4 fill:#ffab40
style G2_H fill:#ff5252
style G4_E fill:#e8eaf6
style G4_L0 fill:#c5cae9
style G4_L2 fill:#fff9c4
style G4_L4 fill:#ffab40
style G4_H fill:#ff5252
Multi-Label Classification Architecture
graph TB
INPUT["Input: 'The manga arrived damaged<br>but the story is amazing'"]
subgraph "DistilBERT Encoder (66M params)"
TOK["Tokenizer → [CLS] The manga arrived damaged ..."]
EMB["Embedding Layer (23M)"]
L0["Transformer Layer 0"]
L1["Transformer Layer 1"]
L2["Transformer Layer 2"]
L3["Transformer Layer 3"]
L4["Transformer Layer 4"]
L5["Transformer Layer 5"]
CLS["[CLS] representation ∈ ℝ⁷⁶⁸"]
end
subgraph "Classification Head (multi-label)"
DROP["Dropout (0.2)"]
FC1["Linear: 768 → 256 + ReLU"]
FC2["Linear: 256 → 6 (one per label)"]
SIG["Sigmoid (independent per label)"]
end
subgraph "Output Probabilities"
O1["frustration: 0.89 ⬆️"]
O2["satisfaction: 0.74 ⬆️"]
O3["urgency: 0.12"]
O4["confusion: 0.08"]
O5["churn_risk: 0.31"]
O6["neutral: 0.05"]
end
subgraph "Threshold Gate"
T1["frustration: 0.89 > 0.35 ✅"]
T2["satisfaction: 0.74 > 0.55 ✅"]
T3["urgency: 0.12 < 0.40 ❌"]
T4["confusion: 0.08 < 0.32 ❌"]
T5["churn_risk: 0.31 > 0.28 ✅"]
T6["neutral: 0.05 < 0.48 ❌"]
end
INPUT --> TOK --> EMB --> L0 --> L1 --> L2 --> L3 --> L4 --> L5 --> CLS
CLS --> DROP --> FC1 --> FC2 --> SIG
SIG --> O1 & O2 & O3 & O4 & O5 & O6
O1 --> T1
O2 --> T2
O3 --> T3
O4 --> T4
O5 --> T5
O6 --> T6
style O1 fill:#ffcdd2
style O2 fill:#c8e6c9
style O5 fill:#fff9c4
style T1 fill:#ffcdd2
style T2 fill:#c8e6c9
style T5 fill:#fff9c4
Focal Loss Effect on Gradient Distribution
graph TB
subgraph "Standard BCE: All examples weighted equally"
BCE_EASY["Easy examples (p > 0.8)<br>70% of data<br>Total gradient share: 70%"]
BCE_HARD["Hard examples (p < 0.5)<br>15% of data<br>Total gradient share: 15%"]
BCE_RARE["Rare class hard examples<br>3% of data (churn_risk)<br>Total gradient share: 3%"]
end
subgraph "Focal Loss (γ=2): Hard examples amplified"
FL_EASY["Easy examples (p > 0.8)<br>70% of data<br>Gradient share: 12% ⬇️"]
FL_HARD["Hard examples (p < 0.5)<br>15% of data<br>Gradient share: 48% ⬆️"]
FL_RARE["Rare class hard examples<br>3% of data (churn_risk)<br>Gradient share: 40% ⬆️"]
end
BCE_EASY -->|"Focal<br>reweighting"| FL_EASY
BCE_HARD -->|"γ=2"| FL_HARD
BCE_RARE -->|"γ=2 + α=6.0"| FL_RARE
style BCE_RARE fill:#ffcdd2
style FL_RARE fill:#c8e6c9
style FL_HARD fill:#fff9c4
Threshold Optimization Visualization
graph TD
subgraph "Per-Label Threshold Search"
S["sweep t ∈ [0.1, 0.9]<br>step = 0.05"]
subgraph "frustration (β=2, recall-focused)"
F1["t=0.50 → P=0.82 R=0.64 F2=0.67"]
F2["t=0.35 → P=0.71 R=0.86 F2=0.82 ◀ OPTIMAL"]
F3["t=0.20 → P=0.58 R=0.93 F2=0.83"]
end
subgraph "churn_risk (β=2, recall-focused)"
C1["t=0.50 → P=0.78 R=0.42 F2=0.46"]
C2["t=0.28 → P=0.55 R=0.81 F2=0.74 ◀ OPTIMAL"]
C3["t=0.15 → P=0.32 R=0.91 F2=0.67"]
end
subgraph "satisfaction (β=0.5, precision-focused)"
SA1["t=0.40 → P=0.69 R=0.85 F0.5=0.71"]
SA2["t=0.55 → P=0.83 R=0.72 F0.5=0.81 ◀ OPTIMAL"]
SA3["t=0.70 → P=0.91 R=0.54 F0.5=0.82"]
end
end
S --> F1 & F2 & F3
S --> C1 & C2 & C3
S --> SA1 & SA2 & SA3
style F2 fill:#c8e6c9
style C2 fill:#c8e6c9
style SA2 fill:#c8e6c9
Implementation Deep-Dive
Multi-Label DistilBERT with Gradual Unfreezing
import torch
import torch.nn as nn
from transformers import (
DistilBertModel,
DistilBertTokenizer,
get_linear_schedule_with_warmup,
)
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.metrics import f1_score, precision_recall_fscore_support
LABELS = ["frustration", "satisfaction", "urgency", "confusion", "churn_risk", "neutral"]
NUM_LABELS = len(LABELS)
class MultiLabelSentimentModel(nn.Module):
"""Multi-label sentiment classifier with per-label sigmoid outputs."""
def __init__(self, model_name: str = "distilbert-base-uncased", hidden_dim: int = 256):
super().__init__()
self.encoder = DistilBertModel.from_pretrained(model_name)
self.head = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(768, hidden_dim),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, NUM_LABELS),
)
def forward(self, input_ids, attention_mask):
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
cls_emb = outputs.last_hidden_state[:, 0, :]
logits = self.head(cls_emb)
return logits # (batch, NUM_LABELS) — raw logits, apply sigmoid externally
class FocalBCELoss(nn.Module):
"""Focal Binary Cross-Entropy for multi-label with class imbalance."""
def __init__(self, gamma: float = 2.0, alpha: list[float] = None):
super().__init__()
self.gamma = gamma
if alpha is not None:
self.alpha = torch.tensor(alpha, dtype=torch.float32)
else:
self.alpha = None
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
probs = torch.sigmoid(logits)
# Standard BCE components
pos_loss = targets * torch.log(probs + 1e-8)
neg_loss = (1 - targets) * torch.log(1 - probs + 1e-8)
# Focal modulation
pos_focal = (1 - probs) ** self.gamma
neg_focal = probs ** self.gamma
loss = -(pos_focal * pos_loss + neg_focal * neg_loss)
# Apply per-label weights
if self.alpha is not None:
alpha = self.alpha.to(logits.device)
loss = loss * alpha.unsqueeze(0)
return loss.mean()
class SentimentDataset(Dataset):
def __init__(self, texts: list[str], labels: list[list[int]], tokenizer, max_length: int = 128):
self.encodings = tokenizer(
texts, padding=True, truncation=True,
max_length=max_length, return_tensors="pt",
)
self.labels = torch.tensor(labels, dtype=torch.float32)
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
return {
"input_ids": self.encodings["input_ids"][idx],
"attention_mask": self.encodings["attention_mask"][idx],
"labels": self.labels[idx],
}
Gradual Unfreezing Trainer
class GradualUnfreezingTrainer:
"""
Implements ULMFiT-style gradual unfreezing with discriminative LRs.
"""
def __init__(
self,
model: MultiLabelSentimentModel,
train_loader: DataLoader,
val_loader: DataLoader,
base_lr: float = 3e-5,
decay_factor: float = 2.6,
num_epochs: int = 8,
gamma: float = 2.0,
label_weights: list[float] = None,
):
self.model = model
self.train_loader = train_loader
self.val_loader = val_loader
self.base_lr = base_lr
self.decay_factor = decay_factor
self.num_epochs = num_epochs
self.criterion = FocalBCELoss(gamma=gamma, alpha=label_weights)
# Layer groups for unfreezing schedule
self.layer_groups = self._get_layer_groups()
def _get_layer_groups(self):
"""Group model parameters by layer depth."""
encoder = self.model.encoder
groups = {
"embeddings": list(encoder.embeddings.parameters()),
"layer_0_1": (
list(encoder.transformer.layer[0].parameters())
+ list(encoder.transformer.layer[1].parameters())
),
"layer_2_3": (
list(encoder.transformer.layer[2].parameters())
+ list(encoder.transformer.layer[3].parameters())
),
"layer_4_5": (
list(encoder.transformer.layer[4].parameters())
+ list(encoder.transformer.layer[5].parameters())
),
"head": list(self.model.head.parameters()),
}
return groups
def _freeze_all_except(self, active_groups: list[str]):
"""Freeze all parameters except those in active groups."""
for name, params in self.layer_groups.items():
requires_grad = name in active_groups
for p in params:
p.requires_grad = requires_grad
def _build_optimizer(self, active_groups: list[str]):
"""Build optimizer with discriminative learning rates."""
lr_multipliers = {
"head": 1.0,
"layer_4_5": 1.0 / self.decay_factor,
"layer_2_3": 1.0 / (self.decay_factor ** 2),
"layer_0_1": 1.0 / (self.decay_factor ** 3),
"embeddings": 1.0 / (self.decay_factor ** 4),
}
param_groups = []
for name in active_groups:
param_groups.append({
"params": self.layer_groups[name],
"lr": self.base_lr * lr_multipliers[name],
})
return torch.optim.AdamW(param_groups, weight_decay=0.01)
def _get_unfreeze_schedule(self) -> dict[int, list[str]]:
"""Define which layer groups are active at each epoch."""
return {
0: ["head"], # Epoch 1
1: ["head", "layer_4_5"], # Epoch 2
2: ["head", "layer_4_5", "layer_2_3"], # Epoch 3
3: ["head", "layer_4_5", "layer_2_3", "layer_0_1", "embeddings"], # Epoch 4+
}
def train(self):
schedule = self._get_unfreeze_schedule()
best_f1 = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(device)
for epoch in range(self.num_epochs):
# Determine active groups
active_key = min(epoch, max(schedule.keys()))
active_groups = schedule[active_key]
self._freeze_all_except(active_groups)
# Rebuild optimizer with current unfreezing
optimizer = self._build_optimizer(active_groups)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=len(self.train_loader) // 10,
num_training_steps=len(self.train_loader),
)
# Training loop
self.model.train()
total_loss = 0
for batch in self.train_loader:
batch = {k: v.to(device) for k, v in batch.items()}
logits = self.model(batch["input_ids"], batch["attention_mask"])
loss = self.criterion(logits, batch["labels"])
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
optimizer.step()
scheduler.step()
total_loss += loss.item()
# Validation
val_metrics = self._evaluate(device)
f1_macro = val_metrics["f1_macro"]
active_str = ", ".join(active_groups)
print(
f"Epoch {epoch+1}/{self.num_epochs} | "
f"Active: [{active_str}] | "
f"Loss: {total_loss/len(self.train_loader):.4f} | "
f"F1: {f1_macro:.4f}"
)
if f1_macro > best_f1:
best_f1 = f1_macro
torch.save(self.model.state_dict(), "best_sentiment_model.pt")
return best_f1
def _evaluate(self, device) -> dict:
"""Evaluate with optimized per-label thresholds."""
self.model.eval()
all_logits, all_labels = [], []
with torch.no_grad():
for batch in self.val_loader:
batch = {k: v.to(device) for k, v in batch.items()}
logits = self.model(batch["input_ids"], batch["attention_mask"])
all_logits.append(logits.cpu())
all_labels.append(batch["labels"].cpu())
all_logits = torch.cat(all_logits)
all_labels = torch.cat(all_labels)
probs = torch.sigmoid(all_logits)
# Apply default threshold
preds = (probs > 0.5).int().numpy()
labels = all_labels.int().numpy()
p, r, f1, _ = precision_recall_fscore_support(labels, preds, average="macro")
per_label = {}
for i, label in enumerate(LABELS):
lp, lr, lf, _ = precision_recall_fscore_support(
labels[:, i], preds[:, i], average="binary"
)
per_label[label] = {"precision": lp, "recall": lr, "f1": lf}
return {"f1_macro": f1, "precision": p, "recall": r, "per_label": per_label}
Threshold Optimization
from sklearn.metrics import fbeta_score
def optimize_thresholds(
probs: np.ndarray, # (N, L)
labels: np.ndarray, # (N, L)
betas: dict[str, float] = None,
) -> dict[str, float]:
"""
Find optimal per-label thresholds using F-beta search.
betas: {label_name: beta_value}. Higher beta = more recall-focused.
"""
if betas is None:
betas = {
"frustration": 2.0, # Recall-focused: catch frustrated users
"satisfaction": 0.5, # Precision-focused: avoid false positives
"urgency": 2.0, # Recall-focused
"confusion": 1.5, # Slightly recall-focused
"churn_risk": 2.0, # Recall-focused: never miss churn risk
"neutral": 1.0, # Balanced
}
optimal_thresholds = {}
for i, label_name in enumerate(LABELS):
beta = betas[label_name]
best_score = 0
best_threshold = 0.5
for threshold in np.arange(0.1, 0.91, 0.05):
preds = (probs[:, i] > threshold).astype(int)
score = fbeta_score(labels[:, i], preds, beta=beta)
if score > best_score:
best_score = score
best_threshold = threshold
optimal_thresholds[label_name] = {
"threshold": round(best_threshold, 2),
"fbeta_score": round(best_score, 4),
"beta": beta,
}
return optimal_thresholds
# After training
thresholds = optimize_thresholds(val_probs, val_labels)
# {
# "frustration": {"threshold": 0.35, "fbeta_score": 0.82, "beta": 2.0},
# "churn_risk": {"threshold": 0.28, "fbeta_score": 0.74, "beta": 2.0},
# ...
# }
SageMaker Training Pipeline
import sagemaker
from sagemaker.huggingface import HuggingFace
def launch_sentiment_training(
train_s3: str,
val_s3: str,
instance_type: str = "ml.g4dn.xlarge",
):
"""Launch multi-label sentiment fine-tuning on SageMaker."""
hyperparameters = {
"model_name": "distilbert-base-uncased",
"num_epochs": 8,
"batch_size": 32,
"base_lr": 3e-5,
"decay_factor": 2.6,
"gamma": 2.0,
"label_weights": "0.4,1.0,1.2,1.5,2.57,6.0",
"gradual_unfreezing": True,
}
estimator = HuggingFace(
entry_point="train_sentiment.py",
source_dir="./src/sentiment/",
instance_type=instance_type,
instance_count=1,
role=sagemaker.get_execution_role(),
transformers_version="4.37",
pytorch_version="2.1",
py_version="py310",
hyperparameters=hyperparameters,
)
estimator.fit({
"train": train_s3,
"validation": val_s3,
})
return estimator
Group Discussion: Key Decision Points
Decision Point 1: Multi-Label vs Multi-Class
Priya (ML Engineer): The fundamental question: should we use multi-class (one label per message) or multi-label (any combination)?
I tested both on our 5,000-message labeled dataset:
| Approach | Exact Match | Hamming Loss | Frustration Recall |
|---|---|---|---|
| Multi-class (6 exclusive) | 68.4% | — | 74.2% |
| Multi-class (21 combinations) | 52.1% | — | 71.8% |
| Multi-label BCE | 76.2% | 0.089 | 86.1% |
| Multi-label Focal BCE | 74.8% | 0.082 | 89.7% |
Aiko (Data Scientist): 23% of our messages have 2+ labels. Multi-class with combinations creates a sparse classification problem (21 classes, some with <50 examples). Multi-label with independent per-label predictions is both simpler and more accurate.
Marcus (Architect): Multi-label also allows independent threshold tuning per label. We want aggressive frustration detection (low threshold, high recall) but conservative satisfaction attribution (high threshold, high precision). Multi-class makes this impossible.
Sam (PM): One concern: the escalation trigger uses frustration + churn_risk. With multi-label, do we OR them (either triggers escalation) or AND them (both needed)?
Jordan (MLOps): We use a priority-weighted OR. If frustration > 0.35 OR churn_risk > 0.28, escalate. But the urgency of escalation increases if both are active (different SLA for the human agent response).
Resolution: Multi-label with focal BCE. Independent thresholds per label, priority-weighted OR for escalation triggers. This achieves 89.7% frustration recall vs 74.2% with multi-class.
Decision Point 2: Gradual Unfreezing vs Full Fine-Tune
Priya (ML Engineer): I compared unfreezing strategies:
| Strategy | Val F1 (macro) | Frustration F1 | Old-task retention | Training time |
|---|---|---|---|---|
| Freeze all, head only | 0.71 | 0.68 | 100% | 12 min |
| Full fine-tune, uniform LR | 0.79 | 0.81 | 89% | 25 min |
| Full fine-tune, discriminative LR | 0.82 | 0.84 | 94% | 25 min |
| Gradual unfreezing + disc. LR | 0.84 | 0.87 | 97% | 35 min |
Aiko (Data Scientist): Gradual unfreezing wins on all metrics. The 97% old-task retention means we preserve the encoder's general language understanding while adding manga-specific sentiment features.
The mathematical intuition: lower layers capture universal syntactic features (negation, intensifiers like "never" or "absolutely"). These features are already useful for sentiment — they just need slight domain adaptation, hence the very low learning rate. Upper layers need significant restructuring to map manga-specific frustration signals (e.g., "volume 14 still not available") to sentiment labels.
Marcus (Architect): The 10 extra minutes of training is worth 97% vs 89% old-task retention. If we lose general sentiment understanding, the model fails on non-manga queries from users who also buy other products.
Resolution: Gradual unfreezing with discriminative LRs. The 35 min training time is acceptable for weekly retrains. For emergency updates (e.g., major product issue), we fall back to head-only training (12 min, 0.71 F1) and schedule a full retrain.
Decision Point 3: Focal Loss Hyperparameters
Aiko (Data Scientist): $\gamma$ and $\alpha$ interact in non-obvious ways:
| $\gamma$ | $\alpha$ weighting | Frustration F1 | churn_risk F1 | Macro F1 |
|---|---|---|---|---|
| 0 (standard BCE) | None | 0.72 | 0.53 | 0.71 |
| 0 | Frequency-inverse | 0.76 | 0.62 | 0.75 |
| 2 | None | 0.79 | 0.61 | 0.78 |
| 2 | Frequency-inverse | 0.87 | 0.69 | 0.84 |
| 3 | Frequency-inverse | 0.85 | 0.72 | 0.83 |
| 5 | Frequency-inverse | 0.78 | 0.73 | 0.79 |
Priya (ML Engineer): $\gamma=2$ with $\alpha$ weighting gives the best macro F1. Increasing $\gamma$ to 5 helps churn_risk (+4%) but hurts frustration (-9%) — the model over-focuses on the hardest rare examples and underfits common patterns.
Sam (PM): churn_risk F1 of 0.69 concerns me. Can we use a separate model for churn detection?
Aiko (Data Scientist): The 0.69 F1 is actually good for a 3% base rate label. With the optimized threshold (0.28), recall is 81% — we catch 4 out of 5 churn-risk messages. Perfect recall is impossible because some churn signals require conversation history, not just the current message.
Resolution: $\gamma=2$, $\alpha$ = frequency-inverse weights. For churn_risk specifically, we augment the model signal with a rule-based detector (keywords: "cancel", "never again", "done with") to achieve combined recall of 91%.
Decision Point 4: Threshold Strategy
Jordan (MLOps): Per-label thresholds add complexity. Are they truly worth it?
Priya (ML Engineer): Absolutely. The threshold optimization table shows churn_risk improves from 0.53 to 0.69 F2 by lowering the threshold from 0.50 to 0.28. This is a 30% jump from a simple post-processing step that costs nothing in inference latency.
Marcus (Architect): Thresholds should be stored in DynamoDB, not hardcoded. This way our escalation team can adjust sensitivity without model retraining. Lowering the frustration threshold from 0.35 to 0.30 during a known outage (lots of frustrated users) makes operational sense.
Sam (PM): Agreed. Threshold tuning should be an operational control, not an ML control. The ML team delivers the model; the operations team tunes the thresholds based on false-positive rates they observe.
Resolution: Per-label thresholds stored in DynamoDB with a 10-second ElastiCache TTL. Operations team can adjust thresholds via an admin dashboard. ML team provides recommended thresholds; operations team overrides based on business context.
Research Paper References
1. Universal Language Model Fine-tuning for Text Classification — ULMFiT (Howard & Ruder, 2018)
Key contribution: Introduced three techniques that now underpin all transformer fine-tuning: (1) discriminative fine-tuning (different LRs per layer), (2) slanted triangular learning rate schedule, (3) gradual unfreezing. Showed that these techniques together prevent catastrophic forgetting during fine-tuning and achieve state-of-the-art on 6 text classification benchmarks.
Relevance to MangaAssist: Our gradual unfreezing schedule directly implements ULMFiT's approach adapted for DistilBERT's 6-layer architecture. The discriminative LR with decay factor 2.6 is from ULMFiT's optimized hyperparameters. Gradual unfreezing improved our old-task retention from 89% to 97%.
2. Focal Loss for Dense Object Detection (Lin et al., 2017)
Key contribution: Originally designed for object detection (addressing the extreme foreground-background imbalance), focal loss down-weights easy examples and focuses training on hard negatives. The $(1-p_t)^\gamma$ modulation is elegant: it is a smooth, differentiable, and parameter-efficient way to reshape the loss landscape.
Relevance to MangaAssist: Adapted for multi-label BCE to handle our 15:1 imbalance between neutral (45%) and churn_risk (3%). Focal loss with $\gamma=2$ improved churn_risk F1 from 0.53 to 0.69, which directly impacts our ability to detect and escalate at-risk customers.
3. Multi-Task Learning Using Uncertainty to Weigh Losses (Kendall et al., 2018)
Key contribution: Used homoscedastic uncertainty to learn relative task weights automatically. The key insight: the optimal weight for a task's loss is inversely proportional to the task's observation noise. Noisier tasks get lower weights, preventing them from dominating the gradient.
Relevance to MangaAssist: Although we use fixed per-label weights in our current implementation, Kendall's uncertainty weighting is the natural extension. If we observe that certain labels have inherently noisy annotations (human annotators disagree on "confusion" vs "neutral" 30% of the time), learned uncertainty weights would down-weight the noisy labels automatically. This is planned for V2.
4. Asymmetric Loss For Multi-Label Classification (Ridnik et al., 2021)
Key contribution: Extended focal loss with asymmetric focusing parameters: $\gamma^+$ for positive samples (typically 0) and $\gamma^-$ for negative samples (typically 4). This addresses the specific challenge of multi-label: most labels are negative for most examples, creating massive negative:positive imbalance per label.
Relevance to MangaAssist: our frustration label is positive in only 18% of messages. Asymmetric loss with $\gamma^-=4, \gamma^+=1$ further improved frustration recall by 2.3% in our ablation. Planned for production after the next retrain cycle.
Production Evaluation and Deployment
Ablation Study Results
| Configuration | Macro F1 | Frustration Recall | churn_risk F1 | Latency |
|---|---|---|---|---|
| SST-2 pretrained (baseline) | 0.54 | 62.4% | 0.18 | 8ms |
| + Head-only training | 0.71 | 74.2% | 0.48 | 8ms |
| + Full fine-tune, uniform LR | 0.79 | 81.3% | 0.58 | 8ms |
| + Discriminative LR | 0.82 | 84.1% | 0.62 | 8ms |
| + Gradual unfreezing | 0.84 | 87.3% | 0.64 | 8ms |
| + Focal loss (γ=2) | 0.84 | 89.7% | 0.69 | 8ms |
| + Threshold optimization | 0.86 | 91.2% | 0.74 | 8ms |
| + Rule boost for churn_risk | — | 91.2% | 0.79 | 9ms |
Each technique contributes meaningfully. The full stack achieves 91.2% frustration recall (from 62.4% baseline) with no latency increase.
Escalation Impact
| Metric | Before (SST-2) | After (Manga-Tuned) | Change |
|---|---|---|---|
| Frustrated users escalated | 62.4% | 91.2% | +28.8% |
| False escalations per day | 12 | 28 | +16 |
| Average resolution time | 8.2 min | 5.1 min | -37.8% |
| Customer satisfaction (escalated) | 3.⅖ | 4.⅕ | +28.1% |
| Monthly churn rate | 4.8% | 3.1% | -35.4% |
The 16 extra false escalations per day cost ~$32 in human agent time. The 1.7% churn reduction saves ~$12K/month in customer lifetime value. ROI: 375:1.