PO-03: RAG Pipeline Retrieval Performance
User Story
As a ML platform engineer, I want to reduce RAG retrieval latency from ~300ms to under 200ms at p95 while maintaining retrieval quality, So that grounded responses (FAQ, product questions, recommendations) are generated faster without sacrificing accuracy.
Acceptance Criteria
- Embedding generation completes in under 30ms at p95.
- OpenSearch KNN search returns top-10 candidates in under 100ms at p95.
- Reranking of top-10 candidates to top-3 completes in under 50ms at p95.
- Total RAG retrieval pipeline (embed → search → rerank) is under 200ms at p95.
- Pre-filtering by metadata reduces search space by 40-60% without losing relevant chunks.
- Embedding cache hit rate exceeds 20% for repeated/similar queries.
High-Level Design
RAG Latency Breakdown
graph LR
subgraph "Current Pipeline (~300ms)"
A[Embed Query<br>~40ms] --> B[KNN Search<br>~150ms]
B --> C[Rerank Top 10<br>~80ms]
C --> D[Return Top 3<br>~5ms]
end
subgraph "Optimized Pipeline (~170ms)"
E[Embed Query<br>~25ms] --> F[Pre-filtered<br>KNN Search<br>~80ms]
F --> G[Light Rerank<br>~40ms]
G --> H[Return Top 3<br>~5ms]
end
style A fill:#f66,stroke:#333
style B fill:#f66,stroke:#333
style C fill:#f66,stroke:#333
style E fill:#2d8,stroke:#333
style F fill:#2d8,stroke:#333
style G fill:#2d8,stroke:#333
Optimization Strategy Overview
graph TD
subgraph "Embedding Optimization"
A1[Embedding Cache<br>for repeated queries]
A2[Quantized Embedding<br>Model]
A3[Batched Embedding<br>for multi-query]
end
subgraph "Search Optimization"
B1[Metadata Pre-filtering<br>Reduce search space]
B2[HNSW Parameter<br>Tuning]
B3[Approximate KNN<br>with ef_search tuning]
end
subgraph "Reranking Optimization"
C1[Cross-Encoder<br>on GPU endpoint]
C2[Score Threshold<br>Early Exit]
C3[Chunk Deduplication<br>before rerank]
end
A1 --> D[Total: < 200ms p95]
A2 --> D
B1 --> D
B2 --> D
C1 --> D
C2 --> D
Low-Level Design
1. Embedding Cache and Optimization
Identical or near-identical queries produce the same embedding vector. Caching avoids redundant Bedrock Titan Embeddings calls.
sequenceDiagram
participant Orchestrator
participant EmbedCache as Embedding Cache<br>(Redis)
participant Bedrock as Bedrock Titan<br>Embeddings
Orchestrator->>EmbedCache: GET embed:{query_hash}
alt Cache Hit (~20%)
EmbedCache-->>Orchestrator: Cached vector (1ms)
else Cache Miss
EmbedCache-->>Orchestrator: null
Orchestrator->>Bedrock: Embed query (25ms)
Bedrock-->>Orchestrator: Vector [1536-dim]
Orchestrator->>EmbedCache: SET embed:{query_hash} TTL=1h
end
Code Example: Embedding Service with Cache
import hashlib
import json
import time
from dataclasses import dataclass
from typing import Optional
import boto3
import numpy as np
import redis
@dataclass
class EmbeddingResult:
vector: list[float]
latency_ms: float
source: str # "cache" or "bedrock"
dimensions: int
class CachedEmbeddingService:
"""Embedding service with Redis cache for repeated queries."""
CACHE_TTL_SECONDS = 3600 # 1 hour
def __init__(
self,
redis_client: redis.Redis,
region: str = "us-east-1",
model_id: str = "amazon.titan-embed-text-v2:0",
):
self.redis = redis_client
self.bedrock = boto3.client("bedrock-runtime", region_name=region)
self.model_id = model_id
def _cache_key(self, text: str) -> str:
"""Stable hash of the input text for cache lookup."""
normalized = text.lower().strip()
digest = hashlib.sha256(normalized.encode("utf-8")).hexdigest()[:24]
return f"embed:{digest}"
async def embed(self, text: str) -> EmbeddingResult:
"""Embed text with cache-first strategy."""
start = time.monotonic()
cache_key = self._cache_key(text)
# Check cache
cached = self.redis.get(cache_key)
if cached is not None:
vector = json.loads(cached)
return EmbeddingResult(
vector=vector,
latency_ms=(time.monotonic() - start) * 1000,
source="cache",
dimensions=len(vector),
)
# Call Bedrock Titan Embeddings
import asyncio
response = await asyncio.to_thread(
self.bedrock.invoke_model,
modelId=self.model_id,
body=json.dumps({"inputText": text}),
contentType="application/json",
)
body = json.loads(response["body"].read())
vector = body["embedding"]
# Cache the result
self.redis.setex(
cache_key,
self.CACHE_TTL_SECONDS,
json.dumps(vector),
)
return EmbeddingResult(
vector=vector,
latency_ms=(time.monotonic() - start) * 1000,
source="bedrock",
dimensions=len(vector),
)
async def embed_batch(self, texts: list[str]) -> list[EmbeddingResult]:
"""Embed multiple texts, using cache where available."""
import asyncio
results = await asyncio.gather(*[self.embed(text) for text in texts])
return list(results)
2. Metadata Pre-Filtered KNN Search
Instead of searching the entire vector index, apply metadata filters first to reduce the candidate set.
graph TD
subgraph "Without Pre-filter"
A1[Query Vector] --> B1[Search ALL 500K chunks]
B1 --> C1[KNN top 10<br>~150ms]
end
subgraph "With Pre-filter"
A2[Query Vector + Intent] --> B2[Filter by source_type<br>+ category]
B2 --> C2[Search 50-100K chunks<br>reduced set]
C2 --> D2[KNN top 10<br>~80ms]
end
style C1 fill:#f66,stroke:#333
style D2 fill:#2d8,stroke:#333
Intent-to-Filter Mapping
| Intent | source_type filter | category filter | Expected Reduction |
|---|---|---|---|
faq |
faq, policy |
any | ~70% (only FAQ/policy chunks) |
product_question |
product_description, review |
manga | ~50% |
recommendation |
product_description, editorial |
manga | ~45% |
return_request |
policy, faq |
any | ~75% |
promotion |
editorial, faq |
manga | ~60% |
Code Example: Pre-Filtered OpenSearch Query
import time
from dataclasses import dataclass
from typing import Optional
from opensearchpy import AsyncOpenSearch
@dataclass
class RetrievedChunk:
chunk_id: str
content: str
score: float
source_type: str
asin: str | None
metadata: dict
@dataclass
class RetrievalResult:
chunks: list[RetrievedChunk]
total_candidates: int
search_latency_ms: float
pre_filter_applied: bool
# Intent to OpenSearch filter mapping
INTENT_FILTERS = {
"faq": {
"source_types": ["faq", "policy"],
"categories": None, # All categories
},
"product_question": {
"source_types": ["product_description", "review"],
"categories": ["manga"],
},
"recommendation": {
"source_types": ["product_description", "editorial"],
"categories": ["manga"],
},
"return_request": {
"source_types": ["policy", "faq"],
"categories": None,
},
"promotion": {
"source_types": ["editorial", "faq"],
"categories": ["manga"],
},
"checkout_help": {
"source_types": ["faq", "policy"],
"categories": None,
},
}
# HNSW search parameters tuned for latency vs recall trade-off
HNSW_PARAMS = {
"ef_search": 100, # Default 512 -> reduced for speed (recall ~95%)
}
class OptimizedVectorSearch:
"""Pre-filtered KNN search with tuned HNSW parameters."""
def __init__(
self,
opensearch_client: AsyncOpenSearch,
index_name: str = "manga_knowledge_base",
):
self.client = opensearch_client
self.index_name = index_name
async def search(
self,
query_vector: list[float],
intent: str,
top_k: int = 10,
asin_filter: Optional[str] = None,
) -> RetrievalResult:
"""Execute pre-filtered KNN search with optimized parameters."""
start = time.monotonic()
# Build the filter clause based on intent
filter_clause = self._build_filter(intent, asin_filter)
# Construct the KNN query with pre-filtering
query_body = {
"size": top_k,
"query": {
"knn": {
"embedding": {
"vector": query_vector,
"k": top_k,
}
}
},
}
# Apply pre-filter if available
if filter_clause:
query_body["query"]["knn"]["embedding"]["filter"] = {
"bool": {"must": filter_clause}
}
pre_filtered = True
else:
pre_filtered = False
# Execute search with custom HNSW parameters
response = await self.client.search(
index=self.index_name,
body=query_body,
params={"search_pipeline": "rag-search-pipeline"},
)
hits = response.get("hits", {}).get("hits", [])
total = response.get("hits", {}).get("total", {}).get("value", 0)
chunks = [
RetrievedChunk(
chunk_id=hit["_source"]["chunk_id"],
content=hit["_source"]["content"],
score=hit["_score"],
source_type=hit["_source"].get("source_type", "unknown"),
asin=hit["_source"].get("asin"),
metadata={
k: v for k, v in hit["_source"].items()
if k not in ("chunk_id", "content", "embedding")
},
)
for hit in hits
]
return RetrievalResult(
chunks=chunks,
total_candidates=total,
search_latency_ms=(time.monotonic() - start) * 1000,
pre_filter_applied=pre_filtered,
)
def _build_filter(
self, intent: str, asin_filter: Optional[str]
) -> list[dict]:
"""Build OpenSearch filter clause from intent mapping."""
clauses = []
intent_config = INTENT_FILTERS.get(intent)
if intent_config:
source_types = intent_config.get("source_types")
if source_types:
clauses.append({"terms": {"source_type": source_types}})
categories = intent_config.get("categories")
if categories:
clauses.append({"terms": {"category": categories}})
if asin_filter:
clauses.append({"term": {"asin": asin_filter}})
return clauses
3. Lightweight Reranker with Early Exit
The reranker re-scores candidate chunks by cross-encoding (query, chunk) pairs. Optimizations include score thresholding and deduplication.
graph TD
A[10 Candidate Chunks] --> B[Deduplicate<br>Remove near-duplicates]
B --> C[7-8 Unique Chunks]
C --> D[Cross-Encoder<br>Rerank]
D --> E{Score > 0.3?}
E -->|Yes| F[Keep Chunk]
E -->|No| G[Discard]
F --> H[Return Top 3<br>by rerank score]
style B fill:#2d8,stroke:#333
style E fill:#2d8,stroke:#333
Code Example: Optimized Reranker
import asyncio
import time
from dataclasses import dataclass
import numpy as np
@dataclass
class RankedChunk:
chunk_id: str
content: str
original_score: float
rerank_score: float
source_type: str
class OptimizedReranker:
"""Cross-encoder reranker with deduplication and early exit."""
SIMILARITY_DEDUP_THRESHOLD = 0.92 # Remove chunks that are >92% similar
MIN_RERANK_SCORE = 0.3 # Discard low-relevance chunks
MAX_OUTPUT_CHUNKS = 3
def __init__(self, cross_encoder_endpoint: str, sagemaker_client):
self.endpoint = cross_encoder_endpoint
self.sagemaker = sagemaker_client
async def rerank(
self, query: str, chunks: list["RetrievedChunk"]
) -> list[RankedChunk]:
"""Deduplicate, rerank, and filter chunks."""
start = time.monotonic()
# Step 1: Deduplicate near-identical chunks
unique_chunks = self._deduplicate(chunks)
# Step 2: Score all (query, chunk) pairs with cross-encoder
scores = await self._cross_encode(query, unique_chunks)
# Step 3: Filter by minimum score and sort
ranked = []
for chunk, score in zip(unique_chunks, scores):
if score >= self.MIN_RERANK_SCORE:
ranked.append(
RankedChunk(
chunk_id=chunk.chunk_id,
content=chunk.content,
original_score=chunk.score,
rerank_score=score,
source_type=chunk.source_type,
)
)
ranked.sort(key=lambda c: c.rerank_score, reverse=True)
return ranked[: self.MAX_OUTPUT_CHUNKS]
def _deduplicate(
self, chunks: list["RetrievedChunk"]
) -> list["RetrievedChunk"]:
"""Remove chunks with near-identical content using Jaccard similarity."""
if len(chunks) <= 1:
return chunks
unique = [chunks[0]]
for chunk in chunks[1:]:
is_duplicate = False
chunk_tokens = set(chunk.content.lower().split())
for existing in unique:
existing_tokens = set(existing.content.lower().split())
intersection = chunk_tokens & existing_tokens
union = chunk_tokens | existing_tokens
jaccard = len(intersection) / len(union) if union else 0
if jaccard > self.SIMILARITY_DEDUP_THRESHOLD:
is_duplicate = True
break
if not is_duplicate:
unique.append(chunk)
return unique
async def _cross_encode(
self, query: str, chunks: list["RetrievedChunk"]
) -> list[float]:
"""Score (query, chunk) pairs using cross-encoder on SageMaker."""
import json
pairs = [
{"text_1": query, "text_2": chunk.content}
for chunk in chunks
]
payload = json.dumps({"pairs": pairs})
response = await asyncio.to_thread(
self.sagemaker.invoke_endpoint,
EndpointName=self.endpoint,
ContentType="application/json",
Body=payload,
)
result = json.loads(response["Body"].read())
return result["scores"]
4. Full Optimized RAG Pipeline
sequenceDiagram
participant Orchestrator
participant EmbedCache as Embedding<br>Cache
participant Bedrock as Bedrock<br>Titan Embed
participant OpenSearch
participant Reranker
Orchestrator->>EmbedCache: Check embed cache
alt Cache Hit
EmbedCache-->>Orchestrator: Vector (1ms)
else Cache Miss
Orchestrator->>Bedrock: Embed query (25ms)
Bedrock-->>Orchestrator: Vector
Orchestrator->>EmbedCache: Cache vector
end
Orchestrator->>OpenSearch: Pre-filtered KNN (80ms)
OpenSearch-->>Orchestrator: 10 candidates
Orchestrator->>Reranker: Dedup + rerank (40ms)
Reranker-->>Orchestrator: Top 3 chunks
Note over Orchestrator: Total: ~150ms (cache hit)<br>or ~170ms (cache miss)
Code Example: Unified RAG Pipeline
import time
from dataclasses import dataclass
@dataclass
class RAGResult:
chunks: list[dict]
total_latency_ms: float
embed_latency_ms: float
search_latency_ms: float
rerank_latency_ms: float
embed_source: str
candidates_searched: int
pre_filter_applied: bool
class OptimizedRAGPipeline:
"""Full RAG pipeline with all performance optimizations."""
def __init__(
self,
embedding_service: "CachedEmbeddingService",
vector_search: "OptimizedVectorSearch",
reranker: "OptimizedReranker",
):
self.embedding = embedding_service
self.search = vector_search
self.reranker = reranker
async def retrieve(
self,
query: str,
intent: str,
asin: str | None = None,
top_k: int = 3,
) -> RAGResult:
"""End-to-end retrieval with caching, pre-filtering, and reranking."""
start = time.monotonic()
# Step 1: Embed query (with cache)
embed_result = await self.embedding.embed(query)
# Step 2: Pre-filtered KNN search
search_result = await self.search.search(
query_vector=embed_result.vector,
intent=intent,
top_k=10, # Retrieve 10 for reranking
asin_filter=asin,
)
# Step 3: Rerank and select top chunks
if search_result.chunks:
ranked_chunks = await self.reranker.rerank(
query=query, chunks=search_result.chunks
)
else:
ranked_chunks = []
total_ms = (time.monotonic() - start) * 1000
return RAGResult(
chunks=[
{
"chunk_id": c.chunk_id,
"content": c.content,
"score": c.rerank_score,
"source_type": c.source_type,
}
for c in ranked_chunks[:top_k]
],
total_latency_ms=total_ms,
embed_latency_ms=embed_result.latency_ms,
search_latency_ms=search_result.search_latency_ms,
rerank_latency_ms=total_ms - embed_result.latency_ms - search_result.search_latency_ms,
embed_source=embed_result.source,
candidates_searched=search_result.total_candidates,
pre_filter_applied=search_result.pre_filter_applied,
)
Metrics and Monitoring
| Metric | Target | Alarm Threshold |
|---|---|---|
rag.total_latency_ms |
p95 < 200ms | p95 > 300ms for 5 min |
rag.embed_latency_ms |
p95 < 30ms | p95 > 50ms |
rag.search_latency_ms |
p95 < 100ms | p95 > 150ms |
rag.rerank_latency_ms |
p95 < 50ms | p95 > 80ms |
rag.embed_cache_hit_rate |
> 20% | < 10% |
rag.pre_filter_reduction |
> 40% | < 25% |
rag.chunks_after_dedup |
< 8 (from 10) | = 10 (no dedup) |
rag.avg_rerank_score |
> 0.5 | < 0.3 (quality issue) |
graph LR
subgraph "RAG Latency Budget"
A[Embed: 25ms] --> D[Total: < 200ms]
B[Search: 80ms] --> D
C[Rerank: 40ms] --> D
E[Other: 25ms] --> D
end