LOCAL PREVIEW View on GitHub

FM API Interface Architecture for GenAI Workloads

MangaAssist context: JP Manga store chatbot on AWS — Bedrock Claude 3 (Sonnet at $3/$15 per 1M tokens input/output, Haiku at $0.25/$1.25), OpenSearch Serverless (vector store), DynamoDB (sessions/products), ECS Fargate (orchestrator), API Gateway WebSocket, ElastiCache Redis. Target: useful answer in under 3 seconds, 1M messages/day scale.


Skill Mapping

Field Value
Certification AWS AIP-C01
Domain 2 — Implementation & Integration
Task 2.5 — Application Integration Patterns
Skill 2.5.1 — Create FM API interfaces to address the specific requirements of GenAI workloads
Focus Areas API Gateway for streaming responses, token limit management, retry strategies for model timeouts
MangaAssist Relevance WebSocket streaming from Bedrock, per-request token budgets, exponential backoff with circuit breakers for Claude 3 timeouts

FM API Interface Mind Map

mindmap
  root((FM API Interfaces<br/>for GenAI Workloads))
    Streaming Responses
      API Gateway WebSocket
        $connect / $disconnect / $default routes
        10-minute idle timeout
        32KB frame limit
        Binary and text frames
      Bedrock InvokeModelWithResponseStream
        Server-sent event chunks
        content_block_delta parsing
        message_start / message_stop lifecycle
      Stream Relay Pipeline
        First-token latency tracking
        Progressive rendering to client
        Backpressure and flow control
        Stream cancellation on disconnect
      REST Fallback
        29-second hard timeout constraint
        Synchronous invoke for short queries
        Health and admin endpoints
    Token Limit Management
      Input Token Budgets
        Prompt template overhead reservation
        Context window allocation per model
        Japanese text token estimation 1-2 chars per token
        Truncation vs summarization strategies
      Output Token Caps
        max_tokens parameter enforcement
        Dynamic cap based on input size
        Cost guardrails per request
      Per-Session Accounting
        Cumulative token counters
        Daily budget enforcement
        Alert thresholds and circuit breaking
      Model-Specific Limits
        Sonnet 200K context window
        Haiku 200K context window
        Safety margin allocation
    Retry Strategies
      Exponential Backoff
        Base delay with jitter
        Max retry ceiling
        Retry budget per time window
      Circuit Breaker Pattern
        Closed / Open / Half-Open states
        Failure rate threshold
        Recovery probe interval
      Timeout Hierarchy
        Client timeout 10s streaming start
        API Gateway 29s REST hard limit
        Orchestrator 25s Bedrock call
        Bedrock 60s model timeout
      Idempotency
        Content-hash dedup keys
        5-second dedup window
        Response caching for replays
    Request Response Transformation
      Payload Normalization
        Client JSON to Bedrock Messages API
        System prompt injection
        Conversation history assembly
      Response Enrichment
        Token usage metadata
        Cost attribution tags
        Latency metrics injection
      Error Mapping
        Bedrock errors to client-friendly messages
        Throttling responses with retry-after
        Model-specific error codes
    API Gateway Configuration
      WebSocket API
        Route selection expression
        Lambda authorizer integration
        Connection management Lambda
        Usage plans and throttling
      REST API
        Integration timeout configuration
        Request validation models
        Response mapping templates
        Stage variables for multi-env

Architecture Overview

The FM API Interface layer is the critical boundary between MangaAssist client applications and Amazon Bedrock foundation models. This layer handles five responsibilities that are unique to GenAI workloads and absent from traditional API designs:

  1. Streaming relay — Token-by-token delivery from Bedrock through WebSocket to the browser
  2. Token budget enforcement — Preventing cost explosions at 1M messages/day scale
  3. Retry with idempotency — Surviving model timeouts without duplicate billing
  4. Request transformation — Converting simple client messages into Claude 3 Messages API payloads
  5. Latency management — Keeping perceived response time under 3 seconds via streaming

MangaAssist API Interface Layer — Architecture Flowchart

graph TB
    subgraph Clients["Client Applications"]
        WEB[Web App<br/>React on Amplify]
        MOB[Mobile App<br/>React Native]
        PARTNER[Partner SDK<br/>REST Client]
    end

    subgraph APIGW["API Gateway Layer"]
        WS_API[WebSocket API<br/>wss://manga-api.example.com]
        REST_API[REST API<br/>https://manga-api.example.com]
        AUTHORIZER[Lambda Authorizer<br/>JWT + API Key Validation]
        THROTTLE[Throttling Engine<br/>100 req/s per connection<br/>10,000 req/s account]
        ROUTES[Route Dispatcher<br/>$connect / $default / $disconnect]
    end

    subgraph Interface["FM API Interface Layer — ECS Fargate"]
        FM_INTERFACE[FMAPIInterface<br/>Central Coordinator]
        TOKEN_ENFORCER[TokenLimitEnforcer<br/>Budget Validation]
        RETRY_MGR[RetryStrategyManager<br/>Backoff + Circuit Breaker]
        STREAM_HANDLER[StreamingAPIHandler<br/>Chunk Relay + Metrics]
        TRANSFORMER[PayloadTransformer<br/>Client to Bedrock Format]
        IDEM[IdempotencyGuard<br/>Dedup + Response Cache]
    end

    subgraph Backend["Backend Services"]
        BEDROCK_S[Bedrock Claude 3 Sonnet<br/>Complex Recommendations]
        BEDROCK_H[Bedrock Claude 3 Haiku<br/>Simple Q&A]
        DYNAMO[DynamoDB<br/>Sessions + Products + Connections]
        OPENSEARCH[OpenSearch Serverless<br/>Manga Vector Store]
        REDIS[ElastiCache Redis<br/>Cache + Rate Limits + Dedup]
    end

    subgraph Monitoring["Observability"]
        CW[CloudWatch Metrics<br/>TTFT / TPS / Error Rate]
        XRAY[X-Ray Traces<br/>End-to-End Latency]
        ALARM[CloudWatch Alarms<br/>P95 TTFT > 800ms]
    end

    WEB -->|wss://| WS_API
    MOB -->|wss://| WS_API
    PARTNER -->|HTTPS POST| REST_API

    WS_API --> AUTHORIZER
    REST_API --> AUTHORIZER
    AUTHORIZER --> THROTTLE
    THROTTLE --> ROUTES
    ROUTES --> FM_INTERFACE

    FM_INTERFACE --> TOKEN_ENFORCER
    FM_INTERFACE --> RETRY_MGR
    FM_INTERFACE --> STREAM_HANDLER
    FM_INTERFACE --> TRANSFORMER
    FM_INTERFACE --> IDEM

    TOKEN_ENFORCER --> REDIS
    RETRY_MGR --> BEDROCK_S
    RETRY_MGR --> BEDROCK_H
    STREAM_HANDLER --> BEDROCK_S
    STREAM_HANDLER --> BEDROCK_H
    TRANSFORMER --> OPENSEARCH
    IDEM --> REDIS

    FM_INTERFACE --> DYNAMO
    FM_INTERFACE --> CW
    FM_INTERFACE --> XRAY

    CW --> ALARM

    style BEDROCK_S fill:#ff9900,color:#000
    style BEDROCK_H fill:#ffb84d,color:#000
    style DYNAMO fill:#3b48cc,color:#fff
    style OPENSEARCH fill:#005eb8,color:#fff
    style REDIS fill:#dc382c,color:#fff
    style FM_INTERFACE fill:#1a73e8,color:#fff
    style TOKEN_ENFORCER fill:#34a853,color:#fff
    style RETRY_MGR fill:#ea4335,color:#fff
    style STREAM_HANDLER fill:#fbbc04,color:#000

API Gateway Configuration for Streaming FM Responses

Why WebSocket for GenAI Streaming

REST APIs on API Gateway have a hard 29-second integration timeout. A Bedrock Claude 3 Sonnet call generating a long manga recommendation can take 5-15 seconds, and the REST timeout leaves zero margin for retries or RAG retrieval. WebSocket APIs solve this with:

  • No per-message timeout (only a 10-minute idle timeout)
  • Full-duplex communication for streaming token-by-token
  • Server-initiated push without client polling
  • Connection persistence across multiple conversation turns

WebSocket API — Route Architecture

sequenceDiagram
    participant Client as Browser Client
    participant APIGW as API Gateway WebSocket
    participant Auth as Lambda Authorizer
    participant Connect as $connect Lambda
    participant Default as $default Lambda
    participant ECS as ECS Fargate Orchestrator
    participant Bedrock as Claude 3 Sonnet
    participant Disconnect as $disconnect Lambda

    Client->>APIGW: WebSocket Upgrade + JWT
    APIGW->>Auth: Validate JWT token
    Auth-->>APIGW: Allow (IAM policy document)
    APIGW->>Connect: $connect route invoked
    Connect->>Connect: Store connectionId in DynamoDB + Redis
    Connect-->>APIGW: 200 OK
    APIGW-->>Client: WebSocket connection established

    Note over Client,Bedrock: Streaming Chat Flow

    Client->>APIGW: {"action":"chat","message":"Best manga for beginners?"}
    APIGW->>Default: $default route invoked
    Default->>ECS: Forward via Cloud Map service discovery
    ECS->>ECS: Token budget check, RAG retrieval
    ECS->>Bedrock: InvokeModelWithResponseStream
    Bedrock-->>ECS: chunk: "For someone new to manga..."
    ECS-->>APIGW: postToConnection chunk 1
    APIGW-->>Client: {"type":"chunk","text":"For someone new to manga..."}
    Bedrock-->>ECS: chunk: "I'd recommend starting with..."
    ECS-->>APIGW: postToConnection chunk 2
    APIGW-->>Client: {"type":"chunk","text":"I'd recommend starting with..."}
    Bedrock-->>ECS: stream complete, usage metadata
    ECS-->>APIGW: postToConnection done
    APIGW-->>Client: {"type":"done","tokens":{"in":487,"out":312},"cost":"$0.0061"}

    Note over Client,Bedrock: Disconnect Flow

    Client->>APIGW: Close WebSocket
    APIGW->>Disconnect: $disconnect route invoked
    Disconnect->>Disconnect: Clean connectionId, preserve session
    Disconnect-->>APIGW: 200 OK

WebSocket API — CloudFormation / SAM Configuration

# template.yaml — API Gateway WebSocket for MangaAssist FM Streaming
AWSTemplateFormatVersion: '2010-09-09'
Transform: AWS::Serverless-2016-10-31
Description: MangaAssist FM API Interface — WebSocket API for streaming Bedrock responses

Globals:
  Function:
    Runtime: python3.12
    Timeout: 30
    MemorySize: 256
    Environment:
      Variables:
        CONNECTIONS_TABLE: !Ref ConnectionsTable
        SESSIONS_TABLE: !Ref SessionsTable
        REDIS_ENDPOINT: !GetAtt RedisCluster.RedisEndpoint.Address
        BEDROCK_REGION: ap-northeast-1

Resources:
  # ─── WebSocket API ───────────────────────────────────────────────
  MangaAssistWebSocketAPI:
    Type: AWS::ApiGatewayV2::Api
    Properties:
      Name: MangaAssist-FM-WebSocket
      ProtocolType: WEBSOCKET
      RouteSelectionExpression: "$request.body.action"
      Description: WebSocket API for streaming FM responses to MangaAssist clients

  # ─── Authorizer ──────────────────────────────────────────────────
  WebSocketAuthorizer:
    Type: AWS::ApiGatewayV2::Authorizer
    Properties:
      ApiId: !Ref MangaAssistWebSocketAPI
      AuthorizerType: REQUEST
      AuthorizerUri: !Sub "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${AuthFunction.Arn}/invocations"
      IdentitySource:
        - "route.request.querystring.token"
      Name: MangaAssistJWTAuthorizer

  # ─── $connect Route ──────────────────────────────────────────────
  ConnectRoute:
    Type: AWS::ApiGatewayV2::Route
    Properties:
      ApiId: !Ref MangaAssistWebSocketAPI
      RouteKey: "$connect"
      AuthorizationType: CUSTOM
      AuthorizerId: !Ref WebSocketAuthorizer
      OperationName: ConnectRoute
      Target: !Sub "integrations/${ConnectIntegration}"

  ConnectIntegration:
    Type: AWS::ApiGatewayV2::Integration
    Properties:
      ApiId: !Ref MangaAssistWebSocketAPI
      IntegrationType: AWS_PROXY
      IntegrationUri: !Sub "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${ConnectFunction.Arn}/invocations"

  # ─── $disconnect Route ───────────────────────────────────────────
  DisconnectRoute:
    Type: AWS::ApiGatewayV2::Route
    Properties:
      ApiId: !Ref MangaAssistWebSocketAPI
      RouteKey: "$disconnect"
      AuthorizationType: NONE
      OperationName: DisconnectRoute
      Target: !Sub "integrations/${DisconnectIntegration}"

  DisconnectIntegration:
    Type: AWS::ApiGatewayV2::Integration
    Properties:
      ApiId: !Ref MangaAssistWebSocketAPI
      IntegrationType: AWS_PROXY
      IntegrationUri: !Sub "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${DisconnectFunction.Arn}/invocations"

  # ─── $default Route ──────────────────────────────────────────────
  DefaultRoute:
    Type: AWS::ApiGatewayV2::Route
    Properties:
      ApiId: !Ref MangaAssistWebSocketAPI
      RouteKey: "$default"
      AuthorizationType: NONE
      OperationName: DefaultRoute
      Target: !Sub "integrations/${DefaultIntegration}"

  DefaultIntegration:
    Type: AWS::ApiGatewayV2::Integration
    Properties:
      ApiId: !Ref MangaAssistWebSocketAPI
      IntegrationType: AWS_PROXY
      IntegrationUri: !Sub "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${DefaultFunction.Arn}/invocations"

  # ─── Stage + Deployment ──────────────────────────────────────────
  WebSocketStage:
    Type: AWS::ApiGatewayV2::Stage
    Properties:
      ApiId: !Ref MangaAssistWebSocketAPI
      StageName: prod
      AutoDeploy: true
      DefaultRouteSettings:
        ThrottlingRateLimit: 100
        ThrottlingBurstLimit: 200
      StageVariables:
        ENVIRONMENT: production
        BEDROCK_MODEL_SONNET: "anthropic.claude-3-sonnet-20240229-v1:0"
        BEDROCK_MODEL_HAIKU: "anthropic.claude-3-haiku-20240307-v1:0"

  # ─── Lambda Functions ────────────────────────────────────────────
  AuthFunction:
    Type: AWS::Serverless::Function
    Properties:
      Handler: handlers.auth.handle_authorize
      CodeUri: src/
      Description: JWT token validation for WebSocket connections

  ConnectFunction:
    Type: AWS::Serverless::Function
    Properties:
      Handler: handlers.connect.handle_connect
      CodeUri: src/
      Description: WebSocket $connect — register connection and link to session

  DisconnectFunction:
    Type: AWS::Serverless::Function
    Properties:
      Handler: handlers.disconnect.handle_disconnect
      CodeUri: src/
      Description: WebSocket $disconnect — clean up connection, preserve session

  DefaultFunction:
    Type: AWS::Serverless::Function
    Properties:
      Handler: handlers.default.handle_default
      CodeUri: src/
      Timeout: 30
      MemorySize: 512
      Description: WebSocket $default — route chat messages to FM interface layer
      Policies:
        - Statement:
            - Effect: Allow
              Action:
                - "bedrock:InvokeModel"
                - "bedrock:InvokeModelWithResponseStream"
              Resource: "arn:aws:bedrock:ap-northeast-1::foundation-model/*"
            - Effect: Allow
              Action:
                - "execute-api:ManageConnections"
              Resource: !Sub "arn:aws:execute-api:${AWS::Region}:${AWS::AccountId}:${MangaAssistWebSocketAPI}/*"

Outputs:
  WebSocketURI:
    Description: WebSocket connection URL
    Value: !Sub "wss://${MangaAssistWebSocketAPI}.execute-api.${AWS::Region}.amazonaws.com/prod"
  WebSocketManagementURL:
    Description: Management API endpoint for postToConnection
    Value: !Sub "https://${MangaAssistWebSocketAPI}.execute-api.${AWS::Region}.amazonaws.com/prod"

REST API — Configuration for Non-Streaming Endpoints

# REST API for health checks, admin operations, and sync queries
  MangaAssistRESTAPI:
    Type: AWS::Serverless::Api
    Properties:
      Name: MangaAssist-FM-REST
      StageName: prod
      EndpointConfiguration:
        Type: REGIONAL
      Auth:
        DefaultAuthorizer: MangaAssistCognitoAuth
        Authorizers:
          MangaAssistCognitoAuth:
            UserPoolArn: !GetAtt CognitoUserPool.Arn
      MethodSettings:
        - ResourcePath: "/*"
          HttpMethod: "*"
          ThrottlingRateLimit: 50
          ThrottlingBurstLimit: 100
          MetricsEnabled: true
          DataTraceEnabled: false
          LoggingLevel: INFO
      # CRITICAL: 29-second hard limit for REST integrations
      # This is why streaming chat uses WebSocket, not REST
      DefinitionBody:
        swagger: "2.0"
        info:
          title: MangaAssist FM REST API
        paths:
          /health:
            get:
              x-amazon-apigateway-integration:
                type: aws_proxy
                httpMethod: POST
                uri: !Sub "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${HealthFunction.Arn}/invocations"
                timeoutInMillis: 5000  # Health checks should be fast
          /chat/sync:
            post:
              x-amazon-apigateway-integration:
                type: aws_proxy
                httpMethod: POST
                uri: !Sub "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${SyncChatFunction.Arn}/invocations"
                timeoutInMillis: 29000  # Hard maximum for REST
              x-amazon-apigateway-request-validator: full
          /admin/sessions:
            get:
              x-amazon-apigateway-integration:
                type: aws_proxy
                httpMethod: POST
                uri: !Sub "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${AdminFunction.Arn}/invocations"
                timeoutInMillis: 10000
          /admin/metrics:
            get:
              x-amazon-apigateway-integration:
                type: aws_proxy
                httpMethod: POST
                uri: !Sub "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${MetricsFunction.Arn}/invocations"
                timeoutInMillis: 10000

Token Limit Management at the API Layer

Token limits in GenAI workloads serve three purposes: preventing cost explosions, staying within model context windows, and maintaining response quality. At MangaAssist's scale of 1M messages/day, even small per-request inefficiencies compound dramatically.

Token Budget Allocation Strategy

Claude 3 Sonnet Context Window: 200,000 tokens

Allocation for MangaAssist Chat Request:
+-- System Prompt Template .......... 300 tokens (fixed)
+-- RAG Context (manga metadata) .... 1,500 tokens (variable)
+-- Conversation History ............ 2,000 tokens (sliding window)
+-- Current User Message ............ 500 tokens (variable)
+-- Safety Margin ................... 500 tokens (reserved)
+-- Output Budget ................... 1,024 tokens (max_tokens param)
+-- Remaining (unused) .............. 194,176 tokens

For Haiku Simple Queries:
+-- System Prompt Template .......... 150 tokens (shorter)
+-- Current User Message ............ 300 tokens (simple questions)
+-- Safety Margin ................... 200 tokens
+-- Output Budget ................... 256 tokens (short answers)
+-- Total Per Request ............... ~906 tokens

Cost Impact at Scale

Scenario Input Tokens/Req Output Tokens/Req Model Daily Cost (1M msgs)
Uncontrolled (no budgets) 5,000 2,000 Sonnet $45,000/day
Budget-enforced Sonnet 2,000 1,024 Sonnet $21,360/day
Haiku for simple queries (70%) 500 256 Haiku $445/day
Sonnet for complex (30%) 2,000 1,024 Sonnet $6,408/day
Optimized blend Mixed $6,853/day

The difference between uncontrolled and optimized: $38,147/day ($1.14M/month).


Retry Strategies for Model Timeouts

Timeout Hierarchy

graph TB
    subgraph Timeouts["Timeout Hierarchy — Innermost to Outermost"]
        direction TB
        BEDROCK[Bedrock Model Timeout<br/>~60s internal<br/>Not configurable]
        ORCH[Orchestrator Timeout<br/>25s per Bedrock call<br/>Configurable in boto3 Config]
        APIGW_REST[API Gateway REST Timeout<br/>29s hard limit<br/>Not configurable]
        APIGW_WS[API Gateway WebSocket<br/>10 min idle timeout<br/>No per-message limit]
        CLIENT[Client Timeout<br/>10s for first token<br/>60s for full response]
    end

    BEDROCK --> ORCH
    ORCH --> APIGW_REST
    ORCH --> APIGW_WS
    APIGW_REST --> CLIENT
    APIGW_WS --> CLIENT

    subgraph Strategy["Retry Strategy per Layer"]
        R1["Bedrock: boto3 adaptive mode<br/>3 retries, SDK-managed backoff"]
        R2["Orchestrator: custom retry<br/>2 retries with jitter backoff"]
        R3["API Gateway: no retry<br/>Client must reconnect"]
        R4["Client: reconnect + replay<br/>Idempotency key prevents dupes"]
    end

    style BEDROCK fill:#dc3545,color:#fff
    style ORCH fill:#ff6b35,color:#fff
    style APIGW_REST fill:#ffc107,color:#000
    style APIGW_WS fill:#28a745,color:#fff
    style CLIENT fill:#1a73e8,color:#fff

Exponential Backoff with Jitter

Retry 1: base_delay * (2^0) + random(0, base_delay) = 1s + jitter
Retry 2: base_delay * (2^1) + random(0, base_delay) = 2s + jitter
Retry 3: base_delay * (2^2) + random(0, base_delay) = 4s + jitter
Max delay cap: 10 seconds (never exceed regardless of retry count)

With base_delay = 1.0s:
  Retry 1: 1.0 - 2.0s
  Retry 2: 2.0 - 3.0s
  Retry 3: 4.0 - 5.0s (capped at 10s)

Circuit Breaker State Machine

stateDiagram-v2
    [*] --> Closed: Initial state

    Closed --> Closed: Success (reset failure count)
    Closed --> Open: Failure count >= threshold (5)

    Open --> Open: Reject request immediately
    Open --> HalfOpen: Recovery timeout elapsed (30s)

    HalfOpen --> Closed: Probe request succeeds
    HalfOpen --> Open: Probe request fails

    note right of Closed
        Normal operation
        Track consecutive failures
        Threshold: 5 failures in 60s
    end note

    note right of Open
        Fast-fail all requests
        Return cached / fallback response
        Wait 30s before probing
    end note

    note right of HalfOpen
        Allow single probe request
        If success then close circuit
        If fail then reopen for another 30s
    end note

Request/Response Transformation for FM Payloads

The API interface layer must transform simple client messages into Bedrock-compatible payloads and enrich responses with metadata before returning them to the client.

Transformation Pipeline

graph LR
    subgraph Input["Client Request"]
        CLIENT_MSG["action: chat<br/>message: Recommend shonen manga<br/>sessionId: abc-123"]
    end

    subgraph Transform["Transformation Steps"]
        T1[1. Session Resolution<br/>Load conversation history]
        T2[2. RAG Retrieval<br/>Query OpenSearch for context]
        T3[3. Token Budget Check<br/>Validate total input tokens]
        T4[4. Prompt Assembly<br/>System + history + RAG + user msg]
        T5[5. Model Selection<br/>Haiku for simple, Sonnet for complex]
    end

    subgraph Output["Bedrock Payload"]
        BEDROCK_MSG["anthropic_version: bedrock-2023-05-31<br/>max_tokens: 1024<br/>system: You are MangaAssist...<br/>messages: history + user msg<br/>temperature: 0.3"]
    end

    CLIENT_MSG --> T1 --> T2 --> T3 --> T4 --> T5 --> BEDROCK_MSG

Core Implementation: FMAPIInterface

"""
MangaAssist FM API Interface — Central Coordinator
Skill 2.5.1: Create FM API interfaces for GenAI workloads.

This module is the main entry point for all FM (Foundation Model) interactions.
It coordinates token budgeting, retry logic, streaming, and request transformation.
"""

import json
import time
import logging
import hashlib
from dataclasses import dataclass, field
from typing import Optional, Generator, Any
from enum import Enum

import boto3
from botocore.config import Config
from botocore.exceptions import ClientError

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


# ─── Configuration Constants ──────────────────────────────────────────────────

class ModelTier(Enum):
    """Available Bedrock model tiers with their identifiers."""
    SONNET = "anthropic.claude-3-sonnet-20240229-v1:0"
    HAIKU = "anthropic.claude-3-haiku-20240307-v1:0"


# Pricing per 1M tokens (USD)
MODEL_PRICING = {
    ModelTier.SONNET: {"input": 3.00, "output": 15.00},
    ModelTier.HAIKU: {"input": 0.25, "output": 1.25},
}

# Model context windows
MODEL_CONTEXT_WINDOWS = {
    ModelTier.SONNET: 200_000,
    ModelTier.HAIKU: 200_000,
}


# ─── Data Classes ─────────────────────────────────────────────────────────────

@dataclass
class FMRequestConfig:
    """Configuration for a single FM API request."""
    model_tier: ModelTier = ModelTier.SONNET
    max_output_tokens: int = 1024
    temperature: float = 0.3
    top_p: float = 0.9
    stream: bool = True
    timeout_seconds: int = 25
    max_retries: int = 2
    idempotency_key: Optional[str] = None


@dataclass
class FMRequest:
    """Fully-formed request ready for Bedrock invocation."""
    session_id: str
    user_message: str
    system_prompt: str
    conversation_history: list[dict]
    rag_context: Optional[str] = None
    config: FMRequestConfig = field(default_factory=FMRequestConfig)
    estimated_input_tokens: int = 0
    request_id: str = ""


@dataclass
class FMResponse:
    """Complete response from an FM invocation."""
    text: str
    input_tokens: int
    output_tokens: int
    model_tier: ModelTier
    time_to_first_token_ms: float
    total_latency_ms: float
    tokens_per_second: float
    cost_usd: float
    request_id: str
    cached: bool = False


@dataclass
class StreamChunk:
    """Single chunk in a streaming response."""
    text: str
    chunk_index: int
    cumulative_tokens: int
    timestamp: float


class FMAPIInterface:
    """
    Central coordinator for all Foundation Model API interactions.

    Responsibilities:
    - Assemble prompts from session history, RAG context, and user messages
    - Enforce token budgets before invocation
    - Route to appropriate model tier (Sonnet vs Haiku)
    - Manage streaming response relay
    - Track costs and latency metrics
    - Coordinate retry strategies

    Usage:
        interface = FMAPIInterface(redis_client=redis, region="ap-northeast-1")
        response = interface.invoke(request)
        # or for streaming:
        for chunk in interface.invoke_stream(request):
            send_to_websocket(chunk)
    """

    # System prompt template for MangaAssist
    SYSTEM_PROMPT_TEMPLATE = (
        "You are MangaAssist, a knowledgeable and friendly manga recommendation "
        "assistant for a Japanese manga store. You help customers discover manga "
        "based on their preferences, reading history, and current interests.\n\n"
        "Guidelines:\n"
        "- Respond in the same language the customer uses (Japanese or English)\n"
        "- Provide specific volume and chapter recommendations when possible\n"
        "- Consider the customer's reading level and genre preferences\n"
        "- Mention current promotions or new releases when relevant\n"
        "{rag_context}"
    )

    def __init__(
        self,
        redis_client,
        region: str = "ap-northeast-1",
        token_enforcer: Optional["TokenLimitEnforcer"] = None,
        retry_manager: Optional["RetryStrategyManager"] = None,
    ):
        self.redis = redis_client
        self.region = region

        # Initialize Bedrock client with retry configuration
        bedrock_config = Config(
            region_name=region,
            retries={"max_attempts": 3, "mode": "adaptive"},
            read_timeout=60,
            connect_timeout=5,
        )
        self.bedrock_runtime = boto3.client(
            "bedrock-runtime", config=bedrock_config
        )

        # Initialize sub-components
        self.token_enforcer = token_enforcer or TokenLimitEnforcer(redis_client)
        self.retry_manager = retry_manager or RetryStrategyManager(redis_client)

        # Metrics tracking
        self._request_count = 0
        self._error_count = 0

    def prepare_request(
        self,
        session_id: str,
        user_message: str,
        conversation_history: list[dict],
        rag_context: Optional[str] = None,
        config: Optional[FMRequestConfig] = None,
    ) -> FMRequest:
        """
        Assemble a complete FM request from components.

        Steps:
        1. Build system prompt with optional RAG context
        2. Estimate total input tokens
        3. Select model tier based on query complexity
        4. Generate idempotency key
        5. Return validated FMRequest
        """
        if config is None:
            config = FMRequestConfig()

        # Build system prompt
        rag_section = ""
        if rag_context:
            rag_section = (
                f"\n\nRelevant product information:\n{rag_context}\n"
            )
        system_prompt = self.SYSTEM_PROMPT_TEMPLATE.format(
            rag_context=rag_section
        )

        # Estimate input tokens
        total_text = system_prompt + user_message
        for turn in conversation_history:
            total_text += turn.get("content", "")
        estimated_tokens = estimate_token_count(total_text, language="ja")

        # Auto-select model tier based on complexity
        if config.model_tier == ModelTier.SONNET:
            # Check if query is simple enough for Haiku
            if self._is_simple_query(user_message, conversation_history):
                config.model_tier = ModelTier.HAIKU
                config.max_output_tokens = min(config.max_output_tokens, 512)
                logger.info(
                    f"Auto-routed to Haiku for simple query: "
                    f"'{user_message[:50]}...'"
                )

        # Generate idempotency key
        ts_window = int(time.time()) // 5  # 5-second dedup window
        idempotency_key = hashlib.sha256(
            f"{session_id}:{user_message}:{ts_window}".encode()
        ).hexdigest()[:16]
        config.idempotency_key = idempotency_key

        request = FMRequest(
            session_id=session_id,
            user_message=user_message,
            system_prompt=system_prompt,
            conversation_history=conversation_history,
            rag_context=rag_context,
            config=config,
            estimated_input_tokens=estimated_tokens,
            request_id=idempotency_key,
        )

        return request

    def invoke(self, request: FMRequest) -> FMResponse:
        """
        Synchronous FM invocation with token enforcement and retries.
        Used for REST API endpoints or when streaming is not needed.
        """
        # Step 1: Token budget enforcement
        self.token_enforcer.validate_request(request)

        # Step 2: Check idempotency cache
        cached = self._check_idempotency_cache(request.config.idempotency_key)
        if cached:
            return cached

        # Step 3: Build Bedrock payload
        payload = self._build_bedrock_payload(request, stream=False)

        # Step 4: Invoke with retry strategy
        start_time = time.time()

        def _do_invoke():
            return self.bedrock_runtime.invoke_model(
                modelId=request.config.model_tier.value,
                contentType="application/json",
                accept="application/json",
                body=json.dumps(payload),
            )

        raw_response = self.retry_manager.execute_with_retry(
            operation=_do_invoke,
            operation_name=f"bedrock_invoke_{request.config.model_tier.name}",
            max_retries=request.config.max_retries,
        )

        total_latency = (time.time() - start_time) * 1000

        # Step 5: Parse response
        response_body = json.loads(raw_response["body"].read())
        text = response_body["content"][0]["text"]
        usage = response_body.get("usage", {})
        input_tokens = usage.get("input_tokens", 0)
        output_tokens = usage.get("output_tokens", 0)

        # Step 6: Calculate cost
        pricing = MODEL_PRICING[request.config.model_tier]
        cost = (
            (input_tokens / 1_000_000) * pricing["input"]
            + (output_tokens / 1_000_000) * pricing["output"]
        )

        response = FMResponse(
            text=text,
            input_tokens=input_tokens,
            output_tokens=output_tokens,
            model_tier=request.config.model_tier,
            time_to_first_token_ms=total_latency,  # No streaming = same as total
            total_latency_ms=total_latency,
            tokens_per_second=output_tokens / (total_latency / 1000) if total_latency > 0 else 0,
            cost_usd=cost,
            request_id=request.request_id,
        )

        # Cache for idempotency
        self._cache_response(request.config.idempotency_key, response)
        self._request_count += 1

        return response

    def invoke_stream(
        self, request: FMRequest
    ) -> Generator[StreamChunk, None, FMResponse]:
        """
        Streaming FM invocation — yields chunks as they arrive.

        Yields StreamChunk objects for each text delta.
        Returns the final FMResponse with complete metrics.

        Usage:
            gen = interface.invoke_stream(request)
            for chunk in gen:
                websocket.send(chunk.text)
        """
        # Step 1: Token budget enforcement
        self.token_enforcer.validate_request(request)

        # Step 2: Check idempotency cache
        cached = self._check_idempotency_cache(request.config.idempotency_key)
        if cached:
            yield StreamChunk(
                text=cached.text,
                chunk_index=0,
                cumulative_tokens=cached.output_tokens,
                timestamp=time.time(),
            )
            return cached

        # Step 3: Build Bedrock payload
        payload = self._build_bedrock_payload(request, stream=True)

        # Step 4: Invoke streaming with retry
        start_time = time.time()
        first_token_time = None
        chunks_text = []
        chunk_index = 0
        input_tokens = 0
        output_tokens = 0

        def _do_stream():
            return self.bedrock_runtime.invoke_model_with_response_stream(
                modelId=request.config.model_tier.value,
                contentType="application/json",
                accept="application/json",
                body=json.dumps(payload),
            )

        raw_response = self.retry_manager.execute_with_retry(
            operation=_do_stream,
            operation_name=f"bedrock_stream_{request.config.model_tier.name}",
            max_retries=request.config.max_retries,
        )

        stream = raw_response.get("body")

        for event in stream:
            chunk = event.get("chunk")
            if not chunk:
                continue

            chunk_data = json.loads(chunk.get("bytes", b"{}"))
            chunk_type = chunk_data.get("type")

            if chunk_type == "content_block_delta":
                delta = chunk_data.get("delta", {})
                text = delta.get("text", "")

                if text:
                    now = time.time()
                    if first_token_time is None:
                        first_token_time = now

                    chunks_text.append(text)
                    chunk_index += 1

                    yield StreamChunk(
                        text=text,
                        chunk_index=chunk_index,
                        cumulative_tokens=chunk_index,  # Approximate
                        timestamp=now,
                    )

            elif chunk_type == "message_start":
                msg_usage = chunk_data.get("message", {}).get("usage", {})
                input_tokens = msg_usage.get("input_tokens", 0)

            elif chunk_type == "message_delta":
                usage = chunk_data.get("usage", {})
                output_tokens = usage.get("output_tokens", 0)

        # Build final response
        end_time = time.time()
        total_latency = (end_time - start_time) * 1000
        ttft = ((first_token_time - start_time) * 1000) if first_token_time else total_latency

        pricing = MODEL_PRICING[request.config.model_tier]
        cost = (
            (input_tokens / 1_000_000) * pricing["input"]
            + (output_tokens / 1_000_000) * pricing["output"]
        )

        full_text = "".join(chunks_text)
        duration_s = (end_time - (first_token_time or start_time))

        response = FMResponse(
            text=full_text,
            input_tokens=input_tokens,
            output_tokens=output_tokens,
            model_tier=request.config.model_tier,
            time_to_first_token_ms=ttft,
            total_latency_ms=total_latency,
            tokens_per_second=output_tokens / duration_s if duration_s > 0 else 0,
            cost_usd=cost,
            request_id=request.request_id,
        )

        self._cache_response(request.config.idempotency_key, response)
        self._request_count += 1

        logger.info(
            "Stream completed",
            extra={
                "request_id": request.request_id,
                "model": request.config.model_tier.name,
                "ttft_ms": round(ttft),
                "total_ms": round(total_latency),
                "tps": round(response.tokens_per_second, 1),
                "input_tokens": input_tokens,
                "output_tokens": output_tokens,
                "cost_usd": round(cost, 6),
            },
        )

        return response

    def _build_bedrock_payload(self, request: FMRequest, stream: bool) -> dict:
        """Construct the Bedrock Messages API payload."""
        messages = []
        for turn in request.conversation_history:
            messages.append({
                "role": turn["role"],
                "content": [{"type": "text", "text": turn["content"]}],
            })
        messages.append({
            "role": "user",
            "content": [{"type": "text", "text": request.user_message}],
        })

        return {
            "anthropic_version": "bedrock-2023-05-31",
            "max_tokens": request.config.max_output_tokens,
            "system": request.system_prompt,
            "messages": messages,
            "temperature": request.config.temperature,
            "top_p": request.config.top_p,
        }

    def _is_simple_query(
        self, message: str, history: list[dict]
    ) -> bool:
        """
        Heuristic to determine if a query can be handled by Haiku.
        Simple queries: greetings, yes/no, short factual questions.
        Complex queries: recommendations, comparisons, long-form reviews.
        """
        simple_indicators = [
            len(message) < 50,
            any(kw in message.lower() for kw in [
                "hello", "hi", "thanks", "bye",
                "yes", "no", "ok",
                "price", "hours", "address", "stock",
            ]),
            len(history) == 0,  # First message often simple
        ]
        complex_indicators = [
            any(kw in message.lower() for kw in [
                "recommend", "suggest", "compare", "review",
                "similar to", "like", "best",
                "explain", "describe", "analysis",
            ]),
            len(message) > 200,
        ]
        return (
            sum(simple_indicators) >= 2
            and not any(complex_indicators)
        )

    def _check_idempotency_cache(
        self, idempotency_key: str
    ) -> Optional[FMResponse]:
        """Check Redis for a cached response from a duplicate request."""
        try:
            cached = self.redis.get(f"idemp_resp:{idempotency_key}")
            if cached:
                data = json.loads(cached)
                return FMResponse(
                    text=data["text"],
                    input_tokens=data["input_tokens"],
                    output_tokens=data["output_tokens"],
                    model_tier=ModelTier(data["model_tier"]),
                    time_to_first_token_ms=0,
                    total_latency_ms=0,
                    tokens_per_second=0,
                    cost_usd=data["cost_usd"],
                    request_id=idempotency_key,
                    cached=True,
                )
        except Exception as e:
            logger.warning(f"Idempotency cache check failed: {e}")
        return None

    def _cache_response(
        self, idempotency_key: str, response: FMResponse
    ) -> None:
        """Cache response in Redis for idempotency dedup."""
        try:
            data = {
                "text": response.text,
                "input_tokens": response.input_tokens,
                "output_tokens": response.output_tokens,
                "model_tier": response.model_tier.value,
                "cost_usd": response.cost_usd,
            }
            self.redis.setex(
                f"idemp_resp:{idempotency_key}",
                30,  # 30-second TTL for dedup
                json.dumps(data, ensure_ascii=False),
            )
        except Exception as e:
            logger.warning(f"Idempotency cache write failed: {e}")

TokenLimitEnforcer

"""
MangaAssist Token Limit Enforcer
Validates and enforces token budgets at the API layer before Bedrock invocation.
"""

import time
import json
import logging
from dataclasses import dataclass
from typing import Optional

logger = logging.getLogger(__name__)


class TokenBudgetExceeded(Exception):
    """Raised when a request exceeds its token budget."""
    def __init__(self, message: str, budget_type: str, limit: int, actual: int):
        super().__init__(message)
        self.budget_type = budget_type
        self.limit = limit
        self.actual = actual


class DailyBudgetExhausted(Exception):
    """Raised when daily token/cost budget is exhausted."""
    def __init__(self, message: str, budget_type: str, limit: float, used: float):
        super().__init__(message)
        self.budget_type = budget_type
        self.limit = limit
        self.used = used


@dataclass
class TokenBudgetConfig:
    """Token budget configuration for different request types."""
    # Per-request limits
    max_input_tokens: int = 4_000
    max_output_tokens: int = 1_024
    max_total_tokens: int = 5_024
    # Per-session limits
    max_session_input_tokens: int = 50_000
    max_session_output_tokens: int = 25_000
    # Daily limits (per user)
    daily_input_token_limit: int = 500_000
    daily_output_token_limit: int = 250_000
    daily_cost_limit_usd: float = 5.00
    # Model-specific context windows
    context_window_size: int = 200_000
    prompt_template_overhead: int = 300
    safety_margin: int = 500
    # Japanese text multiplier (JP chars use more tokens)
    jp_token_multiplier: float = 0.7  # ~1.4 chars per token for JP


class TokenLimitEnforcer:
    """
    Enforces token limits at multiple levels:
    1. Per-request: input tokens, output tokens, total
    2. Per-session: cumulative usage across conversation
    3. Per-user daily: prevent runaway costs
    4. Context window: ensure input fits in model window

    All limits are checked before Bedrock invocation to fail fast
    and avoid unnecessary API calls and charges.
    """

    def __init__(
        self,
        redis_client,
        config: Optional[TokenBudgetConfig] = None,
    ):
        self.redis = redis_client
        self.config = config or TokenBudgetConfig()

    def validate_request(self, request: "FMRequest") -> None:
        """
        Run all token budget validations.
        Raises TokenBudgetExceeded or DailyBudgetExhausted on failure.
        """
        self._validate_per_request_limits(request)
        self._validate_context_window(request)
        self._validate_session_limits(request)
        self._validate_daily_limits(request)

    def _validate_per_request_limits(self, request: "FMRequest") -> None:
        """Check that the request fits within per-request token limits."""
        estimated_input = request.estimated_input_tokens
        max_output = request.config.max_output_tokens

        if estimated_input > self.config.max_input_tokens:
            raise TokenBudgetExceeded(
                f"Input tokens ({estimated_input}) exceed per-request limit "
                f"({self.config.max_input_tokens}). Consider truncating "
                f"conversation history or RAG context.",
                budget_type="per_request_input",
                limit=self.config.max_input_tokens,
                actual=estimated_input,
            )

        if max_output > self.config.max_output_tokens:
            raise TokenBudgetExceeded(
                f"Requested output tokens ({max_output}) exceed limit "
                f"({self.config.max_output_tokens}).",
                budget_type="per_request_output",
                limit=self.config.max_output_tokens,
                actual=max_output,
            )

        total = estimated_input + max_output
        if total > self.config.max_total_tokens:
            raise TokenBudgetExceeded(
                f"Total tokens ({total}) exceed per-request limit "
                f"({self.config.max_total_tokens}).",
                budget_type="per_request_total",
                limit=self.config.max_total_tokens,
                actual=total,
            )

    def _validate_context_window(self, request: "FMRequest") -> None:
        """Ensure the request fits within the model's context window."""
        overhead = self.config.prompt_template_overhead
        safety = self.config.safety_margin
        max_output = request.config.max_output_tokens
        available = self.config.context_window_size - overhead - safety - max_output

        if request.estimated_input_tokens > available:
            raise TokenBudgetExceeded(
                f"Input tokens ({request.estimated_input_tokens}) exceed "
                f"available context window ({available} = {self.config.context_window_size} "
                f"- {overhead} overhead - {safety} safety - {max_output} output).",
                budget_type="context_window",
                limit=available,
                actual=request.estimated_input_tokens,
            )

    def _validate_session_limits(self, request: "FMRequest") -> None:
        """Check cumulative session token usage."""
        session_key = f"session_tokens:{request.session_id}"
        try:
            usage_raw = self.redis.get(session_key)
            if usage_raw:
                usage = json.loads(usage_raw)
                session_input = usage.get("input", 0)
                session_output = usage.get("output", 0)

                if session_input + request.estimated_input_tokens > self.config.max_session_input_tokens:
                    raise TokenBudgetExceeded(
                        f"Session input tokens would reach "
                        f"{session_input + request.estimated_input_tokens}, "
                        f"exceeding limit {self.config.max_session_input_tokens}. "
                        f"Consider starting a new session.",
                        budget_type="session_input",
                        limit=self.config.max_session_input_tokens,
                        actual=session_input + request.estimated_input_tokens,
                    )

                if session_output > self.config.max_session_output_tokens:
                    raise TokenBudgetExceeded(
                        f"Session output tokens ({session_output}) already at limit "
                        f"({self.config.max_session_output_tokens}).",
                        budget_type="session_output",
                        limit=self.config.max_session_output_tokens,
                        actual=session_output,
                    )
        except (ConnectionError, TimeoutError) as e:
            # Redis failure — allow request but log warning
            logger.warning(f"Redis unavailable for session limit check: {e}")

    def _validate_daily_limits(self, request: "FMRequest") -> None:
        """Check user's daily token and cost budget."""
        day_key = time.strftime("%Y-%m-%d")
        user_key = f"daily_tokens:{request.session_id}:{day_key}"

        try:
            usage_raw = self.redis.get(user_key)
            if usage_raw:
                usage = json.loads(usage_raw)
                daily_cost = usage.get("cost_usd", 0.0)

                if daily_cost >= self.config.daily_cost_limit_usd:
                    raise DailyBudgetExhausted(
                        f"Daily cost limit reached: ${daily_cost:.2f} / "
                        f"${self.config.daily_cost_limit_usd:.2f}. "
                        f"Budget resets at midnight UTC.",
                        budget_type="daily_cost",
                        limit=self.config.daily_cost_limit_usd,
                        used=daily_cost,
                    )

                daily_input = usage.get("input_tokens", 0)
                if daily_input + request.estimated_input_tokens > self.config.daily_input_token_limit:
                    raise DailyBudgetExhausted(
                        f"Daily input token limit would be exceeded: "
                        f"{daily_input + request.estimated_input_tokens} / "
                        f"{self.config.daily_input_token_limit}.",
                        budget_type="daily_input_tokens",
                        limit=self.config.daily_input_token_limit,
                        used=daily_input,
                    )
        except (ConnectionError, TimeoutError) as e:
            logger.warning(f"Redis unavailable for daily limit check: {e}")

    def record_usage(
        self,
        request: "FMRequest",
        input_tokens: int,
        output_tokens: int,
        cost_usd: float,
    ) -> None:
        """Record actual token usage after successful invocation."""
        try:
            pipe = self.redis.pipeline()

            # Session-level tracking
            session_key = f"session_tokens:{request.session_id}"
            session_data = self.redis.get(session_key)
            if session_data:
                usage = json.loads(session_data)
            else:
                usage = {"input": 0, "output": 0, "cost_usd": 0.0, "requests": 0}

            usage["input"] += input_tokens
            usage["output"] += output_tokens
            usage["cost_usd"] += cost_usd
            usage["requests"] += 1

            pipe.setex(session_key, 86400, json.dumps(usage))

            # Daily tracking
            day_key = time.strftime("%Y-%m-%d")
            user_key = f"daily_tokens:{request.session_id}:{day_key}"
            daily_data = self.redis.get(user_key)
            if daily_data:
                daily = json.loads(daily_data)
            else:
                daily = {"input_tokens": 0, "output_tokens": 0, "cost_usd": 0.0, "requests": 0}

            daily["input_tokens"] += input_tokens
            daily["output_tokens"] += output_tokens
            daily["cost_usd"] += cost_usd
            daily["requests"] += 1

            # Expire at end of day (max 24h)
            pipe.setex(user_key, 86400, json.dumps(daily))

            pipe.execute()

            logger.info(
                "Usage recorded",
                extra={
                    "session_id": request.session_id,
                    "input_tokens": input_tokens,
                    "output_tokens": output_tokens,
                    "cost_usd": round(cost_usd, 6),
                    "session_total_cost": round(usage["cost_usd"], 4),
                    "daily_total_cost": round(daily["cost_usd"], 4),
                },
            )
        except Exception as e:
            logger.error(f"Failed to record usage: {e}")

    def get_remaining_budget(self, session_id: str) -> dict:
        """Return remaining budget information for a session."""
        day_key = time.strftime("%Y-%m-%d")
        user_key = f"daily_tokens:{session_id}:{day_key}"
        session_key = f"session_tokens:{session_id}"

        try:
            daily_raw = self.redis.get(user_key)
            session_raw = self.redis.get(session_key)

            daily = json.loads(daily_raw) if daily_raw else {"input_tokens": 0, "output_tokens": 0, "cost_usd": 0.0}
            session = json.loads(session_raw) if session_raw else {"input": 0, "output": 0, "cost_usd": 0.0}

            return {
                "daily": {
                    "input_remaining": self.config.daily_input_token_limit - daily.get("input_tokens", 0),
                    "output_remaining": self.config.daily_output_token_limit - daily.get("output_tokens", 0),
                    "cost_remaining_usd": self.config.daily_cost_limit_usd - daily.get("cost_usd", 0.0),
                },
                "session": {
                    "input_remaining": self.config.max_session_input_tokens - session.get("input", 0),
                    "output_remaining": self.config.max_session_output_tokens - session.get("output", 0),
                },
                "per_request": {
                    "max_input": self.config.max_input_tokens,
                    "max_output": self.config.max_output_tokens,
                },
            }
        except Exception as e:
            logger.warning(f"Budget query failed: {e}")
            return {"error": str(e)}

RetryStrategyManager

"""
MangaAssist Retry Strategy Manager
Implements exponential backoff with jitter and circuit breaker pattern for Bedrock calls.
"""

import time
import random
import json
import logging
from dataclasses import dataclass
from enum import Enum
from typing import Callable, Any, Optional

from botocore.exceptions import ClientError

logger = logging.getLogger(__name__)


class CircuitState(Enum):
    """Circuit breaker states."""
    CLOSED = "closed"          # Normal operation — requests flow through
    OPEN = "open"              # Failures exceeded threshold — reject immediately
    HALF_OPEN = "half_open"    # Recovery probe — allow single test request


class CircuitBreakerOpen(Exception):
    """Raised when the circuit breaker is open and rejecting requests."""
    def __init__(self, service: str, retry_after_seconds: float):
        super().__init__(
            f"Circuit breaker OPEN for {service}. "
            f"Retry after {retry_after_seconds:.0f}s."
        )
        self.service = service
        self.retry_after_seconds = retry_after_seconds


class MaxRetriesExceeded(Exception):
    """Raised when all retry attempts have been exhausted."""
    def __init__(self, operation: str, attempts: int, last_error: Exception):
        super().__init__(
            f"Max retries ({attempts}) exceeded for {operation}: {last_error}"
        )
        self.operation = operation
        self.attempts = attempts
        self.last_error = last_error


@dataclass
class RetryConfig:
    """Configuration for retry behavior."""
    base_delay_seconds: float = 1.0
    max_delay_seconds: float = 10.0
    max_retries: int = 3
    jitter_range: float = 1.0  # Random jitter 0 to this value
    # Retryable error codes from Bedrock
    retryable_errors: tuple = (
        "ThrottlingException",
        "ServiceUnavailableException",
        "ModelTimeoutException",
        "InternalServerException",
        "ModelNotReadyException",
    )
    # Non-retryable errors (fail immediately)
    non_retryable_errors: tuple = (
        "ValidationException",
        "AccessDeniedException",
        "ResourceNotFoundException",
        "ModelStreamErrorException",
    )


@dataclass
class CircuitBreakerConfig:
    """Configuration for circuit breaker behavior."""
    failure_threshold: int = 5       # Failures before opening circuit
    failure_window_seconds: int = 60  # Time window for counting failures
    recovery_timeout_seconds: int = 30  # How long to stay open before half-open
    success_threshold: int = 2       # Successes in half-open before closing


class RetryStrategyManager:
    """
    Manages retry strategies and circuit breakers for Bedrock API calls.

    Features:
    - Exponential backoff with full jitter
    - Per-service circuit breakers (separate for Sonnet vs Haiku)
    - Retry budget tracking (prevent retry storms)
    - Detailed logging for debugging timeout patterns

    Usage:
        manager = RetryStrategyManager(redis_client)
        result = manager.execute_with_retry(
            operation=lambda: bedrock.invoke_model(...),
            operation_name="bedrock_invoke_sonnet",
        )
    """

    def __init__(
        self,
        redis_client,
        retry_config: Optional[RetryConfig] = None,
        circuit_config: Optional[CircuitBreakerConfig] = None,
    ):
        self.redis = redis_client
        self.retry_config = retry_config or RetryConfig()
        self.circuit_config = circuit_config or CircuitBreakerConfig()

    def execute_with_retry(
        self,
        operation: Callable[[], Any],
        operation_name: str,
        max_retries: Optional[int] = None,
    ) -> Any:
        """
        Execute an operation with exponential backoff retry and circuit breaker.

        Args:
            operation: Callable that performs the Bedrock API call
            operation_name: Identifier for circuit breaker tracking
            max_retries: Override default max retries

        Returns:
            The result of the operation

        Raises:
            CircuitBreakerOpen: If circuit is open
            MaxRetriesExceeded: If all retries fail
        """
        retries = max_retries if max_retries is not None else self.retry_config.max_retries

        # Check circuit breaker before attempting
        self._check_circuit_breaker(operation_name)

        last_error = None
        for attempt in range(retries + 1):
            try:
                result = operation()

                # Success — record it for circuit breaker
                self._record_success(operation_name)

                if attempt > 0:
                    logger.info(
                        f"Retry succeeded on attempt {attempt + 1} for {operation_name}"
                    )

                return result

            except ClientError as e:
                error_code = e.response["Error"]["Code"]
                last_error = e

                # Non-retryable error — fail immediately
                if error_code in self.retry_config.non_retryable_errors:
                    logger.error(
                        f"Non-retryable error for {operation_name}: {error_code}"
                    )
                    self._record_failure(operation_name, error_code)
                    raise

                # Retryable error — backoff and retry
                if error_code in self.retry_config.retryable_errors:
                    self._record_failure(operation_name, error_code)

                    if attempt < retries:
                        delay = self._calculate_backoff(attempt)
                        logger.warning(
                            f"Retryable error for {operation_name}: {error_code}. "
                            f"Attempt {attempt + 1}/{retries + 1}. "
                            f"Backing off {delay:.2f}s."
                        )
                        time.sleep(delay)

                        # Re-check circuit breaker before retry
                        self._check_circuit_breaker(operation_name)
                    else:
                        logger.error(
                            f"All retries exhausted for {operation_name}: {error_code}"
                        )
                else:
                    # Unknown error — treat as non-retryable
                    logger.error(
                        f"Unknown error for {operation_name}: {error_code}"
                    )
                    raise

            except Exception as e:
                last_error = e
                self._record_failure(operation_name, type(e).__name__)

                if attempt < retries:
                    delay = self._calculate_backoff(attempt)
                    logger.warning(
                        f"Unexpected error for {operation_name}: {e}. "
                        f"Attempt {attempt + 1}/{retries + 1}. "
                        f"Backing off {delay:.2f}s."
                    )
                    time.sleep(delay)
                    self._check_circuit_breaker(operation_name)

        raise MaxRetriesExceeded(
            operation=operation_name,
            attempts=retries + 1,
            last_error=last_error,
        )

    def _calculate_backoff(self, attempt: int) -> float:
        """
        Calculate delay using exponential backoff with full jitter.

        Formula: min(max_delay, base * 2^attempt + random(0, jitter_range))
        """
        exponential = self.retry_config.base_delay_seconds * (2 ** attempt)
        jitter = random.uniform(0, self.retry_config.jitter_range)
        delay = min(
            self.retry_config.max_delay_seconds,
            exponential + jitter,
        )
        return delay

    def _check_circuit_breaker(self, service: str) -> None:
        """Check if the circuit breaker allows this request."""
        try:
            state_key = f"circuit:{service}:state"
            state_raw = self.redis.get(state_key)

            if not state_raw:
                return  # No state = CLOSED (default)

            state_data = json.loads(state_raw)
            state = CircuitState(state_data["state"])

            if state == CircuitState.OPEN:
                opened_at = state_data.get("opened_at", 0)
                elapsed = time.time() - opened_at

                if elapsed >= self.circuit_config.recovery_timeout_seconds:
                    # Transition to HALF_OPEN
                    self._set_circuit_state(service, CircuitState.HALF_OPEN)
                    logger.info(f"Circuit breaker HALF_OPEN for {service}")
                else:
                    retry_after = self.circuit_config.recovery_timeout_seconds - elapsed
                    raise CircuitBreakerOpen(service, retry_after)

            elif state == CircuitState.HALF_OPEN:
                # Allow the request through as a probe
                pass

        except CircuitBreakerOpen:
            raise
        except Exception as e:
            logger.warning(f"Circuit breaker check failed: {e}")

    def _record_success(self, service: str) -> None:
        """Record a successful call for circuit breaker tracking."""
        try:
            state_key = f"circuit:{service}:state"
            state_raw = self.redis.get(state_key)

            if state_raw:
                state_data = json.loads(state_raw)
                state = CircuitState(state_data["state"])

                if state == CircuitState.HALF_OPEN:
                    successes = state_data.get("half_open_successes", 0) + 1
                    if successes >= self.circuit_config.success_threshold:
                        self._set_circuit_state(service, CircuitState.CLOSED)
                        logger.info(f"Circuit breaker CLOSED for {service}")
                    else:
                        state_data["half_open_successes"] = successes
                        self.redis.setex(
                            state_key, 300,
                            json.dumps(state_data),
                        )

            # Reset failure counter on success
            self.redis.delete(f"circuit:{service}:failures")

        except Exception as e:
            logger.warning(f"Circuit breaker success recording failed: {e}")

    def _record_failure(self, service: str, error_code: str) -> None:
        """Record a failure for circuit breaker tracking."""
        try:
            failure_key = f"circuit:{service}:failures"

            # Add timestamped failure
            now = time.time()
            pipe = self.redis.pipeline()
            pipe.zadd(failure_key, {f"{error_code}:{now}": now})
            # Remove failures outside the window
            cutoff = now - self.circuit_config.failure_window_seconds
            pipe.zremrangebyscore(failure_key, "-inf", cutoff)
            pipe.zcard(failure_key)
            pipe.expire(failure_key, self.circuit_config.failure_window_seconds * 2)
            results = pipe.execute()

            failure_count = results[2]  # zcard result

            if failure_count >= self.circuit_config.failure_threshold:
                state_key = f"circuit:{service}:state"
                current_raw = self.redis.get(state_key)
                current_state = CircuitState.CLOSED
                if current_raw:
                    current_state = CircuitState(json.loads(current_raw)["state"])

                if current_state == CircuitState.CLOSED:
                    self._set_circuit_state(service, CircuitState.OPEN)
                    logger.warning(
                        f"Circuit breaker OPENED for {service} after "
                        f"{failure_count} failures in "
                        f"{self.circuit_config.failure_window_seconds}s"
                    )
                elif current_state == CircuitState.HALF_OPEN:
                    self._set_circuit_state(service, CircuitState.OPEN)
                    logger.warning(
                        f"Circuit breaker re-OPENED for {service} "
                        f"(probe failed in half-open)"
                    )

        except Exception as e:
            logger.warning(f"Circuit breaker failure recording failed: {e}")

    def _set_circuit_state(self, service: str, state: CircuitState) -> None:
        """Set circuit breaker state in Redis."""
        state_key = f"circuit:{service}:state"
        state_data = {
            "state": state.value,
            "updated_at": time.time(),
        }
        if state == CircuitState.OPEN:
            state_data["opened_at"] = time.time()
        elif state == CircuitState.HALF_OPEN:
            state_data["half_open_successes"] = 0

        self.redis.setex(state_key, 300, json.dumps(state_data))

    def get_circuit_status(self, service: str) -> dict:
        """Return current circuit breaker status for monitoring."""
        try:
            state_key = f"circuit:{service}:state"
            failure_key = f"circuit:{service}:failures"

            state_raw = self.redis.get(state_key)
            state = "closed"
            if state_raw:
                state = json.loads(state_raw).get("state", "closed")

            now = time.time()
            cutoff = now - self.circuit_config.failure_window_seconds
            failure_count = self.redis.zcount(failure_key, cutoff, "+inf")

            return {
                "service": service,
                "state": state,
                "recent_failures": failure_count,
                "failure_threshold": self.circuit_config.failure_threshold,
                "recovery_timeout_s": self.circuit_config.recovery_timeout_seconds,
            }
        except Exception as e:
            return {"service": service, "state": "unknown", "error": str(e)}

StreamingAPIHandler

"""
MangaAssist Streaming API Handler
Relays Bedrock streaming responses through API Gateway WebSocket to clients.
"""

import json
import time
import logging
from dataclasses import dataclass, field
from typing import Optional

import boto3
from botocore.exceptions import ClientError

logger = logging.getLogger(__name__)


@dataclass
class StreamingConfig:
    """Configuration for streaming behavior."""
    max_frame_size_bytes: int = 32_768      # API Gateway 32KB frame limit
    heartbeat_interval_seconds: float = 5.0  # Keep-alive ping interval
    max_stream_duration_seconds: float = 60.0  # Hard cap on stream duration
    buffer_flush_interval_ms: float = 50.0   # Minimum ms between sends
    enable_backpressure: bool = True
    metrics_enabled: bool = True


@dataclass
class StreamSession:
    """Tracks the state of an active streaming session."""
    connection_id: str
    session_id: str
    request_id: str
    start_time: float = 0.0
    first_chunk_time: float = 0.0
    last_send_time: float = 0.0
    chunks_sent: int = 0
    bytes_sent: int = 0
    errors: list = field(default_factory=list)
    cancelled: bool = False


class StreamingAPIHandler:
    """
    Manages the relay of streaming Bedrock responses to WebSocket clients.

    Responsibilities:
    - Chunk Bedrock stream events into WebSocket-safe frames (<=32KB)
    - Send heartbeat pings during long generation pauses
    - Handle client disconnections mid-stream (cancel Bedrock call)
    - Track streaming metrics (TTFT, TPS, bytes sent)
    - Manage backpressure when client cannot keep up

    Usage:
        handler = StreamingAPIHandler(
            apigw_endpoint="https://abc123.execute-api.ap-northeast-1.amazonaws.com/prod",
            redis_client=redis,
        )
        handler.relay_stream(
            connection_id="abc123",
            session_id="session-456",
            bedrock_stream=response["body"],
        )
    """

    def __init__(
        self,
        apigw_endpoint: str,
        redis_client,
        config: Optional[StreamingConfig] = None,
    ):
        self.config = config or StreamingConfig()
        self.redis = redis_client

        # API Gateway Management API client for postToConnection
        self.apigw_mgmt = boto3.client(
            "apigatewaymanagementapi",
            endpoint_url=apigw_endpoint,
        )

    def relay_stream(
        self,
        connection_id: str,
        session_id: str,
        request_id: str,
        bedrock_stream,
        fm_request: Optional["FMRequest"] = None,
    ) -> StreamSession:
        """
        Relay a Bedrock streaming response to a WebSocket client.

        Reads events from the Bedrock stream, formats them as JSON,
        and sends them via API Gateway postToConnection API.

        Returns StreamSession with metrics.
        """
        stream_session = StreamSession(
            connection_id=connection_id,
            session_id=session_id,
            request_id=request_id,
            start_time=time.time(),
        )

        text_buffer = []
        total_text = []
        input_tokens = 0
        output_tokens = 0

        try:
            for event in bedrock_stream:
                # Check if client disconnected
                if self._is_disconnected(connection_id):
                    stream_session.cancelled = True
                    logger.info(
                        f"Stream cancelled — client disconnected: {connection_id}"
                    )
                    break

                # Check stream duration limit
                elapsed = time.time() - stream_session.start_time
                if elapsed > self.config.max_stream_duration_seconds:
                    logger.warning(
                        f"Stream duration limit reached: {elapsed:.1f}s"
                    )
                    break

                chunk = event.get("chunk")
                if not chunk:
                    continue

                chunk_data = json.loads(chunk.get("bytes", b"{}"))
                chunk_type = chunk_data.get("type")

                if chunk_type == "content_block_delta":
                    delta = chunk_data.get("delta", {})
                    text = delta.get("text", "")

                    if text:
                        now = time.time()
                        if stream_session.first_chunk_time == 0:
                            stream_session.first_chunk_time = now

                        text_buffer.append(text)
                        total_text.append(text)

                        # Flush buffer based on timing or size
                        buffer_text = "".join(text_buffer)
                        time_since_last = (now - stream_session.last_send_time) * 1000

                        should_flush = (
                            time_since_last >= self.config.buffer_flush_interval_ms
                            or len(buffer_text.encode("utf-8")) > self.config.max_frame_size_bytes // 2
                        )

                        if should_flush:
                            self._send_chunk(
                                stream_session,
                                buffer_text,
                                stream_session.chunks_sent,
                            )
                            text_buffer = []

                elif chunk_type == "message_start":
                    msg_usage = chunk_data.get("message", {}).get("usage", {})
                    input_tokens = msg_usage.get("input_tokens", 0)

                elif chunk_type == "message_delta":
                    usage = chunk_data.get("usage", {})
                    output_tokens = usage.get("output_tokens", 0)

                elif chunk_type == "message_stop":
                    # Flush remaining buffer
                    if text_buffer:
                        self._send_chunk(
                            stream_session,
                            "".join(text_buffer),
                            stream_session.chunks_sent,
                        )
                        text_buffer = []

            # Send completion message with metrics
            full_text = "".join(total_text)
            ttft = (
                (stream_session.first_chunk_time - stream_session.start_time) * 1000
                if stream_session.first_chunk_time
                else 0
            )
            total_latency = (time.time() - stream_session.start_time) * 1000
            duration_s = time.time() - (stream_session.first_chunk_time or stream_session.start_time)
            tps = output_tokens / duration_s if duration_s > 0 else 0

            # Calculate cost
            model_tier = ModelTier.SONNET  # Default; real impl gets from request
            if fm_request:
                model_tier = fm_request.config.model_tier
            pricing = MODEL_PRICING[model_tier]
            cost = (
                (input_tokens / 1_000_000) * pricing["input"]
                + (output_tokens / 1_000_000) * pricing["output"]
            )

            done_payload = {
                "type": "done",
                "requestId": request_id,
                "tokens": {
                    "input": input_tokens,
                    "output": output_tokens,
                },
                "metrics": {
                    "ttft_ms": round(ttft),
                    "total_ms": round(total_latency),
                    "tps": round(tps, 1),
                    "chunks": stream_session.chunks_sent,
                },
                "cost_usd": round(cost, 6),
            }

            self._post_to_connection(
                connection_id,
                json.dumps(done_payload, ensure_ascii=False),
            )

            logger.info(
                "Stream relay completed",
                extra={
                    "connection_id": connection_id,
                    "request_id": request_id,
                    "ttft_ms": round(ttft),
                    "total_ms": round(total_latency),
                    "tps": round(tps, 1),
                    "chunks": stream_session.chunks_sent,
                    "bytes": stream_session.bytes_sent,
                    "cancelled": stream_session.cancelled,
                },
            )

        except ClientError as e:
            if e.response["Error"]["Code"] == "GoneException":
                stream_session.cancelled = True
                logger.info(f"Client gone during stream: {connection_id}")
            else:
                stream_session.errors.append(str(e))
                logger.error(f"Stream relay error: {e}")
                self._send_error(connection_id, "Stream error occurred")

        except Exception as e:
            stream_session.errors.append(str(e))
            logger.error(f"Unexpected stream error: {e}", exc_info=True)
            self._send_error(connection_id, "Internal stream error")

        return stream_session

    def _send_chunk(
        self,
        session: StreamSession,
        text: str,
        chunk_index: int,
    ) -> None:
        """Send a text chunk to the WebSocket client."""
        payload = {
            "type": "chunk",
            "text": text,
            "index": chunk_index,
            "requestId": session.request_id,
        }

        encoded = json.dumps(payload, ensure_ascii=False)

        # Check frame size limit
        encoded_bytes = encoded.encode("utf-8")
        if len(encoded_bytes) > self.config.max_frame_size_bytes:
            # Split into multiple frames
            self._send_chunked(session, text, chunk_index)
            return

        self._post_to_connection(session.connection_id, encoded)
        session.chunks_sent += 1
        session.bytes_sent += len(encoded_bytes)
        session.last_send_time = time.time()

    def _send_chunked(
        self,
        session: StreamSession,
        text: str,
        base_index: int,
    ) -> None:
        """Split oversized text into multiple WebSocket frames."""
        max_text_bytes = self.config.max_frame_size_bytes - 200  # Reserve for JSON wrapper
        encoded_text = text.encode("utf-8")
        offset = 0
        sub_index = 0

        while offset < len(encoded_text):
            chunk_bytes = encoded_text[offset:offset + max_text_bytes]
            # Ensure we do not split a multi-byte character
            chunk_text = chunk_bytes.decode("utf-8", errors="ignore")

            payload = {
                "type": "chunk",
                "text": chunk_text,
                "index": base_index,
                "subIndex": sub_index,
                "requestId": session.request_id,
            }

            self._post_to_connection(
                session.connection_id,
                json.dumps(payload, ensure_ascii=False),
            )

            session.chunks_sent += 1
            session.bytes_sent += len(chunk_bytes)
            session.last_send_time = time.time()

            offset += max_text_bytes
            sub_index += 1

    def _post_to_connection(self, connection_id: str, data: str) -> None:
        """Send data to a WebSocket connection via API Gateway Management API."""
        try:
            self.apigw_mgmt.post_to_connection(
                ConnectionId=connection_id,
                Data=data.encode("utf-8"),
            )
        except ClientError as e:
            if e.response["Error"]["Code"] == "GoneException":
                logger.warning(f"Stale connection: {connection_id}")
                self._mark_disconnected(connection_id)
            else:
                raise

    def _send_error(self, connection_id: str, message: str) -> None:
        """Send an error message to the WebSocket client."""
        try:
            payload = json.dumps({
                "type": "error",
                "message": message,
            })
            self._post_to_connection(connection_id, payload)
        except Exception:
            pass  # Best effort

    def _is_disconnected(self, connection_id: str) -> bool:
        """Check if a connection has been marked as disconnected."""
        try:
            return self.redis.get(f"disconnected:{connection_id}") is not None
        except Exception:
            return False

    def _mark_disconnected(self, connection_id: str) -> None:
        """Mark a connection as disconnected in Redis."""
        try:
            self.redis.setex(f"disconnected:{connection_id}", 60, "1")
        except Exception:
            pass

Utility Functions

"""
Shared utility functions for the FM API Interface layer.
"""


def estimate_token_count(text: str, language: str = "ja") -> int:
    """
    Estimate token count for text.

    Japanese text: ~1 token per 1-2 characters (conservative: 0.7 chars/token)
    English text: ~1 token per 4 characters

    For production accuracy, use Anthropic's token counting API or
    the tiktoken library as a proxy.
    """
    if language == "ja":
        jp_chars = sum(1 for c in text if ord(c) > 0x3000)
        en_chars = len(text) - jp_chars
        return int(jp_chars * 1.4) + (en_chars // 4)
    return len(text) // 4


def format_cost_usd(cost: float) -> str:
    """Format a cost value as a USD string."""
    if cost < 0.01:
        return f"${cost:.6f}"
    elif cost < 1.00:
        return f"${cost:.4f}"
    else:
        return f"${cost:.2f}"

Key Takeaways

# Takeaway MangaAssist Application
1 WebSocket APIs bypass the 29-second REST timeout — API Gateway WebSocket routes provide persistent connections for streaming token-by-token from Bedrock without per-message time limits. MangaAssist uses WebSocket for all chat interactions, achieving <500ms time-to-first-token for Claude 3 responses. REST is reserved for health checks and admin APIs.
2 Token budgets must be enforced at four levels — per-request, per-session, daily per-user, and model context window. Missing any level creates a cost explosion vector at scale. At 1M messages/day, the difference between uncontrolled and budget-enforced is $38,147/day ($1.14M/month). TokenLimitEnforcer validates all four levels before every Bedrock call.
3 Exponential backoff with jitter prevents retry storms — Without jitter, synchronized retries from multiple clients create thundering herd effects that amplify Bedrock throttling. RetryStrategyManager uses full jitter (random 0 to base_delay) on top of exponential delay, capped at 10 seconds.
4 Circuit breakers protect against cascading failures — When Bedrock experiences sustained errors, continuing to retry wastes time and money. Circuit breakers fail fast and allow recovery. Separate circuit breakers for Sonnet and Haiku mean a Sonnet outage does not prevent simple Haiku queries from succeeding.
5 Idempotency keys prevent duplicate Bedrock invocations — WebSocket reconnects, mobile double-taps, and retry logic can all send the same message twice. Content-hash dedup catches these. 5-second dedup window with SHA-256 hash of session+message catches rapid duplicates without blocking legitimate repeated questions.
6 Model routing saves 70%+ on costs — Automatically routing simple queries (greetings, stock checks, yes/no) to Haiku ($0.25/$1.25 per 1M) instead of Sonnet ($3/$15) is the single largest cost optimization. FMAPIInterface auto-routes based on message complexity heuristics. 70% Haiku / 30% Sonnet blend reduces daily cost from $45K to $6.8K.
7 Streaming metrics (TTFT, TPS) are essential observability — Time-to-first-token and tokens-per-second reveal whether latency originates in Bedrock, the network, or the relay layer, enabling targeted optimization. CloudWatch custom metrics dashboard shows P50/P95/P99 TTFT with alarms when P95 exceeds 800ms. Every stream logs TTFT, TPS, and cost.