LOCAL PREVIEW View on GitHub

14. Retrieval-Augmented Fine-Tuning (RAFT) — Teaching the LLM to Use Retrieved Context

Problem Statement and MangaAssist Context

MangaAssist's RAG pipeline retrieves relevant manga catalog chunks and passes them to Claude 3.5 Sonnet (or our Llama 3 8B fallback) as context. The problem: off-the-shelf LLMs don't always know how to use retrieved documents effectively. They may:

  1. Ignore relevant context (12% of responses): the model answers from parametric memory instead of using the retrieved manga details
  2. Over-rely on irrelevant chunks (8%): the retriever returns partially relevant documents and the model quotes them verbatim instead of filtering
  3. Miss key details (6%): the model skims the context and misses critical information like price, availability, or volume count

RAFT (Retrieval-Augmented Fine-Tuning) fine-tunes the LLM to be a better "context reader" — teaching it to identify relevant passages, ignore distractors, and extract precise answers from retrieved chunks.

Before vs After RAFT

Metric Without RAFT With RAFT Improvement
Context utilization rate 72% 94% +22%
Distractor rejection 64% 91% +27%
Answer precision (from context) 78% 93% +15%
Hallucination rate 4.0% 1.1% -73%

Mathematical Foundations

RAFT Training Framework (Zhang et al., 2024)

RAFT creates training data where the model must answer questions given a mix of relevant and irrelevant documents. For each training example:

  • Question $q$: A customer query
  • Oracle document $d^*$: The document that contains the answer
  • Distractor documents ${d_1^-, d_2^-, \ldots, d_k^-}$: Retrieved documents that are topically related but don't contain the answer

Two types of training examples:

Type 1 (oracle present, probability $p$): The model sees ${d^, d_1^-, d_2^-, \ldots, d_k^-}$ and must find the answer in $d^$ while ignoring distractors.

$$\mathcal{L}1 = -\log P\theta(a^ | q, d^, d_1^-, \ldots, d_k^-)$$

Type 2 (oracle absent, probability $1-p$): The model sees only distractors ${d_1^-, d_2^-, \ldots, d_k^-}$ and must admit it cannot answer (or answer from parametric knowledge).

$$\mathcal{L}2 = -\log P\theta(\text{"I don't have enough information"} | q, d_1^-, \ldots, d_k^-)$$

The combined training loss:

$$\mathcal{L}_{\text{RAFT}} = p \cdot \mathbb{E}[\mathcal{L}_1] + (1-p) \cdot \mathbb{E}[\mathcal{L}_2]$$

Typical setting: $p = 0.8$ (80% of examples include the oracle), $k = 4$ distractors.

Chain-of-Thought with Citations

RAFT trains the model to generate chain-of-thought reasoning with explicit citations:

$$a^* = \text{"According to [Doc 3], One Piece has 107 volumes as of 2024. The price is ¥484 per volume [Doc 3, paragraph 2]."}$$

This forces the model to: 1. Identify which document is relevant 2. Extract specific information 3. Cite the source explicitly 4. Reason step-by-step

Attention Attribution Analysis

To understand how the model uses context, we measure attention attribution:

$$\text{Attribution}(d_i) = \frac{1}{|a|} \sum_{t \in a} \sum_{l=1}^{L} \sum_{h=1}^{H} \alpha_{l,h}^{(t \to d_i)}$$

where $\alpha_{l,h}^{(t \to d_i)}$ is the attention weight from output token $t$ to document $d_i$ at layer $l$, head $h$.

Before RAFT: - 45% of attention goes to the question (the model answers from its own understanding) - 30% to relevant documents - 25% to distractors

After RAFT: - 20% of attention goes to the question - 65% to the oracle document (model reads the context) - 15% to distractors (model has learned to identify and downweight them)

Context Utilization Score

We define a quantitative metric for how well the model uses retrieved context:

$$\text{CUS}(a, d^) = \frac{|\text{facts}(a) \cap \text{facts}(d^)|}{|\text{facts}(d^*)|}$$

where $\text{facts}(\cdot)$ extracts atomic factual claims. CUS = 1 means the model used all available facts from the oracle document. CUS = 0 means the model ignored the context entirely.

Self-RAG: Self-Reflective RAG (Asai et al., 2023)

Self-RAG adds reflection tokens that the model generates to assess its own retrieval:

  1. [Retrieve]: Should I retrieve? (yes/no)
  2. [IsRel]: Is this retrieved passage relevant? (relevant/irrelevant)
  3. [IsSup]: Is my response supported by the passage? (fully/partially/no)
  4. [IsUse]: Is my response useful? (5-point scale)

These reflection tokens are trained as additional classification heads:

$$P(\text{[IsRel]} = \text{relevant} | q, d_i) = \sigma(W_{\text{rel}} \cdot h_{[CLS]})$$

The model learns to self-assess retrieval quality at inference time, enabling it to decide whether to use, ignore, or re-retrieve context.

REALM Objective (Guu et al., 2020)

REALM pre-trains the retriever and reader jointly by treating retrieval as a latent variable:

$$P(y|x) = \sum_{z \in \mathcal{Z}} \underbrace{P(z|x)}{\text{retriever}} \cdot \underbrace{P(y|x,z)}{\text{reader}}$$

where $z$ is a retrieved document and $\mathcal{Z}$ is the corpus. The gradient flows through both the retriever and reader, improving both simultaneously.

For computational feasibility, the sum is approximated with top-$k$ documents:

$$P(y|x) \approx \sum_{z \in \text{top-}k} P(z|x) \cdot P(y|x,z)$$

Connection to RAFT: REALM optimizes retrieval quality, while RAFT accepts the retriever as-is and optimizes the reader to handle imperfect retrieval. For MangaAssist, we combine both: our embedding adapter (doc 02) improves retrieval, and RAFT improves reading.


Model Internals — Layer-by-Layer Diagrams

RAFT Training Data Construction

graph TB
    subgraph "RAFT Data Pipeline"
        QUERY["Customer query:<br>'How many volumes does<br>One Piece have?'"]

        RETRIEVER["Retriever (doc 02 embedding)<br>Returns top-5 chunks"]

        ORACLE["Oracle doc (d*):<br>'One Piece by Eiichiro Oda<br>currently has 107 volumes<br>as of January 2024...'"]

        D1["Distractor d₁:<br>'One Piece merchandise<br>includes figures, posters...'"]
        D2["Distractor d₂:<br>'Naruto by Masashi Kishimoto<br>has 72 volumes...'"]
        D3["Distractor d₃:<br>'One Piece anime has<br>over 1000 episodes...'"]
        D4["Distractor d₄:<br>'Manga sales in Japan<br>reached ¥612B in 2023...'"]

        RETRIEVER --> ORACLE & D1 & D2 & D3 & D4

        TYPE1["Type 1 (80% of data):<br>Input: q + [d*, d₁, d₂, d₃, d₄]<br>Target: 'According to [Doc 1],<br>One Piece has 107 volumes<br>as of January 2024.'"]

        TYPE2["Type 2 (20% of data):<br>Input: q + [d₁, d₂, d₃, d₄]<br>Target: 'Based on the available<br>documents, I don't have the<br>exact current volume count.'"]

        ORACLE --> TYPE1
        D1 & D2 & D3 & D4 --> TYPE1 & TYPE2
    end

    style ORACLE fill:#c8e6c9
    style D1 fill:#ffcdd2
    style D2 fill:#ffcdd2
    style D3 fill:#ffcdd2
    style D4 fill:#ffcdd2
    style TYPE1 fill:#c8e6c9
    style TYPE2 fill:#fff9c4

Attention Shift Before/After RAFT

graph LR
    subgraph "Before RAFT: Attention Distribution"
        B_Q["Question tokens<br>45% of attention ⚠️<br>Model answers from<br>parametric memory"]
        B_O["Oracle doc<br>30% of attention ⚠️<br>Under-utilized"]
        B_D["Distractors<br>25% of attention ⚠️<br>Not filtered"]
    end

    subgraph "After RAFT: Attention Distribution"
        A_Q["Question tokens<br>20% of attention ✅<br>Used for understanding<br>the query"]
        A_O["Oracle doc<br>65% of attention ✅<br>Model reads and<br>extracts from context"]
        A_D["Distractors<br>15% of attention ✅<br>Identified and<br>downweighted"]
    end

    B_Q -->|"RAFT<br>training"| A_Q
    B_O -->|"RAFT<br>training"| A_O
    B_D -->|"RAFT<br>training"| A_D

    style B_Q fill:#fff9c4
    style B_O fill:#ffcdd2
    style B_D fill:#ffcdd2
    style A_Q fill:#c8e6c9
    style A_O fill:#c8e6c9
    style A_D fill:#c8e6c9

Self-RAG Reflection Flow

sequenceDiagram
    participant Q as Query
    participant LLM as RAFT LLM
    participant R as Retriever
    participant REF as Reflection Tokens

    Q->>LLM: "What's the price of Demon Slayer vol 23?"
    LLM->>REF: [Retrieve] = yes (needs external info)
    REF->>R: Trigger retrieval
    R->>LLM: Returns 5 chunks

    LLM->>REF: [IsRel] per chunk
    Note over REF: Chunk 1: relevant ✅ (has price)<br>Chunk 2: irrelevant ❌ (anime info)<br>Chunk 3: irrelevant ❌ (reviews)<br>Chunk 4: partially ⚠️ (old price)<br>Chunk 5: irrelevant ❌ (merch)

    LLM->>LLM: Generate answer using Chunk 1
    LLM->>REF: [IsSup] = fully supported ✅
    LLM->>REF: [IsUse] = 5/5 (directly answers)
    LLM->>Q: "Demon Slayer vol 23 is ¥484 [Source: Chunk 1]"

RAFT vs Standard RAG Pipeline

graph TB
    subgraph "Standard RAG"
        S_R["Retriever returns<br>top-5 chunks"]
        S_C["Concatenate ALL chunks<br>into prompt context"]
        S_L["LLM processes<br>long context (2000+ tokens)<br>May fixate on distractors"]
        S_A["Answer may include<br>irrelevant info from<br>distractor chunks"]

        S_R --> S_C --> S_L --> S_A
    end

    subgraph "RAFT-Enhanced RAG"
        R_R["Retriever returns<br>top-5 chunks"]
        R_C["Concatenate chunks<br>with position markers<br>[Doc 1], [Doc 2], ..."]
        R_L["RAFT LLM:<br>1. Identifies oracle via learned<br>   attention patterns<br>2. Ignores distractor content<br>3. Extracts precise facts<br>4. Generates with citations"]
        R_A["Answer is precise,<br>cited, and distractor-free:<br>'According to [Doc 1]...'"]

        R_R --> R_C --> R_L --> R_A
    end

    style S_A fill:#ffcdd2
    style R_A fill:#c8e6c9

Training Loop Architecture

graph TB
    subgraph "RAFT Fine-Tuning Loop (QLoRA)"
        DATA["RAFT training data<br>2,000 examples<br>80% oracle-present<br>20% oracle-absent"]

        MODEL["Llama 3 8B + QLoRA<br>r=16, target: q,k,v,o projections<br>Frozen base + trainable adapter"]

        FWD["Forward pass:<br>Input: [q; Doc1; Doc2; ... ; Doc5]<br>Output: chain-of-thought +<br>cited answer"]

        LOSS["Loss: standard causal LM loss<br>Only on answer tokens<br>(prompt tokens masked)"]

        BWD["Backward: gradients update<br>LoRA adapters only<br>Model learns to:<br>• Attend to oracle document<br>• Generate citations<br>• Reject distractors"]

        DATA --> MODEL --> FWD --> LOSS --> BWD
        BWD -->|"next batch"| MODEL
    end

    style MODEL fill:#e3f2fd
    style LOSS fill:#fff9c4

Implementation Deep-Dive

RAFT Data Generator

import json
import random
import boto3
from typing import Optional


class RAFTDataGenerator:
    """
    Generate RAFT training data from the MangaAssist knowledge base.
    Creates (question, oracle_doc, distractor_docs, answer) tuples.
    """

    def __init__(self, opensearch_client, bedrock_client=None):
        self.os_client = opensearch_client
        self.bedrock = bedrock_client or boto3.client("bedrock-runtime")

    def generate_qa_pairs(
        self,
        num_pairs: int = 2000,
        num_distractors: int = 4,
        oracle_present_ratio: float = 0.8,
    ) -> list[dict]:
        """Generate RAFT training examples from the manga catalog."""
        examples = []

        # Get all indexed documents
        all_docs = self._get_all_documents()

        for _ in range(num_pairs):
            # Select a random oracle document
            oracle_doc = random.choice(all_docs)

            # Generate a question about this document using Claude
            question = self._generate_question(oracle_doc)

            # Generate a chain-of-thought answer with citation
            answer = self._generate_cited_answer(question, oracle_doc)

            # Find distractors: topically similar but don't answer the question
            distractors = self._find_distractors(
                question, oracle_doc, all_docs, num_distractors,
            )

            # Create two types of examples
            if random.random() < oracle_present_ratio:
                # Type 1: Oracle present among distractors
                docs = [oracle_doc] + distractors
                random.shuffle(docs)
                oracle_idx = docs.index(oracle_doc) + 1  # 1-indexed

                examples.append({
                    "question": question,
                    "documents": [
                        {"doc_id": i + 1, "content": d["content"]}
                        for i, d in enumerate(docs)
                    ],
                    "oracle_doc_id": oracle_idx,
                    "answer": answer,
                    "type": "oracle_present",
                })
            else:
                # Type 2: Oracle absent — only distractors
                examples.append({
                    "question": question,
                    "documents": [
                        {"doc_id": i + 1, "content": d["content"]}
                        for i, d in enumerate(distractors)
                    ],
                    "oracle_doc_id": None,
                    "answer": (
                        "Based on the available documents, I don't have enough "
                        "information to answer this question precisely. Let me "
                        "check with our manga specialist team."
                    ),
                    "type": "oracle_absent",
                })

        return examples

    def _generate_question(self, doc: dict) -> str:
        """Generate a natural customer question about a document."""
        prompt = f"""Generate a natural customer question that would be answered by this manga product information:

{doc['content'][:500]}

The question should sound like a real customer asking in a bookstore chat. Output only the question."""

        response = self.bedrock.invoke_model(
            modelId="anthropic.claude-3-5-sonnet-20241022-v2:0",
            body=json.dumps({
                "anthropic_version": "bedrock-2023-05-01",
                "max_tokens": 100,
                "messages": [{"role": "user", "content": prompt}],
            }),
        )
        return json.loads(response["body"].read())["content"][0]["text"].strip()

    def _generate_cited_answer(self, question: str, oracle_doc: dict) -> str:
        """Generate a chain-of-thought answer with citations."""
        prompt = f"""Answer this customer question using ONLY the provided document.
Include explicit citations like [Doc X, paragraph Y].
Think step-by-step before giving the final answer.

Question: {question}

Document: {oracle_doc['content'][:800]}

Format:
## Reasoning
[Step-by-step reasoning about which parts of the document answer the question]

## Answer
[Conversational answer with citations]"""

        response = self.bedrock.invoke_model(
            modelId="anthropic.claude-3-5-sonnet-20241022-v2:0",
            body=json.dumps({
                "anthropic_version": "bedrock-2023-05-01",
                "max_tokens": 300,
                "messages": [{"role": "user", "content": prompt}],
            }),
        )
        return json.loads(response["body"].read())["content"][0]["text"].strip()

    def _find_distractors(
        self, question: str, oracle: dict, all_docs: list, n: int,
    ) -> list[dict]:
        """Find topically similar but non-answering documents."""
        # Use semantic search to find similar docs
        similar = self._semantic_search(question, top_k=n + 5)
        distractors = [d for d in similar if d["id"] != oracle["id"]][:n]

        # If not enough similar docs, pad with random ones
        while len(distractors) < n:
            rand_doc = random.choice(all_docs)
            if rand_doc["id"] != oracle["id"] and rand_doc not in distractors:
                distractors.append(rand_doc)

        return distractors

    def _get_all_documents(self) -> list[dict]:
        """Retrieve all documents from OpenSearch."""
        results = self.os_client.search(
            index="manga-catalog",
            body={"query": {"match_all": {}}, "size": 10000},
        )
        return [
            {"id": hit["_id"], "content": hit["_source"]["content"]}
            for hit in results["hits"]["hits"]
        ]

    def _semantic_search(self, query: str, top_k: int = 10) -> list[dict]:
        """Search for similar documents."""
        # Embed query and search (simplified)
        results = self.os_client.search(
            index="manga-catalog",
            body={
                "query": {"match": {"content": query}},
                "size": top_k,
            },
        )
        return [
            {"id": hit["_id"], "content": hit["_source"]["content"]}
            for hit in results["hits"]["hits"]
        ]

RAFT Trainer

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
from torch.utils.data import Dataset, DataLoader


class RAFTDataset(Dataset):
    """Format RAFT examples for causal LM training."""

    def __init__(self, examples: list[dict], tokenizer, max_length: int = 2048):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.processed = []

        for ex in examples:
            # Format input: question + documents
            doc_text = "\n\n".join(
                f"[Doc {d['doc_id']}]: {d['content']}"
                for d in ex["documents"]
            )

            input_text = (
                f"Question: {ex['question']}\n\n"
                f"Retrieved Documents:\n{doc_text}\n\n"
                f"Instructions: Answer the question using ONLY the provided documents. "
                f"Cite specific documents. If the documents don't contain the answer, "
                f"say so.\n\n"
                f"Answer: {ex['answer']}"
            )

            encoded = tokenizer(
                input_text, truncation=True, max_length=max_length,
                padding="max_length", return_tensors="pt",
            )

            # Find where the answer starts to mask prompt tokens
            answer_prefix = "Answer: "
            answer_start = input_text.find(answer_prefix) + len(answer_prefix)
            prompt_tokens = len(tokenizer(input_text[:answer_start])["input_ids"])

            labels = encoded["input_ids"].clone().squeeze()
            labels[:prompt_tokens] = -100  # Mask prompt tokens
            labels[labels == tokenizer.pad_token_id] = -100

            self.processed.append({
                "input_ids": encoded["input_ids"].squeeze(),
                "attention_mask": encoded["attention_mask"].squeeze(),
                "labels": labels,
            })

    def __len__(self):
        return len(self.processed)

    def __getitem__(self, idx):
        return self.processed[idx]


class RAFTTrainer:
    """Fine-tune LLM with RAFT for better context utilization."""

    def __init__(self, model_name: str = "meta-llama/Llama-3-8b-hf"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token

        # QLoRA setup
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            load_in_4bit=True,
            torch_dtype=torch.bfloat16,
            device_map="auto",
        )
        self.model = prepare_model_for_kbit_training(self.model)

        lora_config = LoraConfig(
            r=16,
            lora_alpha=32,
            target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
            lora_dropout=0.05,
            task_type="CAUSAL_LM",
        )
        self.model = get_peft_model(self.model, lora_config)

    def train(self, raft_data: list[dict], epochs: int = 3, lr: float = 2e-5):
        dataset = RAFTDataset(raft_data, self.tokenizer)
        loader = DataLoader(dataset, batch_size=2, shuffle=True)

        optimizer = torch.optim.AdamW(
            [p for p in self.model.parameters() if p.requires_grad],
            lr=lr, weight_decay=0.01,
        )

        self.model.train()
        for epoch in range(epochs):
            total_loss = 0
            for batch in loader:
                batch = {k: v.to(self.model.device) for k, v in batch.items()}
                outputs = self.model(**batch)
                loss = outputs.loss

                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    [p for p in self.model.parameters() if p.requires_grad], 1.0,
                )
                optimizer.step()
                optimizer.zero_grad()
                total_loss += loss.item()

            print(f"Epoch {epoch+1}: loss={total_loss/len(loader):.4f}")

    def evaluate_context_utilization(
        self, test_examples: list[dict],
    ) -> dict:
        """Measure how well the model uses retrieved context."""
        self.model.eval()
        results = {
            "context_utilization": [],
            "distractor_rejection": [],
            "citation_accuracy": [],
        }

        for ex in test_examples:
            # Format input (without answer)
            doc_text = "\n\n".join(
                f"[Doc {d['doc_id']}]: {d['content']}"
                for d in ex["documents"]
            )
            input_text = (
                f"Question: {ex['question']}\n\n"
                f"Retrieved Documents:\n{doc_text}\n\n"
                f"Instructions: Answer the question using ONLY the provided documents. "
                f"Cite specific documents. If the documents don't contain the answer, "
                f"say so.\n\n"
                f"Answer:"
            )

            inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs, max_new_tokens=300, temperature=0.1,
                )
            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            answer = response[len(input_text):]

            # Check citation accuracy
            cited_docs = self._extract_citations(answer)
            if ex["oracle_doc_id"]:
                correct_citation = ex["oracle_doc_id"] in cited_docs
                results["citation_accuracy"].append(1.0 if correct_citation else 0.0)

                # Distractor rejection: cited only oracle, not distractors
                distractor_cited = any(
                    d in cited_docs for d in range(1, len(ex["documents"]) + 1)
                    if d != ex["oracle_doc_id"]
                )
                results["distractor_rejection"].append(0.0 if distractor_cited else 1.0)
            else:
                # Oracle absent: model should not confidently answer
                hedging = any(
                    phrase in answer.lower()
                    for phrase in ["don't have", "not enough", "cannot find"]
                )
                results["context_utilization"].append(1.0 if hedging else 0.0)

        return {
            k: sum(v) / len(v) if v else 0.0
            for k, v in results.items()
        }

    def _extract_citations(self, text: str) -> list[int]:
        """Extract document numbers from citations like [Doc 1], [Doc 3]."""
        import re
        return [int(m) for m in re.findall(r'\[Doc\s*(\d+)\]', text)]

Integration with MangaAssist RAG Pipeline

class RAFTEnhancedRAG:
    """
    Production RAG pipeline with RAFT-trained LLM.
    Replaces direct Claude calls with a RAFT-aware inference step.
    """

    def __init__(self, raft_model_path: str, retriever):
        self.tokenizer = AutoTokenizer.from_pretrained(raft_model_path)
        self.model = AutoModelForCausalLM.from_pretrained(
            raft_model_path,
            torch_dtype=torch.bfloat16,
            device_map="auto",
        )
        self.retriever = retriever

    def answer(self, query: str, top_k: int = 5) -> dict:
        """Full RAG pipeline with RAFT-trained context reading."""
        # Step 1: Retrieve (existing pipeline from docs 02, 03)
        chunks = self.retriever.search(query, top_k=top_k)

        # Step 2: Format with document markers
        doc_text = "\n\n".join(
            f"[Doc {i+1}]: {chunk['content']}"
            for i, chunk in enumerate(chunks)
        )

        prompt = (
            f"Question: {query}\n\n"
            f"Retrieved Documents:\n{doc_text}\n\n"
            f"Instructions: Answer the question using ONLY the provided documents. "
            f"Cite specific documents. If the documents don't contain the answer, "
            f"say so.\n\n"
            f"Answer:"
        )

        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=300,
                temperature=0.3,
                do_sample=True,
                top_p=0.9,
            )

        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        answer = response[len(prompt):]

        return {
            "answer": answer,
            "sources": chunks,
            "num_docs_retrieved": len(chunks),
        }

Group Discussion: Key Decision Points

Decision Point 1: RAFT vs Standard Fine-Tuning for RAG

Priya (ML Engineer): Comparison of approaches for improving context utilization:

Approach Context Utilization Hallucination Training Data Training Cost
Standard RAG (no FT) 72% 4.0% 0 $0
Standard SFT on QA pairs 82% 2.8% 2,000 pairs $48
RAFT (oracle + distractors) 94% 1.1% 2,000 pairs $48
Self-RAG (+ reflection tokens) 92% 0.9% 5,000 pairs $120

Aiko (Data Scientist): RAFT matches Self-RAG on context utilization with 60% less training data. The distractor training is the key differentiator — it teaches the model when NOT to use context, which standard SFT doesn't cover.

Marcus (Architect): Self-RAG's reflection tokens (IsRel, IsSup, IsUse) add inference overhead: the model generates extra tokens for self-assessment. At scale, that's 15-20% more tokens per response, which directly impacts our LLM latency budget.

Sam (PM): Self-RAG has a lower hallucination rate (0.9% vs 1.1%). Is that gap worth the extra complexity?

Aiko (Data Scientist): At our volume (10K messages/day), that is 9 vs 11 hallucinations per day — practically identical. The hallucination difference is within noise.

Resolution: RAFT for production. It provides 94% context utilization with simple implementation. Self-RAG is over-engineered for our current scale. We may adopt the [IsRel] reflection token concept (just relevance assessment, not the full reflection framework) as a lightweight enhancement later.

Decision Point 2: Oracle-Present Ratio

Aiko (Data Scientist): Varying the oracle-present ratio $p$ in training:

$p$ (oracle present) Context Util Distractor Rejection "I don't know" accuracy
1.0 (always present) 91% 82% 12% (terrible)
0.9 93% 87% 68%
0.8 94% 91% 84%
0.7 92% 93% 89%
0.5 88% 95% 94%

Priya (ML Engineer): $p = 0.8$ is optimal for MangaAssist. Our retriever (doc 02 + doc 03) has ~85% recall@5, meaning the oracle is present in retrieved results 85% of the time. Training with $p = 0.8$ matches the real-world distribution.

Resolution: Use $p = 0.8$, matching our retriever's actual recall rate. If we improve retriever recall (e.g., to 92%), we should increase $p$ accordingly.

Decision Point 3: Number of Distractor Documents

Marcus (Architect): More distractors = harder training. How many?

Distractors Context Util Training Tokens/Example Max Seq Length Needed
1 88% ~800 1024
2 91% ~1200 1536
4 94% ~2000 2048
6 94.5% ~2800 3072

Jordan (MLOps): 4 distractors is the sweet spot: 94% context utilization within 2048 sequence length. Going to 6 distractors requires 3072 tokens, which increases memory and training time by ~40% for 0.5% improvement.

Priya (ML Engineer): 4 distractors also matches our production pipeline: we retrieve top-5 documents, so the model sees 1 oracle + 4 distractors in the typical case.

Resolution: 4 distractors in training, matching the top-5 retrieval configuration. Sequence length of 2048 for training.


Research Paper References

1. RAFT: Adapting Language Model to Domain Specific RAG (Zhang et al., 2024)

Key contribution: Proposed mixing oracle and distractor documents during fine-tuning, combined with chain-of-thought citations. Achieved 35-60% improvement in context utilization across multiple domains. The paper showed that distractor training is more important than increasing oracle data — models need to learn what to ignore as much as what to attend to.

Relevance to MangaAssist: RAFT is our primary technique for improving the Llama 3 8B fallback model's ability to use retrieved manga catalog data. The 22% improvement in context utilization directly reduces hallucinations.

2. Self-RAG: Learning to Retrieve, Generate, and Critique through Self-Reflection (Asai et al., 2023)

Key contribution: Introduced reflection tokens that enable the model to self-assess retrieval quality. The model learns to decide when to retrieve, whether retrieved passages are relevant, and whether its generation is supported by evidence. This end-to-end self-reflection eliminates the need for external relevance classifiers.

Relevance to MangaAssist: While we don't use full Self-RAG, its [IsRel] token concept inspired our relevance scoring in the re-ranking step. Understanding Self-RAG helps us design future improvements to our RAG pipeline.

3. REALM: Retrieval-Augmented Language Model Pre-Training (Guu et al., 2020)

Key contribution: Jointly trained the retriever and reader as a single model by treating retrieval as a latent variable. This end-to-end optimization ensures the retriever learns to find documents that the reader can actually use, and vice versa. The paper showed that pre-training with retrieval improves downstream task performance even on tasks without explicit retrieval.

Relevance to MangaAssist: REALM's joint optimization concept motivates our combined approach: embedding adapter fine-tuning (doc 02) for the retriever + RAFT for the reader. While we don't train them end-to-end (too complex), we evaluate them jointly to ensure improvements in one don't degrade the other.


Production Results

RAFT Impact on MangaAssist

Metric Before RAFT After RAFT Target
Context utilization 72% 94% ≥90% ✅
Distractor rejection 64% 91% ≥85% ✅
Citation accuracy 0% (no citations) 87% ≥80% ✅
Hallucination rate 4.0% 1.1% ≤2% ✅
"I don't know" accuracy 35% 84% ≥75% ✅
User satisfaction 3.8/5 4.⅗ ≥4.0/5 ✅

Cost

Item Cost
RAFT data generation (2,000 examples via Claude) $45
RAFT training (3 epochs, 1× A100, 4 hours) $32
Evaluation pipeline $8
Total per training cycle $85

ROI: Hallucination reduction from 4.0% → 1.1% saves ~400 incorrect responses/day × $0.50 customer support cost = $200/day = $6,000/month. Training cost $85 quarterly = $340/year. ROI: 211:1.