LOCAL PREVIEW View on GitHub

PO-05: Caching Layer Performance

User Story

As a infrastructure engineer, I want to optimize ElastiCache Redis for sub-2ms cache hits with high availability and efficient memory usage, So that cached product data, recommendations, and promotions are served instantly without adding latency to the request path.

Acceptance Criteria

  • Cache hit latency is under 2ms at p95 for single-key reads.
  • Multi-key reads (pipeline) complete in under 5ms for up to 10 keys.
  • Connection pool eliminates per-request connection overhead.
  • Cache hit rate exceeds 70% for product details and 50% for recommendations.
  • Cache cluster maintains < 70% memory utilization with proper eviction policies.
  • Failover to replica completes in under 5 seconds with no data loss for reads.

High-Level Design

Cache Latency Profile

graph LR
    subgraph "Without Optimization"
        A[New Connection<br>per request<br>~5ms overhead] --> B[Single GET<br>~3ms]
        B --> C[Another GET<br>~3ms]
        C --> D[Total: ~11ms<br>for 2 keys]
    end

    subgraph "With Optimization"
        E[Connection Pool<br>~0ms overhead] --> F[Pipeline GET×2<br>~1.5ms total]
        F --> G[Total: ~1.5ms<br>for 2 keys]
    end

    style D fill:#f66,stroke:#333
    style G fill:#2d8,stroke:#333

Optimization Strategy

graph TD
    subgraph "Connection Management"
        A1[Connection Pool<br>Pre-warmed]
        A2[Keep-Alive<br>TCP settings]
        A3[Cluster Mode<br>Local reads]
    end

    subgraph "Data Access Patterns"
        B1[Pipeline Reads<br>Batch multiple keys]
        B2[Serialization<br>MessagePack vs JSON]
        B3[Key Design<br>Minimize key length]
    end

    subgraph "Cluster Config"
        C1[Node Sizing<br>r7g.large]
        C2[Replica Placement<br>Same AZ as compute]
        C3[Eviction Policy<br>allkeys-lfu]
    end

    A1 --> D[p95 < 2ms]
    B1 --> D
    C2 --> D

Low-Level Design

1. Connection Pool Management

Creating a new Redis connection per request adds 3-5ms of TCP handshake overhead. A pre-warmed connection pool amortizes this to near-zero.

graph TD
    subgraph "Orchestrator Instance"
        A[Request Handler 1] --> D[Connection Pool<br>min=10, max=50]
        B[Request Handler 2] --> D
        C[Request Handler 3] --> D
    end

    D --> E[Redis Primary<br>Writes]
    D --> F[Redis Replica 1<br>Same-AZ Reads]
    D --> G[Redis Replica 2<br>Cross-AZ Reads]

    style D fill:#2d8,stroke:#333
    style F fill:#2d8,stroke:#333

Code Example: Optimized Redis Cache Client

import asyncio
import json
import logging
import time
from dataclasses import dataclass
from typing import Any, Optional

import msgpack
import redis.asyncio as aioredis

logger = logging.getLogger(__name__)


@dataclass
class CacheGetResult:
    value: Any
    hit: bool
    latency_ms: float
    key: str


@dataclass
class CacheStats:
    hits: int = 0
    misses: int = 0
    errors: int = 0
    avg_hit_latency_ms: float = 0.0
    avg_miss_latency_ms: float = 0.0


class OptimizedRedisClient:
    """High-performance Redis client with connection pooling and pipelining."""

    def __init__(
        self,
        primary_endpoint: str,
        reader_endpoint: str,
        port: int = 6379,
        password: Optional[str] = None,
        min_connections: int = 10,
        max_connections: int = 50,
        socket_timeout: float = 0.5,
        socket_connect_timeout: float = 0.5,
    ):
        # Write pool (primary)
        self._write_pool = aioredis.ConnectionPool.from_url(
            f"redis://{primary_endpoint}:{port}",
            password=password,
            min_connections=min_connections,
            max_connections=max_connections,
            socket_timeout=socket_timeout,
            socket_connect_timeout=socket_connect_timeout,
            socket_keepalive=True,
            health_check_interval=15,
        )
        self._write_client = aioredis.Redis(connection_pool=self._write_pool)

        # Read pool (replica for lower latency)
        self._read_pool = aioredis.ConnectionPool.from_url(
            f"redis://{reader_endpoint}:{port}",
            password=password,
            min_connections=min_connections,
            max_connections=max_connections,
            socket_timeout=socket_timeout,
            socket_connect_timeout=socket_connect_timeout,
            socket_keepalive=True,
            health_check_interval=15,
        )
        self._read_client = aioredis.Redis(connection_pool=self._read_pool)

    async def get(self, key: str) -> CacheGetResult:
        """Single key read from replica with latency tracking."""
        start = time.monotonic()
        try:
            raw = await self._read_client.get(key)
            latency = (time.monotonic() - start) * 1000

            if raw is None:
                return CacheGetResult(value=None, hit=False, latency_ms=latency, key=key)

            value = msgpack.unpackb(raw, raw=False)
            return CacheGetResult(value=value, hit=True, latency_ms=latency, key=key)

        except Exception as e:
            latency = (time.monotonic() - start) * 1000
            logger.warning(f"Cache GET failed for {key}: {e}")
            return CacheGetResult(value=None, hit=False, latency_ms=latency, key=key)

    async def get_many(self, keys: list[str]) -> list[CacheGetResult]:
        """Pipeline multiple key reads in a single round-trip."""
        if not keys:
            return []

        start = time.monotonic()
        try:
            async with self._read_client.pipeline(transaction=False) as pipe:
                for key in keys:
                    pipe.get(key)
                raw_results = await pipe.execute()

            latency = (time.monotonic() - start) * 1000
            per_key_latency = latency / len(keys)

            results = []
            for key, raw in zip(keys, raw_results):
                if raw is not None:
                    value = msgpack.unpackb(raw, raw=False)
                    results.append(
                        CacheGetResult(
                            value=value, hit=True,
                            latency_ms=per_key_latency, key=key,
                        )
                    )
                else:
                    results.append(
                        CacheGetResult(
                            value=None, hit=False,
                            latency_ms=per_key_latency, key=key,
                        )
                    )

            return results

        except Exception as e:
            latency = (time.monotonic() - start) * 1000
            logger.warning(f"Cache pipeline GET failed: {e}")
            return [
                CacheGetResult(value=None, hit=False, latency_ms=latency, key=k)
                for k in keys
            ]

    async def set(
        self,
        key: str,
        value: Any,
        ttl_seconds: int = 300,
    ) -> bool:
        """Write to primary with MessagePack serialization."""
        try:
            packed = msgpack.packb(value, use_bin_type=True)
            await self._write_client.setex(key, ttl_seconds, packed)
            return True
        except Exception as e:
            logger.warning(f"Cache SET failed for {key}: {e}")
            return False

    async def set_many(
        self,
        items: dict[str, tuple[Any, int]],
    ) -> int:
        """Pipeline multiple writes. items = {key: (value, ttl_seconds)}."""
        if not items:
            return 0

        success_count = 0
        try:
            async with self._write_client.pipeline(transaction=False) as pipe:
                for key, (value, ttl) in items.items():
                    packed = msgpack.packb(value, use_bin_type=True)
                    pipe.setex(key, ttl, packed)
                results = await pipe.execute()
                success_count = sum(1 for r in results if r)
        except Exception as e:
            logger.warning(f"Cache pipeline SET failed: {e}")

        return success_count

    async def delete(self, key: str) -> bool:
        """Delete a key from cache (for invalidation)."""
        try:
            await self._write_client.delete(key)
            return True
        except Exception as e:
            logger.warning(f"Cache DELETE failed for {key}: {e}")
            return False

    async def close(self) -> None:
        """Clean shutdown of connection pools."""
        await self._write_client.close()
        await self._read_client.close()
        await self._write_pool.disconnect()
        await self._read_pool.disconnect()

2. Cache-Aside Pattern with Fallback

The orchestrator always tries cache first. On miss, it fetches from the origin service and populates the cache. If Redis is entirely down, requests fall through to origin.

sequenceDiagram
    participant Orchestrator
    participant Cache as ElastiCache Redis
    participant Origin as Origin Service

    Orchestrator->>Cache: GET product:{asin}
    alt Cache Hit (~70%)
        Cache-->>Orchestrator: Data (1ms)
    else Cache Miss (~30%)
        Cache-->>Orchestrator: null
        Orchestrator->>Origin: GET /products/{asin}
        Origin-->>Orchestrator: Product data (50ms)
        Orchestrator->>Cache: SET product:{asin} TTL=300s
    else Redis Down
        Cache-->>Orchestrator: Connection error
        Note over Orchestrator: Fallback to origin directly
        Orchestrator->>Origin: GET /products/{asin}
        Origin-->>Orchestrator: Product data (50ms)
    end

Code Example: Cache-Aside Service Wrapper

import asyncio
import logging
import time
from typing import Any, Callable, Optional

logger = logging.getLogger(__name__)


class CacheAsideWrapper:
    """Generic cache-aside wrapper for any origin service call."""

    def __init__(self, cache: "OptimizedRedisClient"):
        self.cache = cache

    async def get_or_fetch(
        self,
        cache_key: str,
        fetch_fn: Callable[[], Any],
        ttl_seconds: int = 300,
    ) -> tuple[Any, bool]:
        """
        Try cache first, fetch on miss, populate cache.
        Returns (data, was_cache_hit).
        """
        # Try cache
        result = await self.cache.get(cache_key)
        if result.hit:
            return result.value, True

        # Fetch from origin
        data = await fetch_fn()
        if data is not None:
            # Populate cache (fire-and-forget)
            asyncio.create_task(
                self.cache.set(cache_key, data, ttl_seconds)
            )

        return data, False

    async def get_many_or_fetch(
        self,
        keys_and_fetchers: dict[str, tuple[Callable, int]],
    ) -> dict[str, tuple[Any, bool]]:
        """
        Batch cache check, then fetch only missed keys.
        keys_and_fetchers = {cache_key: (fetch_fn, ttl_seconds)}
        """
        keys = list(keys_and_fetchers.keys())

        # Pipeline cache check
        cache_results = await self.cache.get_many(keys)

        results = {}
        fetch_tasks = {}

        for result in cache_results:
            if result.hit:
                results[result.key] = (result.value, True)
            else:
                # Schedule fetch for this key
                fetch_fn, ttl = keys_and_fetchers[result.key]
                fetch_tasks[result.key] = (fetch_fn, ttl)

        # Fetch all missed keys in parallel
        if fetch_tasks:
            fetch_results = await asyncio.gather(
                *[fn() for fn, _ in fetch_tasks.values()],
                return_exceptions=True,
            )

            populate_items = {}
            for (key, (_, ttl)), fetched in zip(
                fetch_tasks.items(), fetch_results
            ):
                if isinstance(fetched, Exception):
                    logger.warning(f"Fetch failed for {key}: {fetched}")
                    results[key] = (None, False)
                else:
                    results[key] = (fetched, False)
                    if fetched is not None:
                        populate_items[key] = (fetched, ttl)

            # Populate cache for all fetched items
            if populate_items:
                asyncio.create_task(self.cache.set_many(populate_items))

        return results

3. Event-Driven Cache Invalidation

Stale data is worse than a cache miss. SNS events from catalog and promotion services trigger targeted invalidation.

graph TD
    subgraph "Event Sources"
        A[Product Catalog<br>Change Event] -->|SNS| D[Invalidation Lambda]
        B[Promotion<br>Change Event] -->|SNS| D
        C[Review Update<br>Event] -->|SNS| D
    end

    subgraph "Invalidation Lambda"
        D --> E{Event Type?}
        E -->|product_updated| F[DELETE product:{asin}]
        E -->|product_price_changed| G[DELETE product:{asin}<br>+ related reco keys]
        E -->|promotion_changed| H[DELETE promo:{section}]
        E -->|review_updated| I[DELETE review:{asin}]
    end

    F --> J[ElastiCache Redis]
    G --> J
    H --> J
    I --> J

Code Example: Cache Invalidation Handler

import json
import logging
from typing import Any

import redis

logger = logging.getLogger(__name__)


class CacheInvalidationHandler:
    """Handles SNS-triggered cache invalidation events."""

    def __init__(self, redis_client: redis.Redis):
        self.redis = redis_client

    def handle_sns_event(self, event: dict) -> dict[str, int]:
        """Process an SNS event and invalidate affected cache keys."""
        records = event.get("Records", [])
        deleted_count = 0
        error_count = 0

        for record in records:
            try:
                message = json.loads(record["Sns"]["Message"])
                event_type = message.get("event_type")
                keys = self._resolve_keys(event_type, message)

                if keys:
                    deleted = self.redis.delete(*keys)
                    deleted_count += deleted
                    logger.info(
                        f"Invalidated {deleted} keys for event {event_type}: {keys}"
                    )

            except Exception as e:
                error_count += 1
                logger.error(f"Failed to process invalidation event: {e}")

        return {"deleted": deleted_count, "errors": error_count}

    def _resolve_keys(self, event_type: str, message: dict) -> list[str]:
        """Map event type to cache keys that need invalidation."""
        keys = []

        if event_type == "product_updated":
            asin = message.get("asin")
            if asin:
                keys.append(f"product:{asin}")
                # Also invalidate related recommendation caches
                # Pattern-based deletion for reco keys containing this ASIN
                for key in self._scan_pattern(f"reco:*:{asin}"):
                    keys.append(key)

        elif event_type == "product_price_changed":
            # Prices are never cached, but product details include price
            asin = message.get("asin")
            if asin:
                keys.append(f"product:{asin}")

        elif event_type == "promotion_changed":
            section = message.get("store_section", "manga-home")
            keys.append(f"promo:{section}")

        elif event_type == "review_updated":
            asin = message.get("asin")
            if asin:
                keys.append(f"review:{asin}")

        return keys

    def _scan_pattern(self, pattern: str) -> list[str]:
        """Scan for keys matching a pattern (use sparingly)."""
        keys = []
        cursor = 0
        while True:
            cursor, found = self.redis.scan(cursor, match=pattern, count=100)
            keys.extend(k.decode() if isinstance(k, bytes) else k for k in found)
            if cursor == 0:
                break
        return keys


# Lambda handler entry point
def lambda_handler(event: dict, context: Any) -> dict:
    """AWS Lambda handler for SNS cache invalidation events."""
    import os
    redis_client = redis.Redis(
        host=os.environ["REDIS_PRIMARY_ENDPOINT"],
        port=6379,
        password=os.environ.get("REDIS_PASSWORD"),
        ssl=True,
    )

    handler = CacheInvalidationHandler(redis_client)
    result = handler.handle_sns_event(event)

    return {
        "statusCode": 200,
        "body": json.dumps(result),
    }

Metrics and Monitoring

Metric Target Alarm Threshold
cache.get_latency_ms p95 < 2ms p95 > 5ms for 3 min
cache.pipeline_latency_ms p95 < 5ms (10 keys) p95 > 10ms
cache.hit_rate_product > 70% < 50%
cache.hit_rate_reco > 50% < 30%
cache.hit_rate_promo > 80% < 60%
cache.memory_utilization < 70% > 85%
cache.evictions_per_sec < 10 > 50
cache.connection_pool_usage < 80% > 90%
cache.invalidation_latency_ms p95 < 100ms p95 > 500ms
graph LR
    subgraph "Cache Performance Targets"
        A[Single GET<br>< 2ms p95]
        B[Pipeline 10 keys<br>< 5ms p95]
        C[Hit Rate<br>> 65% overall]
        D[Memory Usage<br>< 70%]
    end