LOCAL PREVIEW View on GitHub

PO-03: RAG Pipeline Retrieval Performance

User Story

As a ML platform engineer, I want to reduce RAG retrieval latency from ~300ms to under 200ms at p95 while maintaining retrieval quality, So that grounded responses (FAQ, product questions, recommendations) are generated faster without sacrificing accuracy.

Acceptance Criteria

  • Embedding generation completes in under 30ms at p95.
  • OpenSearch KNN search returns top-10 candidates in under 100ms at p95.
  • Reranking of top-10 candidates to top-3 completes in under 50ms at p95.
  • Total RAG retrieval pipeline (embed → search → rerank) is under 200ms at p95.
  • Pre-filtering by metadata reduces search space by 40-60% without losing relevant chunks.
  • Embedding cache hit rate exceeds 20% for repeated/similar queries.

High-Level Design

RAG Latency Breakdown

graph LR
    subgraph "Current Pipeline (~300ms)"
        A[Embed Query<br>~40ms] --> B[KNN Search<br>~150ms]
        B --> C[Rerank Top 10<br>~80ms]
        C --> D[Return Top 3<br>~5ms]
    end

    subgraph "Optimized Pipeline (~170ms)"
        E[Embed Query<br>~25ms] --> F[Pre-filtered<br>KNN Search<br>~80ms]
        F --> G[Light Rerank<br>~40ms]
        G --> H[Return Top 3<br>~5ms]
    end

    style A fill:#f66,stroke:#333
    style B fill:#f66,stroke:#333
    style C fill:#f66,stroke:#333
    style E fill:#2d8,stroke:#333
    style F fill:#2d8,stroke:#333
    style G fill:#2d8,stroke:#333

Optimization Strategy Overview

graph TD
    subgraph "Embedding Optimization"
        A1[Embedding Cache<br>for repeated queries]
        A2[Quantized Embedding<br>Model]
        A3[Batched Embedding<br>for multi-query]
    end

    subgraph "Search Optimization"
        B1[Metadata Pre-filtering<br>Reduce search space]
        B2[HNSW Parameter<br>Tuning]
        B3[Approximate KNN<br>with ef_search tuning]
    end

    subgraph "Reranking Optimization"
        C1[Cross-Encoder<br>on GPU endpoint]
        C2[Score Threshold<br>Early Exit]
        C3[Chunk Deduplication<br>before rerank]
    end

    A1 --> D[Total: < 200ms p95]
    A2 --> D
    B1 --> D
    B2 --> D
    C1 --> D
    C2 --> D

Low-Level Design

1. Embedding Cache and Optimization

Identical or near-identical queries produce the same embedding vector. Caching avoids redundant Bedrock Titan Embeddings calls.

sequenceDiagram
    participant Orchestrator
    participant EmbedCache as Embedding Cache<br>(Redis)
    participant Bedrock as Bedrock Titan<br>Embeddings

    Orchestrator->>EmbedCache: GET embed:{query_hash}
    alt Cache Hit (~20%)
        EmbedCache-->>Orchestrator: Cached vector (1ms)
    else Cache Miss
        EmbedCache-->>Orchestrator: null
        Orchestrator->>Bedrock: Embed query (25ms)
        Bedrock-->>Orchestrator: Vector [1536-dim]
        Orchestrator->>EmbedCache: SET embed:{query_hash} TTL=1h
    end

Code Example: Embedding Service with Cache

import hashlib
import json
import time
from dataclasses import dataclass
from typing import Optional

import boto3
import numpy as np
import redis


@dataclass
class EmbeddingResult:
    vector: list[float]
    latency_ms: float
    source: str  # "cache" or "bedrock"
    dimensions: int


class CachedEmbeddingService:
    """Embedding service with Redis cache for repeated queries."""

    CACHE_TTL_SECONDS = 3600  # 1 hour

    def __init__(
        self,
        redis_client: redis.Redis,
        region: str = "us-east-1",
        model_id: str = "amazon.titan-embed-text-v2:0",
    ):
        self.redis = redis_client
        self.bedrock = boto3.client("bedrock-runtime", region_name=region)
        self.model_id = model_id

    def _cache_key(self, text: str) -> str:
        """Stable hash of the input text for cache lookup."""
        normalized = text.lower().strip()
        digest = hashlib.sha256(normalized.encode("utf-8")).hexdigest()[:24]
        return f"embed:{digest}"

    async def embed(self, text: str) -> EmbeddingResult:
        """Embed text with cache-first strategy."""
        start = time.monotonic()
        cache_key = self._cache_key(text)

        # Check cache
        cached = self.redis.get(cache_key)
        if cached is not None:
            vector = json.loads(cached)
            return EmbeddingResult(
                vector=vector,
                latency_ms=(time.monotonic() - start) * 1000,
                source="cache",
                dimensions=len(vector),
            )

        # Call Bedrock Titan Embeddings
        import asyncio
        response = await asyncio.to_thread(
            self.bedrock.invoke_model,
            modelId=self.model_id,
            body=json.dumps({"inputText": text}),
            contentType="application/json",
        )

        body = json.loads(response["body"].read())
        vector = body["embedding"]

        # Cache the result
        self.redis.setex(
            cache_key,
            self.CACHE_TTL_SECONDS,
            json.dumps(vector),
        )

        return EmbeddingResult(
            vector=vector,
            latency_ms=(time.monotonic() - start) * 1000,
            source="bedrock",
            dimensions=len(vector),
        )

    async def embed_batch(self, texts: list[str]) -> list[EmbeddingResult]:
        """Embed multiple texts, using cache where available."""
        import asyncio
        results = await asyncio.gather(*[self.embed(text) for text in texts])
        return list(results)

Instead of searching the entire vector index, apply metadata filters first to reduce the candidate set.

graph TD
    subgraph "Without Pre-filter"
        A1[Query Vector] --> B1[Search ALL 500K chunks]
        B1 --> C1[KNN top 10<br>~150ms]
    end

    subgraph "With Pre-filter"
        A2[Query Vector + Intent] --> B2[Filter by source_type<br>+ category]
        B2 --> C2[Search 50-100K chunks<br>reduced set]
        C2 --> D2[KNN top 10<br>~80ms]
    end

    style C1 fill:#f66,stroke:#333
    style D2 fill:#2d8,stroke:#333

Intent-to-Filter Mapping

Intent source_type filter category filter Expected Reduction
faq faq, policy any ~70% (only FAQ/policy chunks)
product_question product_description, review manga ~50%
recommendation product_description, editorial manga ~45%
return_request policy, faq any ~75%
promotion editorial, faq manga ~60%

Code Example: Pre-Filtered OpenSearch Query

import time
from dataclasses import dataclass
from typing import Optional

from opensearchpy import AsyncOpenSearch


@dataclass
class RetrievedChunk:
    chunk_id: str
    content: str
    score: float
    source_type: str
    asin: str | None
    metadata: dict


@dataclass
class RetrievalResult:
    chunks: list[RetrievedChunk]
    total_candidates: int
    search_latency_ms: float
    pre_filter_applied: bool


# Intent to OpenSearch filter mapping
INTENT_FILTERS = {
    "faq": {
        "source_types": ["faq", "policy"],
        "categories": None,  # All categories
    },
    "product_question": {
        "source_types": ["product_description", "review"],
        "categories": ["manga"],
    },
    "recommendation": {
        "source_types": ["product_description", "editorial"],
        "categories": ["manga"],
    },
    "return_request": {
        "source_types": ["policy", "faq"],
        "categories": None,
    },
    "promotion": {
        "source_types": ["editorial", "faq"],
        "categories": ["manga"],
    },
    "checkout_help": {
        "source_types": ["faq", "policy"],
        "categories": None,
    },
}

# HNSW search parameters tuned for latency vs recall trade-off
HNSW_PARAMS = {
    "ef_search": 100,  # Default 512 -> reduced for speed (recall ~95%)
}


class OptimizedVectorSearch:
    """Pre-filtered KNN search with tuned HNSW parameters."""

    def __init__(
        self,
        opensearch_client: AsyncOpenSearch,
        index_name: str = "manga_knowledge_base",
    ):
        self.client = opensearch_client
        self.index_name = index_name

    async def search(
        self,
        query_vector: list[float],
        intent: str,
        top_k: int = 10,
        asin_filter: Optional[str] = None,
    ) -> RetrievalResult:
        """Execute pre-filtered KNN search with optimized parameters."""
        start = time.monotonic()

        # Build the filter clause based on intent
        filter_clause = self._build_filter(intent, asin_filter)

        # Construct the KNN query with pre-filtering
        query_body = {
            "size": top_k,
            "query": {
                "knn": {
                    "embedding": {
                        "vector": query_vector,
                        "k": top_k,
                    }
                }
            },
        }

        # Apply pre-filter if available
        if filter_clause:
            query_body["query"]["knn"]["embedding"]["filter"] = {
                "bool": {"must": filter_clause}
            }
            pre_filtered = True
        else:
            pre_filtered = False

        # Execute search with custom HNSW parameters
        response = await self.client.search(
            index=self.index_name,
            body=query_body,
            params={"search_pipeline": "rag-search-pipeline"},
        )

        hits = response.get("hits", {}).get("hits", [])
        total = response.get("hits", {}).get("total", {}).get("value", 0)

        chunks = [
            RetrievedChunk(
                chunk_id=hit["_source"]["chunk_id"],
                content=hit["_source"]["content"],
                score=hit["_score"],
                source_type=hit["_source"].get("source_type", "unknown"),
                asin=hit["_source"].get("asin"),
                metadata={
                    k: v for k, v in hit["_source"].items()
                    if k not in ("chunk_id", "content", "embedding")
                },
            )
            for hit in hits
        ]

        return RetrievalResult(
            chunks=chunks,
            total_candidates=total,
            search_latency_ms=(time.monotonic() - start) * 1000,
            pre_filter_applied=pre_filtered,
        )

    def _build_filter(
        self, intent: str, asin_filter: Optional[str]
    ) -> list[dict]:
        """Build OpenSearch filter clause from intent mapping."""
        clauses = []

        intent_config = INTENT_FILTERS.get(intent)
        if intent_config:
            source_types = intent_config.get("source_types")
            if source_types:
                clauses.append({"terms": {"source_type": source_types}})

            categories = intent_config.get("categories")
            if categories:
                clauses.append({"terms": {"category": categories}})

        if asin_filter:
            clauses.append({"term": {"asin": asin_filter}})

        return clauses

3. Lightweight Reranker with Early Exit

The reranker re-scores candidate chunks by cross-encoding (query, chunk) pairs. Optimizations include score thresholding and deduplication.

graph TD
    A[10 Candidate Chunks] --> B[Deduplicate<br>Remove near-duplicates]
    B --> C[7-8 Unique Chunks]
    C --> D[Cross-Encoder<br>Rerank]
    D --> E{Score > 0.3?}
    E -->|Yes| F[Keep Chunk]
    E -->|No| G[Discard]
    F --> H[Return Top 3<br>by rerank score]

    style B fill:#2d8,stroke:#333
    style E fill:#2d8,stroke:#333

Code Example: Optimized Reranker

import asyncio
import time
from dataclasses import dataclass

import numpy as np


@dataclass
class RankedChunk:
    chunk_id: str
    content: str
    original_score: float
    rerank_score: float
    source_type: str


class OptimizedReranker:
    """Cross-encoder reranker with deduplication and early exit."""

    SIMILARITY_DEDUP_THRESHOLD = 0.92  # Remove chunks that are >92% similar
    MIN_RERANK_SCORE = 0.3  # Discard low-relevance chunks
    MAX_OUTPUT_CHUNKS = 3

    def __init__(self, cross_encoder_endpoint: str, sagemaker_client):
        self.endpoint = cross_encoder_endpoint
        self.sagemaker = sagemaker_client

    async def rerank(
        self, query: str, chunks: list["RetrievedChunk"]
    ) -> list[RankedChunk]:
        """Deduplicate, rerank, and filter chunks."""
        start = time.monotonic()

        # Step 1: Deduplicate near-identical chunks
        unique_chunks = self._deduplicate(chunks)

        # Step 2: Score all (query, chunk) pairs with cross-encoder
        scores = await self._cross_encode(query, unique_chunks)

        # Step 3: Filter by minimum score and sort
        ranked = []
        for chunk, score in zip(unique_chunks, scores):
            if score >= self.MIN_RERANK_SCORE:
                ranked.append(
                    RankedChunk(
                        chunk_id=chunk.chunk_id,
                        content=chunk.content,
                        original_score=chunk.score,
                        rerank_score=score,
                        source_type=chunk.source_type,
                    )
                )

        ranked.sort(key=lambda c: c.rerank_score, reverse=True)
        return ranked[: self.MAX_OUTPUT_CHUNKS]

    def _deduplicate(
        self, chunks: list["RetrievedChunk"]
    ) -> list["RetrievedChunk"]:
        """Remove chunks with near-identical content using Jaccard similarity."""
        if len(chunks) <= 1:
            return chunks

        unique = [chunks[0]]
        for chunk in chunks[1:]:
            is_duplicate = False
            chunk_tokens = set(chunk.content.lower().split())
            for existing in unique:
                existing_tokens = set(existing.content.lower().split())
                intersection = chunk_tokens & existing_tokens
                union = chunk_tokens | existing_tokens
                jaccard = len(intersection) / len(union) if union else 0
                if jaccard > self.SIMILARITY_DEDUP_THRESHOLD:
                    is_duplicate = True
                    break
            if not is_duplicate:
                unique.append(chunk)

        return unique

    async def _cross_encode(
        self, query: str, chunks: list["RetrievedChunk"]
    ) -> list[float]:
        """Score (query, chunk) pairs using cross-encoder on SageMaker."""
        import json

        pairs = [
            {"text_1": query, "text_2": chunk.content}
            for chunk in chunks
        ]

        payload = json.dumps({"pairs": pairs})

        response = await asyncio.to_thread(
            self.sagemaker.invoke_endpoint,
            EndpointName=self.endpoint,
            ContentType="application/json",
            Body=payload,
        )

        result = json.loads(response["Body"].read())
        return result["scores"]

4. Full Optimized RAG Pipeline

sequenceDiagram
    participant Orchestrator
    participant EmbedCache as Embedding<br>Cache
    participant Bedrock as Bedrock<br>Titan Embed
    participant OpenSearch
    participant Reranker

    Orchestrator->>EmbedCache: Check embed cache
    alt Cache Hit
        EmbedCache-->>Orchestrator: Vector (1ms)
    else Cache Miss
        Orchestrator->>Bedrock: Embed query (25ms)
        Bedrock-->>Orchestrator: Vector
        Orchestrator->>EmbedCache: Cache vector
    end

    Orchestrator->>OpenSearch: Pre-filtered KNN (80ms)
    OpenSearch-->>Orchestrator: 10 candidates

    Orchestrator->>Reranker: Dedup + rerank (40ms)
    Reranker-->>Orchestrator: Top 3 chunks

    Note over Orchestrator: Total: ~150ms (cache hit)<br>or ~170ms (cache miss)

Code Example: Unified RAG Pipeline

import time
from dataclasses import dataclass


@dataclass
class RAGResult:
    chunks: list[dict]
    total_latency_ms: float
    embed_latency_ms: float
    search_latency_ms: float
    rerank_latency_ms: float
    embed_source: str
    candidates_searched: int
    pre_filter_applied: bool


class OptimizedRAGPipeline:
    """Full RAG pipeline with all performance optimizations."""

    def __init__(
        self,
        embedding_service: "CachedEmbeddingService",
        vector_search: "OptimizedVectorSearch",
        reranker: "OptimizedReranker",
    ):
        self.embedding = embedding_service
        self.search = vector_search
        self.reranker = reranker

    async def retrieve(
        self,
        query: str,
        intent: str,
        asin: str | None = None,
        top_k: int = 3,
    ) -> RAGResult:
        """End-to-end retrieval with caching, pre-filtering, and reranking."""
        start = time.monotonic()

        # Step 1: Embed query (with cache)
        embed_result = await self.embedding.embed(query)

        # Step 2: Pre-filtered KNN search
        search_result = await self.search.search(
            query_vector=embed_result.vector,
            intent=intent,
            top_k=10,  # Retrieve 10 for reranking
            asin_filter=asin,
        )

        # Step 3: Rerank and select top chunks
        if search_result.chunks:
            ranked_chunks = await self.reranker.rerank(
                query=query, chunks=search_result.chunks
            )
        else:
            ranked_chunks = []

        total_ms = (time.monotonic() - start) * 1000

        return RAGResult(
            chunks=[
                {
                    "chunk_id": c.chunk_id,
                    "content": c.content,
                    "score": c.rerank_score,
                    "source_type": c.source_type,
                }
                for c in ranked_chunks[:top_k]
            ],
            total_latency_ms=total_ms,
            embed_latency_ms=embed_result.latency_ms,
            search_latency_ms=search_result.search_latency_ms,
            rerank_latency_ms=total_ms - embed_result.latency_ms - search_result.search_latency_ms,
            embed_source=embed_result.source,
            candidates_searched=search_result.total_candidates,
            pre_filter_applied=search_result.pre_filter_applied,
        )

Metrics and Monitoring

Metric Target Alarm Threshold
rag.total_latency_ms p95 < 200ms p95 > 300ms for 5 min
rag.embed_latency_ms p95 < 30ms p95 > 50ms
rag.search_latency_ms p95 < 100ms p95 > 150ms
rag.rerank_latency_ms p95 < 50ms p95 > 80ms
rag.embed_cache_hit_rate > 20% < 10%
rag.pre_filter_reduction > 40% < 25%
rag.chunks_after_dedup < 8 (from 10) = 10 (no dedup)
rag.avg_rerank_score > 0.5 < 0.3 (quality issue)
graph LR
    subgraph "RAG Latency Budget"
        A[Embed: 25ms] --> D[Total: < 200ms]
        B[Search: 80ms] --> D
        C[Rerank: 40ms] --> D
        E[Other: 25ms] --> D
    end