PO-05: Caching Layer Performance
User Story
As a infrastructure engineer, I want to optimize ElastiCache Redis for sub-2ms cache hits with high availability and efficient memory usage, So that cached product data, recommendations, and promotions are served instantly without adding latency to the request path.
Acceptance Criteria
- Cache hit latency is under 2ms at p95 for single-key reads.
- Multi-key reads (pipeline) complete in under 5ms for up to 10 keys.
- Connection pool eliminates per-request connection overhead.
- Cache hit rate exceeds 70% for product details and 50% for recommendations.
- Cache cluster maintains < 70% memory utilization with proper eviction policies.
- Failover to replica completes in under 5 seconds with no data loss for reads.
High-Level Design
Cache Latency Profile
graph LR
subgraph "Without Optimization"
A[New Connection<br>per request<br>~5ms overhead] --> B[Single GET<br>~3ms]
B --> C[Another GET<br>~3ms]
C --> D[Total: ~11ms<br>for 2 keys]
end
subgraph "With Optimization"
E[Connection Pool<br>~0ms overhead] --> F[Pipeline GET×2<br>~1.5ms total]
F --> G[Total: ~1.5ms<br>for 2 keys]
end
style D fill:#f66,stroke:#333
style G fill:#2d8,stroke:#333
Optimization Strategy
graph TD
subgraph "Connection Management"
A1[Connection Pool<br>Pre-warmed]
A2[Keep-Alive<br>TCP settings]
A3[Cluster Mode<br>Local reads]
end
subgraph "Data Access Patterns"
B1[Pipeline Reads<br>Batch multiple keys]
B2[Serialization<br>MessagePack vs JSON]
B3[Key Design<br>Minimize key length]
end
subgraph "Cluster Config"
C1[Node Sizing<br>r7g.large]
C2[Replica Placement<br>Same AZ as compute]
C3[Eviction Policy<br>allkeys-lfu]
end
A1 --> D[p95 < 2ms]
B1 --> D
C2 --> D
Low-Level Design
1. Connection Pool Management
Creating a new Redis connection per request adds 3-5ms of TCP handshake overhead. A pre-warmed connection pool amortizes this to near-zero.
graph TD
subgraph "Orchestrator Instance"
A[Request Handler 1] --> D[Connection Pool<br>min=10, max=50]
B[Request Handler 2] --> D
C[Request Handler 3] --> D
end
D --> E[Redis Primary<br>Writes]
D --> F[Redis Replica 1<br>Same-AZ Reads]
D --> G[Redis Replica 2<br>Cross-AZ Reads]
style D fill:#2d8,stroke:#333
style F fill:#2d8,stroke:#333
Code Example: Optimized Redis Cache Client
import asyncio
import json
import logging
import time
from dataclasses import dataclass
from typing import Any, Optional
import msgpack
import redis.asyncio as aioredis
logger = logging.getLogger(__name__)
@dataclass
class CacheGetResult:
value: Any
hit: bool
latency_ms: float
key: str
@dataclass
class CacheStats:
hits: int = 0
misses: int = 0
errors: int = 0
avg_hit_latency_ms: float = 0.0
avg_miss_latency_ms: float = 0.0
class OptimizedRedisClient:
"""High-performance Redis client with connection pooling and pipelining."""
def __init__(
self,
primary_endpoint: str,
reader_endpoint: str,
port: int = 6379,
password: Optional[str] = None,
min_connections: int = 10,
max_connections: int = 50,
socket_timeout: float = 0.5,
socket_connect_timeout: float = 0.5,
):
# Write pool (primary)
self._write_pool = aioredis.ConnectionPool.from_url(
f"redis://{primary_endpoint}:{port}",
password=password,
min_connections=min_connections,
max_connections=max_connections,
socket_timeout=socket_timeout,
socket_connect_timeout=socket_connect_timeout,
socket_keepalive=True,
health_check_interval=15,
)
self._write_client = aioredis.Redis(connection_pool=self._write_pool)
# Read pool (replica for lower latency)
self._read_pool = aioredis.ConnectionPool.from_url(
f"redis://{reader_endpoint}:{port}",
password=password,
min_connections=min_connections,
max_connections=max_connections,
socket_timeout=socket_timeout,
socket_connect_timeout=socket_connect_timeout,
socket_keepalive=True,
health_check_interval=15,
)
self._read_client = aioredis.Redis(connection_pool=self._read_pool)
async def get(self, key: str) -> CacheGetResult:
"""Single key read from replica with latency tracking."""
start = time.monotonic()
try:
raw = await self._read_client.get(key)
latency = (time.monotonic() - start) * 1000
if raw is None:
return CacheGetResult(value=None, hit=False, latency_ms=latency, key=key)
value = msgpack.unpackb(raw, raw=False)
return CacheGetResult(value=value, hit=True, latency_ms=latency, key=key)
except Exception as e:
latency = (time.monotonic() - start) * 1000
logger.warning(f"Cache GET failed for {key}: {e}")
return CacheGetResult(value=None, hit=False, latency_ms=latency, key=key)
async def get_many(self, keys: list[str]) -> list[CacheGetResult]:
"""Pipeline multiple key reads in a single round-trip."""
if not keys:
return []
start = time.monotonic()
try:
async with self._read_client.pipeline(transaction=False) as pipe:
for key in keys:
pipe.get(key)
raw_results = await pipe.execute()
latency = (time.monotonic() - start) * 1000
per_key_latency = latency / len(keys)
results = []
for key, raw in zip(keys, raw_results):
if raw is not None:
value = msgpack.unpackb(raw, raw=False)
results.append(
CacheGetResult(
value=value, hit=True,
latency_ms=per_key_latency, key=key,
)
)
else:
results.append(
CacheGetResult(
value=None, hit=False,
latency_ms=per_key_latency, key=key,
)
)
return results
except Exception as e:
latency = (time.monotonic() - start) * 1000
logger.warning(f"Cache pipeline GET failed: {e}")
return [
CacheGetResult(value=None, hit=False, latency_ms=latency, key=k)
for k in keys
]
async def set(
self,
key: str,
value: Any,
ttl_seconds: int = 300,
) -> bool:
"""Write to primary with MessagePack serialization."""
try:
packed = msgpack.packb(value, use_bin_type=True)
await self._write_client.setex(key, ttl_seconds, packed)
return True
except Exception as e:
logger.warning(f"Cache SET failed for {key}: {e}")
return False
async def set_many(
self,
items: dict[str, tuple[Any, int]],
) -> int:
"""Pipeline multiple writes. items = {key: (value, ttl_seconds)}."""
if not items:
return 0
success_count = 0
try:
async with self._write_client.pipeline(transaction=False) as pipe:
for key, (value, ttl) in items.items():
packed = msgpack.packb(value, use_bin_type=True)
pipe.setex(key, ttl, packed)
results = await pipe.execute()
success_count = sum(1 for r in results if r)
except Exception as e:
logger.warning(f"Cache pipeline SET failed: {e}")
return success_count
async def delete(self, key: str) -> bool:
"""Delete a key from cache (for invalidation)."""
try:
await self._write_client.delete(key)
return True
except Exception as e:
logger.warning(f"Cache DELETE failed for {key}: {e}")
return False
async def close(self) -> None:
"""Clean shutdown of connection pools."""
await self._write_client.close()
await self._read_client.close()
await self._write_pool.disconnect()
await self._read_pool.disconnect()
2. Cache-Aside Pattern with Fallback
The orchestrator always tries cache first. On miss, it fetches from the origin service and populates the cache. If Redis is entirely down, requests fall through to origin.
sequenceDiagram
participant Orchestrator
participant Cache as ElastiCache Redis
participant Origin as Origin Service
Orchestrator->>Cache: GET product:{asin}
alt Cache Hit (~70%)
Cache-->>Orchestrator: Data (1ms)
else Cache Miss (~30%)
Cache-->>Orchestrator: null
Orchestrator->>Origin: GET /products/{asin}
Origin-->>Orchestrator: Product data (50ms)
Orchestrator->>Cache: SET product:{asin} TTL=300s
else Redis Down
Cache-->>Orchestrator: Connection error
Note over Orchestrator: Fallback to origin directly
Orchestrator->>Origin: GET /products/{asin}
Origin-->>Orchestrator: Product data (50ms)
end
Code Example: Cache-Aside Service Wrapper
import asyncio
import logging
import time
from typing import Any, Callable, Optional
logger = logging.getLogger(__name__)
class CacheAsideWrapper:
"""Generic cache-aside wrapper for any origin service call."""
def __init__(self, cache: "OptimizedRedisClient"):
self.cache = cache
async def get_or_fetch(
self,
cache_key: str,
fetch_fn: Callable[[], Any],
ttl_seconds: int = 300,
) -> tuple[Any, bool]:
"""
Try cache first, fetch on miss, populate cache.
Returns (data, was_cache_hit).
"""
# Try cache
result = await self.cache.get(cache_key)
if result.hit:
return result.value, True
# Fetch from origin
data = await fetch_fn()
if data is not None:
# Populate cache (fire-and-forget)
asyncio.create_task(
self.cache.set(cache_key, data, ttl_seconds)
)
return data, False
async def get_many_or_fetch(
self,
keys_and_fetchers: dict[str, tuple[Callable, int]],
) -> dict[str, tuple[Any, bool]]:
"""
Batch cache check, then fetch only missed keys.
keys_and_fetchers = {cache_key: (fetch_fn, ttl_seconds)}
"""
keys = list(keys_and_fetchers.keys())
# Pipeline cache check
cache_results = await self.cache.get_many(keys)
results = {}
fetch_tasks = {}
for result in cache_results:
if result.hit:
results[result.key] = (result.value, True)
else:
# Schedule fetch for this key
fetch_fn, ttl = keys_and_fetchers[result.key]
fetch_tasks[result.key] = (fetch_fn, ttl)
# Fetch all missed keys in parallel
if fetch_tasks:
fetch_results = await asyncio.gather(
*[fn() for fn, _ in fetch_tasks.values()],
return_exceptions=True,
)
populate_items = {}
for (key, (_, ttl)), fetched in zip(
fetch_tasks.items(), fetch_results
):
if isinstance(fetched, Exception):
logger.warning(f"Fetch failed for {key}: {fetched}")
results[key] = (None, False)
else:
results[key] = (fetched, False)
if fetched is not None:
populate_items[key] = (fetched, ttl)
# Populate cache for all fetched items
if populate_items:
asyncio.create_task(self.cache.set_many(populate_items))
return results
3. Event-Driven Cache Invalidation
Stale data is worse than a cache miss. SNS events from catalog and promotion services trigger targeted invalidation.
graph TD
subgraph "Event Sources"
A[Product Catalog<br>Change Event] -->|SNS| D[Invalidation Lambda]
B[Promotion<br>Change Event] -->|SNS| D
C[Review Update<br>Event] -->|SNS| D
end
subgraph "Invalidation Lambda"
D --> E{Event Type?}
E -->|product_updated| F[DELETE product:{asin}]
E -->|product_price_changed| G[DELETE product:{asin}<br>+ related reco keys]
E -->|promotion_changed| H[DELETE promo:{section}]
E -->|review_updated| I[DELETE review:{asin}]
end
F --> J[ElastiCache Redis]
G --> J
H --> J
I --> J
Code Example: Cache Invalidation Handler
import json
import logging
from typing import Any
import redis
logger = logging.getLogger(__name__)
class CacheInvalidationHandler:
"""Handles SNS-triggered cache invalidation events."""
def __init__(self, redis_client: redis.Redis):
self.redis = redis_client
def handle_sns_event(self, event: dict) -> dict[str, int]:
"""Process an SNS event and invalidate affected cache keys."""
records = event.get("Records", [])
deleted_count = 0
error_count = 0
for record in records:
try:
message = json.loads(record["Sns"]["Message"])
event_type = message.get("event_type")
keys = self._resolve_keys(event_type, message)
if keys:
deleted = self.redis.delete(*keys)
deleted_count += deleted
logger.info(
f"Invalidated {deleted} keys for event {event_type}: {keys}"
)
except Exception as e:
error_count += 1
logger.error(f"Failed to process invalidation event: {e}")
return {"deleted": deleted_count, "errors": error_count}
def _resolve_keys(self, event_type: str, message: dict) -> list[str]:
"""Map event type to cache keys that need invalidation."""
keys = []
if event_type == "product_updated":
asin = message.get("asin")
if asin:
keys.append(f"product:{asin}")
# Also invalidate related recommendation caches
# Pattern-based deletion for reco keys containing this ASIN
for key in self._scan_pattern(f"reco:*:{asin}"):
keys.append(key)
elif event_type == "product_price_changed":
# Prices are never cached, but product details include price
asin = message.get("asin")
if asin:
keys.append(f"product:{asin}")
elif event_type == "promotion_changed":
section = message.get("store_section", "manga-home")
keys.append(f"promo:{section}")
elif event_type == "review_updated":
asin = message.get("asin")
if asin:
keys.append(f"review:{asin}")
return keys
def _scan_pattern(self, pattern: str) -> list[str]:
"""Scan for keys matching a pattern (use sparingly)."""
keys = []
cursor = 0
while True:
cursor, found = self.redis.scan(cursor, match=pattern, count=100)
keys.extend(k.decode() if isinstance(k, bytes) else k for k in found)
if cursor == 0:
break
return keys
# Lambda handler entry point
def lambda_handler(event: dict, context: Any) -> dict:
"""AWS Lambda handler for SNS cache invalidation events."""
import os
redis_client = redis.Redis(
host=os.environ["REDIS_PRIMARY_ENDPOINT"],
port=6379,
password=os.environ.get("REDIS_PASSWORD"),
ssl=True,
)
handler = CacheInvalidationHandler(redis_client)
result = handler.handle_sns_event(event)
return {
"statusCode": 200,
"body": json.dumps(result),
}
Metrics and Monitoring
| Metric | Target | Alarm Threshold |
|---|---|---|
cache.get_latency_ms |
p95 < 2ms | p95 > 5ms for 3 min |
cache.pipeline_latency_ms |
p95 < 5ms (10 keys) | p95 > 10ms |
cache.hit_rate_product |
> 70% | < 50% |
cache.hit_rate_reco |
> 50% | < 30% |
cache.hit_rate_promo |
> 80% | < 60% |
cache.memory_utilization |
< 70% | > 85% |
cache.evictions_per_sec |
< 10 | > 50 |
cache.connection_pool_usage |
< 80% | > 90% |
cache.invalidation_latency_ms |
p95 < 100ms | p95 > 500ms |
graph LR
subgraph "Cache Performance Targets"
A[Single GET<br>< 2ms p95]
B[Pipeline 10 keys<br>< 5ms p95]
C[Hit Rate<br>> 65% overall]
D[Memory Usage<br>< 70%]
end