13. Multi-Task Learning — Single Model for Intent, Sentiment, and Entities
Problem Statement and MangaAssist Context
MangaAssist currently runs three separate DistilBERT models: intent classification (10 classes), sentiment analysis (multi-label), and a potential entity extraction module. Each model consumes 66-265MB of memory and adds its own latency to the pipeline. Multi-task learning (MTL) trains a single shared encoder with multiple task-specific heads, potentially reducing total memory by 60-70% and enabling shared representations that improve all tasks simultaneously. The challenge: tasks can interfere with each other during training, degrading performance on some tasks while improving others.
Current vs MTL Architecture
| Metric | 3 Separate Models | Single MTL Model |
|---|---|---|
| Parameters | 3 × 66M = 198M | 66M + 3 heads ≈ 68M |
| Memory (INT8) | 3 × 66MB = 198MB | 70MB |
| Total inference time | 15ms + 12ms + 10ms = 37ms | 18ms (shared encoder + 3 heads) |
| Encoder forward passes | 3 | 1 |
Mathematical Foundations
Multi-Task Loss Formulation
Given $T$ tasks with individual losses $\mathcal{L}_1, \mathcal{L}_2, \ldots, \mathcal{L}_T$, the simplest MTL objective:
$$\mathcal{L}{\text{MTL}} = \sum{t=1}^{T} w_t \mathcal{L}_t$$
The weights $w_t$ determine task priority. Naive equal weighting ($w_t = 1/T$) rarely works because: 1. Loss magnitudes differ: cross-entropy for classification might be ~1.5 while token-level NER loss might be ~0.3 2. Learning rates differ: some tasks converge faster 3. Gradient magnitudes differ: dominant tasks suppress learning on other tasks
Uncertainty Weighting (Kendall et al., 2018)
Learns task weights automatically using homoscedastic uncertainty:
For each task $t$, introduce a learnable log-variance parameter $\sigma_t$. The weighted loss:
$$\mathcal{L}{\text{MTL}} = \sum{t=1}^{T} \left[\frac{1}{2\sigma_t^2} \mathcal{L}_t + \log \sigma_t\right]$$
Derivation: Start from the Gaussian likelihood for task $t$:
$$p(y_t | f_t(x)) = \frac{1}{\sqrt{2\pi}\sigma_t} \exp\left(-\frac{(y_t - f_t(x))^2}{2\sigma_t^2}\right)$$
Taking the negative log-likelihood:
$$-\log p = \frac{(y_t - f_t(x))^2}{2\sigma_t^2} + \log \sigma_t + \text{const}$$
The $\frac{1}{2\sigma_t^2}$ term acts as the weight: when a task has high uncertainty ($\sigma_t$ large), its loss contribution is downweighted. The $\log \sigma_t$ regularizer prevents all weights from going to infinity (which would trivially minimize the loss).
Gradient with respect to $\sigma_t$:
$$\frac{\partial \mathcal{L}}{\partial \sigma_t} = -\frac{\mathcal{L}_t}{\sigma_t^3} + \frac{1}{\sigma_t}$$
At equilibrium: $\sigma_t^2 = \mathcal{L}_t$. Tasks with higher loss (harder tasks) automatically get lower weight — the model doesn't waste capacity fighting noise in hard tasks.
In practice, we parameterize with $s_t = \log \sigma_t^2$:
$$\mathcal{L}{\text{MTL}} = \sum{t=1}^{T} \left[\frac{1}{2} e^{-s_t} \mathcal{L}_t + \frac{1}{2} s_t\right]$$
This is numerically stable ($e^{-s_t}$ avoids division by zero).
GradNorm (Chen et al., 2018)
Instead of weighting losses, GradNorm normalizes gradient magnitudes across tasks. The idea: if one task's gradients are 10× larger than another's, the first task dominates training regardless of loss weights.
Define the gradient norm for task $t$ at a shared layer:
$$G_t = | \nabla_{W_{\text{shared}}} w_t \mathcal{L}_t |_2$$
GradNorm wants all tasks to have similar gradient norms, scaled by their relative training speed:
$$\bar{G} = \mathbb{E}_t[G_t], \quad r_t = \frac{\tilde{\mathcal{L}}_t}{\mathbb{E}_t[\tilde{\mathcal{L}}_t]}$$
where $\tilde{\mathcal{L}}_t = \mathcal{L}_t(t) / \mathcal{L}_t(0)$ is the ratio of current loss to initial loss (training rate).
The target gradient norm for task $t$:
$$G_t^{\text{target}} = \bar{G} \cdot r_t^{\alpha}$$
where $\alpha$ controls how aggressively we balance tasks ($\alpha = 1.5$ is typical). Tasks that are training slowly ($r_t$ large) get larger target gradient norms.
Update rule for task weights:
$$\mathcal{L}_{\text{grad}} = \sum_t |G_t - G_t^{\text{target}}|_1$$
$$w_t \leftarrow w_t - \eta_w \frac{\partial \mathcal{L}_{\text{grad}}}{\partial w_t}$$
Gradient Surgery — PCGrad (Yu et al., 2020)
When two tasks have conflicting gradients (their gradient vectors point in opposite directions), standard MTL forces a compromise that hurts both. PCGrad (Projecting Conflicting Gradients) detects conflicts and removes the conflicting component:
For tasks $i$ and $j$ with gradients $g_i$ and $g_j$:
$$\text{If } g_i \cdot g_j < 0 \text{ (conflict):} \quad g_i' = g_i - \frac{g_i \cdot g_j}{|g_j|^2} g_j$$
This projects $g_i$ onto the plane perpendicular to $g_j$, removing the conflicting component while preserving the non-conflicting direction.
Geometric interpretation: Each task gradient lives in a high-dimensional parameter space. When two tasks conflict, their gradients form an obtuse angle. PCGrad projects each onto the other's normal plane, finding a direction that helps (or at least doesn't hurt) both tasks.
For our 3 tasks (intent, sentiment, NER):
$$\text{Conflict rate} \approx 15-25\%$$
This means 15-25% of mini-batches have at least one pair of conflicting gradients. Without PCGrad, these conflicts cause oscillations in training and degrade the weaker task.
Task Affinity and Negative Transfer
Not all task combinations are beneficial. Task affinity measures whether co-training helps:
$$A_{i \to j} = \frac{\text{Performance of } j \text{ when co-trained with } i}{\text{Performance of } j \text{ alone}} - 1$$
If $A_{i \to j} > 0$: task $i$ helps task $j$ (positive transfer). If $A_{i \to j} < 0$: task $i$ hurts task $j$ (negative transfer).
For MangaAssist:
| Task Pair | Affinity | Direction |
|---|---|---|
| Intent → Sentiment | +2.1% | Positive: intent features help sentiment |
| Sentiment → Intent | +1.4% | Positive: emotional context aids intent |
| NER → Intent | +0.8% | Weak positive |
| Intent → NER | -0.3% | Weak negative: intent task dominates |
Model Internals — Layer-by-Layer Diagrams
Multi-Task Architecture
graph TB
subgraph "Shared Encoder (DistilBERT, 66M params)"
INPUT["Input: 'I want to return this damaged manga volume'"]
TOK["Tokenizer → [CLS] I want to return this damaged manga volume [SEP]"]
EMB["Embedding Layer (shared)"]
L1["Transformer Layer 1 (shared)"]
L2["Transformer Layer 2 (shared)"]
L3["Transformer Layer 3 (shared)"]
L4["Transformer Layer 4 (shared)"]
L5["Transformer Layer 5 (shared)"]
L6["Transformer Layer 6 (shared)"]
INPUT --> TOK --> EMB --> L1 --> L2 --> L3 --> L4 --> L5 --> L6
end
subgraph "Task-Specific Heads (2M params total)"
CLS["[CLS] token → 768-dim"]
SEQ["All tokens → 768-dim each"]
L6 --> CLS & SEQ
H_INT["Intent Head<br>Linear(768→256→10)<br>Softmax<br>Prediction: return_refund"]
H_SENT["Sentiment Head<br>Linear(768→256→6)<br>Sigmoid per label<br>Pred: frustration=0.82"]
H_NER["NER Head<br>Linear(768→256→9)<br>Per-token softmax<br>Pred: manga=PRODUCT"]
CLS --> H_INT & H_SENT
SEQ --> H_NER
end
style L1 fill:#e3f2fd
style L2 fill:#e3f2fd
style L3 fill:#e3f2fd
style L4 fill:#e3f2fd
style L5 fill:#e3f2fd
style L6 fill:#e3f2fd
style H_INT fill:#c8e6c9
style H_SENT fill:#fff9c4
style H_NER fill:#ffccbc
Gradient Flow and Conflict Detection
graph TB
subgraph "Gradient Conflict in Shared Layers"
LOSS_I["Intent Loss = 0.42<br>∇W₃(intent)"]
LOSS_S["Sentiment Loss = 0.68<br>∇W₃(sentiment)"]
LOSS_N["NER Loss = 0.31<br>∇W₃(NER)"]
CHECK["Conflict Detection:<br>g_intent · g_sentiment = -0.15 ❌ CONFLICT<br>g_intent · g_ner = +0.34 ✅ aligned<br>g_sentiment · g_ner = +0.08 ✅ aligned"]
LOSS_I --> CHECK
LOSS_S --> CHECK
LOSS_N --> CHECK
PCGRAD["PCGrad Resolution:<br>g'_intent = g_intent - proj(g_intent, g_sentiment)<br>Remove conflicting component<br><br>g'_sentiment = g_sentiment - proj(g_sentiment, g_intent)<br>Remove conflicting component<br><br>g_ner unchanged (no conflicts)"]
CHECK --> PCGRAD
FINAL["Final gradient for Layer 3:<br>g_shared = g'_intent + g'_sentiment + g_ner<br>No task dominance, no destructive interference"]
PCGRAD --> FINAL
end
style CHECK fill:#ffcdd2
style PCGRAD fill:#c8e6c9
style FINAL fill:#c8e6c9
Uncertainty Weighting Dynamics
graph LR
subgraph "Epoch 1: Equal weights"
E1_I["Intent<br>w=0.33<br>Loss=2.30<br>σ²=2.30"]
E1_S["Sentiment<br>w=0.33<br>Loss=0.85<br>σ²=0.85"]
E1_N["NER<br>w=0.33<br>Loss=1.45<br>σ²=1.45"]
end
subgraph "Epoch 5: Weights adapted"
E5_I["Intent<br>w=0.21 ⬇<br>Loss=0.95<br>Easy → downweight"]
E5_S["Sentiment<br>w=0.38 ⬆<br>Loss=0.62<br>Hard → upweight"]
E5_N["NER<br>w=0.41 ⬆<br>Loss=0.78<br>Still learning"]
end
subgraph "Epoch 15: Converged"
E15_I["Intent<br>w=0.25<br>Loss=0.42<br>Converged"]
E15_S["Sentiment<br>w=0.35<br>Loss=0.38<br>Near-converged"]
E15_N["NER<br>w=0.40<br>Loss=0.31<br>Converged"]
end
E1_I -->|"σ adapts"| E5_I -->|"stabilizes"| E15_I
E1_S -->|"σ adapts"| E5_S -->|"stabilizes"| E15_S
E1_N -->|"σ adapts"| E5_N -->|"stabilizes"| E15_N
GradNorm Balancing
graph TB
subgraph "GradNorm: Normalize gradient magnitudes"
G1["Without GradNorm:<br>‖∇intent‖ = 2.5<br>‖∇sentiment‖ = 0.3<br>‖∇NER‖ = 1.1<br><br>Intent dominates 8.3×<br>over sentiment!"]
GN["GradNorm target (α=1.5):<br>r_intent = L(t)/L(0) = 0.18 (fast)<br>r_sentiment = 0.73 (slow)<br>r_NER = 0.54<br><br>Target ‖∇intent‖ = 0.49<br>Target ‖∇sentiment‖ = 1.80<br>Target ‖∇NER‖ = 1.21"]
G2["With GradNorm:<br>w_intent adjusted down → ‖g‖ ≈ 0.5<br>w_sentiment adjusted up → ‖g‖ ≈ 1.8<br>w_NER adjusted slightly → ‖g‖ ≈ 1.2<br><br>Balanced! Sentiment gets<br>more gradient signal"]
G1 --> GN --> G2
end
style G1 fill:#ffcdd2
style G2 fill:#c8e6c9
Task Head Architecture Detail
graph TB
subgraph "Intent Head (Classification)"
CLS1["[CLS] embedding: 768-dim"]
INT_D["Dropout(0.1)"]
INT_L1["Linear(768→256) + GELU"]
INT_LN["LayerNorm(256)"]
INT_L2["Linear(256→10)"]
INT_SF["Softmax → 10 intent probabilities"]
CLS1 --> INT_D --> INT_L1 --> INT_LN --> INT_L2 --> INT_SF
end
subgraph "Sentiment Head (Multi-Label)"
CLS2["[CLS] embedding: 768-dim"]
SENT_D["Dropout(0.1)"]
SENT_L1["Linear(768→256) + GELU"]
SENT_LN["LayerNorm(256)"]
SENT_L2["Linear(256→6)"]
SENT_SIG["Sigmoid per label → 6 independent probabilities"]
CLS2 --> SENT_D --> SENT_L1 --> SENT_LN --> SENT_L2 --> SENT_SIG
end
subgraph "NER Head (Token Classification)"
TOK["All token embeddings: [n×768]"]
NER_D["Dropout(0.1)"]
NER_L1["Linear(768→256) + GELU"]
NER_LN["LayerNorm(256)"]
NER_L2["Linear(256→9)"]
NER_SF["Per-token Softmax → BIO tags"]
TOK --> NER_D --> NER_L1 --> NER_LN --> NER_L2 --> NER_SF
end
style INT_SF fill:#c8e6c9
style SENT_SIG fill:#fff9c4
style NER_SF fill:#ffccbc
Implementation Deep-Dive
Multi-Task Model
import torch
import torch.nn as nn
from transformers import DistilBertModel, DistilBertTokenizer
class MultiTaskHead(nn.Module):
"""A single task head with configurable output."""
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, dropout: float = 0.1):
super().__init__()
self.head = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, output_dim),
)
def forward(self, x):
return self.head(x)
class MultiTaskDistilBERT(nn.Module):
"""
Single DistilBERT encoder with three task-specific heads.
Replaces 3 separate models → 65% memory reduction.
"""
def __init__(
self,
num_intents: int = 10,
num_sentiments: int = 6,
num_ner_tags: int = 9,
model_name: str = "distilbert-base-uncased",
):
super().__init__()
self.encoder = DistilBertModel.from_pretrained(model_name)
hidden = self.encoder.config.hidden_size # 768
self.intent_head = MultiTaskHead(hidden, 256, num_intents)
self.sentiment_head = MultiTaskHead(hidden, 256, num_sentiments)
self.ner_head = MultiTaskHead(hidden, 256, num_ner_tags)
def forward(self, input_ids, attention_mask, task: str = "all"):
encoder_output = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
hidden_states = encoder_output.last_hidden_state # [batch, seq_len, 768]
cls_output = hidden_states[:, 0, :] # [batch, 768]
results = {}
if task in ("all", "intent"):
results["intent_logits"] = self.intent_head(cls_output)
if task in ("all", "sentiment"):
results["sentiment_logits"] = self.sentiment_head(cls_output)
if task in ("all", "ner"):
results["ner_logits"] = self.ner_head(hidden_states)
return results
Uncertainty-Weighted Loss
class UncertaintyWeightedLoss(nn.Module):
"""
Kendall et al. 2018: automatically learn task weights
from homoscedastic uncertainty.
"""
def __init__(self, num_tasks: int = 3):
super().__init__()
# Log-variance parameters (learnable)
self.log_vars = nn.Parameter(torch.zeros(num_tasks))
def forward(self, losses: list[torch.Tensor]) -> tuple[torch.Tensor, dict]:
total_loss = 0
weights = {}
for i, loss in enumerate(losses):
# L_total = (1/2σ²) * L_t + log(σ)
# Using log_var = log(σ²):
precision = torch.exp(-self.log_vars[i]) # 1/σ²
total_loss += 0.5 * precision * loss + 0.5 * self.log_vars[i]
weights[f"task_{i}_weight"] = precision.item()
weights[f"task_{i}_sigma"] = torch.exp(0.5 * self.log_vars[i]).item()
return total_loss, weights
PCGrad Implementation
import copy
class PCGrad:
"""
Yu et al. 2020: Project Conflicting Gradients.
Removes conflicting gradient components between tasks.
"""
def __init__(self, optimizer):
self.optimizer = optimizer
def step(self, losses: list[torch.Tensor], shared_params: list[nn.Parameter]):
"""
Compute per-task gradients, resolve conflicts, then update.
"""
task_grads = []
for loss in losses:
self.optimizer.zero_grad()
loss.backward(retain_graph=True)
grads = [p.grad.clone() if p.grad is not None else torch.zeros_like(p)
for p in shared_params]
task_grads.append(grads)
# Resolve conflicts via projection
num_tasks = len(losses)
projected_grads = [list(g) for g in task_grads] # Deep copy
for i in range(num_tasks):
for j in range(num_tasks):
if i == j:
continue
for k in range(len(shared_params)):
g_i = projected_grads[i][k]
g_j = task_grads[j][k]
dot = (g_i * g_j).sum()
if dot < 0:
# Conflict detected: project g_i onto plane perpendicular to g_j
projected_grads[i][k] = g_i - (dot / (g_j.norm() ** 2 + 1e-8)) * g_j
# Sum projected gradients and apply
self.optimizer.zero_grad()
for k, param in enumerate(shared_params):
param.grad = sum(projected_grads[i][k] for i in range(num_tasks))
self.optimizer.step()
class GradNormLoss(nn.Module):
"""
Chen et al. 2018: Dynamically balance gradient norms.
"""
def __init__(self, num_tasks: int = 3, alpha: float = 1.5):
super().__init__()
self.num_tasks = num_tasks
self.alpha = alpha
self.task_weights = nn.Parameter(torch.ones(num_tasks))
self.initial_losses = None
def forward(
self,
losses: list[torch.Tensor],
shared_layer: nn.Module,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Returns:
weighted_loss: The task-weighted MTL loss
gradnorm_loss: The GradNorm regularization loss
"""
# Record initial losses for training rate computation
if self.initial_losses is None:
self.initial_losses = [l.item() for l in losses]
# Weighted task loss
weighted_loss = sum(self.task_weights[i] * losses[i] for i in range(self.num_tasks))
# Compute gradient norms for shared layer
grad_norms = []
for i in range(self.num_tasks):
gw = torch.autograd.grad(
self.task_weights[i] * losses[i],
shared_layer.parameters(),
retain_graph=True,
create_graph=True,
)
grad_norms.append(torch.norm(torch.cat([g.flatten() for g in gw])))
# Average gradient norm
avg_norm = sum(grad_norms) / self.num_tasks
# Training rate ratios
loss_ratios = [losses[i].item() / (self.initial_losses[i] + 1e-8)
for i in range(self.num_tasks)]
avg_ratio = sum(loss_ratios) / self.num_tasks
relative_rates = [r / (avg_ratio + 1e-8) for r in loss_ratios]
# Target gradient norms
target_norms = [avg_norm * (r ** self.alpha) for r in relative_rates]
# GradNorm loss
gradnorm_loss = sum(
torch.abs(grad_norms[i] - target_norms[i].detach())
for i in range(self.num_tasks)
)
return weighted_loss, gradnorm_loss
Multi-Task Trainer
class MultiTaskTrainer:
"""End-to-end MTL trainer with configurable loss balancing."""
def __init__(
self,
model: MultiTaskDistilBERT,
strategy: str = "uncertainty", # "uncertainty", "pcgrad", "gradnorm"
):
self.model = model
self.strategy = strategy
if strategy == "uncertainty":
self.loss_module = UncertaintyWeightedLoss(num_tasks=3)
elif strategy == "gradnorm":
self.loss_module = GradNormLoss(num_tasks=3, alpha=1.5)
all_params = list(model.parameters())
if strategy == "uncertainty":
all_params += list(self.loss_module.parameters())
elif strategy == "gradnorm":
all_params += list(self.loss_module.parameters())
self.optimizer = torch.optim.AdamW(all_params, lr=2e-5, weight_decay=0.01)
if strategy == "pcgrad":
self.pcgrad = PCGrad(self.optimizer)
self.intent_loss_fn = nn.CrossEntropyLoss()
self.sentiment_loss_fn = nn.BCEWithLogitsLoss()
self.ner_loss_fn = nn.CrossEntropyLoss()
def train_step(self, batch: dict) -> dict:
self.model.train()
outputs = self.model(
batch["input_ids"], batch["attention_mask"], task="all",
)
# Compute per-task losses
intent_loss = self.intent_loss_fn(outputs["intent_logits"], batch["intent_labels"])
sentiment_loss = self.sentiment_loss_fn(
outputs["sentiment_logits"], batch["sentiment_labels"].float(),
)
ner_loss = self.ner_loss_fn(
outputs["ner_logits"].view(-1, 9), batch["ner_labels"].view(-1),
)
losses = [intent_loss, sentiment_loss, ner_loss]
if self.strategy == "uncertainty":
total_loss, weight_info = self.loss_module(losses)
self.optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
return {"loss": total_loss.item(), **weight_info}
elif self.strategy == "pcgrad":
shared_params = list(self.model.encoder.parameters())
self.pcgrad.step(losses, shared_params)
# Update task heads normally
head_loss = sum(losses)
self.optimizer.zero_grad()
head_loss.backward()
self.optimizer.step()
return {"loss": sum(l.item() for l in losses)}
elif self.strategy == "gradnorm":
shared_layer = self.model.encoder.transformer.layer[-1]
weighted_loss, gn_loss = self.loss_module(losses, shared_layer)
self.optimizer.zero_grad()
(weighted_loss + gn_loss).backward()
self.optimizer.step()
# Renormalize task weights
with torch.no_grad():
self.loss_module.task_weights.data = (
self.loss_module.task_weights / self.loss_module.task_weights.sum() * 3
)
return {"loss": weighted_loss.item(), "gn_loss": gn_loss.item()}
def evaluate(self, val_loader) -> dict:
self.model.eval()
intent_correct = intent_total = 0
sentiment_f1_samples = []
ner_f1_samples = []
with torch.no_grad():
for batch in val_loader:
outputs = self.model(
batch["input_ids"], batch["attention_mask"], task="all",
)
# Intent accuracy
preds = outputs["intent_logits"].argmax(dim=-1)
intent_correct += (preds == batch["intent_labels"]).sum().item()
intent_total += batch["intent_labels"].size(0)
# Sentiment F1 (per-label threshold)
sent_preds = (torch.sigmoid(outputs["sentiment_logits"]) > 0.5).float()
# Simplified per-sample F1
for i in range(sent_preds.size(0)):
tp = (sent_preds[i] * batch["sentiment_labels"][i]).sum()
fp = (sent_preds[i] * (1 - batch["sentiment_labels"][i])).sum()
fn = ((1 - sent_preds[i]) * batch["sentiment_labels"][i]).sum()
prec = tp / (tp + fp + 1e-8)
rec = tp / (tp + fn + 1e-8)
f1 = 2 * prec * rec / (prec + rec + 1e-8)
sentiment_f1_samples.append(f1.item())
return {
"intent_accuracy": intent_correct / intent_total,
"sentiment_f1": sum(sentiment_f1_samples) / len(sentiment_f1_samples),
}
Group Discussion: Key Decision Points
Decision Point 1: Loss Balancing Strategy
Priya (ML Engineer): Compared all three strategies on MangaAssist data:
| Strategy | Intent Acc | Sentiment F1 | NER F1 | Training Stability |
|---|---|---|---|---|
| Equal weights (baseline) | 90.1% | 82.3% | 76.5% | Oscillatory |
| Uncertainty weighting | 91.8% | 85.1% | 79.2% | Smooth |
| GradNorm (α=1.5) | 91.5% | 84.8% | 80.1% | Smooth |
| PCGrad | 92.0% | 85.4% | 78.8% | Slightly noisy |
Separate models (reference): Intent 92.1%, Sentiment 84.2%, NER 78.0%
Aiko (Data Scientist): PCGrad gives the best intent and sentiment scores but NER drops compared to GradNorm. Uncertainty weighting is the best all-around: consistent improvement across all tasks with the simplest implementation.
Marcus (Architect): PCGrad adds computational overhead: we need to compute per-task gradients separately (3 backward passes) and then project. That triples the backward pass time.
Jordan (MLOps): Uncertainty weighting adds only 3 parameters ($s_1, s_2, s_3$) and zero computational overhead. It self-tunes during training. I strongly prefer it for production simplicity.
Resolution: Uncertainty weighting as the default. PCGrad as an option for tasks where we need to maximize a specific metric (e.g., intent accuracy) at the expense of others.
Decision Point 2: MTL vs Separate Models
Marcus (Architect): The fundamental trade-off:
| Aspect | Separate Models | Single MTL |
|---|---|---|
| Memory | 198MB (INT8) | 70MB (INT8) |
| Latency | 37ms (3 passes) | 18ms (1 pass) |
| Intent accuracy | 92.1% | 91.8% |
| Sentiment F1 | 84.2% | 85.1% |
| NER F1 | 78.0% | 79.2% |
| Training complexity | Simple | Moderate |
| Independent deployment | ✅ | ❌ (coupled) |
Sam (PM): MTL improves sentiment and NER while only losing 0.3% on intent. That's positive transfer — the shared encoder learns features useful for all tasks.
Priya (ML Engineer): The sentiment improvement (+0.9% F1) makes sense: understanding intent ("return") helps predict sentiment ("frustrated"). The NER improvement (+1.2% F1) comes from context learned across tasks.
Jordan (MLOps): My concern is coupling: if we need to retrain sentiment due to new labels, we have to retrain the entire MTL model. With separate models, we only retrain sentiment.
Marcus (Architect): Mitigation: freeze the shared encoder and only retrain the specific head that changed. This preserves the shared representations while allowing independent head updates.
Resolution: Deploy MTL model for production. Keep the ability to freeze the encoder and retrain individual heads. Maintain separate models as a benchmark reference and fallback if MTL causes unexpected regressions.
Decision Point 3: Which Layers to Share
Priya (ML Engineer): Not all tasks benefit from sharing all layers:
| Sharing Config | Intent | Sentiment | NER | Total Params |
|---|---|---|---|---|
| Share all 6 layers | 91.8% | 85.1% | 79.2% | 68M |
| Share layers 1-4, split 5-6 | 92.2% | 84.6% | 80.5% | 82M |
| Share layers 1-2, split 3-6 | 92.0% | 83.8% | 81.0% | 110M |
| No sharing (separate) | 92.1% | 84.2% | 78.0% | 198M |
Aiko (Data Scientist): Sharing all 6 layers is the sweet spot for our case. The partial sharing configurations give marginal improvements on some tasks but increase parameters. The 82M model with partial sharing doesn't justify the 20% parameter increase for a 0.4% intent improvement.
Resolution: Share all encoder layers. The memory efficiency (68M vs 198M) and latency benefit (1 forward pass vs 3) dominate over the marginal quality differences.
Research Paper References
1. Multi-Task Learning Using Uncertainty to Weigh Losses (Kendall et al., 2018)
Key contribution: Introduced homoscedastic uncertainty weighting for MTL — task weights are derived from the task-specific noise variance, which is learned jointly with the model parameters. This eliminates the need to manually tune task weights and automatically adapts during training. The paper proved this approach on computer vision (semantic segmentation + depth estimation + instance segmentation).
Relevance to MangaAssist: Uncertainty weighting is our primary balancing strategy. It correctly downweights the noisy NER task during early training and upweights it as the shared representations improve. The automatic weight adaptation saves us from extensive hyperparameter searches.
2. GradNorm: Gradient Normalization for Adaptive Loss Balancing (Chen et al., 2018)
Key contribution: Normalized gradient magnitudes across tasks by dynamically adjusting task weights. Introduced the asymmetry parameter $\alpha$ that controls how aggressively the algorithm balances tasks based on their training speed. Higher $\alpha$ forces more equal convergence rates across tasks.
Relevance to MangaAssist: GradNorm serves as our secondary balancing strategy. It is particularly useful when tasks have very different convergence speeds — which happens when we add a new task to an existing MTL model and need to catch up the new task without degrading established ones.
3. Gradient Surgery for Multi-Task Learning (Yu et al., 2020)
Key contribution: Identified that conflicting gradients (negative cosine similarity) are a primary cause of negative transfer in MTL. PCGrad projects conflicting gradients onto non-conflicting directions, provably reducing the variance of the combined gradient. The paper showed improvements across multi-task RL, NLP, and vision benchmarks.
Relevance to MangaAssist: PCGrad resolves the 15-25% of conflicting batches in our training. While we don't use it by default (computational overhead), understanding gradient conflicts helped us diagnose why our initial MTL model underperformed on NER — the intent gradient was dominating the shared layers.
Production Results
MTL vs Separate Models
| Metric | Separate Models | MTL (Uncertainty) | Change |
|---|---|---|---|
| Intent accuracy | 92.1% | 91.8% | -0.3% |
| Sentiment F1 | 84.2% | 85.1% | +0.9% |
| NER F1 | 78.0% | 79.2% | +1.2% |
| Combined latency | 37ms | 18ms | -51% |
| Memory (INT8) | 198MB | 70MB | -65% |
| Monthly cost (compute) | $42 | $15 | -64% |
ROI Analysis
Memory savings: 128MB freed on GPU → can fit larger KV cache for LLM → supports 20% more concurrent users.
Latency savings: 19ms freed in pipeline → reduces P99 from 720ms to 701ms → better user experience.
Cost savings: $27/month × 12 = $324/year. Training cost: one-time $8 for MTL retraining on g5.xlarge.
Net annual benefit: $316 in compute + 20% more concurrency + positive transfer on sentiment (+0.9%) and NER (+1.2%).