LOCAL PREVIEW View on GitHub

Resilient FM Systems Architecture

MangaAssist context: JP Manga store chatbot on AWS — Bedrock Claude 3 (Sonnet at $3/$15 per 1M tokens input/output, Haiku at $0.25/$1.25), OpenSearch Serverless (vector store), DynamoDB (sessions/products), ECS Fargate (orchestrator), API Gateway WebSocket, ElastiCache Redis. Target: useful answer in under 3 seconds, 1M messages/day scale.


Skill Mapping

Attribute Value
Certification AWS Certified AI Practitioner (AIP-C01)
Domain 2 — Implementation and Integration of Foundation Models
Task 2.4 — Design resilient and scalable FM-based applications
Skill 2.4.3 — Create resilient FM systems to ensure reliable operations
Focus Areas AWS SDK exponential backoff, API Gateway rate limiting, fallback mechanisms for graceful degradation, X-Ray for observability across service boundaries

Mindmap: FM Resilience Pillars

mindmap
  root((Resilient FM Systems))
    Retries & Backoff
      Exponential Backoff
        Base delay doubling
        Maximum retry cap
      Jitter Strategy
        Full jitter
        Decorrelated jitter
        Equal jitter
      Circuit Breaker
        Closed state — normal
        Open state — fast fail
        Half-open state — probe
      SDK Retry Config
        boto3 retry modes
        Adaptive retry
        Standard retry
    Rate Limiting
      API Gateway Throttling
        Account-level limits
        Stage-level limits
        Method-level limits
      Usage Plans
        Throttle rate
        Burst capacity
        Quota allocation
      Token Bucket Algorithm
        Steady-state rate
        Burst allowance
        Refill cadence
      Client-Side Rate Limiting
        Pre-emptive throttle
        Request queuing
        Priority lanes
    Fallback Mechanisms
      Model Tiering
        Sonnet primary
        Haiku secondary
        Cached tertiary
      Graceful Degradation
        Feature reduction
        Static responses
        Maintenance messaging
      Cache-Based Fallback
        Redis hot cache
        Stale-while-revalidate
        TTL strategies
      Content Fallback
        Pre-computed answers
        FAQ database
        Human escalation
    Observability
      AWS X-Ray
        Distributed tracing
        Service maps
        Subsegment annotations
      CloudWatch Metrics
        Latency percentiles
        Error rates
        Throttle counts
      Alarms & Dashboards
        Composite alarms
        Anomaly detection
        Real-time dashboards
      Logging
        Structured JSON logs
        Correlation IDs
        Log Insights queries

Architecture Flowchart: MangaAssist Resilience Layers

flowchart TB
    subgraph ClientLayer["Client Layer"]
        User["Manga Store User<br/>Mobile / Web"]
        WSConn["WebSocket Connection<br/>Persistent Session"]
    end

    subgraph EdgeLayer["Edge & Ingress Layer"]
        CF["CloudFront<br/>Edge Cache"]
        WAF["AWS WAF<br/>Rate Rules + IP Throttle"]
        APIGW["API Gateway WebSocket<br/>Usage Plans + Throttling"]
    end

    subgraph RateLimitLayer["Rate Limiting Layer"]
        RL_Account["Account-Level Throttle<br/>10,000 req/s steady"]
        RL_Stage["Stage-Level Throttle<br/>5,000 req/s per stage"]
        RL_Method["Method-Level Throttle<br/>1,000 req/s per route"]
        RL_User["Per-User Throttle<br/>Redis Token Bucket"]
    end

    subgraph OrchestrationLayer["Orchestration Layer — ECS Fargate"]
        LB["Application Load Balancer<br/>Health Checks"]
        Orch["FM Orchestrator Service<br/>Retry + Fallback Logic"]
        CB["Circuit Breaker<br/>Failure Threshold Monitor"]
        BackoffEngine["Exponential Backoff Engine<br/>Jitter + Max Retry"]
    end

    subgraph FMLayer["Foundation Model Layer"]
        Sonnet["Claude 3 Sonnet<br/>Primary Model<br/>$3/$15 per 1M tokens"]
        Haiku["Claude 3 Haiku<br/>Fallback Model<br/>$0.25/$1.25 per 1M tokens"]
        BedrockGW["Bedrock Runtime API<br/>InvokeModel / Converse"]
    end

    subgraph FallbackLayer["Fallback & Cache Layer"]
        Redis["ElastiCache Redis<br/>Response Cache<br/>Hot Answers"]
        DDB["DynamoDB<br/>Product Data + Sessions"]
        StaticResp["Static Response Store<br/>Pre-Computed FAQ Answers"]
        GracefulMsg["Graceful Degradation<br/>User-Friendly Messages"]
    end

    subgraph ObservabilityLayer["Observability Layer"]
        XRay["AWS X-Ray<br/>Distributed Tracing"]
        CW["CloudWatch<br/>Metrics + Alarms"]
        CWLogs["CloudWatch Logs<br/>Structured JSON"]
        Dashboard["CloudWatch Dashboard<br/>FM Health Overview"]
        SNS["SNS Alerts<br/>On-Call Notification"]
    end

    User --> WSConn --> CF --> WAF --> APIGW
    APIGW --> RL_Account --> RL_Stage --> RL_Method --> RL_User
    RL_User --> LB --> Orch
    Orch --> CB
    CB -->|Closed| BackoffEngine --> BedrockGW --> Sonnet
    CB -->|Open| Redis
    BedrockGW -->|Throttled/Error| BackoffEngine
    BackoffEngine -->|Max Retries Exceeded| Haiku
    Haiku -->|Fail| Redis -->|Miss| StaticResp -->|Miss| GracefulMsg

    Orch -.->|Trace| XRay
    Orch -.->|Metrics| CW
    Orch -.->|Logs| CWLogs
    CW -.-> Dashboard
    CW -.->|Alarm| SNS
    XRay -.-> Dashboard

    style ClientLayer fill:#e3f2fd,stroke:#1565c0,stroke-width:2px
    style EdgeLayer fill:#fff3e0,stroke:#e65100,stroke-width:2px
    style RateLimitLayer fill:#fce4ec,stroke:#c62828,stroke-width:2px
    style OrchestrationLayer fill:#e8f5e9,stroke:#2e7d32,stroke-width:2px
    style FMLayer fill:#f3e5f5,stroke:#6a1b9a,stroke-width:2px
    style FallbackLayer fill:#fff8e1,stroke:#f57f17,stroke-width:2px
    style ObservabilityLayer fill:#e0f7fa,stroke:#00695c,stroke-width:2px

Resilience Decision Flow

flowchart TD
    Start["Incoming User Message"] --> Validate["Validate Request<br/>Auth + Schema"]
    Validate -->|Invalid| Reject["Return 400<br/>Validation Error"]
    Validate -->|Valid| CheckRL["Check Rate Limits"]
    CheckRL -->|Exceeded| Throttle["Return 429<br/>Retry-After Header"]
    CheckRL -->|OK| CheckCB["Check Circuit Breaker State"]
    CheckCB -->|Open| CacheFallback["Serve Cached Response"]
    CheckCB -->|Closed/Half-Open| InvokePrimary["Invoke Claude 3 Sonnet<br/>via Bedrock"]

    InvokePrimary -->|Success| Return["Return Response<br/>Cache Result"]
    InvokePrimary -->|ThrottlingException| Backoff["Exponential Backoff<br/>with Full Jitter"]
    InvokePrimary -->|ServiceUnavailable| Backoff
    InvokePrimary -->|ModelTimeout| Backoff

    Backoff --> RetryCheck{"Retry Count<br/>< Max Retries?"}
    RetryCheck -->|Yes| InvokePrimary
    RetryCheck -->|No| FallbackHaiku["Fallback: Claude 3 Haiku"]

    FallbackHaiku -->|Success| Return
    FallbackHaiku -->|Fail| CacheFallback

    CacheFallback -->|Hit| Return
    CacheFallback -->|Miss| StaticFallback["Static FAQ Response"]

    StaticFallback -->|Match| Return
    StaticFallback -->|No Match| GracefulDeg["Graceful Degradation Message<br/>'Let me check on that...'"]
    GracefulDeg --> Return

    style Start fill:#c8e6c9
    style Return fill:#c8e6c9
    style Reject fill:#ffcdd2
    style Throttle fill:#ffcdd2
    style Backoff fill:#fff9c4
    style FallbackHaiku fill:#fff9c4
    style CacheFallback fill:#ffe0b2
    style GracefulDeg fill:#ffe0b2

1. Exponential Backoff with Jitter for Bedrock Calls

Why Exponential Backoff Matters for MangaAssist

At 1M messages/day (~11.6 requests/second average, with burst peaks of 50-100 rps during evening hours in Japan), MangaAssist will inevitably encounter Bedrock API throttling. The default boto3 retry behavior is insufficient because:

  1. Correlated retries — All ECS tasks retry at the same intervals, creating thundering herd
  2. Fixed delays — Without jitter, retries synchronize and amplify the overload
  3. No awareness — Standard retries do not distinguish between transient throttles and persistent failures

Jitter Strategies Compared

Strategy Formula Best For Drawback
No Jitter min(cap, base * 2^attempt) Never — causes thundering herd Synchronized retries
Full Jitter random(0, min(cap, base * 2^attempt)) General purpose Occasionally very short waits
Equal Jitter min(cap, base * 2^attempt) / 2 + random(0, min(cap, base * 2^attempt) / 2) When minimum wait is needed Slightly less spread
Decorrelated Jitter min(cap, random(base, prev_sleep * 3)) High-contention scenarios Can produce long waits early

For MangaAssist, decorrelated jitter is optimal because the fleet of ECS tasks creates high contention on the Bedrock API during peak manga browsing hours (7 PM - 11 PM JST).

Backoff Timing Visualization

Attempt 1: |==|                                    delay: 100-300ms
Attempt 2: |======|                                delay: 200-900ms
Attempt 3: |==============|                        delay: 400-2700ms
Attempt 4: |============================|          delay: 800-8100ms
Attempt 5: |=========================================| delay: max 10000ms
            0     2s    4s    6s    8s    10s

Boto3 Retry Configuration Modes

Standard Mode:
  - Max attempts: 3 (default)
  - Retryable errors: Throttling, TransientError, ServiceUnavailable
  - Backoff: Exponential with full jitter
  - Token bucket: No

Adaptive Mode:
  - Max attempts: 3 (default)
  - Retryable errors: Same as standard
  - Backoff: Exponential with full jitter
  - Token bucket: Yes — client-side rate limiter adjusts to server throttle signals

Legacy Mode:
  - Max attempts: 5 (default)
  - Retryable errors: Broader set
  - Backoff: Exponential (no jitter by default)
  - Token bucket: No

Code: ExponentialBackoffClient

"""
ExponentialBackoffClient — Resilient Bedrock client with exponential backoff,
jitter strategies, and circuit breaker integration for MangaAssist.

Handles:
- ThrottlingException from Bedrock API
- ServiceUnavailableException during regional outages
- ModelTimeoutException for long-running inferences
- ModelNotReadyException during cold starts

Design: Decorrelated jitter with circuit breaker awareness. When the circuit
breaker opens, the client stops retrying and immediately delegates to the
fallback orchestrator.
"""

import time
import random
import logging
import hashlib
import json
from enum import Enum
from dataclasses import dataclass, field
from typing import Optional, Dict, Any, Callable, List
from datetime import datetime, timezone

import boto3
from botocore.config import Config
from botocore.exceptions import ClientError

logger = logging.getLogger("mangaassist.backoff")


class JitterStrategy(Enum):
    """Jitter strategies for exponential backoff."""
    FULL = "full"
    EQUAL = "equal"
    DECORRELATED = "decorrelated"
    NONE = "none"


class CircuitState(Enum):
    """Circuit breaker states."""
    CLOSED = "closed"
    OPEN = "open"
    HALF_OPEN = "half_open"


@dataclass
class BackoffConfig:
    """
    Configuration for exponential backoff behavior.

    Attributes:
        base_delay_ms: Starting delay in milliseconds (default 100ms).
        max_delay_ms: Maximum delay cap in milliseconds (default 10s).
        max_retries: Maximum number of retry attempts (default 5).
        jitter_strategy: Which jitter algorithm to use.
        retry_budget_per_minute: Max retries across the fleet per minute.
        retryable_errors: Set of error codes that should trigger retry.
    """
    base_delay_ms: float = 100.0
    max_delay_ms: float = 10000.0
    max_retries: int = 5
    jitter_strategy: JitterStrategy = JitterStrategy.DECORRELATED
    retry_budget_per_minute: int = 100
    retryable_errors: set = field(default_factory=lambda: {
        "ThrottlingException",
        "ServiceUnavailableException",
        "ModelTimeoutException",
        "ModelNotReadyException",
        "TooManyRequestsException",
        "InternalServerException",
    })


@dataclass
class RetryMetrics:
    """Tracks retry statistics for observability."""
    total_attempts: int = 0
    total_retries: int = 0
    total_successes: int = 0
    total_failures: int = 0
    total_circuit_breaks: int = 0
    last_error: Optional[str] = None
    last_retry_timestamp: Optional[datetime] = None
    retry_latencies_ms: List[float] = field(default_factory=list)

    @property
    def success_rate(self) -> float:
        if self.total_attempts == 0:
            return 0.0
        return self.total_successes / self.total_attempts

    @property
    def avg_retry_latency_ms(self) -> float:
        if not self.retry_latencies_ms:
            return 0.0
        return sum(self.retry_latencies_ms) / len(self.retry_latencies_ms)

    @property
    def p99_retry_latency_ms(self) -> float:
        if not self.retry_latencies_ms:
            return 0.0
        sorted_latencies = sorted(self.retry_latencies_ms)
        idx = int(len(sorted_latencies) * 0.99)
        return sorted_latencies[min(idx, len(sorted_latencies) - 1)]


class CircuitBreaker:
    """
    Circuit breaker to prevent cascading failures when Bedrock is unhealthy.

    State transitions:
      CLOSED -> OPEN:      failure_count >= failure_threshold
      OPEN -> HALF_OPEN:   recovery_timeout elapsed
      HALF_OPEN -> CLOSED: probe request succeeds
      HALF_OPEN -> OPEN:   probe request fails
    """

    def __init__(
        self,
        failure_threshold: int = 5,
        recovery_timeout_s: float = 30.0,
        half_open_max_requests: int = 1,
    ):
        self.failure_threshold = failure_threshold
        self.recovery_timeout_s = recovery_timeout_s
        self.half_open_max_requests = half_open_max_requests

        self._state = CircuitState.CLOSED
        self._failure_count = 0
        self._last_failure_time: Optional[float] = None
        self._half_open_request_count = 0

    @property
    def state(self) -> CircuitState:
        if self._state == CircuitState.OPEN:
            if self._last_failure_time and (
                time.time() - self._last_failure_time >= self.recovery_timeout_s
            ):
                self._state = CircuitState.HALF_OPEN
                self._half_open_request_count = 0
                logger.info("Circuit breaker transitioned to HALF_OPEN")
        return self._state

    def allow_request(self) -> bool:
        """Check if the circuit breaker allows a request through."""
        state = self.state
        if state == CircuitState.CLOSED:
            return True
        elif state == CircuitState.HALF_OPEN:
            if self._half_open_request_count < self.half_open_max_requests:
                self._half_open_request_count += 1
                return True
            return False
        else:  # OPEN
            return False

    def record_success(self):
        """Record a successful request."""
        if self._state == CircuitState.HALF_OPEN:
            self._state = CircuitState.CLOSED
            self._failure_count = 0
            logger.info("Circuit breaker transitioned to CLOSED (recovered)")
        elif self._state == CircuitState.CLOSED:
            self._failure_count = max(0, self._failure_count - 1)

    def record_failure(self):
        """Record a failed request."""
        self._failure_count += 1
        self._last_failure_time = time.time()

        if self._state == CircuitState.HALF_OPEN:
            self._state = CircuitState.OPEN
            logger.warning("Circuit breaker re-opened from HALF_OPEN (probe failed)")
        elif (
            self._state == CircuitState.CLOSED
            and self._failure_count >= self.failure_threshold
        ):
            self._state = CircuitState.OPEN
            logger.warning(
                f"Circuit breaker OPENED after {self._failure_count} failures"
            )


class ExponentialBackoffClient:
    """
    Resilient Bedrock client with configurable exponential backoff, jitter,
    and circuit breaker integration.

    Usage:
        client = ExponentialBackoffClient(
            config=BackoffConfig(
                base_delay_ms=100,
                max_delay_ms=10000,
                max_retries=5,
                jitter_strategy=JitterStrategy.DECORRELATED,
            )
        )

        response = await client.invoke_model(
            model_id="anthropic.claude-3-sonnet-20240229-v1:0",
            body={"messages": [{"role": "user", "content": "..."}]},
        )
    """

    RETRYABLE_HTTP_CODES = {429, 500, 502, 503, 504}

    def __init__(
        self,
        config: Optional[BackoffConfig] = None,
        circuit_breaker: Optional[CircuitBreaker] = None,
        region: str = "ap-northeast-1",
    ):
        self.config = config or BackoffConfig()
        self.circuit_breaker = circuit_breaker or CircuitBreaker()
        self.metrics = RetryMetrics()
        self._prev_sleep_ms = self.config.base_delay_ms

        # Configure boto3 with adaptive retry as the base layer
        boto_config = Config(
            region_name=region,
            retries={
                "mode": "adaptive",
                "max_attempts": 2,  # Low — we handle retries ourselves
            },
            read_timeout=30,
            connect_timeout=5,
        )
        self._bedrock = boto3.client("bedrock-runtime", config=boto_config)

    def _calculate_delay_ms(self, attempt: int) -> float:
        """
        Calculate the delay for the given retry attempt using the configured
        jitter strategy.

        Args:
            attempt: Zero-based attempt number (0 = first retry).

        Returns:
            Delay in milliseconds.
        """
        base = self.config.base_delay_ms
        cap = self.config.max_delay_ms
        strategy = self.config.jitter_strategy

        if strategy == JitterStrategy.NONE:
            return min(cap, base * (2 ** attempt))

        elif strategy == JitterStrategy.FULL:
            exp_delay = min(cap, base * (2 ** attempt))
            return random.uniform(0, exp_delay)

        elif strategy == JitterStrategy.EQUAL:
            exp_delay = min(cap, base * (2 ** attempt))
            half = exp_delay / 2
            return half + random.uniform(0, half)

        elif strategy == JitterStrategy.DECORRELATED:
            delay = random.uniform(base, self._prev_sleep_ms * 3)
            delay = min(cap, delay)
            self._prev_sleep_ms = delay
            return delay

        return min(cap, base * (2 ** attempt))

    def _is_retryable(self, error: ClientError) -> bool:
        """Determine if the error is retryable."""
        error_code = error.response.get("Error", {}).get("Code", "")
        http_code = error.response.get("ResponseMetadata", {}).get(
            "HTTPStatusCode", 0
        )
        return (
            error_code in self.config.retryable_errors
            or http_code in self.RETRYABLE_HTTP_CODES
        )

    def _generate_cache_key(self, model_id: str, body: Dict) -> str:
        """Generate a deterministic cache key for a Bedrock request."""
        content = json.dumps({"model": model_id, "body": body}, sort_keys=True)
        return hashlib.sha256(content.encode()).hexdigest()[:16]

    def invoke_model(
        self,
        model_id: str,
        body: Dict[str, Any],
        on_retry: Optional[Callable[[int, float, str], None]] = None,
    ) -> Dict[str, Any]:
        """
        Invoke a Bedrock model with exponential backoff and circuit breaker.

        Args:
            model_id: Bedrock model identifier.
            body: Request body for the model.
            on_retry: Optional callback(attempt, delay_ms, error_code) on each retry.

        Returns:
            Parsed model response.

        Raises:
            CircuitBreakerOpenError: If the circuit breaker is open.
            MaxRetriesExceededError: If all retries are exhausted.
        """
        self.metrics.total_attempts += 1
        self._prev_sleep_ms = self.config.base_delay_ms
        cache_key = self._generate_cache_key(model_id, body)

        # Check circuit breaker
        if not self.circuit_breaker.allow_request():
            self.metrics.total_circuit_breaks += 1
            logger.warning(
                f"Circuit breaker OPEN — skipping Bedrock call [{cache_key}]"
            )
            raise CircuitBreakerOpenError(
                f"Circuit breaker is {self.circuit_breaker.state.value}"
            )

        last_exception = None

        for attempt in range(self.config.max_retries + 1):
            try:
                start_time = time.monotonic()

                response = self._bedrock.invoke_model(
                    modelId=model_id,
                    contentType="application/json",
                    accept="application/json",
                    body=json.dumps(body),
                )

                elapsed_ms = (time.monotonic() - start_time) * 1000
                self.metrics.retry_latencies_ms.append(elapsed_ms)
                self.circuit_breaker.record_success()
                self.metrics.total_successes += 1

                result = json.loads(response["body"].read())
                logger.info(
                    f"Bedrock call succeeded [{cache_key}] "
                    f"attempt={attempt} latency={elapsed_ms:.0f}ms"
                )
                return result

            except ClientError as e:
                last_exception = e
                error_code = e.response.get("Error", {}).get("Code", "Unknown")
                self.metrics.last_error = error_code

                if not self._is_retryable(e):
                    self.circuit_breaker.record_failure()
                    self.metrics.total_failures += 1
                    logger.error(
                        f"Non-retryable error [{cache_key}]: {error_code}"
                    )
                    raise

                if attempt < self.config.max_retries:
                    delay_ms = self._calculate_delay_ms(attempt)
                    delay_s = delay_ms / 1000.0
                    self.metrics.total_retries += 1
                    self.metrics.last_retry_timestamp = datetime.now(timezone.utc)

                    logger.warning(
                        f"Retryable error [{cache_key}]: {error_code} — "
                        f"attempt {attempt + 1}/{self.config.max_retries}, "
                        f"sleeping {delay_ms:.0f}ms"
                    )

                    if on_retry:
                        on_retry(attempt, delay_ms, error_code)

                    time.sleep(delay_s)
                else:
                    self.circuit_breaker.record_failure()
                    self.metrics.total_failures += 1

        raise MaxRetriesExceededError(
            f"All {self.config.max_retries} retries exhausted for [{cache_key}]",
            last_exception=last_exception,
        )

    def get_metrics_snapshot(self) -> Dict[str, Any]:
        """Return a snapshot of retry metrics for CloudWatch publishing."""
        return {
            "total_attempts": self.metrics.total_attempts,
            "total_retries": self.metrics.total_retries,
            "total_successes": self.metrics.total_successes,
            "total_failures": self.metrics.total_failures,
            "total_circuit_breaks": self.metrics.total_circuit_breaks,
            "success_rate": round(self.metrics.success_rate, 4),
            "avg_retry_latency_ms": round(self.metrics.avg_retry_latency_ms, 1),
            "p99_retry_latency_ms": round(self.metrics.p99_retry_latency_ms, 1),
            "last_error": self.metrics.last_error,
            "circuit_breaker_state": self.circuit_breaker.state.value,
        }


class CircuitBreakerOpenError(Exception):
    """Raised when the circuit breaker is open and no requests are allowed."""
    pass


class MaxRetriesExceededError(Exception):
    """Raised when all retry attempts have been exhausted."""

    def __init__(self, message: str, last_exception: Optional[Exception] = None):
        super().__init__(message)
        self.last_exception = last_exception

2. API Gateway Rate Limiting and Throttling

API Gateway Throttling Architecture

API Gateway provides three layers of throttling for MangaAssist:

Layer 1: Account-Level      → 10,000 req/s (AWS default, can request increase)
Layer 2: Stage-Level         → Configured per deployment stage (prod, staging)
Layer 3: Method/Route-Level  → Fine-grained per WebSocket route ($connect, $default, sendMessage)

Throttle Settings for MangaAssist

Layer Setting Value Rationale
Account Default 10,000 rps AWS default, sufficient for 1M msgs/day
Prod Stage Rate 5,000 rps Reserve headroom for other services
Prod Stage Burst 7,500 Handle evening manga browsing spikes
$connect Route Rate 500 rps Limit new connections (prevent DDoS)
sendMessage Route Rate 2,000 rps Primary chat route
sendMessage Route Burst 4,000 Handle rapid-fire user queries

Token Bucket Algorithm Explained

Token Bucket for MangaAssist sendMessage route:
  ┌─────────────────────────────────────────┐
  │  Bucket Capacity: 4,000 tokens (burst)  │
  │  Refill Rate: 2,000 tokens/sec (rate)   │
  │                                         │
  │  Time 0s:  [████████████████] 4000/4000 │
  │  Burst:    [░░░░████████████] 2500/4000 │  1500 requests consumed
  │  Time 1s:  [██████████████░░] 3500/4000 │  +2000 refilled, -1000 consumed
  │  Time 2s:  [████████████████] 4000/4000 │  Refilled to capacity
  └─────────────────────────────────────────┘

When tokens = 0: API Gateway returns HTTP 429 Too Many Requests

Code: RateLimitManager

"""
RateLimitManager — Manages API Gateway rate limiting configuration and
provides a client-side pre-emptive rate limiter for MangaAssist.

The server-side rate limiting is configured via API Gateway usage plans.
The client-side limiter sits in the ECS Fargate orchestrator and prevents
sending requests that would be throttled, reducing wasted network round trips.

Architecture:
  Client-side (Redis token bucket) → API Gateway (usage plan) → Bedrock (service limits)
"""

import time
import logging
import json
from dataclasses import dataclass
from typing import Optional, Dict, Any, List
from datetime import datetime, timezone

import boto3
import redis

logger = logging.getLogger("mangaassist.ratelimit")


@dataclass
class ThrottleConfig:
    """
    Throttle configuration matching API Gateway usage plan settings.

    Attributes:
        rate_limit: Steady-state requests per second.
        burst_limit: Maximum burst capacity.
        quota_limit: Requests per quota period.
        quota_period: DAY, WEEK, or MONTH.
    """
    rate_limit: float = 2000.0
    burst_limit: int = 4000
    quota_limit: int = 50_000_000  # 50M per month
    quota_period: str = "MONTH"


@dataclass
class UserThrottleConfig:
    """Per-user throttle settings for fair resource allocation."""
    rate_limit: float = 5.0       # 5 messages/second per user
    burst_limit: int = 10          # Allow short bursts of 10
    daily_quota: int = 5000        # 5000 messages/day per user
    premium_multiplier: float = 3.0  # Premium users get 3x


class RedisTokenBucket:
    """
    Redis-backed token bucket for distributed rate limiting across ECS tasks.

    Uses a Lua script for atomic token consumption to avoid race conditions
    when multiple ECS tasks check the same bucket simultaneously.

    Key structure:
        mangaassist:ratelimit:{scope}:{identifier}:tokens  — current token count
        mangaassist:ratelimit:{scope}:{identifier}:last_ts — last refill timestamp
    """

    # Atomic Lua script: refill tokens based on elapsed time, then consume
    LUA_CONSUME = """
    local tokens_key = KEYS[1]
    local ts_key = KEYS[2]
    local rate = tonumber(ARGV[1])
    local capacity = tonumber(ARGV[2])
    local now = tonumber(ARGV[3])
    local requested = tonumber(ARGV[4])

    local last_ts = tonumber(redis.call('get', ts_key) or now)
    local current_tokens = tonumber(redis.call('get', tokens_key) or capacity)

    local elapsed = math.max(0, now - last_ts)
    local new_tokens = math.min(capacity, current_tokens + (elapsed * rate))

    if new_tokens >= requested then
        new_tokens = new_tokens - requested
        redis.call('set', tokens_key, new_tokens)
        redis.call('set', ts_key, now)
        redis.call('expire', tokens_key, 300)
        redis.call('expire', ts_key, 300)
        return 1
    else
        redis.call('set', tokens_key, new_tokens)
        redis.call('set', ts_key, now)
        return 0
    end
    """

    def __init__(
        self,
        redis_client: redis.Redis,
        scope: str = "global",
        rate: float = 2000.0,
        capacity: int = 4000,
    ):
        self.redis = redis_client
        self.scope = scope
        self.rate = rate
        self.capacity = capacity
        self._lua_sha: Optional[str] = None

    def _ensure_script(self):
        """Load the Lua script into Redis if not already cached."""
        if self._lua_sha is None:
            self._lua_sha = self.redis.script_load(self.LUA_CONSUME)

    def try_consume(self, identifier: str, tokens: int = 1) -> bool:
        """
        Attempt to consume tokens from the bucket.

        Args:
            identifier: Unique identifier (user_id, api_key, etc.).
            tokens: Number of tokens to consume (default 1).

        Returns:
            True if tokens were consumed, False if rate limit exceeded.
        """
        self._ensure_script()
        prefix = f"mangaassist:ratelimit:{self.scope}:{identifier}"
        tokens_key = f"{prefix}:tokens"
        ts_key = f"{prefix}:last_ts"

        result = self.redis.evalsha(
            self._lua_sha,
            2,
            tokens_key,
            ts_key,
            str(self.rate),
            str(self.capacity),
            str(time.time()),
            str(tokens),
        )
        return bool(result)

    def get_remaining(self, identifier: str) -> Dict[str, Any]:
        """Get the remaining tokens and time until refill for an identifier."""
        prefix = f"mangaassist:ratelimit:{self.scope}:{identifier}"
        tokens_key = f"{prefix}:tokens"
        ts_key = f"{prefix}:last_ts"

        current = self.redis.get(tokens_key)
        last_ts = self.redis.get(ts_key)

        if current is None:
            return {"remaining": self.capacity, "reset_in_seconds": 0}

        current_tokens = float(current)
        elapsed = time.time() - float(last_ts or time.time())
        refilled = min(self.capacity, current_tokens + (elapsed * self.rate))

        return {
            "remaining": int(refilled),
            "capacity": self.capacity,
            "rate_per_second": self.rate,
            "reset_in_seconds": max(
                0, (self.capacity - refilled) / self.rate
            ),
        }


class RateLimitManager:
    """
    Manages rate limiting for MangaAssist at multiple tiers:
      1. Global — protects Bedrock from aggregate overload
      2. Per-user — ensures fair access across manga readers
      3. Per-route — different limits for different WebSocket routes

    Integrates with:
      - Redis token bucket for distributed enforcement
      - API Gateway usage plans for server-side enforcement
      - CloudWatch metrics for rate limit monitoring
    """

    def __init__(
        self,
        redis_client: redis.Redis,
        global_config: Optional[ThrottleConfig] = None,
        user_config: Optional[UserThrottleConfig] = None,
    ):
        self.global_config = global_config or ThrottleConfig()
        self.user_config = user_config or UserThrottleConfig()

        # Global bucket — protects Bedrock aggregate capacity
        self.global_bucket = RedisTokenBucket(
            redis_client=redis_client,
            scope="global",
            rate=self.global_config.rate_limit,
            capacity=self.global_config.burst_limit,
        )

        # Per-user bucket — fair access across users
        self.user_bucket = RedisTokenBucket(
            redis_client=redis_client,
            scope="user",
            rate=self.user_config.rate_limit,
            capacity=self.user_config.burst_limit,
        )

        self._cloudwatch = boto3.client("cloudwatch", region_name="ap-northeast-1")

    def check_rate_limit(
        self, user_id: str, is_premium: bool = False
    ) -> "RateLimitResult":
        """
        Check all rate limit tiers for a user request.

        Order of checks:
          1. Global limit — is the system at capacity?
          2. User limit — has this user exceeded their allowance?

        Args:
            user_id: Unique user identifier.
            is_premium: Whether the user has a premium subscription.

        Returns:
            RateLimitResult with allow/deny and metadata.
        """
        # Check global limit first
        if not self.global_bucket.try_consume("bedrock"):
            self._publish_throttle_metric("global", user_id)
            remaining = self.global_bucket.get_remaining("bedrock")
            return RateLimitResult(
                allowed=False,
                tier="global",
                reason="System at capacity — please retry shortly",
                retry_after_seconds=remaining["reset_in_seconds"],
                remaining_tokens=remaining["remaining"],
            )

        # Check per-user limit
        if not self.user_bucket.try_consume(user_id):
            self._publish_throttle_metric("user", user_id)
            remaining = self.user_bucket.get_remaining(user_id)
            retry_after = remaining["reset_in_seconds"]
            if is_premium:
                # Premium users get extended burst — try premium bucket
                retry_after = max(0, retry_after / self.user_config.premium_multiplier)
            return RateLimitResult(
                allowed=False,
                tier="user",
                reason="You're sending messages too fast — slow down a bit!",
                retry_after_seconds=retry_after,
                remaining_tokens=remaining["remaining"],
            )

        return RateLimitResult(allowed=True, tier="none")

    def _publish_throttle_metric(self, tier: str, user_id: str):
        """Publish a throttle event metric to CloudWatch."""
        try:
            self._cloudwatch.put_metric_data(
                Namespace="MangaAssist/RateLimiting",
                MetricData=[
                    {
                        "MetricName": "ThrottledRequests",
                        "Dimensions": [
                            {"Name": "Tier", "Value": tier},
                            {"Name": "Service", "Value": "MangaAssist"},
                        ],
                        "Timestamp": datetime.now(timezone.utc),
                        "Value": 1,
                        "Unit": "Count",
                    }
                ],
            )
        except Exception as e:
            logger.warning(f"Failed to publish throttle metric: {e}")

    def configure_api_gateway_usage_plan(
        self,
        api_id: str,
        stage_name: str = "prod",
    ) -> Dict[str, Any]:
        """
        Create or update the API Gateway usage plan for MangaAssist.

        This configures the server-side throttling that acts as a second
        layer of protection behind our client-side Redis rate limiter.
        """
        apigw = boto3.client("apigateway", region_name="ap-northeast-1")

        usage_plan = apigw.create_usage_plan(
            name="MangaAssist-Production",
            description="Rate limiting for MangaAssist manga chatbot",
            apiStages=[
                {
                    "apiId": api_id,
                    "stage": stage_name,
                    "throttle": {
                        "sendMessage": {
                            "burstLimit": self.global_config.burst_limit,
                            "rateLimit": self.global_config.rate_limit,
                        },
                        "$connect": {
                            "burstLimit": 1000,
                            "rateLimit": 500.0,
                        },
                    },
                }
            ],
            throttle={
                "burstLimit": self.global_config.burst_limit,
                "rateLimit": self.global_config.rate_limit,
            },
            quota={
                "limit": self.global_config.quota_limit,
                "period": self.global_config.quota_period,
            },
        )

        logger.info(
            f"Created usage plan: {usage_plan['id']} "
            f"rate={self.global_config.rate_limit} "
            f"burst={self.global_config.burst_limit}"
        )
        return usage_plan


@dataclass
class RateLimitResult:
    """Result of a rate limit check."""
    allowed: bool
    tier: str = "none"
    reason: str = ""
    retry_after_seconds: float = 0.0
    remaining_tokens: int = 0

3. Fallback Patterns for Graceful Degradation

MangaAssist Fallback Cascade

Priority 1: Claude 3 Sonnet  → Full capability, rich manga recommendations
     ↓ (on failure)
Priority 2: Claude 3 Haiku   → Reduced capability, faster/cheaper responses
     ↓ (on failure)
Priority 3: Cached Response   → Redis cached answer for similar query
     ↓ (on miss)
Priority 4: Static FAQ        → Pre-computed answers for common manga questions
     ↓ (on miss)
Priority 5: Graceful Message  → Friendly "checking on that" message + human escalation

Fallback Quality vs. Cost Tradeoff

Tier Model/Source Quality Latency Cost per 1M Availability
1 Claude 3 Sonnet Excellent 1-3s $3/$15 I/O 99.5%
2 Claude 3 Haiku Good 0.3-1s $0.25/$1.25 I/O 99.5%
3 Redis Cache Varies (stale risk) <10ms $0.0001 99.99%
4 Static FAQ Basic <5ms $0 99.999%
5 Graceful Msg Minimal <1ms $0 100%

Code: FallbackOrchestrator

"""
FallbackOrchestrator — Multi-tier fallback system for MangaAssist.

Implements the fallback cascade: Sonnet → Haiku → Cached → Static → Graceful.

Each tier is attempted in order. Failures at any tier are logged and
the system automatically proceeds to the next tier. The user always
receives a response — the quality may degrade but availability is maintained.
"""

import json
import time
import logging
import hashlib
from enum import Enum
from dataclasses import dataclass, field
from typing import Optional, Dict, Any, List, Callable

import boto3
import redis

logger = logging.getLogger("mangaassist.fallback")


class FallbackTier(Enum):
    """Fallback tiers in priority order."""
    SONNET = "sonnet"
    HAIKU = "haiku"
    CACHED = "cached"
    STATIC_FAQ = "static_faq"
    GRACEFUL = "graceful"


@dataclass
class FallbackResponse:
    """
    Response from the fallback orchestrator.

    Attributes:
        content: The response text to return to the user.
        tier: Which fallback tier produced the response.
        latency_ms: Time taken to produce the response.
        is_degraded: Whether the response is from a degraded tier.
        confidence: Estimated quality score (1.0 = full Sonnet, 0.1 = graceful).
        cache_hit: Whether this was served from cache.
        metadata: Additional metadata (model_id, cache_key, etc.).
    """
    content: str
    tier: FallbackTier
    latency_ms: float
    is_degraded: bool = False
    confidence: float = 1.0
    cache_hit: bool = False
    metadata: Dict[str, Any] = field(default_factory=dict)


@dataclass
class FallbackMetrics:
    """Tracks which fallback tiers are being used and how often."""
    tier_counts: Dict[str, int] = field(default_factory=lambda: {
        "sonnet": 0,
        "haiku": 0,
        "cached": 0,
        "static_faq": 0,
        "graceful": 0,
    })
    tier_latencies: Dict[str, List[float]] = field(default_factory=lambda: {
        "sonnet": [],
        "haiku": [],
        "cached": [],
        "static_faq": [],
        "graceful": [],
    })
    total_degraded: int = 0
    total_requests: int = 0


class StaticFAQStore:
    """
    Pre-computed answers for common manga-related questions.

    These are loaded from DynamoDB and cached in memory. They serve as
    a reliable fallback when both Bedrock models and the Redis cache
    are unavailable.
    """

    def __init__(self):
        self._faq_entries: Dict[str, str] = {}
        self._load_default_faqs()

    def _load_default_faqs(self):
        """Load default FAQ entries for MangaAssist."""
        self._faq_entries = {
            "shipping": (
                "Standard shipping takes 3-5 business days within Japan. "
                "International shipping is available to select countries "
                "and typically takes 7-14 business days."
            ),
            "return_policy": (
                "You can return unopened manga within 30 days of purchase "
                "for a full refund. Opened items can be exchanged if damaged."
            ),
            "payment_methods": (
                "We accept credit cards (Visa, Mastercard, JCB), "
                "PayPay, Line Pay, and convenience store payment."
            ),
            "order_tracking": (
                "You can track your order by visiting My Orders in your "
                "account. You'll also receive tracking updates via email."
            ),
            "manga_recommendation": (
                "Our most popular series right now include One Piece, "
                "Jujutsu Kaisen, Spy x Family, and Chainsaw Man. "
                "Check our trending section for more recommendations!"
            ),
            "store_hours": (
                "Our online store is available 24/7. Physical store "
                "hours are 10:00 AM - 9:00 PM JST, Monday through Sunday."
            ),
            "membership": (
                "Join our membership program for free! Members earn points "
                "on every purchase, get early access to new releases, and "
                "receive exclusive discounts."
            ),
        }

    def find_best_match(self, query: str) -> Optional[str]:
        """
        Find the best matching FAQ entry for a user query.

        Uses simple keyword matching. For production, this would use
        a lightweight embedding model or TF-IDF matching.
        """
        query_lower = query.lower()

        keyword_map = {
            "shipping": ["ship", "deliver", "delivery", "send", "arrive"],
            "return_policy": ["return", "refund", "exchange", "damaged"],
            "payment_methods": ["pay", "payment", "credit", "card", "paypay"],
            "order_tracking": ["track", "order", "where", "status"],
            "manga_recommendation": [
                "recommend", "suggest", "popular", "best", "good manga",
                "what should i read", "trending",
            ],
            "store_hours": ["hours", "open", "close", "when", "store"],
            "membership": ["member", "membership", "join", "points", "loyalty"],
        }

        best_key = None
        best_score = 0

        for faq_key, keywords in keyword_map.items():
            score = sum(1 for kw in keywords if kw in query_lower)
            if score > best_score:
                best_score = score
                best_key = faq_key

        if best_key and best_score > 0:
            return self._faq_entries.get(best_key)
        return None


class FallbackOrchestrator:
    """
    Orchestrates the multi-tier fallback cascade for MangaAssist.

    Usage:
        orchestrator = FallbackOrchestrator(
            backoff_client=exponential_backoff_client,
            redis_client=redis.Redis(...),
        )

        response = orchestrator.get_response(
            user_query="What manga do you recommend?",
            conversation_history=[...],
        )
        # response.tier tells you which tier served the response
    """

    GRACEFUL_MESSAGES = [
        (
            "I'm having a bit of trouble right now, but I'm still here! "
            "Could you try asking again in a moment? In the meantime, "
            "you can browse our manga catalog directly."
        ),
        (
            "I apologize for the wait! Our recommendation engine is taking "
            "a short break. You can check out our trending manga section "
            "while I get back up to speed."
        ),
        (
            "I'm experiencing some temporary difficulties. For immediate "
            "help, you can reach our support team at support@mangaassist.jp "
            "or browse our FAQ section."
        ),
    ]

    def __init__(
        self,
        backoff_client,  # ExponentialBackoffClient
        redis_client: redis.Redis,
        faq_store: Optional[StaticFAQStore] = None,
        cache_ttl_seconds: int = 3600,
        sonnet_model_id: str = "anthropic.claude-3-sonnet-20240229-v1:0",
        haiku_model_id: str = "anthropic.claude-3-haiku-20240307-v1:0",
    ):
        self.backoff_client = backoff_client
        self.redis = redis_client
        self.faq_store = faq_store or StaticFAQStore()
        self.cache_ttl = cache_ttl_seconds
        self.sonnet_model_id = sonnet_model_id
        self.haiku_model_id = haiku_model_id
        self.metrics = FallbackMetrics()
        self._cloudwatch = boto3.client("cloudwatch", region_name="ap-northeast-1")

    def _build_request_body(
        self, query: str, history: List[Dict], system_prompt: str
    ) -> Dict:
        """Build the Bedrock Converse API request body."""
        messages = []
        for msg in history[-10:]:  # Last 10 messages for context
            messages.append({
                "role": msg["role"],
                "content": [{"text": msg["content"]}],
            })
        messages.append({
            "role": "user",
            "content": [{"text": query}],
        })

        return {
            "anthropic_version": "bedrock-2023-05-31",
            "max_tokens": 1024,
            "system": system_prompt,
            "messages": messages,
        }

    def _cache_key(self, query: str) -> str:
        """Generate a cache key for a user query."""
        normalized = query.lower().strip()
        return f"mangaassist:cache:{hashlib.sha256(normalized.encode()).hexdigest()[:16]}"

    def _try_cache(self, query: str) -> Optional[str]:
        """Try to retrieve a cached response for the query."""
        try:
            key = self._cache_key(query)
            cached = self.redis.get(key)
            if cached:
                return cached.decode("utf-8")
        except Exception as e:
            logger.warning(f"Redis cache lookup failed: {e}")
        return None

    def _store_cache(self, query: str, response: str):
        """Store a response in the Redis cache."""
        try:
            key = self._cache_key(query)
            self.redis.setex(key, self.cache_ttl, response)
        except Exception as e:
            logger.warning(f"Redis cache store failed: {e}")

    def get_response(
        self,
        user_query: str,
        conversation_history: Optional[List[Dict]] = None,
        system_prompt: Optional[str] = None,
    ) -> FallbackResponse:
        """
        Get a response using the fallback cascade.

        Tries each tier in order until one succeeds.
        """
        self.metrics.total_requests += 1
        history = conversation_history or []
        sys_prompt = system_prompt or (
            "You are MangaAssist, a helpful assistant for a Japanese manga store. "
            "Help customers find manga, answer questions about orders, and provide "
            "recommendations based on their preferences."
        )

        body = self._build_request_body(user_query, history, sys_prompt)

        # Tier 1: Claude 3 Sonnet (full capability)
        try:
            start = time.monotonic()
            result = self.backoff_client.invoke_model(
                model_id=self.sonnet_model_id, body=body
            )
            elapsed = (time.monotonic() - start) * 1000
            content = result["content"][0]["text"]

            self._store_cache(user_query, content)
            self._record_tier("sonnet", elapsed)

            return FallbackResponse(
                content=content,
                tier=FallbackTier.SONNET,
                latency_ms=elapsed,
                confidence=1.0,
                metadata={"model_id": self.sonnet_model_id},
            )
        except Exception as e:
            logger.warning(f"Sonnet tier failed: {e}")

        # Tier 2: Claude 3 Haiku (reduced capability, faster/cheaper)
        try:
            start = time.monotonic()
            result = self.backoff_client.invoke_model(
                model_id=self.haiku_model_id, body=body
            )
            elapsed = (time.monotonic() - start) * 1000
            content = result["content"][0]["text"]

            self._store_cache(user_query, content)
            self._record_tier("haiku", elapsed)
            self.metrics.total_degraded += 1

            return FallbackResponse(
                content=content,
                tier=FallbackTier.HAIKU,
                latency_ms=elapsed,
                is_degraded=True,
                confidence=0.7,
                metadata={"model_id": self.haiku_model_id},
            )
        except Exception as e:
            logger.warning(f"Haiku tier failed: {e}")

        # Tier 3: Redis cached response
        start = time.monotonic()
        cached_content = self._try_cache(user_query)
        elapsed = (time.monotonic() - start) * 1000

        if cached_content:
            self._record_tier("cached", elapsed)
            self.metrics.total_degraded += 1
            return FallbackResponse(
                content=cached_content,
                tier=FallbackTier.CACHED,
                latency_ms=elapsed,
                is_degraded=True,
                confidence=0.5,
                cache_hit=True,
                metadata={"cache_key": self._cache_key(user_query)},
            )

        # Tier 4: Static FAQ match
        start = time.monotonic()
        faq_answer = self.faq_store.find_best_match(user_query)
        elapsed = (time.monotonic() - start) * 1000

        if faq_answer:
            self._record_tier("static_faq", elapsed)
            self.metrics.total_degraded += 1
            return FallbackResponse(
                content=faq_answer,
                tier=FallbackTier.STATIC_FAQ,
                latency_ms=elapsed,
                is_degraded=True,
                confidence=0.3,
                metadata={"source": "faq_store"},
            )

        # Tier 5: Graceful degradation message
        import random
        graceful_msg = random.choice(self.GRACEFUL_MESSAGES)
        self._record_tier("graceful", 0)
        self.metrics.total_degraded += 1

        return FallbackResponse(
            content=graceful_msg,
            tier=FallbackTier.GRACEFUL,
            latency_ms=0,
            is_degraded=True,
            confidence=0.1,
            metadata={"message_variant": self.GRACEFUL_MESSAGES.index(graceful_msg)},
        )

    def _record_tier(self, tier: str, latency_ms: float):
        """Record tier usage for metrics."""
        self.metrics.tier_counts[tier] += 1
        self.metrics.tier_latencies[tier].append(latency_ms)

4. X-Ray Tracing Across Service Boundaries

X-Ray Trace Flow for MangaAssist

                           Trace ID: 1-abc123-def456789
    ┌─────────────────────────────────────────────────────────────────┐
    │ Segment: API Gateway                                           │
    │ ├── Duration: 15ms                                             │
    │ ├── Subsegment: WebSocket $default route                       │
    │ │   └── Annotations: route=$default, connection_id=abc123      │
    │ └── HTTP: 200                                                  │
    ├─────────────────────────────────────────────────────────────────┤
    │ Segment: ECS Fargate (FM Orchestrator)                         │
    │ ├── Duration: 2150ms                                           │
    │ ├── Subsegment: Rate Limit Check                               │
    │ │   ├── Duration: 2ms                                          │
    │ │   └── Annotations: tier=user, allowed=true                   │
    │ ├── Subsegment: OpenSearch Vector Search                       │
    │ │   ├── Duration: 45ms                                         │
    │ │   └── Annotations: index=manga-embeddings, hits=5            │
    │ ├── Subsegment: Bedrock InvokeModel                            │
    │ │   ├── Duration: 2050ms                                       │
    │ │   ├── Annotations: model=claude-3-sonnet, fallback_tier=1    │
    │ │   ├── Metadata: input_tokens=450, output_tokens=380          │
    │ │   └── HTTP: 200                                              │
    │ └── Subsegment: Redis Cache Store                              │
    │     ├── Duration: 3ms                                          │
    │     └── Annotations: cache_key=abc123, ttl=3600                │
    ├─────────────────────────────────────────────────────────────────┤
    │ Segment: Bedrock Runtime                                       │
    │ ├── Duration: 2000ms                                           │
    │ └── HTTP: 200                                                  │
    └─────────────────────────────────────────────────────────────────┘

X-Ray Service Map

flowchart LR
    Client["Client<br/>1M msgs/day"] --> APIGW["API Gateway<br/>avg 15ms"]
    APIGW --> ECS["ECS Fargate<br/>avg 2.1s"]
    ECS --> Bedrock["Bedrock<br/>avg 2.0s"]
    ECS --> OpenSearch["OpenSearch<br/>avg 45ms"]
    ECS --> Redis["ElastiCache<br/>avg 3ms"]
    ECS --> DDB["DynamoDB<br/>avg 8ms"]

    style Client fill:#e3f2fd
    style APIGW fill:#fff3e0
    style ECS fill:#e8f5e9
    style Bedrock fill:#f3e5f5
    style OpenSearch fill:#fce4ec
    style Redis fill:#fff8e1
    style DDB fill:#e0f7fa

Code: XRayFMTracer

"""
XRayFMTracer — AWS X-Ray instrumentation for MangaAssist FM calls.

Provides distributed tracing across:
  API Gateway → ECS Fargate → Bedrock / OpenSearch / Redis / DynamoDB

Key annotations:
  - model_id: Which Bedrock model was invoked
  - fallback_tier: Which tier of the fallback cascade served the response
  - input_tokens / output_tokens: Token usage for cost tracking
  - cache_hit: Whether the response came from cache
  - error_code: Bedrock error code (if any)
  - user_id: Anonymized user identifier for per-user trace analysis

Subsegments are created for each downstream service call, enabling
the X-Ray service map to visualize the full request flow.
"""

import json
import time
import logging
import functools
from typing import Optional, Dict, Any, Callable
from contextlib import contextmanager

from aws_xray_sdk.core import xray_recorder, patch_all
from aws_xray_sdk.core.models.subsegment import Subsegment
import boto3

logger = logging.getLogger("mangaassist.xray")

# Patch all supported libraries (boto3, requests, etc.)
patch_all()


class XRayFMTracer:
    """
    X-Ray tracing wrapper specialized for Foundation Model calls.

    Creates rich subsegments with FM-specific annotations and metadata
    that enable:
      - Service map visualization of the full MangaAssist request flow
      - Latency breakdown by service (Bedrock, OpenSearch, Redis, DynamoDB)
      - Error analysis with Bedrock-specific error codes
      - Cost attribution via token tracking
      - Fallback cascade visibility

    Usage:
        tracer = XRayFMTracer(service_name="MangaAssist")

        with tracer.trace_fm_call(
            model_id="anthropic.claude-3-sonnet-20240229-v1:0",
            user_id="user_abc123",
        ) as trace:
            response = bedrock_client.invoke_model(...)
            trace.record_success(
                input_tokens=450,
                output_tokens=380,
                fallback_tier="sonnet",
            )
    """

    def __init__(
        self,
        service_name: str = "MangaAssist",
        sampling_rate: float = 0.05,  # 5% sampling for 1M msgs/day
    ):
        self.service_name = service_name
        self.sampling_rate = sampling_rate

        # Configure the X-Ray recorder
        xray_recorder.configure(
            service=service_name,
            sampling=True,
            context_missing="LOG_ERROR",
            daemon_address="127.0.0.1:2000",
        )

    @contextmanager
    def trace_fm_call(
        self,
        model_id: str,
        user_id: str,
        operation: str = "InvokeModel",
    ):
        """
        Context manager that creates an X-Ray subsegment for an FM call.

        Args:
            model_id: Bedrock model identifier.
            user_id: Anonymized user identifier.
            operation: Bedrock API operation name.

        Yields:
            FMTraceContext with methods to record success/failure.
        """
        subsegment = xray_recorder.begin_subsegment(
            name=f"Bedrock.{operation}",
            namespace="aws",
        )

        trace_ctx = FMTraceContext(
            subsegment=subsegment,
            model_id=model_id,
            user_id=user_id,
            start_time=time.monotonic(),
        )

        # Set initial annotations
        if subsegment:
            subsegment.put_annotation("model_id", model_id)
            subsegment.put_annotation("user_id", user_id)
            subsegment.put_annotation("operation", operation)
            subsegment.put_annotation("service", self.service_name)

        try:
            yield trace_ctx
        except Exception as e:
            trace_ctx.record_failure(str(e))
            raise
        finally:
            if subsegment:
                elapsed_ms = (time.monotonic() - trace_ctx.start_time) * 1000
                subsegment.put_metadata(
                    "latency_ms", round(elapsed_ms, 1), self.service_name
                )
                xray_recorder.end_subsegment()

    @contextmanager
    def trace_service_call(
        self,
        service_name: str,
        operation: str,
        namespace: str = "aws",
    ):
        """
        Context manager for tracing calls to other AWS services
        (OpenSearch, Redis, DynamoDB).
        """
        subsegment = xray_recorder.begin_subsegment(
            name=f"{service_name}.{operation}",
            namespace=namespace,
        )

        start_time = time.monotonic()

        try:
            yield subsegment
        except Exception as e:
            if subsegment:
                subsegment.add_exception(e, stack=True)
            raise
        finally:
            if subsegment:
                elapsed_ms = (time.monotonic() - start_time) * 1000
                subsegment.put_metadata(
                    "latency_ms", round(elapsed_ms, 1), service_name
                )
                xray_recorder.end_subsegment()

    def trace_function(
        self, name: Optional[str] = None, capture_args: bool = False
    ) -> Callable:
        """
        Decorator to trace a function as an X-Ray subsegment.

        Args:
            name: Subsegment name (defaults to function name).
            capture_args: Whether to capture function arguments as metadata.
        """
        def decorator(func: Callable) -> Callable:
            @functools.wraps(func)
            def wrapper(*args, **kwargs):
                seg_name = name or func.__name__
                subsegment = xray_recorder.begin_subsegment(seg_name)

                if subsegment and capture_args:
                    subsegment.put_metadata(
                        "args", str(args)[:500], self.service_name
                    )
                    subsegment.put_metadata(
                        "kwargs", str(kwargs)[:500], self.service_name
                    )

                try:
                    result = func(*args, **kwargs)
                    return result
                except Exception as e:
                    if subsegment:
                        subsegment.add_exception(e, stack=True)
                    raise
                finally:
                    xray_recorder.end_subsegment()

            return wrapper
        return decorator

    def create_sampling_rule(self) -> Dict[str, Any]:
        """
        Create an X-Ray sampling rule optimized for MangaAssist.

        At 1M messages/day, we use reservoir-based sampling:
          - 1 trace/second guaranteed (reservoir)
          - 5% of remaining traffic sampled (fixed rate)
          - Error traces always captured (separate rule)
        """
        xray_client = boto3.client("xray", region_name="ap-northeast-1")

        # Main sampling rule — 5% of normal traffic
        main_rule = xray_client.create_sampling_rule(
            SamplingRule={
                "RuleName": "MangaAssist-Default",
                "Priority": 100,
                "FixedRate": self.sampling_rate,
                "ReservoirSize": 1,  # 1 trace/second guaranteed
                "ServiceName": self.service_name,
                "ServiceType": "AWS::ECS::Container",
                "Host": "*",
                "HTTPMethod": "*",
                "URLPath": "*",
                "ResourceARN": "*",
                "Version": 1,
            }
        )

        # Error sampling rule — 100% of errors captured
        error_rule = xray_client.create_sampling_rule(
            SamplingRule={
                "RuleName": "MangaAssist-Errors",
                "Priority": 50,  # Higher priority than default
                "FixedRate": 1.0,  # Capture all errors
                "ReservoirSize": 10,
                "ServiceName": self.service_name,
                "ServiceType": "AWS::ECS::Container",
                "Host": "*",
                "HTTPMethod": "*",
                "URLPath": "*/error*",
                "ResourceARN": "*",
                "Version": 1,
            }
        )

        return {
            "main_rule": main_rule,
            "error_rule": error_rule,
        }

    def get_trace_summary(self, trace_id: str) -> Dict[str, Any]:
        """Retrieve a trace summary for debugging and analysis."""
        xray_client = boto3.client("xray", region_name="ap-northeast-1")

        response = xray_client.batch_get_traces(TraceIds=[trace_id])

        if not response.get("Traces"):
            return {"error": f"Trace {trace_id} not found"}

        trace = response["Traces"][0]
        segments = []

        for segment in trace.get("Segments", []):
            doc = json.loads(segment["Document"])
            segments.append({
                "name": doc.get("name"),
                "duration_ms": (
                    (doc.get("end_time", 0) - doc.get("start_time", 0)) * 1000
                ),
                "annotations": doc.get("annotations", {}),
                "error": doc.get("error", False),
                "fault": doc.get("fault", False),
            })

        return {
            "trace_id": trace_id,
            "duration_ms": trace.get("Duration", 0) * 1000,
            "segments": segments,
        }


class FMTraceContext:
    """
    Context object yielded by trace_fm_call for recording FM-specific data.
    """

    def __init__(
        self,
        subsegment: Optional[Subsegment],
        model_id: str,
        user_id: str,
        start_time: float,
    ):
        self.subsegment = subsegment
        self.model_id = model_id
        self.user_id = user_id
        self.start_time = start_time

    def record_success(
        self,
        input_tokens: int = 0,
        output_tokens: int = 0,
        fallback_tier: str = "sonnet",
        cache_hit: bool = False,
    ):
        """Record a successful FM invocation."""
        if not self.subsegment:
            return

        self.subsegment.put_annotation("fallback_tier", fallback_tier)
        self.subsegment.put_annotation("cache_hit", cache_hit)
        self.subsegment.put_annotation("status", "success")

        self.subsegment.put_metadata("input_tokens", input_tokens, "fm_usage")
        self.subsegment.put_metadata("output_tokens", output_tokens, "fm_usage")
        self.subsegment.put_metadata(
            "estimated_cost_usd",
            self._estimate_cost(input_tokens, output_tokens),
            "fm_usage",
        )

    def record_failure(
        self,
        error_message: str,
        error_code: str = "Unknown",
        is_throttle: bool = False,
    ):
        """Record a failed FM invocation."""
        if not self.subsegment:
            return

        self.subsegment.put_annotation("status", "error")
        self.subsegment.put_annotation("error_code", error_code)
        self.subsegment.put_annotation("is_throttle", is_throttle)
        self.subsegment.put_metadata("error_message", error_message, "errors")
        self.subsegment.add_error_flag()

    def _estimate_cost(self, input_tokens: int, output_tokens: int) -> float:
        """Estimate the cost of the FM call based on model pricing."""
        pricing = {
            "anthropic.claude-3-sonnet-20240229-v1:0": {
                "input_per_1m": 3.0,
                "output_per_1m": 15.0,
            },
            "anthropic.claude-3-haiku-20240307-v1:0": {
                "input_per_1m": 0.25,
                "output_per_1m": 1.25,
            },
        }

        model_pricing = pricing.get(self.model_id, {
            "input_per_1m": 3.0,
            "output_per_1m": 15.0,
        })

        input_cost = (input_tokens / 1_000_000) * model_pricing["input_per_1m"]
        output_cost = (output_tokens / 1_000_000) * model_pricing["output_per_1m"]
        return round(input_cost + output_cost, 6)

Summary: Resilience Layer Integration

flowchart TB
    subgraph ResilienceLayers["MangaAssist Resilience Stack"]
        direction TB
        L1["Layer 1: Edge Protection<br/>WAF + CloudFront + API GW Throttle"]
        L2["Layer 2: Rate Limiting<br/>Redis Token Bucket + Usage Plans"]
        L3["Layer 3: Retry + Backoff<br/>Exponential Backoff + Jitter"]
        L4["Layer 4: Circuit Breaker<br/>Fast-Fail on Persistent Failures"]
        L5["Layer 5: Fallback Cascade<br/>Sonnet → Haiku → Cache → FAQ → Graceful"]
        L6["Layer 6: Observability<br/>X-Ray + CloudWatch + Alarms"]

        L1 --> L2 --> L3 --> L4 --> L5
        L6 -.->|monitors| L1
        L6 -.->|monitors| L2
        L6 -.->|monitors| L3
        L6 -.->|monitors| L4
        L6 -.->|monitors| L5
    end

    style L1 fill:#fff3e0,stroke:#e65100
    style L2 fill:#fce4ec,stroke:#c62828
    style L3 fill:#fff9c4,stroke:#f57f17
    style L4 fill:#e8f5e9,stroke:#2e7d32
    style L5 fill:#e3f2fd,stroke:#1565c0
    style L6 fill:#f3e5f5,stroke:#6a1b9a

Key Design Decisions

Decision Choice Rationale
Jitter Strategy Decorrelated Best for high-contention ECS fleet hitting same Bedrock endpoint
Circuit Breaker Threshold 5 failures Balance between sensitivity and false positives
Recovery Timeout 30 seconds Allow Bedrock transient issues to resolve
Sampling Rate 5% (normal), 100% (errors) Cost-effective at 1M msgs/day while catching all errors
Cache TTL 1 hour Manga data changes infrequently; pricing updates handled separately
Fallback Tiers 5 levels Ensures 100% availability even during complete Bedrock outage
Boto3 Retry Mode Adaptive Client-side rate limiter prevents sending doomed requests

Quick Reference Card

EXPONENTIAL BACKOFF:
  Formula (decorrelated): min(cap, random(base, prev_delay * 3))
  Base: 100ms | Cap: 10s | Max retries: 5
  Boto3 mode: adaptive (client-side token bucket)

RATE LIMITING:
  Global: 2,000 rps (burst 4,000) via Redis token bucket
  Per-user: 5 rps (burst 10), premium 3x multiplier
  API GW: Usage plan with stage + route throttles

FALLBACK CASCADE:
  Sonnet (1.0) → Haiku (0.7) → Cache (0.5) → FAQ (0.3) → Graceful (0.1)
  Each tier adds ~10ms overhead for the fallback logic

CIRCUIT BREAKER:
  Threshold: 5 consecutive failures → OPEN
  Recovery: 30s timeout → HALF_OPEN → 1 probe → CLOSED/OPEN
  Integration: Stops retries, jumps to cache fallback

X-RAY TRACING:
  Sampling: 5% normal, 100% errors
  Annotations: model_id, fallback_tier, cache_hit, error_code
  Metadata: input_tokens, output_tokens, estimated_cost_usd
  Service map: APIGW → ECS → Bedrock/OpenSearch/Redis/DDB