12. Quantization-Aware Training — INT8/INT4 Without Quality Loss
Problem Statement and MangaAssist Context
MangaAssist runs five models in its inference pipeline. At FP32 precision, the combined memory footprint would exceed available GPU/CPU budgets. Post-training quantization (PTQ) naively rounds weights after training, which degrades quality — especially for smaller models where each weight carries more information. Quantization-aware training (QAT) simulates quantization during training so the model learns weight values that are robust to rounding. Modern techniques like GPTQ and AWQ go further: they use second-order information (Hessian) and activation-aware scaling to minimize quantization error without full retraining.
Memory Budget Pressure
| Model | FP32 Size | FP16 Size | INT8 Size | INT4 Size |
|---|---|---|---|---|
| Llama 3 70B | 280GB | 140GB | 70GB | 35GB |
| Llama 3 8B (distilled) | 32GB | 16GB | 8GB | 4GB |
| DistilBERT (intent) | 265MB | 132MB | 66MB | 33MB |
| MiniLM-L6 (reranker) | 90MB | 45MB | 22MB | 11MB |
| DistilBERT (sentiment) | 265MB | 132MB | 66MB | 33MB |
Goal: Run the full pipeline on a single g5.xlarge (1× A10G, 24GB VRAM) instead of g5.2xlarge (1× A10G, 24GB VRAM but more system RAM). The LLM fallback (Llama 3 8B) must fit in 4-8GB to leave headroom for the smaller models and KV cache.
Mathematical Foundations
Uniform Quantization
Map a floating-point weight $w \in [w_{\min}, w_{\max}]$ to an integer $q \in [0, 2^b - 1]$ for $b$-bit quantization:
Symmetric quantization (zero maps to zero):
$$q = \text{round}\left(\frac{w}{s}\right), \quad s = \frac{\max(|w_{\min}|, |w_{\max}|)}{2^{b-1} - 1}$$
$$\hat{w} = q \cdot s$$
Asymmetric quantization (allows asymmetric weight distributions):
$$q = \text{round}\left(\frac{w - z}{s}\right), \quad s = \frac{w_{\max} - w_{\min}}{2^b - 1}, \quad z = w_{\min}$$
$$\hat{w} = q \cdot s + z$$
Quantization error: For a single weight:
$$\epsilon = w - \hat{w} = w - \text{round}\left(\frac{w}{s}\right) \cdot s$$
The maximum error per weight is $\epsilon_{\max} = \frac{s}{2}$. For INT8 symmetric quantization of weights in $[-1, 1]$:
$$s = \frac{1}{127} \approx 0.0079, \quad \epsilon_{\max} \approx 0.004$$
For INT4: $s = \frac{1}{7} \approx 0.143, \quad \epsilon_{\max} \approx 0.071$ — 18× larger error per weight. This is why INT4 requires more sophisticated techniques.
Straight-Through Estimator (STE)
The rounding function $\text{round}(\cdot)$ has zero gradient almost everywhere. QAT uses the straight-through estimator to bypass this:
Forward pass: Apply quantization (round)
$$\hat{w} = \text{round}\left(\frac{w}{s}\right) \cdot s$$
Backward pass: Pretend quantization didn't happen (pass gradient through)
$$\frac{\partial \mathcal{L}}{\partial w} \approx \frac{\partial \mathcal{L}}{\partial \hat{w}}$$
More precisely, the STE gradient is:
$$\frac{\partial \hat{w}}{\partial w} = \begin{cases} 1 & \text{if } w_{\min} \leq w \leq w_{\max} \ 0 & \text{otherwise (clamped)} \end{cases}$$
Why this works: The STE is biased but low-variance. During training, the model adjusts its weight values to positions where rounding introduces minimal error. Weights migrate toward "quantization-friendly" values that are near integer multiples of the scale factor $s$.
GPTQ — Post-Training Quantization with Hessian Information (Frantar & Alistarh, 2022)
GPTQ quantizes weights column-by-column using second-order error correction. For a weight matrix $W$ and a calibration set producing activation matrix $X$:
The quantization objective per row:
$$\min_{\hat{W}} | WX - \hat{W}X |_2^2$$
This is an optimal brain surgeon (OBS) problem. The Hessian of the quantization error for the $i$-th column:
$$H = 2X X^T$$
GPTQ quantizes columns sequentially. After quantizing column $j$:
- Compute the quantization error $\delta_j = w_j - \hat{w}_j$
- Compensate remaining columns to minimize total error:
$$w_k \leftarrow w_k - \frac{\delta_j \cdot H_{jk}}{H_{jj}}, \quad \forall k > j$$
This error compensation is the key insight: the remaining unquantized columns absorb the error from quantizing column $j$. The Hessian tells us how to redistribute the error optimally.
Computational trick: GPTQ avoids explicit Hessian computation by using the Cholesky decomposition of $H^{-1}$, processing weights in blocks of 128 columns for numerical stability.
AWQ — Activation-Aware Weight Quantization (Lin et al., 2023)
AWQ observes that not all weights are equally important. Weights that interact with large activations have more impact on the output. Instead of quantizing all weights uniformly, AWQ scales weights before quantization:
$$\hat{w}_i = \text{quant}(w_i \cdot s_i) / s_i$$
where $s_i$ is a per-channel scaling factor determined by the activation magnitudes:
$$s_i = \left(\frac{\text{mean}(|X_i|)}{\max_j \text{mean}(|X_j|)}\right)^{\alpha}$$
with $\alpha \in [0, 1]$ searched per layer (typically $\alpha \approx 0.5$).
Intuition: Weights connected to channels with large activations get scaled up before quantization (more quantization bins allocated to them) and scaled back down after. This is equivalent to allocating more "precision bits" to the most important weight channels.
SmoothQuant (Xiao et al., 2022)
SmoothQuant addresses the activation outlier problem. Some activation channels have magnitudes 100× larger than others, making activation quantization difficult. SmoothQuant migrates the quantization difficulty from activations to weights:
$$Y = (X \text{diag}(s)^{-1}) \cdot (\text{diag}(s) W) = \hat{X} \hat{W}$$
The smoothing factor $s$ is per-channel:
$$s_j = \frac{\max(|X_j|)^\alpha}{\max(|W_j|)^{1-\alpha}}$$
This balances the dynamic range between activations and weights, making both easier to quantize.
Quantization Error Propagation Through Layers
For a model with $L$ layers, each with quantization error $\epsilon_l$, the total output error is approximately:
$$\epsilon_{\text{total}} \approx \sum_{l=1}^{L} \left(\prod_{k=l+1}^{L} |W_k|\right) \cdot \epsilon_l$$
Implication: Errors in early layers get amplified by all subsequent layers. This means: 1. Early layers should be quantized more conservatively (higher bits) 2. Late layers (closest to output) contribute less error propagation 3. Mixed-precision approaches assign more bits to early layers
For DistilBERT with 6 layers and weight norms $|W_k| \approx 1.2$: - Layer 1 error amplification: $1.2^5 \approx 2.49×$ - Layer 6 error amplification: $1.2^0 = 1.0×$
For Llama 3 8B with 32 layers and $|W_k| \approx 1.05$: - Layer 1 error amplification: $1.05^{31} \approx 4.54×$ - Layer 32 error amplification: $1.0×$
Model Internals — Layer-by-Layer Diagrams
Quantization Grid Visualization
graph TB
subgraph "FP32 → INT8 Quantization (one weight row)"
FP32["FP32 weights:<br>[-0.82, 0.15, -0.03, 0.67, 1.24, -0.56, 0.91, -1.10]<br>Continuous values, infinite precision"]
SCALE["Compute scale: s = max(|w|) / 127<br>s = 1.24 / 127 = 0.00976"]
INT8["INT8 quantized:<br>[-84, 15, -3, 69, 127, -57, 93, -113]<br>Only 256 possible values"]
DEQ["Dequantized (FP32):<br>[-0.820, 0.146, -0.029, 0.673, 1.240, -0.557, 0.908, -1.103]<br>Errors: [0, 0.004, 0.001, 0.003, 0, 0.003, 0.002, 0.003]"]
FP32 --> SCALE --> INT8 --> DEQ
end
subgraph "FP32 → INT4 Quantization (same weights)"
FP32_2["Same FP32 weights:<br>[-0.82, 0.15, -0.03, 0.67, 1.24, -0.56, 0.91, -1.10]"]
SCALE_2["Compute scale: s = max(|w|) / 7<br>s = 1.24 / 7 = 0.177"]
INT4["INT4 quantized:<br>[-5, 1, 0, 4, 7, -3, 5, -6]<br>Only 16 possible values!"]
DEQN4["Dequantized (FP32):<br>[-0.886, 0.177, 0.000, 0.709, 1.240, -0.531, 0.886, -1.063]<br>Errors: [0.066, 0.027, 0.030, 0.039, 0, 0.029, 0.024, 0.037]"]
FP32_2 --> SCALE_2 --> INT4 --> DEQN4
end
style INT8 fill:#c8e6c9
style INT4 fill:#fff9c4
style DEQN4 fill:#ffcdd2
QAT Forward-Backward Pass
sequenceDiagram
participant W as FP32 Weights (master)
participant FQ as Fake Quantize
participant FWD as Forward Pass
participant LOSS as Loss
participant BWD as Backward (STE)
Note over W,BWD: QAT Training Loop
W->>FQ: w_float
FQ->>FQ: q = round(w/s) · s<br>(simulate INT8 rounding)
FQ->>FWD: w_quantized (but still FP32 tensor)
FWD->>FWD: Normal forward with quantized weights<br>Model "sees" what INT8 inference will look like
FWD->>LOSS: logits → cross-entropy
LOSS->>BWD: ∂L/∂logits
BWD->>BWD: Backprop through all layers
BWD->>FQ: ∂L/∂w_quantized
FQ->>W: ∂L/∂w ≈ ∂L/∂w_quantized (STE)<br>Pass gradient straight through!
W->>W: w ← w - η · ∂L/∂w<br>Update FP32 master weights
Note over W,BWD: Key: FP32 weights adjust to<br>positions where rounding error is minimal
GPTQ Column-by-Column Quantization
graph LR
subgraph "GPTQ: Quantize col j, compensate cols j+1...n"
C1["Col 1<br>✅ Quantized<br>Error: δ₁"]
C2["Col 2<br>✅ Quantized<br>Error: δ₂<br>+ correction<br>from δ₁"]
C3["Col 3<br>🔄 Quantizing now<br>Error: δ₃"]
C4["Col 4<br>⬜ Float<br>Absorbs δ₃ error<br>w₄ -= δ₃·H₃₄/H₃₃"]
C5["Col 5<br>⬜ Float<br>Absorbs δ₃ error<br>w₅ -= δ₃·H₃₅/H₃₃"]
CN["Col n<br>⬜ Float<br>Last col absorbs<br>all residual error"]
C1 --> C2 --> C3 --> C4 --> C5 --> CN
end
subgraph "Hessian H = 2XX^T"
H["H tells us how<br>columns interact.<br>H_jk large → col k can<br>compensate col j's error.<br>H_jj large → col j is<br>sensitive to quantization.<br><br>Calibration: 128 samples<br>from MangaAssist queries"]
end
style C1 fill:#c8e6c9
style C2 fill:#c8e6c9
style C3 fill:#fff9c4
style C4 fill:#e3f2fd
style C5 fill:#e3f2fd
Mixed-Precision Layer Assignment
graph TB
subgraph "Llama 3 8B: Per-Layer Bit Assignment"
EMB["Embedding layer<br>INT8 (higher precision)<br>Maps tokens → continuous space<br>Error here amplifies 32×"]
L1["Layers 1-4 (early)<br>INT8 (8-bit)<br>Error amplification: 4.5×<br>Capture low-level patterns"]
L5["Layers 5-16 (middle)<br>INT4 (4-bit, GPTQ)<br>Error amplification: 2.1×<br>Redundant representations"]
L17["Layers 17-28 (late-middle)<br>INT4 (4-bit, GPTQ)<br>Error amplification: 1.3×<br>High-level features"]
L29["Layers 29-32 (final)<br>INT8 (8-bit)<br>Error amplification: 1.0×<br>Direct output influence"]
HEAD["LM Head<br>FP16 (highest precision)<br>Maps to vocabulary logits<br>Precision critical for next-token"]
EMB --> L1 --> L5 --> L17 --> L29 --> HEAD
end
subgraph "Memory Summary"
MEM["Embedding: 512MB (INT8)<br>Layers 1-4: 800MB (INT8)<br>Layers 5-28: 2.4GB (INT4)<br>Layers 29-32: 800MB (INT8)<br>LM Head: 512MB (FP16)<br>─────────────<br>Total: ~5.0GB<br>(vs 16GB FP16, 8GB uniform INT8)"]
end
style EMB fill:#c8e6c9
style L1 fill:#c8e6c9
style L5 fill:#fff9c4
style L17 fill:#fff9c4
style L29 fill:#c8e6c9
style HEAD fill:#bbdefb
AWQ Activation-Aware Scaling
graph TB
subgraph "AWQ: Scale important channels before quantization"
ACT["Calibration activations X<br>128 MangaAssist queries"]
CH1["Channel 1: mean(|X₁|) = 0.02<br>Low activation → low importance<br>Scale: s₁ = 0.14^0.5 = 0.37"]
CH2["Channel 2: mean(|X₂|) = 0.15<br>High activation → HIGH importance<br>Scale: s₂ = 1.0^0.5 = 1.0"]
CH3["Channel 3: mean(|X₃|) = 0.08<br>Medium activation<br>Scale: s₃ = 0.53^0.5 = 0.73"]
ACT --> CH1 & CH2 & CH3
Q1["w₁ scaled by 0.37<br>→ fewer quant bins<br>→ more error (OK, low impact)"]
Q2["w₂ scaled by 1.0<br>→ full quant bins<br>→ less error (good, high impact)"]
Q3["w₃ scaled by 0.73<br>→ moderate bins<br>→ moderate error"]
CH1 --> Q1
CH2 --> Q2
CH3 --> Q3
end
style Q2 fill:#c8e6c9
style Q1 fill:#ffcdd2
Implementation Deep-Dive
QAT for DistilBERT Intent Classifier
import torch
import torch.nn as nn
import torch.quantization as quant
from transformers import DistilBertModel, DistilBertTokenizer
from torch.quantization import QuantStub, DeQuantStub, prepare_qat, convert
class QuantizedIntentClassifier(nn.Module):
"""
DistilBERT intent classifier with quantization-aware training.
Simulates INT8 during training so the model learns quantization-robust weights.
"""
def __init__(self, num_intents: int = 10, model_name: str = "distilbert-base-uncased"):
super().__init__()
self.quant = QuantStub() # Quantize input
self.dequant = DeQuantStub() # Dequantize output
self.bert = DistilBertModel.from_pretrained(model_name)
self.classifier = nn.Sequential(
nn.Linear(768, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, num_intents),
)
def forward(self, input_ids, attention_mask):
# Quantize activations entering the classifier head
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
cls_output = outputs.last_hidden_state[:, 0, :] # [CLS] token
x = self.quant(cls_output)
logits = self.classifier(x)
logits = self.dequant(logits)
return logits
class QATTrainer:
"""Train with fake-quantization to prepare for INT8 deployment."""
def __init__(self, num_intents: int = 10):
self.model = QuantizedIntentClassifier(num_intents)
self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
def prepare_qat(self):
"""Insert fake-quantization modules for QAT."""
# Define quantization config
self.model.qconfig = quant.get_default_qat_qconfig("x86")
# Fuse modules for efficiency (Conv+BN+ReLU, Linear+ReLU)
# For our classifier head:
torch.quantization.fuse_modules(
self.model.classifier,
[["0", "1"]], # Linear + ReLU
inplace=True,
)
# Insert fake-quant observers
prepare_qat(self.model, inplace=True)
print("QAT prepared. Fake-quantization modules inserted.")
def train(self, train_loader, val_loader, epochs: int = 5, lr: float = 2e-5):
"""
Train with quantization simulation.
The model "sees" quantized weights during forward pass
but updates FP32 master weights via STE during backward pass.
"""
optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
self.model.train()
for epoch in range(epochs):
total_loss = 0
correct = 0
total = 0
# Freeze batch norm stats after epoch 3 for stable quantization
if epoch >= 3:
self.model.apply(torch.quantization.disable_observer)
for batch in train_loader:
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
labels = batch["labels"]
logits = self.model(input_ids, attention_mask)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item()
correct += (logits.argmax(dim=-1) == labels).sum().item()
total += labels.size(0)
train_acc = correct / total
val_acc = self._evaluate(val_loader)
print(f"Epoch {epoch+1}: loss={total_loss/len(train_loader):.4f}, "
f"train_acc={train_acc:.4f}, val_acc={val_acc:.4f}")
def convert_to_int8(self):
"""Convert QAT model to actual INT8 for deployment."""
self.model.eval()
quantized_model = convert(self.model, inplace=False)
return quantized_model
def _evaluate(self, loader):
self.model.eval()
correct = total = 0
with torch.no_grad():
for batch in loader:
logits = self.model(batch["input_ids"], batch["attention_mask"])
correct += (logits.argmax(dim=-1) == batch["labels"]).sum().item()
total += batch["labels"].size(0)
self.model.train()
return correct / total
def benchmark_size_and_speed(self, quantized_model, sample_input):
"""Compare FP32 vs INT8 model."""
import time
# Size
torch.save(self.model.state_dict(), "/tmp/fp32_model.pt")
torch.save(quantized_model.state_dict(), "/tmp/int8_model.pt")
import os
fp32_size = os.path.getsize("/tmp/fp32_model.pt") / 1024 / 1024
int8_size = os.path.getsize("/tmp/int8_model.pt") / 1024 / 1024
print(f"FP32: {fp32_size:.1f}MB, INT8: {int8_size:.1f}MB, "
f"Compression: {fp32_size/int8_size:.1f}×")
# Speed (CPU inference)
input_ids, attention_mask = sample_input
self.model.eval()
quantized_model.eval()
# Warmup
for _ in range(10):
self.model(input_ids, attention_mask)
quantized_model(input_ids, attention_mask)
# Benchmark
start = time.time()
for _ in range(100):
self.model(input_ids, attention_mask)
fp32_time = (time.time() - start) / 100
start = time.time()
for _ in range(100):
quantized_model(input_ids, attention_mask)
int8_time = (time.time() - start) / 100
print(f"FP32: {fp32_time*1000:.1f}ms, INT8: {int8_time*1000:.1f}ms, "
f"Speedup: {fp32_time/int8_time:.1f}×")
GPTQ for Llama 3 8B
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from transformers import AutoTokenizer
import torch
class GPTQQuantizer:
"""
Apply GPTQ quantization to Llama 3 8B.
Uses calibration data from MangaAssist to optimize quantization.
"""
def __init__(self, model_name: str = "meta-llama/Llama-3-8b-hf"):
self.model_name = model_name
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer.pad_token = self.tokenizer.eos_token
def quantize(
self,
calibration_texts: list[str],
bits: int = 4,
group_size: int = 128,
desc_act: bool = True,
):
"""
Quantize model using GPTQ with MangaAssist calibration data.
Args:
calibration_texts: 128 representative MangaAssist conversations
bits: 4 for INT4, 8 for INT8
group_size: Quantize in groups of 128 weights (more groups = more precision)
desc_act: Sort columns by activation magnitude (descending) before quantizing
"""
quantize_config = BaseQuantizeConfig(
bits=bits,
group_size=group_size,
desc_act=desc_act,
sym=True, # Symmetric quantization
model_file_base_name="mangaassist-llama3-8b",
)
# Load model in FP16
model = AutoGPTQForCausalLM.from_pretrained(
self.model_name,
quantize_config=quantize_config,
torch_dtype=torch.float16,
device_map="auto",
)
# Prepare calibration data
calibration_dataset = [
self.tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
for text in calibration_texts
]
# Run GPTQ quantization
model.quantize(
calibration_dataset,
batch_size=1,
)
return model
def save_quantized(self, model, output_dir: str):
model.save_quantized(output_dir)
self.tokenizer.save_pretrained(output_dir)
print(f"Saved GPTQ model to {output_dir}")
def evaluate_quality(self, model, test_prompts: list[str]) -> dict:
"""Compare quantized model quality against FP16 baseline."""
results = {"perplexity_samples": [], "generation_samples": []}
for prompt in test_prompts[:20]:
inputs = self.tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model(**inputs, labels=inputs["input_ids"])
perplexity = torch.exp(outputs.loss).item()
results["perplexity_samples"].append(perplexity)
generated = model.generate(**inputs, max_new_tokens=100)
text = self.tokenizer.decode(generated[0], skip_special_tokens=True)
results["generation_samples"].append({
"prompt": prompt,
"response": text[len(prompt):],
})
results["mean_perplexity"] = sum(results["perplexity_samples"]) / len(results["perplexity_samples"])
return results
AWQ Quantization
from awq import AutoAWQForCausalLM
class AWQQuantizer:
"""
Activation-Aware Weight Quantization.
Preserves channels with large activations at higher precision.
"""
def __init__(self, model_name: str = "meta-llama/Llama-3-8b-hf"):
self.model_name = model_name
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer.pad_token = self.tokenizer.eos_token
def quantize(self, calibration_texts: list[str]):
"""
Apply AWQ quantization.
Automatically determines per-channel scaling factors
from calibration data.
"""
model = AutoAWQForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float16,
device_map="auto",
)
quant_config = {
"zero_point": True,
"q_group_size": 128,
"w_bit": 4,
}
# Tokenize calibration data
calib_data = [
self.tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
for text in calibration_texts[:128]
]
model.quantize(
tokenizer=self.tokenizer,
quant_config=quant_config,
calib_data=calib_data,
)
return model
class QuantizationBenchmark:
"""Compare GPTQ vs AWQ vs QAT on MangaAssist tasks."""
def __init__(self):
self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3-8b-hf")
def run_benchmark(self, models: dict, test_data: list[dict]) -> dict:
"""
Benchmark multiple quantization methods head-to-head.
"""
results = {}
for name, model in models.items():
perplexities = []
latencies = []
import time
for item in test_data:
inputs = self.tokenizer(
item["prompt"], return_tensors="pt", max_length=256, truncation=True,
).to(model.device)
# Perplexity
with torch.no_grad():
outputs = model(**inputs, labels=inputs["input_ids"])
perplexities.append(torch.exp(outputs.loss).item())
# Latency
start = time.time()
with torch.no_grad():
model.generate(**inputs, max_new_tokens=50)
latencies.append(time.time() - start)
results[name] = {
"mean_perplexity": sum(perplexities) / len(perplexities),
"mean_latency_ms": sum(latencies) / len(latencies) * 1000,
"model_size_gb": sum(p.numel() * p.element_size() for p in model.parameters()) / 1e9,
}
return results
Group Discussion: Key Decision Points
Decision Point 1: GPTQ vs AWQ vs QAT
Priya (ML Engineer): Benchmark on Llama 3 8B for MangaAssist:
| Method | Perplexity (↓) | Accuracy (intent routing) | Size | Quantize Time | Inference (ms) |
|---|---|---|---|---|---|
| FP16 (baseline) | 5.82 | 93.1% | 16GB | N/A | 480 |
| QAT INT8 | 5.87 | 92.8% | 8GB | 6h training | 320 |
| GPTQ INT4 | 5.98 | 92.1% | 4.2GB | 15 min calibration | 280 |
| AWQ INT4 | 5.93 | 92.4% | 4.1GB | 12 min calibration | 260 |
| NF4 (QLoRA, doc 04) | 6.02 | 91.8% | 4.0GB | N/A | 290 |
Aiko (Data Scientist): AWQ is the winner: best perplexity among INT4 methods, fastest inference, and only 12 minutes to quantize. The activation-aware scaling pays off on MangaAssist data because our activation distributions are highly non-uniform (manga-specific tokens have consistently large activations).
Marcus (Architect): AWQ INT4 gives us 7.3% of the FP16 memory (4.1GB vs 16GB) with only 0.7% accuracy loss (92.4% vs 93.1%). That is well within our CPQ threshold.
Jordan (MLOps): AWQ also has better production support: vLLM natively supports AWQ kernels for fast inference. GPTQ requires the AutoGPTQ library which has more compatibility issues.
Resolution: AWQ INT4 for Llama 3 8B (primary deployment). GPTQ INT4 as fallback if AWQ has compatibility issues with future model versions. QAT INT8 for DistilBERT models where 2× compression is sufficient and we want maximum accuracy.
Decision Point 2: Calibration Data Composition
Aiko (Data Scientist): Calibration data matters more than people think:
| Calibration Source | Perplexity | Intent Accuracy | Notes |
|---|---|---|---|
| Random Wikipedia | 6.31 | 90.2% | Domain mismatch |
| C4 (generic web) | 6.15 | 91.0% | Better but still generic |
| MangaAssist conversations | 5.93 | 92.4% | Domain-matched ✅ |
| MangaAssist + edge cases | 5.91 | 92.6% | Marginally better |
Priya (ML Engineer): Using domain-specific calibration data (actual MangaAssist conversations) reduces perplexity by 0.22 compared to generic web text. The calibration data shapes what the Hessian "sees" as important, so it directly affects which weights get priority during quantization.
Resolution: Use 128 representative MangaAssist conversations spanning all 10 intent categories as calibration data. Refresh calibration data quarterly to capture evolving conversation patterns.
Decision Point 3: Mixed-Precision Strategy
Marcus (Architect): Should all layers be INT4, or should we use mixed precision?
Priya (ML Engineer): Error propagation analysis shows early and late layers are more sensitive:
| Strategy | Perplexity | Size | Inference |
|---|---|---|---|
| Uniform INT4 | 5.93 | 4.1GB | 260ms |
| Mixed (early/late INT8, middle INT4) | 5.88 | 5.0GB | 275ms |
| Mixed with FP16 LM head | 5.85 | 5.5GB | 280ms |
| Uniform INT8 | 5.87 | 8.0GB | 320ms |
Jordan (MLOps): The mixed-precision approach (5.0GB) is interesting but adds deployment complexity. We need custom model definitions instead of using off-the-shelf AWQ configs.
Sam (PM): How does 5.0GB vs 4.1GB affect our deployment? Both fit on a single A10G (24GB). The question is headroom for KV cache and other models.
Marcus (Architect): With AWQ INT4 at 4.1GB, we have 19.9GB for KV cache and the smaller models. With mixed-precision at 5.0GB, we have 19GB. Both are comfortable.
Resolution: Start with uniform AWQ INT4 (simpler, 4.1GB). The 0.05 perplexity gap from mixed-precision doesn't justify the deployment complexity. If we observe quality issues in production, implement mixed-precision as a targeted fix.
Research Paper References
1. GPTQ: Accurate Post-Training Quantization for Generative Pre-Trained Transformers (Frantar & Alistarh, 2022)
Key contribution: Applied optimal brain surgeon (OBS) techniques to LLM quantization. The column-by-column quantization with Hessian-based error compensation achieves INT4 quantization of 175B parameter models with <1% perplexity degradation. The Cholesky-based batched processing makes it run in under 4 hours on a single GPU.
Relevance to MangaAssist: GPTQ is our fallback quantization method. Its theoretical guarantee (minimal output error given the Hessian) provides a reliable baseline against which we benchmark AWQ.
2. AWQ: Activation-Aware Weight Quantization for LLM Compression and Acceleration (Lin et al., 2023)
Key contribution: Identified that 0.1-1% of weight channels are critical (connected to large activations) and that protecting these via per-channel scaling eliminates most quantization error. AWQ achieves better quality than GPTQ without using the more complex Hessian-based approach. The method is also hardware-friendly — the scaling can be fused into the quantized computation.
Relevance to MangaAssist: AWQ is our primary quantization method for Llama 3 8B. The activation-aware scaling aligns well with our domain-specific calibration data, preserving the channels most active during manga-related conversations.
3. SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models (Xiao et al., 2022)
Key contribution: Addressed the activation outlier problem by migrating quantization difficulty from activations to weights through per-channel smoothing. This enables W8A8 (8-bit weights AND activations) quantization, which further accelerates inference on hardware with INT8 matrix multiply support.
Relevance to MangaAssist: SmoothQuant is applicable to our INT8 quantized DistilBERT models. The smoothing technique improves INT8 accuracy by 0.3% on our intent classifier — modest but free (zero inference overhead since the smoothing is folded into weights).
Production Results
Quantization Impact Summary
| Model | Method | Original Size | Quantized Size | Accuracy Delta | Latency Change |
|---|---|---|---|---|---|
| Llama 3 8B | AWQ INT4 | 16GB | 4.1GB | -0.7% | -42% faster |
| DistilBERT (intent) | QAT INT8 | 265MB | 66MB | -0.3% | -35% faster |
| MiniLM-L6 (reranker) | QAT INT8 | 90MB | 22MB | -0.2% | -30% faster |
| DistilBERT (sentiment) | QAT INT8 | 265MB | 66MB | -0.4% | -33% faster |
Total Pipeline Memory
| Configuration | Total VRAM | Instance | Monthly Cost |
|---|---|---|---|
| All FP16 | 16.6GB | g5.2xlarge | $680 |
| All INT8 | 8.2GB | g5.xlarge | $510 |
| Mixed (LLM INT4, rest INT8) | 4.3GB | g5.xlarge | $510 |
| Mixed + KV cache headroom | 4.3GB + 8GB cache | g5.xlarge | $510 |
Savings: Mixed quantization saves $170/month ($2,040/year) by fitting on a smaller instance while losing <1% quality. The combined pipeline latency drops from 680ms to 520ms — a 24% improvement that directly impacts user experience.