LOCAL PREVIEW View on GitHub

13. Multi-Task Learning — Single Model for Intent, Sentiment, and Entities

Problem Statement and MangaAssist Context

MangaAssist currently runs three separate DistilBERT models: intent classification (10 classes), sentiment analysis (multi-label), and a potential entity extraction module. Each model consumes 66-265MB of memory and adds its own latency to the pipeline. Multi-task learning (MTL) trains a single shared encoder with multiple task-specific heads, potentially reducing total memory by 60-70% and enabling shared representations that improve all tasks simultaneously. The challenge: tasks can interfere with each other during training, degrading performance on some tasks while improving others.

Current vs MTL Architecture

Metric 3 Separate Models Single MTL Model
Parameters 3 × 66M = 198M 66M + 3 heads ≈ 68M
Memory (INT8) 3 × 66MB = 198MB 70MB
Total inference time 15ms + 12ms + 10ms = 37ms 18ms (shared encoder + 3 heads)
Encoder forward passes 3 1

Mathematical Foundations

Multi-Task Loss Formulation

Given $T$ tasks with individual losses $\mathcal{L}_1, \mathcal{L}_2, \ldots, \mathcal{L}_T$, the simplest MTL objective:

$$\mathcal{L}{\text{MTL}} = \sum{t=1}^{T} w_t \mathcal{L}_t$$

The weights $w_t$ determine task priority. Naive equal weighting ($w_t = 1/T$) rarely works because: 1. Loss magnitudes differ: cross-entropy for classification might be ~1.5 while token-level NER loss might be ~0.3 2. Learning rates differ: some tasks converge faster 3. Gradient magnitudes differ: dominant tasks suppress learning on other tasks

Uncertainty Weighting (Kendall et al., 2018)

Learns task weights automatically using homoscedastic uncertainty:

For each task $t$, introduce a learnable log-variance parameter $\sigma_t$. The weighted loss:

$$\mathcal{L}{\text{MTL}} = \sum{t=1}^{T} \left[\frac{1}{2\sigma_t^2} \mathcal{L}_t + \log \sigma_t\right]$$

Derivation: Start from the Gaussian likelihood for task $t$:

$$p(y_t | f_t(x)) = \frac{1}{\sqrt{2\pi}\sigma_t} \exp\left(-\frac{(y_t - f_t(x))^2}{2\sigma_t^2}\right)$$

Taking the negative log-likelihood:

$$-\log p = \frac{(y_t - f_t(x))^2}{2\sigma_t^2} + \log \sigma_t + \text{const}$$

The $\frac{1}{2\sigma_t^2}$ term acts as the weight: when a task has high uncertainty ($\sigma_t$ large), its loss contribution is downweighted. The $\log \sigma_t$ regularizer prevents all weights from going to infinity (which would trivially minimize the loss).

Gradient with respect to $\sigma_t$:

$$\frac{\partial \mathcal{L}}{\partial \sigma_t} = -\frac{\mathcal{L}_t}{\sigma_t^3} + \frac{1}{\sigma_t}$$

At equilibrium: $\sigma_t^2 = \mathcal{L}_t$. Tasks with higher loss (harder tasks) automatically get lower weight — the model doesn't waste capacity fighting noise in hard tasks.

In practice, we parameterize with $s_t = \log \sigma_t^2$:

$$\mathcal{L}{\text{MTL}} = \sum{t=1}^{T} \left[\frac{1}{2} e^{-s_t} \mathcal{L}_t + \frac{1}{2} s_t\right]$$

This is numerically stable ($e^{-s_t}$ avoids division by zero).

GradNorm (Chen et al., 2018)

Instead of weighting losses, GradNorm normalizes gradient magnitudes across tasks. The idea: if one task's gradients are 10× larger than another's, the first task dominates training regardless of loss weights.

Define the gradient norm for task $t$ at a shared layer:

$$G_t = | \nabla_{W_{\text{shared}}} w_t \mathcal{L}_t |_2$$

GradNorm wants all tasks to have similar gradient norms, scaled by their relative training speed:

$$\bar{G} = \mathbb{E}_t[G_t], \quad r_t = \frac{\tilde{\mathcal{L}}_t}{\mathbb{E}_t[\tilde{\mathcal{L}}_t]}$$

where $\tilde{\mathcal{L}}_t = \mathcal{L}_t(t) / \mathcal{L}_t(0)$ is the ratio of current loss to initial loss (training rate).

The target gradient norm for task $t$:

$$G_t^{\text{target}} = \bar{G} \cdot r_t^{\alpha}$$

where $\alpha$ controls how aggressively we balance tasks ($\alpha = 1.5$ is typical). Tasks that are training slowly ($r_t$ large) get larger target gradient norms.

Update rule for task weights:

$$\mathcal{L}_{\text{grad}} = \sum_t |G_t - G_t^{\text{target}}|_1$$

$$w_t \leftarrow w_t - \eta_w \frac{\partial \mathcal{L}_{\text{grad}}}{\partial w_t}$$

Gradient Surgery — PCGrad (Yu et al., 2020)

When two tasks have conflicting gradients (their gradient vectors point in opposite directions), standard MTL forces a compromise that hurts both. PCGrad (Projecting Conflicting Gradients) detects conflicts and removes the conflicting component:

For tasks $i$ and $j$ with gradients $g_i$ and $g_j$:

$$\text{If } g_i \cdot g_j < 0 \text{ (conflict):} \quad g_i' = g_i - \frac{g_i \cdot g_j}{|g_j|^2} g_j$$

This projects $g_i$ onto the plane perpendicular to $g_j$, removing the conflicting component while preserving the non-conflicting direction.

Geometric interpretation: Each task gradient lives in a high-dimensional parameter space. When two tasks conflict, their gradients form an obtuse angle. PCGrad projects each onto the other's normal plane, finding a direction that helps (or at least doesn't hurt) both tasks.

For our 3 tasks (intent, sentiment, NER):

$$\text{Conflict rate} \approx 15-25\%$$

This means 15-25% of mini-batches have at least one pair of conflicting gradients. Without PCGrad, these conflicts cause oscillations in training and degrade the weaker task.

Task Affinity and Negative Transfer

Not all task combinations are beneficial. Task affinity measures whether co-training helps:

$$A_{i \to j} = \frac{\text{Performance of } j \text{ when co-trained with } i}{\text{Performance of } j \text{ alone}} - 1$$

If $A_{i \to j} > 0$: task $i$ helps task $j$ (positive transfer). If $A_{i \to j} < 0$: task $i$ hurts task $j$ (negative transfer).

For MangaAssist:

Task Pair Affinity Direction
Intent → Sentiment +2.1% Positive: intent features help sentiment
Sentiment → Intent +1.4% Positive: emotional context aids intent
NER → Intent +0.8% Weak positive
Intent → NER -0.3% Weak negative: intent task dominates

Model Internals — Layer-by-Layer Diagrams

Multi-Task Architecture

graph TB
    subgraph "Shared Encoder (DistilBERT, 66M params)"
        INPUT["Input: 'I want to return this damaged manga volume'"]
        TOK["Tokenizer → [CLS] I want to return this damaged manga volume [SEP]"]
        EMB["Embedding Layer (shared)"]
        L1["Transformer Layer 1 (shared)"]
        L2["Transformer Layer 2 (shared)"]
        L3["Transformer Layer 3 (shared)"]
        L4["Transformer Layer 4 (shared)"]
        L5["Transformer Layer 5 (shared)"]
        L6["Transformer Layer 6 (shared)"]

        INPUT --> TOK --> EMB --> L1 --> L2 --> L3 --> L4 --> L5 --> L6
    end

    subgraph "Task-Specific Heads (2M params total)"
        CLS["[CLS] token → 768-dim"]
        SEQ["All tokens → 768-dim each"]

        L6 --> CLS & SEQ

        H_INT["Intent Head<br>Linear(768→256→10)<br>Softmax<br>Prediction: return_refund"]
        H_SENT["Sentiment Head<br>Linear(768→256→6)<br>Sigmoid per label<br>Pred: frustration=0.82"]
        H_NER["NER Head<br>Linear(768→256→9)<br>Per-token softmax<br>Pred: manga=PRODUCT"]

        CLS --> H_INT & H_SENT
        SEQ --> H_NER
    end

    style L1 fill:#e3f2fd
    style L2 fill:#e3f2fd
    style L3 fill:#e3f2fd
    style L4 fill:#e3f2fd
    style L5 fill:#e3f2fd
    style L6 fill:#e3f2fd
    style H_INT fill:#c8e6c9
    style H_SENT fill:#fff9c4
    style H_NER fill:#ffccbc

Gradient Flow and Conflict Detection

graph TB
    subgraph "Gradient Conflict in Shared Layers"
        LOSS_I["Intent Loss = 0.42<br>∇W₃(intent)"]
        LOSS_S["Sentiment Loss = 0.68<br>∇W₃(sentiment)"]
        LOSS_N["NER Loss = 0.31<br>∇W₃(NER)"]

        CHECK["Conflict Detection:<br>g_intent · g_sentiment = -0.15 ❌ CONFLICT<br>g_intent · g_ner = +0.34 ✅ aligned<br>g_sentiment · g_ner = +0.08 ✅ aligned"]

        LOSS_I --> CHECK
        LOSS_S --> CHECK
        LOSS_N --> CHECK

        PCGRAD["PCGrad Resolution:<br>g'_intent = g_intent - proj(g_intent, g_sentiment)<br>Remove conflicting component<br><br>g'_sentiment = g_sentiment - proj(g_sentiment, g_intent)<br>Remove conflicting component<br><br>g_ner unchanged (no conflicts)"]

        CHECK --> PCGRAD

        FINAL["Final gradient for Layer 3:<br>g_shared = g'_intent + g'_sentiment + g_ner<br>No task dominance, no destructive interference"]

        PCGRAD --> FINAL
    end

    style CHECK fill:#ffcdd2
    style PCGRAD fill:#c8e6c9
    style FINAL fill:#c8e6c9

Uncertainty Weighting Dynamics

graph LR
    subgraph "Epoch 1: Equal weights"
        E1_I["Intent<br>w=0.33<br>Loss=2.30<br>σ²=2.30"]
        E1_S["Sentiment<br>w=0.33<br>Loss=0.85<br>σ²=0.85"]
        E1_N["NER<br>w=0.33<br>Loss=1.45<br>σ²=1.45"]
    end

    subgraph "Epoch 5: Weights adapted"
        E5_I["Intent<br>w=0.21 ⬇<br>Loss=0.95<br>Easy → downweight"]
        E5_S["Sentiment<br>w=0.38 ⬆<br>Loss=0.62<br>Hard → upweight"]
        E5_N["NER<br>w=0.41 ⬆<br>Loss=0.78<br>Still learning"]
    end

    subgraph "Epoch 15: Converged"
        E15_I["Intent<br>w=0.25<br>Loss=0.42<br>Converged"]
        E15_S["Sentiment<br>w=0.35<br>Loss=0.38<br>Near-converged"]
        E15_N["NER<br>w=0.40<br>Loss=0.31<br>Converged"]
    end

    E1_I -->|"σ adapts"| E5_I -->|"stabilizes"| E15_I
    E1_S -->|"σ adapts"| E5_S -->|"stabilizes"| E15_S
    E1_N -->|"σ adapts"| E5_N -->|"stabilizes"| E15_N

GradNorm Balancing

graph TB
    subgraph "GradNorm: Normalize gradient magnitudes"
        G1["Without GradNorm:<br>‖∇intent‖ = 2.5<br>‖∇sentiment‖ = 0.3<br>‖∇NER‖ = 1.1<br><br>Intent dominates 8.3×<br>over sentiment!"]

        GN["GradNorm target (α=1.5):<br>r_intent = L(t)/L(0) = 0.18 (fast)<br>r_sentiment = 0.73 (slow)<br>r_NER = 0.54<br><br>Target ‖∇intent‖ = 0.49<br>Target ‖∇sentiment‖ = 1.80<br>Target ‖∇NER‖ = 1.21"]

        G2["With GradNorm:<br>w_intent adjusted down → ‖g‖ ≈ 0.5<br>w_sentiment adjusted up → ‖g‖ ≈ 1.8<br>w_NER adjusted slightly → ‖g‖ ≈ 1.2<br><br>Balanced! Sentiment gets<br>more gradient signal"]

        G1 --> GN --> G2
    end

    style G1 fill:#ffcdd2
    style G2 fill:#c8e6c9

Task Head Architecture Detail

graph TB
    subgraph "Intent Head (Classification)"
        CLS1["[CLS] embedding: 768-dim"]
        INT_D["Dropout(0.1)"]
        INT_L1["Linear(768→256) + GELU"]
        INT_LN["LayerNorm(256)"]
        INT_L2["Linear(256→10)"]
        INT_SF["Softmax → 10 intent probabilities"]

        CLS1 --> INT_D --> INT_L1 --> INT_LN --> INT_L2 --> INT_SF
    end

    subgraph "Sentiment Head (Multi-Label)"
        CLS2["[CLS] embedding: 768-dim"]
        SENT_D["Dropout(0.1)"]
        SENT_L1["Linear(768→256) + GELU"]
        SENT_LN["LayerNorm(256)"]
        SENT_L2["Linear(256→6)"]
        SENT_SIG["Sigmoid per label → 6 independent probabilities"]

        CLS2 --> SENT_D --> SENT_L1 --> SENT_LN --> SENT_L2 --> SENT_SIG
    end

    subgraph "NER Head (Token Classification)"
        TOK["All token embeddings: [n×768]"]
        NER_D["Dropout(0.1)"]
        NER_L1["Linear(768→256) + GELU"]
        NER_LN["LayerNorm(256)"]
        NER_L2["Linear(256→9)"]
        NER_SF["Per-token Softmax → BIO tags"]

        TOK --> NER_D --> NER_L1 --> NER_LN --> NER_L2 --> NER_SF
    end

    style INT_SF fill:#c8e6c9
    style SENT_SIG fill:#fff9c4
    style NER_SF fill:#ffccbc

Implementation Deep-Dive

Multi-Task Model

import torch
import torch.nn as nn
from transformers import DistilBertModel, DistilBertTokenizer


class MultiTaskHead(nn.Module):
    """A single task head with configurable output."""

    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, dropout: float = 0.1):
        super().__init__()
        self.head = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, x):
        return self.head(x)


class MultiTaskDistilBERT(nn.Module):
    """
    Single DistilBERT encoder with three task-specific heads.
    Replaces 3 separate models → 65% memory reduction.
    """

    def __init__(
        self,
        num_intents: int = 10,
        num_sentiments: int = 6,
        num_ner_tags: int = 9,
        model_name: str = "distilbert-base-uncased",
    ):
        super().__init__()
        self.encoder = DistilBertModel.from_pretrained(model_name)
        hidden = self.encoder.config.hidden_size  # 768

        self.intent_head = MultiTaskHead(hidden, 256, num_intents)
        self.sentiment_head = MultiTaskHead(hidden, 256, num_sentiments)
        self.ner_head = MultiTaskHead(hidden, 256, num_ner_tags)

    def forward(self, input_ids, attention_mask, task: str = "all"):
        encoder_output = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        hidden_states = encoder_output.last_hidden_state  # [batch, seq_len, 768]
        cls_output = hidden_states[:, 0, :]  # [batch, 768]

        results = {}
        if task in ("all", "intent"):
            results["intent_logits"] = self.intent_head(cls_output)
        if task in ("all", "sentiment"):
            results["sentiment_logits"] = self.sentiment_head(cls_output)
        if task in ("all", "ner"):
            results["ner_logits"] = self.ner_head(hidden_states)

        return results

Uncertainty-Weighted Loss

class UncertaintyWeightedLoss(nn.Module):
    """
    Kendall et al. 2018: automatically learn task weights
    from homoscedastic uncertainty.
    """

    def __init__(self, num_tasks: int = 3):
        super().__init__()
        # Log-variance parameters (learnable)
        self.log_vars = nn.Parameter(torch.zeros(num_tasks))

    def forward(self, losses: list[torch.Tensor]) -> tuple[torch.Tensor, dict]:
        total_loss = 0
        weights = {}

        for i, loss in enumerate(losses):
            # L_total = (1/2σ²) * L_t + log(σ)
            # Using log_var = log(σ²):
            precision = torch.exp(-self.log_vars[i])  # 1/σ²
            total_loss += 0.5 * precision * loss + 0.5 * self.log_vars[i]
            weights[f"task_{i}_weight"] = precision.item()
            weights[f"task_{i}_sigma"] = torch.exp(0.5 * self.log_vars[i]).item()

        return total_loss, weights

PCGrad Implementation

import copy


class PCGrad:
    """
    Yu et al. 2020: Project Conflicting Gradients.
    Removes conflicting gradient components between tasks.
    """

    def __init__(self, optimizer):
        self.optimizer = optimizer

    def step(self, losses: list[torch.Tensor], shared_params: list[nn.Parameter]):
        """
        Compute per-task gradients, resolve conflicts, then update.
        """
        task_grads = []

        for loss in losses:
            self.optimizer.zero_grad()
            loss.backward(retain_graph=True)
            grads = [p.grad.clone() if p.grad is not None else torch.zeros_like(p)
                     for p in shared_params]
            task_grads.append(grads)

        # Resolve conflicts via projection
        num_tasks = len(losses)
        projected_grads = [list(g) for g in task_grads]  # Deep copy

        for i in range(num_tasks):
            for j in range(num_tasks):
                if i == j:
                    continue
                for k in range(len(shared_params)):
                    g_i = projected_grads[i][k]
                    g_j = task_grads[j][k]

                    dot = (g_i * g_j).sum()
                    if dot < 0:
                        # Conflict detected: project g_i onto plane perpendicular to g_j
                        projected_grads[i][k] = g_i - (dot / (g_j.norm() ** 2 + 1e-8)) * g_j

        # Sum projected gradients and apply
        self.optimizer.zero_grad()
        for k, param in enumerate(shared_params):
            param.grad = sum(projected_grads[i][k] for i in range(num_tasks))

        self.optimizer.step()


class GradNormLoss(nn.Module):
    """
    Chen et al. 2018: Dynamically balance gradient norms.
    """

    def __init__(self, num_tasks: int = 3, alpha: float = 1.5):
        super().__init__()
        self.num_tasks = num_tasks
        self.alpha = alpha
        self.task_weights = nn.Parameter(torch.ones(num_tasks))
        self.initial_losses = None

    def forward(
        self,
        losses: list[torch.Tensor],
        shared_layer: nn.Module,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Returns:
            weighted_loss: The task-weighted MTL loss
            gradnorm_loss: The GradNorm regularization loss
        """
        # Record initial losses for training rate computation
        if self.initial_losses is None:
            self.initial_losses = [l.item() for l in losses]

        # Weighted task loss
        weighted_loss = sum(self.task_weights[i] * losses[i] for i in range(self.num_tasks))

        # Compute gradient norms for shared layer
        grad_norms = []
        for i in range(self.num_tasks):
            gw = torch.autograd.grad(
                self.task_weights[i] * losses[i],
                shared_layer.parameters(),
                retain_graph=True,
                create_graph=True,
            )
            grad_norms.append(torch.norm(torch.cat([g.flatten() for g in gw])))

        # Average gradient norm
        avg_norm = sum(grad_norms) / self.num_tasks

        # Training rate ratios
        loss_ratios = [losses[i].item() / (self.initial_losses[i] + 1e-8)
                       for i in range(self.num_tasks)]
        avg_ratio = sum(loss_ratios) / self.num_tasks
        relative_rates = [r / (avg_ratio + 1e-8) for r in loss_ratios]

        # Target gradient norms
        target_norms = [avg_norm * (r ** self.alpha) for r in relative_rates]

        # GradNorm loss
        gradnorm_loss = sum(
            torch.abs(grad_norms[i] - target_norms[i].detach())
            for i in range(self.num_tasks)
        )

        return weighted_loss, gradnorm_loss

Multi-Task Trainer

class MultiTaskTrainer:
    """End-to-end MTL trainer with configurable loss balancing."""

    def __init__(
        self,
        model: MultiTaskDistilBERT,
        strategy: str = "uncertainty",  # "uncertainty", "pcgrad", "gradnorm"
    ):
        self.model = model
        self.strategy = strategy

        if strategy == "uncertainty":
            self.loss_module = UncertaintyWeightedLoss(num_tasks=3)
        elif strategy == "gradnorm":
            self.loss_module = GradNormLoss(num_tasks=3, alpha=1.5)

        all_params = list(model.parameters())
        if strategy == "uncertainty":
            all_params += list(self.loss_module.parameters())
        elif strategy == "gradnorm":
            all_params += list(self.loss_module.parameters())

        self.optimizer = torch.optim.AdamW(all_params, lr=2e-5, weight_decay=0.01)

        if strategy == "pcgrad":
            self.pcgrad = PCGrad(self.optimizer)

        self.intent_loss_fn = nn.CrossEntropyLoss()
        self.sentiment_loss_fn = nn.BCEWithLogitsLoss()
        self.ner_loss_fn = nn.CrossEntropyLoss()

    def train_step(self, batch: dict) -> dict:
        self.model.train()
        outputs = self.model(
            batch["input_ids"], batch["attention_mask"], task="all",
        )

        # Compute per-task losses
        intent_loss = self.intent_loss_fn(outputs["intent_logits"], batch["intent_labels"])
        sentiment_loss = self.sentiment_loss_fn(
            outputs["sentiment_logits"], batch["sentiment_labels"].float(),
        )
        ner_loss = self.ner_loss_fn(
            outputs["ner_logits"].view(-1, 9), batch["ner_labels"].view(-1),
        )

        losses = [intent_loss, sentiment_loss, ner_loss]

        if self.strategy == "uncertainty":
            total_loss, weight_info = self.loss_module(losses)
            self.optimizer.zero_grad()
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.optimizer.step()
            return {"loss": total_loss.item(), **weight_info}

        elif self.strategy == "pcgrad":
            shared_params = list(self.model.encoder.parameters())
            self.pcgrad.step(losses, shared_params)
            # Update task heads normally
            head_loss = sum(losses)
            self.optimizer.zero_grad()
            head_loss.backward()
            self.optimizer.step()
            return {"loss": sum(l.item() for l in losses)}

        elif self.strategy == "gradnorm":
            shared_layer = self.model.encoder.transformer.layer[-1]
            weighted_loss, gn_loss = self.loss_module(losses, shared_layer)
            self.optimizer.zero_grad()
            (weighted_loss + gn_loss).backward()
            self.optimizer.step()
            # Renormalize task weights
            with torch.no_grad():
                self.loss_module.task_weights.data = (
                    self.loss_module.task_weights / self.loss_module.task_weights.sum() * 3
                )
            return {"loss": weighted_loss.item(), "gn_loss": gn_loss.item()}

    def evaluate(self, val_loader) -> dict:
        self.model.eval()
        intent_correct = intent_total = 0
        sentiment_f1_samples = []
        ner_f1_samples = []

        with torch.no_grad():
            for batch in val_loader:
                outputs = self.model(
                    batch["input_ids"], batch["attention_mask"], task="all",
                )

                # Intent accuracy
                preds = outputs["intent_logits"].argmax(dim=-1)
                intent_correct += (preds == batch["intent_labels"]).sum().item()
                intent_total += batch["intent_labels"].size(0)

                # Sentiment F1 (per-label threshold)
                sent_preds = (torch.sigmoid(outputs["sentiment_logits"]) > 0.5).float()
                # Simplified per-sample F1
                for i in range(sent_preds.size(0)):
                    tp = (sent_preds[i] * batch["sentiment_labels"][i]).sum()
                    fp = (sent_preds[i] * (1 - batch["sentiment_labels"][i])).sum()
                    fn = ((1 - sent_preds[i]) * batch["sentiment_labels"][i]).sum()
                    prec = tp / (tp + fp + 1e-8)
                    rec = tp / (tp + fn + 1e-8)
                    f1 = 2 * prec * rec / (prec + rec + 1e-8)
                    sentiment_f1_samples.append(f1.item())

        return {
            "intent_accuracy": intent_correct / intent_total,
            "sentiment_f1": sum(sentiment_f1_samples) / len(sentiment_f1_samples),
        }

Group Discussion: Key Decision Points

Decision Point 1: Loss Balancing Strategy

Priya (ML Engineer): Compared all three strategies on MangaAssist data:

Strategy Intent Acc Sentiment F1 NER F1 Training Stability
Equal weights (baseline) 90.1% 82.3% 76.5% Oscillatory
Uncertainty weighting 91.8% 85.1% 79.2% Smooth
GradNorm (α=1.5) 91.5% 84.8% 80.1% Smooth
PCGrad 92.0% 85.4% 78.8% Slightly noisy

Separate models (reference): Intent 92.1%, Sentiment 84.2%, NER 78.0%

Aiko (Data Scientist): PCGrad gives the best intent and sentiment scores but NER drops compared to GradNorm. Uncertainty weighting is the best all-around: consistent improvement across all tasks with the simplest implementation.

Marcus (Architect): PCGrad adds computational overhead: we need to compute per-task gradients separately (3 backward passes) and then project. That triples the backward pass time.

Jordan (MLOps): Uncertainty weighting adds only 3 parameters ($s_1, s_2, s_3$) and zero computational overhead. It self-tunes during training. I strongly prefer it for production simplicity.

Resolution: Uncertainty weighting as the default. PCGrad as an option for tasks where we need to maximize a specific metric (e.g., intent accuracy) at the expense of others.

Decision Point 2: MTL vs Separate Models

Marcus (Architect): The fundamental trade-off:

Aspect Separate Models Single MTL
Memory 198MB (INT8) 70MB (INT8)
Latency 37ms (3 passes) 18ms (1 pass)
Intent accuracy 92.1% 91.8%
Sentiment F1 84.2% 85.1%
NER F1 78.0% 79.2%
Training complexity Simple Moderate
Independent deployment ❌ (coupled)

Sam (PM): MTL improves sentiment and NER while only losing 0.3% on intent. That's positive transfer — the shared encoder learns features useful for all tasks.

Priya (ML Engineer): The sentiment improvement (+0.9% F1) makes sense: understanding intent ("return") helps predict sentiment ("frustrated"). The NER improvement (+1.2% F1) comes from context learned across tasks.

Jordan (MLOps): My concern is coupling: if we need to retrain sentiment due to new labels, we have to retrain the entire MTL model. With separate models, we only retrain sentiment.

Marcus (Architect): Mitigation: freeze the shared encoder and only retrain the specific head that changed. This preserves the shared representations while allowing independent head updates.

Resolution: Deploy MTL model for production. Keep the ability to freeze the encoder and retrain individual heads. Maintain separate models as a benchmark reference and fallback if MTL causes unexpected regressions.

Decision Point 3: Which Layers to Share

Priya (ML Engineer): Not all tasks benefit from sharing all layers:

Sharing Config Intent Sentiment NER Total Params
Share all 6 layers 91.8% 85.1% 79.2% 68M
Share layers 1-4, split 5-6 92.2% 84.6% 80.5% 82M
Share layers 1-2, split 3-6 92.0% 83.8% 81.0% 110M
No sharing (separate) 92.1% 84.2% 78.0% 198M

Aiko (Data Scientist): Sharing all 6 layers is the sweet spot for our case. The partial sharing configurations give marginal improvements on some tasks but increase parameters. The 82M model with partial sharing doesn't justify the 20% parameter increase for a 0.4% intent improvement.

Resolution: Share all encoder layers. The memory efficiency (68M vs 198M) and latency benefit (1 forward pass vs 3) dominate over the marginal quality differences.


Research Paper References

1. Multi-Task Learning Using Uncertainty to Weigh Losses (Kendall et al., 2018)

Key contribution: Introduced homoscedastic uncertainty weighting for MTL — task weights are derived from the task-specific noise variance, which is learned jointly with the model parameters. This eliminates the need to manually tune task weights and automatically adapts during training. The paper proved this approach on computer vision (semantic segmentation + depth estimation + instance segmentation).

Relevance to MangaAssist: Uncertainty weighting is our primary balancing strategy. It correctly downweights the noisy NER task during early training and upweights it as the shared representations improve. The automatic weight adaptation saves us from extensive hyperparameter searches.

2. GradNorm: Gradient Normalization for Adaptive Loss Balancing (Chen et al., 2018)

Key contribution: Normalized gradient magnitudes across tasks by dynamically adjusting task weights. Introduced the asymmetry parameter $\alpha$ that controls how aggressively the algorithm balances tasks based on their training speed. Higher $\alpha$ forces more equal convergence rates across tasks.

Relevance to MangaAssist: GradNorm serves as our secondary balancing strategy. It is particularly useful when tasks have very different convergence speeds — which happens when we add a new task to an existing MTL model and need to catch up the new task without degrading established ones.

3. Gradient Surgery for Multi-Task Learning (Yu et al., 2020)

Key contribution: Identified that conflicting gradients (negative cosine similarity) are a primary cause of negative transfer in MTL. PCGrad projects conflicting gradients onto non-conflicting directions, provably reducing the variance of the combined gradient. The paper showed improvements across multi-task RL, NLP, and vision benchmarks.

Relevance to MangaAssist: PCGrad resolves the 15-25% of conflicting batches in our training. While we don't use it by default (computational overhead), understanding gradient conflicts helped us diagnose why our initial MTL model underperformed on NER — the intent gradient was dominating the shared layers.


Production Results

MTL vs Separate Models

Metric Separate Models MTL (Uncertainty) Change
Intent accuracy 92.1% 91.8% -0.3%
Sentiment F1 84.2% 85.1% +0.9%
NER F1 78.0% 79.2% +1.2%
Combined latency 37ms 18ms -51%
Memory (INT8) 198MB 70MB -65%
Monthly cost (compute) $42 $15 -64%

ROI Analysis

Memory savings: 128MB freed on GPU → can fit larger KV cache for LLM → supports 20% more concurrent users.

Latency savings: 19ms freed in pipeline → reduces P99 from 720ms to 701ms → better user experience.

Cost savings: $27/month × 12 = $324/year. Training cost: one-time $8 for MTL retraining on g5.xlarge.

Net annual benefit: $316 in compute + 20% more concurrency + positive transfer on sentiment (+0.9%) and NER (+1.2%).