LOCAL PREVIEW View on GitHub

12. Quantization-Aware Training — INT8/INT4 Without Quality Loss

Problem Statement and MangaAssist Context

MangaAssist runs five models in its inference pipeline. At FP32 precision, the combined memory footprint would exceed available GPU/CPU budgets. Post-training quantization (PTQ) naively rounds weights after training, which degrades quality — especially for smaller models where each weight carries more information. Quantization-aware training (QAT) simulates quantization during training so the model learns weight values that are robust to rounding. Modern techniques like GPTQ and AWQ go further: they use second-order information (Hessian) and activation-aware scaling to minimize quantization error without full retraining.

Memory Budget Pressure

Model FP32 Size FP16 Size INT8 Size INT4 Size
Llama 3 70B 280GB 140GB 70GB 35GB
Llama 3 8B (distilled) 32GB 16GB 8GB 4GB
DistilBERT (intent) 265MB 132MB 66MB 33MB
MiniLM-L6 (reranker) 90MB 45MB 22MB 11MB
DistilBERT (sentiment) 265MB 132MB 66MB 33MB

Goal: Run the full pipeline on a single g5.xlarge (1× A10G, 24GB VRAM) instead of g5.2xlarge (1× A10G, 24GB VRAM but more system RAM). The LLM fallback (Llama 3 8B) must fit in 4-8GB to leave headroom for the smaller models and KV cache.


Mathematical Foundations

Uniform Quantization

Map a floating-point weight $w \in [w_{\min}, w_{\max}]$ to an integer $q \in [0, 2^b - 1]$ for $b$-bit quantization:

Symmetric quantization (zero maps to zero):

$$q = \text{round}\left(\frac{w}{s}\right), \quad s = \frac{\max(|w_{\min}|, |w_{\max}|)}{2^{b-1} - 1}$$

$$\hat{w} = q \cdot s$$

Asymmetric quantization (allows asymmetric weight distributions):

$$q = \text{round}\left(\frac{w - z}{s}\right), \quad s = \frac{w_{\max} - w_{\min}}{2^b - 1}, \quad z = w_{\min}$$

$$\hat{w} = q \cdot s + z$$

Quantization error: For a single weight:

$$\epsilon = w - \hat{w} = w - \text{round}\left(\frac{w}{s}\right) \cdot s$$

The maximum error per weight is $\epsilon_{\max} = \frac{s}{2}$. For INT8 symmetric quantization of weights in $[-1, 1]$:

$$s = \frac{1}{127} \approx 0.0079, \quad \epsilon_{\max} \approx 0.004$$

For INT4: $s = \frac{1}{7} \approx 0.143, \quad \epsilon_{\max} \approx 0.071$ — 18× larger error per weight. This is why INT4 requires more sophisticated techniques.

Straight-Through Estimator (STE)

The rounding function $\text{round}(\cdot)$ has zero gradient almost everywhere. QAT uses the straight-through estimator to bypass this:

Forward pass: Apply quantization (round)

$$\hat{w} = \text{round}\left(\frac{w}{s}\right) \cdot s$$

Backward pass: Pretend quantization didn't happen (pass gradient through)

$$\frac{\partial \mathcal{L}}{\partial w} \approx \frac{\partial \mathcal{L}}{\partial \hat{w}}$$

More precisely, the STE gradient is:

$$\frac{\partial \hat{w}}{\partial w} = \begin{cases} 1 & \text{if } w_{\min} \leq w \leq w_{\max} \ 0 & \text{otherwise (clamped)} \end{cases}$$

Why this works: The STE is biased but low-variance. During training, the model adjusts its weight values to positions where rounding introduces minimal error. Weights migrate toward "quantization-friendly" values that are near integer multiples of the scale factor $s$.

GPTQ — Post-Training Quantization with Hessian Information (Frantar & Alistarh, 2022)

GPTQ quantizes weights column-by-column using second-order error correction. For a weight matrix $W$ and a calibration set producing activation matrix $X$:

The quantization objective per row:

$$\min_{\hat{W}} | WX - \hat{W}X |_2^2$$

This is an optimal brain surgeon (OBS) problem. The Hessian of the quantization error for the $i$-th column:

$$H = 2X X^T$$

GPTQ quantizes columns sequentially. After quantizing column $j$:

  1. Compute the quantization error $\delta_j = w_j - \hat{w}_j$
  2. Compensate remaining columns to minimize total error:

$$w_k \leftarrow w_k - \frac{\delta_j \cdot H_{jk}}{H_{jj}}, \quad \forall k > j$$

This error compensation is the key insight: the remaining unquantized columns absorb the error from quantizing column $j$. The Hessian tells us how to redistribute the error optimally.

Computational trick: GPTQ avoids explicit Hessian computation by using the Cholesky decomposition of $H^{-1}$, processing weights in blocks of 128 columns for numerical stability.

AWQ — Activation-Aware Weight Quantization (Lin et al., 2023)

AWQ observes that not all weights are equally important. Weights that interact with large activations have more impact on the output. Instead of quantizing all weights uniformly, AWQ scales weights before quantization:

$$\hat{w}_i = \text{quant}(w_i \cdot s_i) / s_i$$

where $s_i$ is a per-channel scaling factor determined by the activation magnitudes:

$$s_i = \left(\frac{\text{mean}(|X_i|)}{\max_j \text{mean}(|X_j|)}\right)^{\alpha}$$

with $\alpha \in [0, 1]$ searched per layer (typically $\alpha \approx 0.5$).

Intuition: Weights connected to channels with large activations get scaled up before quantization (more quantization bins allocated to them) and scaled back down after. This is equivalent to allocating more "precision bits" to the most important weight channels.

SmoothQuant (Xiao et al., 2022)

SmoothQuant addresses the activation outlier problem. Some activation channels have magnitudes 100× larger than others, making activation quantization difficult. SmoothQuant migrates the quantization difficulty from activations to weights:

$$Y = (X \text{diag}(s)^{-1}) \cdot (\text{diag}(s) W) = \hat{X} \hat{W}$$

The smoothing factor $s$ is per-channel:

$$s_j = \frac{\max(|X_j|)^\alpha}{\max(|W_j|)^{1-\alpha}}$$

This balances the dynamic range between activations and weights, making both easier to quantize.

Quantization Error Propagation Through Layers

For a model with $L$ layers, each with quantization error $\epsilon_l$, the total output error is approximately:

$$\epsilon_{\text{total}} \approx \sum_{l=1}^{L} \left(\prod_{k=l+1}^{L} |W_k|\right) \cdot \epsilon_l$$

Implication: Errors in early layers get amplified by all subsequent layers. This means: 1. Early layers should be quantized more conservatively (higher bits) 2. Late layers (closest to output) contribute less error propagation 3. Mixed-precision approaches assign more bits to early layers

For DistilBERT with 6 layers and weight norms $|W_k| \approx 1.2$: - Layer 1 error amplification: $1.2^5 \approx 2.49×$ - Layer 6 error amplification: $1.2^0 = 1.0×$

For Llama 3 8B with 32 layers and $|W_k| \approx 1.05$: - Layer 1 error amplification: $1.05^{31} \approx 4.54×$ - Layer 32 error amplification: $1.0×$


Model Internals — Layer-by-Layer Diagrams

Quantization Grid Visualization

graph TB
    subgraph "FP32 → INT8 Quantization (one weight row)"
        FP32["FP32 weights:<br>[-0.82, 0.15, -0.03, 0.67, 1.24, -0.56, 0.91, -1.10]<br>Continuous values, infinite precision"]

        SCALE["Compute scale: s = max(|w|) / 127<br>s = 1.24 / 127 = 0.00976"]

        INT8["INT8 quantized:<br>[-84, 15, -3, 69, 127, -57, 93, -113]<br>Only 256 possible values"]

        DEQ["Dequantized (FP32):<br>[-0.820, 0.146, -0.029, 0.673, 1.240, -0.557, 0.908, -1.103]<br>Errors: [0, 0.004, 0.001, 0.003, 0, 0.003, 0.002, 0.003]"]

        FP32 --> SCALE --> INT8 --> DEQ
    end

    subgraph "FP32 → INT4 Quantization (same weights)"
        FP32_2["Same FP32 weights:<br>[-0.82, 0.15, -0.03, 0.67, 1.24, -0.56, 0.91, -1.10]"]

        SCALE_2["Compute scale: s = max(|w|) / 7<br>s = 1.24 / 7 = 0.177"]

        INT4["INT4 quantized:<br>[-5, 1, 0, 4, 7, -3, 5, -6]<br>Only 16 possible values!"]

        DEQN4["Dequantized (FP32):<br>[-0.886, 0.177, 0.000, 0.709, 1.240, -0.531, 0.886, -1.063]<br>Errors: [0.066, 0.027, 0.030, 0.039, 0, 0.029, 0.024, 0.037]"]

        FP32_2 --> SCALE_2 --> INT4 --> DEQN4
    end

    style INT8 fill:#c8e6c9
    style INT4 fill:#fff9c4
    style DEQN4 fill:#ffcdd2

QAT Forward-Backward Pass

sequenceDiagram
    participant W as FP32 Weights (master)
    participant FQ as Fake Quantize
    participant FWD as Forward Pass
    participant LOSS as Loss
    participant BWD as Backward (STE)

    Note over W,BWD: QAT Training Loop
    W->>FQ: w_float
    FQ->>FQ: q = round(w/s) · s<br>(simulate INT8 rounding)
    FQ->>FWD: w_quantized (but still FP32 tensor)
    FWD->>FWD: Normal forward with quantized weights<br>Model "sees" what INT8 inference will look like
    FWD->>LOSS: logits → cross-entropy

    LOSS->>BWD: ∂L/∂logits
    BWD->>BWD: Backprop through all layers
    BWD->>FQ: ∂L/∂w_quantized
    FQ->>W: ∂L/∂w ≈ ∂L/∂w_quantized (STE)<br>Pass gradient straight through!
    W->>W: w ← w - η · ∂L/∂w<br>Update FP32 master weights

    Note over W,BWD: Key: FP32 weights adjust to<br>positions where rounding error is minimal

GPTQ Column-by-Column Quantization

graph LR
    subgraph "GPTQ: Quantize col j, compensate cols j+1...n"
        C1["Col 1<br>✅ Quantized<br>Error: δ₁"]
        C2["Col 2<br>✅ Quantized<br>Error: δ₂<br>+ correction<br>from δ₁"]
        C3["Col 3<br>🔄 Quantizing now<br>Error: δ₃"]
        C4["Col 4<br>⬜ Float<br>Absorbs δ₃ error<br>w₄ -= δ₃·H₃₄/H₃₃"]
        C5["Col 5<br>⬜ Float<br>Absorbs δ₃ error<br>w₅ -= δ₃·H₃₅/H₃₃"]
        CN["Col n<br>⬜ Float<br>Last col absorbs<br>all residual error"]

        C1 --> C2 --> C3 --> C4 --> C5 --> CN
    end

    subgraph "Hessian H = 2XX^T"
        H["H tells us how<br>columns interact.<br>H_jk large → col k can<br>compensate col j's error.<br>H_jj large → col j is<br>sensitive to quantization.<br><br>Calibration: 128 samples<br>from MangaAssist queries"]
    end

    style C1 fill:#c8e6c9
    style C2 fill:#c8e6c9
    style C3 fill:#fff9c4
    style C4 fill:#e3f2fd
    style C5 fill:#e3f2fd

Mixed-Precision Layer Assignment

graph TB
    subgraph "Llama 3 8B: Per-Layer Bit Assignment"
        EMB["Embedding layer<br>INT8 (higher precision)<br>Maps tokens → continuous space<br>Error here amplifies 32×"]

        L1["Layers 1-4 (early)<br>INT8 (8-bit)<br>Error amplification: 4.5×<br>Capture low-level patterns"]

        L5["Layers 5-16 (middle)<br>INT4 (4-bit, GPTQ)<br>Error amplification: 2.1×<br>Redundant representations"]

        L17["Layers 17-28 (late-middle)<br>INT4 (4-bit, GPTQ)<br>Error amplification: 1.3×<br>High-level features"]

        L29["Layers 29-32 (final)<br>INT8 (8-bit)<br>Error amplification: 1.0×<br>Direct output influence"]

        HEAD["LM Head<br>FP16 (highest precision)<br>Maps to vocabulary logits<br>Precision critical for next-token"]

        EMB --> L1 --> L5 --> L17 --> L29 --> HEAD
    end

    subgraph "Memory Summary"
        MEM["Embedding: 512MB (INT8)<br>Layers 1-4: 800MB (INT8)<br>Layers 5-28: 2.4GB (INT4)<br>Layers 29-32: 800MB (INT8)<br>LM Head: 512MB (FP16)<br>─────────────<br>Total: ~5.0GB<br>(vs 16GB FP16, 8GB uniform INT8)"]
    end

    style EMB fill:#c8e6c9
    style L1 fill:#c8e6c9
    style L5 fill:#fff9c4
    style L17 fill:#fff9c4
    style L29 fill:#c8e6c9
    style HEAD fill:#bbdefb

AWQ Activation-Aware Scaling

graph TB
    subgraph "AWQ: Scale important channels before quantization"
        ACT["Calibration activations X<br>128 MangaAssist queries"]

        CH1["Channel 1: mean(|X₁|) = 0.02<br>Low activation → low importance<br>Scale: s₁ = 0.14^0.5 = 0.37"]
        CH2["Channel 2: mean(|X₂|) = 0.15<br>High activation → HIGH importance<br>Scale: s₂ = 1.0^0.5 = 1.0"]
        CH3["Channel 3: mean(|X₃|) = 0.08<br>Medium activation<br>Scale: s₃ = 0.53^0.5 = 0.73"]

        ACT --> CH1 & CH2 & CH3

        Q1["w₁ scaled by 0.37<br>→ fewer quant bins<br>→ more error (OK, low impact)"]
        Q2["w₂ scaled by 1.0<br>→ full quant bins<br>→ less error (good, high impact)"]
        Q3["w₃ scaled by 0.73<br>→ moderate bins<br>→ moderate error"]

        CH1 --> Q1
        CH2 --> Q2
        CH3 --> Q3
    end

    style Q2 fill:#c8e6c9
    style Q1 fill:#ffcdd2

Implementation Deep-Dive

QAT for DistilBERT Intent Classifier

import torch
import torch.nn as nn
import torch.quantization as quant
from transformers import DistilBertModel, DistilBertTokenizer
from torch.quantization import QuantStub, DeQuantStub, prepare_qat, convert


class QuantizedIntentClassifier(nn.Module):
    """
    DistilBERT intent classifier with quantization-aware training.
    Simulates INT8 during training so the model learns quantization-robust weights.
    """

    def __init__(self, num_intents: int = 10, model_name: str = "distilbert-base-uncased"):
        super().__init__()
        self.quant = QuantStub()     # Quantize input
        self.dequant = DeQuantStub() # Dequantize output

        self.bert = DistilBertModel.from_pretrained(model_name)
        self.classifier = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_intents),
        )

    def forward(self, input_ids, attention_mask):
        # Quantize activations entering the classifier head
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]  # [CLS] token
        x = self.quant(cls_output)
        logits = self.classifier(x)
        logits = self.dequant(logits)
        return logits


class QATTrainer:
    """Train with fake-quantization to prepare for INT8 deployment."""

    def __init__(self, num_intents: int = 10):
        self.model = QuantizedIntentClassifier(num_intents)
        self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

    def prepare_qat(self):
        """Insert fake-quantization modules for QAT."""
        # Define quantization config
        self.model.qconfig = quant.get_default_qat_qconfig("x86")

        # Fuse modules for efficiency (Conv+BN+ReLU, Linear+ReLU)
        # For our classifier head:
        torch.quantization.fuse_modules(
            self.model.classifier,
            [["0", "1"]],  # Linear + ReLU
            inplace=True,
        )

        # Insert fake-quant observers
        prepare_qat(self.model, inplace=True)
        print("QAT prepared. Fake-quantization modules inserted.")

    def train(self, train_loader, val_loader, epochs: int = 5, lr: float = 2e-5):
        """
        Train with quantization simulation.
        The model "sees" quantized weights during forward pass
        but updates FP32 master weights via STE during backward pass.
        """
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()

        self.model.train()
        for epoch in range(epochs):
            total_loss = 0
            correct = 0
            total = 0

            # Freeze batch norm stats after epoch 3 for stable quantization
            if epoch >= 3:
                self.model.apply(torch.quantization.disable_observer)

            for batch in train_loader:
                input_ids = batch["input_ids"]
                attention_mask = batch["attention_mask"]
                labels = batch["labels"]

                logits = self.model(input_ids, attention_mask)
                loss = criterion(logits, labels)

                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                total_loss += loss.item()
                correct += (logits.argmax(dim=-1) == labels).sum().item()
                total += labels.size(0)

            train_acc = correct / total
            val_acc = self._evaluate(val_loader)
            print(f"Epoch {epoch+1}: loss={total_loss/len(train_loader):.4f}, "
                  f"train_acc={train_acc:.4f}, val_acc={val_acc:.4f}")

    def convert_to_int8(self):
        """Convert QAT model to actual INT8 for deployment."""
        self.model.eval()
        quantized_model = convert(self.model, inplace=False)
        return quantized_model

    def _evaluate(self, loader):
        self.model.eval()
        correct = total = 0
        with torch.no_grad():
            for batch in loader:
                logits = self.model(batch["input_ids"], batch["attention_mask"])
                correct += (logits.argmax(dim=-1) == batch["labels"]).sum().item()
                total += batch["labels"].size(0)
        self.model.train()
        return correct / total

    def benchmark_size_and_speed(self, quantized_model, sample_input):
        """Compare FP32 vs INT8 model."""
        import time

        # Size
        torch.save(self.model.state_dict(), "/tmp/fp32_model.pt")
        torch.save(quantized_model.state_dict(), "/tmp/int8_model.pt")

        import os
        fp32_size = os.path.getsize("/tmp/fp32_model.pt") / 1024 / 1024
        int8_size = os.path.getsize("/tmp/int8_model.pt") / 1024 / 1024
        print(f"FP32: {fp32_size:.1f}MB, INT8: {int8_size:.1f}MB, "
              f"Compression: {fp32_size/int8_size:.1f}×")

        # Speed (CPU inference)
        input_ids, attention_mask = sample_input
        self.model.eval()
        quantized_model.eval()

        # Warmup
        for _ in range(10):
            self.model(input_ids, attention_mask)
            quantized_model(input_ids, attention_mask)

        # Benchmark
        start = time.time()
        for _ in range(100):
            self.model(input_ids, attention_mask)
        fp32_time = (time.time() - start) / 100

        start = time.time()
        for _ in range(100):
            quantized_model(input_ids, attention_mask)
        int8_time = (time.time() - start) / 100

        print(f"FP32: {fp32_time*1000:.1f}ms, INT8: {int8_time*1000:.1f}ms, "
              f"Speedup: {fp32_time/int8_time:.1f}×")

GPTQ for Llama 3 8B

from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from transformers import AutoTokenizer
import torch


class GPTQQuantizer:
    """
    Apply GPTQ quantization to Llama 3 8B.
    Uses calibration data from MangaAssist to optimize quantization.
    """

    def __init__(self, model_name: str = "meta-llama/Llama-3-8b-hf"):
        self.model_name = model_name
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token

    def quantize(
        self,
        calibration_texts: list[str],
        bits: int = 4,
        group_size: int = 128,
        desc_act: bool = True,
    ):
        """
        Quantize model using GPTQ with MangaAssist calibration data.

        Args:
            calibration_texts: 128 representative MangaAssist conversations
            bits: 4 for INT4, 8 for INT8
            group_size: Quantize in groups of 128 weights (more groups = more precision)
            desc_act: Sort columns by activation magnitude (descending) before quantizing
        """
        quantize_config = BaseQuantizeConfig(
            bits=bits,
            group_size=group_size,
            desc_act=desc_act,
            sym=True,  # Symmetric quantization
            model_file_base_name="mangaassist-llama3-8b",
        )

        # Load model in FP16
        model = AutoGPTQForCausalLM.from_pretrained(
            self.model_name,
            quantize_config=quantize_config,
            torch_dtype=torch.float16,
            device_map="auto",
        )

        # Prepare calibration data
        calibration_dataset = [
            self.tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
            for text in calibration_texts
        ]

        # Run GPTQ quantization
        model.quantize(
            calibration_dataset,
            batch_size=1,
        )

        return model

    def save_quantized(self, model, output_dir: str):
        model.save_quantized(output_dir)
        self.tokenizer.save_pretrained(output_dir)
        print(f"Saved GPTQ model to {output_dir}")

    def evaluate_quality(self, model, test_prompts: list[str]) -> dict:
        """Compare quantized model quality against FP16 baseline."""
        results = {"perplexity_samples": [], "generation_samples": []}

        for prompt in test_prompts[:20]:
            inputs = self.tokenizer(prompt, return_tensors="pt").to(model.device)

            with torch.no_grad():
                outputs = model(**inputs, labels=inputs["input_ids"])
                perplexity = torch.exp(outputs.loss).item()
                results["perplexity_samples"].append(perplexity)

            generated = model.generate(**inputs, max_new_tokens=100)
            text = self.tokenizer.decode(generated[0], skip_special_tokens=True)
            results["generation_samples"].append({
                "prompt": prompt,
                "response": text[len(prompt):],
            })

        results["mean_perplexity"] = sum(results["perplexity_samples"]) / len(results["perplexity_samples"])
        return results

AWQ Quantization

from awq import AutoAWQForCausalLM


class AWQQuantizer:
    """
    Activation-Aware Weight Quantization.
    Preserves channels with large activations at higher precision.
    """

    def __init__(self, model_name: str = "meta-llama/Llama-3-8b-hf"):
        self.model_name = model_name
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token

    def quantize(self, calibration_texts: list[str]):
        """
        Apply AWQ quantization.
        Automatically determines per-channel scaling factors
        from calibration data.
        """
        model = AutoAWQForCausalLM.from_pretrained(
            self.model_name,
            torch_dtype=torch.float16,
            device_map="auto",
        )

        quant_config = {
            "zero_point": True,
            "q_group_size": 128,
            "w_bit": 4,
        }

        # Tokenize calibration data
        calib_data = [
            self.tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
            for text in calibration_texts[:128]
        ]

        model.quantize(
            tokenizer=self.tokenizer,
            quant_config=quant_config,
            calib_data=calib_data,
        )

        return model


class QuantizationBenchmark:
    """Compare GPTQ vs AWQ vs QAT on MangaAssist tasks."""

    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3-8b-hf")

    def run_benchmark(self, models: dict, test_data: list[dict]) -> dict:
        """
        Benchmark multiple quantization methods head-to-head.
        """
        results = {}
        for name, model in models.items():
            perplexities = []
            latencies = []
            import time

            for item in test_data:
                inputs = self.tokenizer(
                    item["prompt"], return_tensors="pt", max_length=256, truncation=True,
                ).to(model.device)

                # Perplexity
                with torch.no_grad():
                    outputs = model(**inputs, labels=inputs["input_ids"])
                    perplexities.append(torch.exp(outputs.loss).item())

                # Latency
                start = time.time()
                with torch.no_grad():
                    model.generate(**inputs, max_new_tokens=50)
                latencies.append(time.time() - start)

            results[name] = {
                "mean_perplexity": sum(perplexities) / len(perplexities),
                "mean_latency_ms": sum(latencies) / len(latencies) * 1000,
                "model_size_gb": sum(p.numel() * p.element_size() for p in model.parameters()) / 1e9,
            }

        return results

Group Discussion: Key Decision Points

Decision Point 1: GPTQ vs AWQ vs QAT

Priya (ML Engineer): Benchmark on Llama 3 8B for MangaAssist:

Method Perplexity (↓) Accuracy (intent routing) Size Quantize Time Inference (ms)
FP16 (baseline) 5.82 93.1% 16GB N/A 480
QAT INT8 5.87 92.8% 8GB 6h training 320
GPTQ INT4 5.98 92.1% 4.2GB 15 min calibration 280
AWQ INT4 5.93 92.4% 4.1GB 12 min calibration 260
NF4 (QLoRA, doc 04) 6.02 91.8% 4.0GB N/A 290

Aiko (Data Scientist): AWQ is the winner: best perplexity among INT4 methods, fastest inference, and only 12 minutes to quantize. The activation-aware scaling pays off on MangaAssist data because our activation distributions are highly non-uniform (manga-specific tokens have consistently large activations).

Marcus (Architect): AWQ INT4 gives us 7.3% of the FP16 memory (4.1GB vs 16GB) with only 0.7% accuracy loss (92.4% vs 93.1%). That is well within our CPQ threshold.

Jordan (MLOps): AWQ also has better production support: vLLM natively supports AWQ kernels for fast inference. GPTQ requires the AutoGPTQ library which has more compatibility issues.

Resolution: AWQ INT4 for Llama 3 8B (primary deployment). GPTQ INT4 as fallback if AWQ has compatibility issues with future model versions. QAT INT8 for DistilBERT models where 2× compression is sufficient and we want maximum accuracy.

Decision Point 2: Calibration Data Composition

Aiko (Data Scientist): Calibration data matters more than people think:

Calibration Source Perplexity Intent Accuracy Notes
Random Wikipedia 6.31 90.2% Domain mismatch
C4 (generic web) 6.15 91.0% Better but still generic
MangaAssist conversations 5.93 92.4% Domain-matched ✅
MangaAssist + edge cases 5.91 92.6% Marginally better

Priya (ML Engineer): Using domain-specific calibration data (actual MangaAssist conversations) reduces perplexity by 0.22 compared to generic web text. The calibration data shapes what the Hessian "sees" as important, so it directly affects which weights get priority during quantization.

Resolution: Use 128 representative MangaAssist conversations spanning all 10 intent categories as calibration data. Refresh calibration data quarterly to capture evolving conversation patterns.

Decision Point 3: Mixed-Precision Strategy

Marcus (Architect): Should all layers be INT4, or should we use mixed precision?

Priya (ML Engineer): Error propagation analysis shows early and late layers are more sensitive:

Strategy Perplexity Size Inference
Uniform INT4 5.93 4.1GB 260ms
Mixed (early/late INT8, middle INT4) 5.88 5.0GB 275ms
Mixed with FP16 LM head 5.85 5.5GB 280ms
Uniform INT8 5.87 8.0GB 320ms

Jordan (MLOps): The mixed-precision approach (5.0GB) is interesting but adds deployment complexity. We need custom model definitions instead of using off-the-shelf AWQ configs.

Sam (PM): How does 5.0GB vs 4.1GB affect our deployment? Both fit on a single A10G (24GB). The question is headroom for KV cache and other models.

Marcus (Architect): With AWQ INT4 at 4.1GB, we have 19.9GB for KV cache and the smaller models. With mixed-precision at 5.0GB, we have 19GB. Both are comfortable.

Resolution: Start with uniform AWQ INT4 (simpler, 4.1GB). The 0.05 perplexity gap from mixed-precision doesn't justify the deployment complexity. If we observe quality issues in production, implement mixed-precision as a targeted fix.


Research Paper References

1. GPTQ: Accurate Post-Training Quantization for Generative Pre-Trained Transformers (Frantar & Alistarh, 2022)

Key contribution: Applied optimal brain surgeon (OBS) techniques to LLM quantization. The column-by-column quantization with Hessian-based error compensation achieves INT4 quantization of 175B parameter models with <1% perplexity degradation. The Cholesky-based batched processing makes it run in under 4 hours on a single GPU.

Relevance to MangaAssist: GPTQ is our fallback quantization method. Its theoretical guarantee (minimal output error given the Hessian) provides a reliable baseline against which we benchmark AWQ.

2. AWQ: Activation-Aware Weight Quantization for LLM Compression and Acceleration (Lin et al., 2023)

Key contribution: Identified that 0.1-1% of weight channels are critical (connected to large activations) and that protecting these via per-channel scaling eliminates most quantization error. AWQ achieves better quality than GPTQ without using the more complex Hessian-based approach. The method is also hardware-friendly — the scaling can be fused into the quantized computation.

Relevance to MangaAssist: AWQ is our primary quantization method for Llama 3 8B. The activation-aware scaling aligns well with our domain-specific calibration data, preserving the channels most active during manga-related conversations.

3. SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models (Xiao et al., 2022)

Key contribution: Addressed the activation outlier problem by migrating quantization difficulty from activations to weights through per-channel smoothing. This enables W8A8 (8-bit weights AND activations) quantization, which further accelerates inference on hardware with INT8 matrix multiply support.

Relevance to MangaAssist: SmoothQuant is applicable to our INT8 quantized DistilBERT models. The smoothing technique improves INT8 accuracy by 0.3% on our intent classifier — modest but free (zero inference overhead since the smoothing is folded into weights).


Production Results

Quantization Impact Summary

Model Method Original Size Quantized Size Accuracy Delta Latency Change
Llama 3 8B AWQ INT4 16GB 4.1GB -0.7% -42% faster
DistilBERT (intent) QAT INT8 265MB 66MB -0.3% -35% faster
MiniLM-L6 (reranker) QAT INT8 90MB 22MB -0.2% -30% faster
DistilBERT (sentiment) QAT INT8 265MB 66MB -0.4% -33% faster

Total Pipeline Memory

Configuration Total VRAM Instance Monthly Cost
All FP16 16.6GB g5.2xlarge $680
All INT8 8.2GB g5.xlarge $510
Mixed (LLM INT4, rest INT8) 4.3GB g5.xlarge $510
Mixed + KV cache headroom 4.3GB + 8GB cache g5.xlarge $510

Savings: Mixed quantization saves $170/month ($2,040/year) by fitting on a smaller instance while losing <1% quality. The combined pipeline latency drops from 680ms to 520ms — a 24% improvement that directly impacts user experience.