LOCAL PREVIEW View on GitHub

Streaming Response Handling, Token Management, and Retry Patterns

MangaAssist context: JP Manga store chatbot on AWS — Bedrock Claude 3 (Sonnet at $3/$15 per 1M tokens input/output, Haiku at $0.25/$1.25), OpenSearch Serverless (vector store), DynamoDB (sessions/products), ECS Fargate (orchestrator), API Gateway WebSocket, ElastiCache Redis. Target: useful answer in under 3 seconds, 1M messages/day scale.


Skill Mapping

Field Value
Certification AWS AIP-C01
Domain 2 — Implementation & Integration
Task 2.5 — Application Integration Patterns
Skill 2.5.1 — Create FM API interfaces to address the specific requirements of GenAI workloads
Deep-Dive Focus Chunked streaming via WebSocket, token counting middleware, adaptive retry timing, connection pooling, request queuing

Overview

This document is a deep-dive companion to 01-fm-api-interface-architecture.md. Where the architecture document defines the components (FMAPIInterface, TokenLimitEnforcer, RetryStrategyManager, StreamingAPIHandler), this document examines the three hardest operational problems in GenAI API interfaces:

  1. Streaming — How chunks flow from Bedrock through ECS Fargate through API Gateway WebSocket to the browser, handling backpressure, reconnection, and partial failures
  2. Token management — How middleware intercepts every request to count tokens, enforce budgets, and track costs in real time
  3. Retry patterns — How adaptive retry timing avoids thundering herds while maximizing request success rate under Bedrock throttling

Part 1: Chunked Streaming via WebSocket

The Streaming Pipeline — End to End

graph LR
    subgraph Bedrock["Amazon Bedrock"]
        MODEL[Claude 3 Sonnet/Haiku]
        STREAM_OUT[InvokeModelWithResponseStream<br/>yields content_block_delta events]
    end

    subgraph Fargate["ECS Fargate Orchestrator"]
        CONSUMER[Stream Consumer<br/>reads event stream]
        TOKENIZER[Token Counter<br/>real-time output tracking]
        CHUNKER[Chunk Aggregator<br/>buffers small deltas]
        METRICS[Latency Tracker<br/>TTFT + TPS metrics]
        BACKPRESSURE[Backpressure Monitor<br/>pause if client slow]
    end

    subgraph APIGW["API Gateway"]
        MGMT_API[Management API<br/>postToConnection]
        FRAME[Frame Encoder<br/>JSON + UTF-8 <= 32KB]
    end

    subgraph Client["Browser / Mobile"]
        WS_CLIENT[WebSocket Client<br/>onmessage handler]
        RENDERER[Progressive Renderer<br/>append text as it arrives]
        RECONNECT[Reconnect Handler<br/>exponential backoff]
    end

    MODEL --> STREAM_OUT
    STREAM_OUT --> CONSUMER
    CONSUMER --> TOKENIZER
    CONSUMER --> CHUNKER
    CONSUMER --> METRICS
    CHUNKER --> BACKPRESSURE
    BACKPRESSURE --> FRAME
    FRAME --> MGMT_API
    MGMT_API --> WS_CLIENT
    WS_CLIENT --> RENDERER
    RECONNECT -.->|on disconnect| MGMT_API

    style MODEL fill:#ff9900,color:#000
    style MGMT_API fill:#527fff,color:#fff

Why Aggregate Chunks Before Sending

Bedrock's InvokeModelWithResponseStream emits very small deltas — often a single word or even a partial word. Sending each delta as a separate WebSocket frame would:

  • Overwhelm the client with hundreds of tiny messages per second
  • Hit API Gateway's per-connection rate limit (documented as "several hundred" messages/second)
  • Increase overhead since each frame has JSON wrapper + network overhead
  • Fragment Japanese text mid-character across frame boundaries

The solution is a chunk aggregator that buffers deltas and flushes based on time intervals (50-100ms) or size thresholds (8KB), whichever comes first.

Chunk Aggregation Timing

Bedrock output rate: ~40-80 tokens/second for Sonnet
Average token size: ~4 bytes English, ~3 bytes Japanese (UTF-8)
Bytes per second: ~120-240 bytes/second text content

Aggregation strategy:
  - Flush every 50ms → ~6-12 bytes per frame (too small)
  - Flush every 100ms → ~12-24 bytes per frame (still small but acceptable)
  - Flush every 200ms → ~24-48 bytes per frame (good balance)
  - Flush when buffer > 8KB → handles burst output

MangaAssist setting: flush every 100ms OR when buffer > 4KB
  → Typical frame: 12-48 bytes of text + ~150 bytes JSON wrapper
  → ~5-10 frames per second to client
  → Well within API Gateway rate limits
  → Smooth progressive rendering in browser

StreamingMiddleware Implementation

"""
MangaAssist Streaming Middleware
Sits between the Bedrock stream consumer and the WebSocket relay layer.
Handles chunk aggregation, backpressure detection, heartbeats, and stream lifecycle.
"""

import json
import time
import asyncio
import logging
from dataclasses import dataclass, field
from typing import Optional, AsyncIterator, Callable
from enum import Enum

import boto3
from botocore.exceptions import ClientError

logger = logging.getLogger(__name__)


class StreamState(Enum):
    """States of a streaming session."""
    INITIALIZING = "initializing"
    STREAMING = "streaming"
    PAUSED = "paused"         # Backpressure — waiting for client to catch up
    COMPLETING = "completing"  # Final flush + done message
    COMPLETED = "completed"
    CANCELLED = "cancelled"   # Client disconnected
    ERROR = "error"


@dataclass
class ChunkBuffer:
    """
    Accumulates small Bedrock deltas and flushes them as aggregated chunks.
    Prevents overwhelming the WebSocket with hundreds of tiny frames.
    """
    texts: list = field(default_factory=list)
    byte_count: int = 0
    last_flush_time: float = 0.0
    flush_interval_ms: float = 100.0      # Flush every 100ms
    max_buffer_bytes: int = 4096          # Flush when buffer exceeds 4KB
    total_flushed_chunks: int = 0
    total_flushed_bytes: int = 0

    def add(self, text: str) -> None:
        """Add a text delta to the buffer."""
        self.texts.append(text)
        self.byte_count += len(text.encode("utf-8"))

    def should_flush(self) -> bool:
        """Determine if the buffer should be flushed now."""
        if not self.texts:
            return False

        # Size-based flush
        if self.byte_count >= self.max_buffer_bytes:
            return True

        # Time-based flush
        now = time.time()
        elapsed_ms = (now - self.last_flush_time) * 1000
        if elapsed_ms >= self.flush_interval_ms:
            return True

        return False

    def flush(self) -> Optional[str]:
        """Flush the buffer and return aggregated text, or None if empty."""
        if not self.texts:
            return None

        aggregated = "".join(self.texts)
        aggregated_bytes = len(aggregated.encode("utf-8"))

        self.total_flushed_chunks += 1
        self.total_flushed_bytes += aggregated_bytes

        self.texts = []
        self.byte_count = 0
        self.last_flush_time = time.time()

        return aggregated

    @property
    def is_empty(self) -> bool:
        return len(self.texts) == 0


@dataclass
class StreamMetricsCollector:
    """Collects real-time metrics during a streaming session."""
    stream_start: float = 0.0
    first_token_time: float = 0.0
    last_token_time: float = 0.0
    input_tokens: int = 0
    output_tokens: int = 0
    chunks_from_bedrock: int = 0
    chunks_to_client: int = 0
    bytes_to_client: int = 0
    backpressure_pauses: int = 0
    total_pause_duration_ms: float = 0.0
    reconnections: int = 0

    @property
    def time_to_first_token_ms(self) -> float:
        """Time from request start to first token received from Bedrock."""
        if self.first_token_time and self.stream_start:
            return (self.first_token_time - self.stream_start) * 1000
        return 0.0

    @property
    def tokens_per_second(self) -> float:
        """Output token generation rate."""
        if self.first_token_time and self.last_token_time:
            duration = self.last_token_time - self.first_token_time
            if duration > 0:
                return self.output_tokens / duration
        return 0.0

    @property
    def total_latency_ms(self) -> float:
        """Total time from start to last token."""
        if self.last_token_time and self.stream_start:
            return (self.last_token_time - self.stream_start) * 1000
        return 0.0

    @property
    def compression_ratio(self) -> float:
        """Ratio of Bedrock chunks to client frames — higher = more aggregation."""
        if self.chunks_to_client > 0:
            return self.chunks_from_bedrock / self.chunks_to_client
        return 0.0

    def to_dict(self) -> dict:
        """Export metrics as a dictionary for logging and CloudWatch."""
        return {
            "ttft_ms": round(self.time_to_first_token_ms),
            "tps": round(self.tokens_per_second, 1),
            "total_latency_ms": round(self.total_latency_ms),
            "input_tokens": self.input_tokens,
            "output_tokens": self.output_tokens,
            "chunks_from_bedrock": self.chunks_from_bedrock,
            "chunks_to_client": self.chunks_to_client,
            "bytes_to_client": self.bytes_to_client,
            "compression_ratio": round(self.compression_ratio, 1),
            "backpressure_pauses": self.backpressure_pauses,
            "pause_duration_ms": round(self.total_pause_duration_ms),
            "reconnections": self.reconnections,
        }


class StreamingMiddleware:
    """
    Middleware that sits between the Bedrock stream and the WebSocket client.

    Features:
    - Chunk aggregation with time-based and size-based flushing
    - Backpressure detection (monitors postToConnection failures)
    - Heartbeat pings during pauses in Bedrock output
    - Graceful stream cancellation on client disconnect
    - Real-time metrics collection (TTFT, TPS, compression ratio)
    - Japanese text boundary awareness (never split mid-character)

    Architecture:
        Bedrock Stream → StreamingMiddleware → API Gateway postToConnection → Client

    Usage:
        middleware = StreamingMiddleware(
            apigw_endpoint="https://...",
            redis_client=redis,
        )
        metrics = await middleware.process_stream(
            connection_id="abc",
            session_id="sess-123",
            request_id="req-456",
            bedrock_stream=response["body"],
        )
    """

    HEARTBEAT_INTERVAL_S = 5.0   # Send heartbeat if no data for 5s
    MAX_STREAM_DURATION_S = 120.0  # Hard kill after 2 minutes
    BACKPRESSURE_PAUSE_S = 0.5   # Pause duration on send failure
    MAX_BACKPRESSURE_PAUSES = 10  # Cancel stream after this many pauses
    MAX_FRAME_BYTES = 32_768     # API Gateway WebSocket frame limit

    def __init__(
        self,
        apigw_endpoint: str,
        redis_client,
        chunk_buffer_config: Optional[dict] = None,
    ):
        self.redis = redis_client
        self.apigw_mgmt = boto3.client(
            "apigatewaymanagementapi",
            endpoint_url=apigw_endpoint,
        )
        self._buffer_config = chunk_buffer_config or {}

    def process_stream(
        self,
        connection_id: str,
        session_id: str,
        request_id: str,
        bedrock_stream,
        on_complete: Optional[Callable] = None,
    ) -> StreamMetricsCollector:
        """
        Process a Bedrock stream and relay to WebSocket client.

        Args:
            connection_id: API Gateway WebSocket connection ID
            session_id: MangaAssist session ID
            request_id: Unique request identifier for idempotency
            bedrock_stream: The Bedrock response stream object
            on_complete: Optional callback when stream finishes

        Returns:
            StreamMetricsCollector with performance data
        """
        metrics = StreamMetricsCollector(stream_start=time.time())
        buffer = ChunkBuffer(**self._buffer_config)
        state = StreamState.INITIALIZING
        full_text_parts = []

        try:
            state = StreamState.STREAMING

            for event in bedrock_stream:
                # Check for client disconnect
                if self._is_client_gone(connection_id):
                    state = StreamState.CANCELLED
                    logger.info(f"Stream cancelled — client gone: {connection_id}")
                    break

                # Check stream duration limit
                elapsed = time.time() - metrics.stream_start
                if elapsed > self.MAX_STREAM_DURATION_S:
                    logger.warning(f"Stream duration limit: {elapsed:.1f}s")
                    break

                chunk = event.get("chunk")
                if not chunk:
                    continue

                chunk_data = json.loads(chunk.get("bytes", b"{}"))
                chunk_type = chunk_data.get("type")

                # ─── Handle content_block_delta (text chunks) ──────────
                if chunk_type == "content_block_delta":
                    delta = chunk_data.get("delta", {})
                    text = delta.get("text", "")

                    if text:
                        now = time.time()
                        metrics.chunks_from_bedrock += 1

                        if metrics.first_token_time == 0:
                            metrics.first_token_time = now

                        metrics.last_token_time = now
                        buffer.add(text)
                        full_text_parts.append(text)

                        # Flush if buffer is ready
                        if buffer.should_flush():
                            aggregated = buffer.flush()
                            if aggregated:
                                success = self._send_text_chunk(
                                    connection_id,
                                    request_id,
                                    aggregated,
                                    metrics.chunks_to_client,
                                )
                                if success:
                                    metrics.chunks_to_client += 1
                                    metrics.bytes_to_client += len(
                                        aggregated.encode("utf-8")
                                    )
                                else:
                                    # Backpressure detected
                                    metrics.backpressure_pauses += 1
                                    if metrics.backpressure_pauses > self.MAX_BACKPRESSURE_PAUSES:
                                        state = StreamState.CANCELLED
                                        logger.warning(
                                            f"Stream cancelled — too many "
                                            f"backpressure pauses: {connection_id}"
                                        )
                                        break
                                    pause_start = time.time()
                                    time.sleep(self.BACKPRESSURE_PAUSE_S)
                                    metrics.total_pause_duration_ms += (
                                        (time.time() - pause_start) * 1000
                                    )

                # ─── Handle message_start (input token count) ──────────
                elif chunk_type == "message_start":
                    msg = chunk_data.get("message", {})
                    usage = msg.get("usage", {})
                    metrics.input_tokens = usage.get("input_tokens", 0)

                # ─── Handle message_delta (output token count) ─────────
                elif chunk_type == "message_delta":
                    usage = chunk_data.get("usage", {})
                    metrics.output_tokens = usage.get("output_tokens", 0)

                # ─── Handle message_stop (stream complete) ─────────────
                elif chunk_type == "message_stop":
                    state = StreamState.COMPLETING

            # Final buffer flush
            if not buffer.is_empty:
                remaining = buffer.flush()
                if remaining:
                    self._send_text_chunk(
                        connection_id, request_id,
                        remaining, metrics.chunks_to_client,
                    )
                    metrics.chunks_to_client += 1
                    metrics.bytes_to_client += len(remaining.encode("utf-8"))

            # Send completion message
            if state != StreamState.CANCELLED:
                self._send_done_message(
                    connection_id, request_id, metrics,
                    "".join(full_text_parts),
                )
                state = StreamState.COMPLETED

            if on_complete:
                on_complete(metrics, "".join(full_text_parts))

        except ClientError as e:
            if e.response["Error"]["Code"] == "GoneException":
                state = StreamState.CANCELLED
                logger.info(f"Client gone during stream: {connection_id}")
            else:
                state = StreamState.ERROR
                logger.error(f"Stream error: {e}")

        except Exception as e:
            state = StreamState.ERROR
            logger.error(f"Unexpected stream error: {e}", exc_info=True)
            self._send_error_message(connection_id, "Stream processing error")

        # Log final metrics
        logger.info(
            f"Stream finished: state={state.value}",
            extra={
                "connection_id": connection_id,
                "request_id": request_id,
                "state": state.value,
                **metrics.to_dict(),
            },
        )

        return metrics

    def _send_text_chunk(
        self,
        connection_id: str,
        request_id: str,
        text: str,
        chunk_index: int,
    ) -> bool:
        """Send a text chunk to the client. Returns True on success."""
        payload = {
            "type": "chunk",
            "text": text,
            "index": chunk_index,
            "requestId": request_id,
        }

        encoded = json.dumps(payload, ensure_ascii=False).encode("utf-8")

        # Split if exceeds frame limit
        if len(encoded) > self.MAX_FRAME_BYTES:
            return self._send_split_frames(
                connection_id, request_id, text, chunk_index
            )

        try:
            self.apigw_mgmt.post_to_connection(
                ConnectionId=connection_id,
                Data=encoded,
            )
            return True
        except ClientError as e:
            if e.response["Error"]["Code"] == "GoneException":
                self._mark_client_gone(connection_id)
                return False
            logger.warning(f"Send failed: {e}")
            return False

    def _send_split_frames(
        self,
        connection_id: str,
        request_id: str,
        text: str,
        chunk_index: int,
    ) -> bool:
        """Split oversized text across multiple frames."""
        max_text_bytes = self.MAX_FRAME_BYTES - 256  # Reserve for JSON wrapper
        text_bytes = text.encode("utf-8")
        offset = 0
        sub_index = 0
        all_sent = True

        while offset < len(text_bytes):
            segment = text_bytes[offset:offset + max_text_bytes]
            # Avoid splitting multi-byte UTF-8 characters
            segment_text = segment.decode("utf-8", errors="ignore")

            payload = json.dumps({
                "type": "chunk",
                "text": segment_text,
                "index": chunk_index,
                "subIndex": sub_index,
                "requestId": request_id,
            }, ensure_ascii=False).encode("utf-8")

            try:
                self.apigw_mgmt.post_to_connection(
                    ConnectionId=connection_id,
                    Data=payload,
                )
            except ClientError:
                all_sent = False
                break

            offset += max_text_bytes
            sub_index += 1

        return all_sent

    def _send_done_message(
        self,
        connection_id: str,
        request_id: str,
        metrics: StreamMetricsCollector,
        full_text: str,
    ) -> None:
        """Send stream completion message with metrics."""
        payload = {
            "type": "done",
            "requestId": request_id,
            "tokens": {
                "input": metrics.input_tokens,
                "output": metrics.output_tokens,
            },
            "metrics": {
                "ttft_ms": round(metrics.time_to_first_token_ms),
                "total_ms": round(metrics.total_latency_ms),
                "tps": round(metrics.tokens_per_second, 1),
                "chunks": metrics.chunks_to_client,
            },
            "textLength": len(full_text),
        }

        try:
            self.apigw_mgmt.post_to_connection(
                ConnectionId=connection_id,
                Data=json.dumps(payload, ensure_ascii=False).encode("utf-8"),
            )
        except ClientError:
            logger.warning(f"Failed to send done message: {connection_id}")

    def _send_error_message(self, connection_id: str, message: str) -> None:
        """Send an error message to the client."""
        try:
            self.apigw_mgmt.post_to_connection(
                ConnectionId=connection_id,
                Data=json.dumps({
                    "type": "error",
                    "message": message,
                }).encode("utf-8"),
            )
        except Exception:
            pass

    def _is_client_gone(self, connection_id: str) -> bool:
        """Check Redis for disconnection marker."""
        try:
            return self.redis.get(f"disconnected:{connection_id}") is not None
        except Exception:
            return False

    def _mark_client_gone(self, connection_id: str) -> None:
        """Mark a connection as gone in Redis."""
        try:
            self.redis.setex(f"disconnected:{connection_id}", 120, "1")
        except Exception:
            pass

Part 2: Token Counting Middleware

The Token Counting Problem

Token counting in GenAI applications is harder than it appears:

  1. Pre-request estimation — You must know approximately how many input tokens a request will consume before sending it to Bedrock, because Bedrock charges for all input tokens regardless of output length.
  2. Japanese text complexity — Japanese characters consume 1-3 tokens each depending on the character type (hiragana, katakana, kanji, romaji), making estimation harder than English.
  3. Real-time output tracking — During streaming, you need to track output tokens as they arrive to enforce per-request and per-session budgets.
  4. Middleware position — Token counting must happen at the middleware layer, intercepting every request before it reaches Bedrock and every response chunk as it streams back.

Token Estimation Accuracy

Text Type Chars per Token Example Estimation Method
English prose ~4.0 "manga recommendation" len(text) // 4
Japanese hiragana ~1.5 "おすすめのマンガ" len(text) * 0.7
Japanese kanji ~1.0 "漫画推薦" len(text) * 1.0
Mixed JP/EN ~2.0 "One Pieceのおすすめ" Weighted average
JSON/structured ~3.5 {"title":"..."} len(text) // 3.5
URLs/paths ~2.0 https://manga.store/vol/1 len(text) // 2

For MangaAssist with its mixed JP/EN content, the conservative estimate of 1 token per 2 characters provides a safe upper bound that prevents context window overflows.

Token Budget Flow

graph TB
    subgraph PreRequest["Pre-Request Token Check"]
        ESTIMATE[Estimate Input Tokens<br/>system + history + RAG + user msg]
        CHECK_REQ[Check Per-Request Budget<br/>max 4,000 input tokens]
        CHECK_SESSION[Check Session Budget<br/>max 50,000 cumulative input]
        CHECK_DAILY[Check Daily Budget<br/>max $5.00/day per user]
        CHECK_WINDOW[Check Context Window<br/>200K - overhead - output - safety]
    end

    subgraph DuringStream["During Streaming"]
        COUNT_OUT[Count Output Tokens<br/>from message_delta events]
        ENFORCE_OUT[Enforce Output Budget<br/>cancel if exceeds max_tokens]
        TRACK_COST[Track Running Cost<br/>input + output so far]
    end

    subgraph PostRequest["Post-Request Recording"]
        RECORD_ACTUAL[Record Actual Usage<br/>exact tokens from Bedrock]
        UPDATE_SESSION[Update Session Counters<br/>Redis + DynamoDB]
        UPDATE_DAILY[Update Daily Counters<br/>Redis with TTL]
        EMIT_METRIC[Emit CloudWatch Metric<br/>token usage + cost]
    end

    ESTIMATE --> CHECK_REQ
    CHECK_REQ -->|pass| CHECK_SESSION
    CHECK_REQ -->|fail| REJECT[Reject: Token Budget Exceeded]
    CHECK_SESSION -->|pass| CHECK_DAILY
    CHECK_SESSION -->|fail| REJECT
    CHECK_DAILY -->|pass| CHECK_WINDOW
    CHECK_DAILY -->|fail| REJECT
    CHECK_WINDOW -->|pass| INVOKE[Invoke Bedrock]
    CHECK_WINDOW -->|fail| TRUNCATE[Truncate History + Retry]

    INVOKE --> COUNT_OUT
    COUNT_OUT --> ENFORCE_OUT
    ENFORCE_OUT --> TRACK_COST

    TRACK_COST --> RECORD_ACTUAL
    RECORD_ACTUAL --> UPDATE_SESSION
    RECORD_ACTUAL --> UPDATE_DAILY
    RECORD_ACTUAL --> EMIT_METRIC

    style REJECT fill:#dc3545,color:#fff
    style TRUNCATE fill:#ffc107,color:#000
    style INVOKE fill:#28a745,color:#fff

TokenCounterMiddleware Implementation

"""
MangaAssist Token Counter Middleware
Intercepts every FM request to estimate, validate, and track token usage.
"""

import json
import time
import logging
import re
from dataclasses import dataclass, field
from typing import Optional

logger = logging.getLogger(__name__)


@dataclass
class TokenEstimate:
    """Detailed token estimation breakdown for a request."""
    system_prompt_tokens: int = 0
    conversation_history_tokens: int = 0
    rag_context_tokens: int = 0
    user_message_tokens: int = 0
    total_input_tokens: int = 0
    max_output_tokens: int = 0
    total_budget: int = 0
    # Estimation metadata
    language_detected: str = "mixed"
    estimation_method: str = "character_ratio"
    confidence: float = 0.8  # 0.0 to 1.0

    def to_dict(self) -> dict:
        return {
            "system_prompt": self.system_prompt_tokens,
            "history": self.conversation_history_tokens,
            "rag_context": self.rag_context_tokens,
            "user_message": self.user_message_tokens,
            "total_input": self.total_input_tokens,
            "max_output": self.max_output_tokens,
            "total_budget": self.total_budget,
            "language": self.language_detected,
            "method": self.estimation_method,
            "confidence": self.confidence,
        }


@dataclass
class TokenUsageRecord:
    """Actual token usage after a completed request."""
    input_tokens: int
    output_tokens: int
    model_id: str
    cost_usd: float
    estimation_error_pct: float  # How far off the estimate was
    session_id: str
    request_id: str
    timestamp: float = field(default_factory=time.time)

    def to_dict(self) -> dict:
        return {
            "input_tokens": self.input_tokens,
            "output_tokens": self.output_tokens,
            "model_id": self.model_id,
            "cost_usd": round(self.cost_usd, 6),
            "estimation_error_pct": round(self.estimation_error_pct, 1),
            "session_id": self.session_id,
            "request_id": self.request_id,
            "timestamp": self.timestamp,
        }


class TokenCounterMiddleware:
    """
    Middleware that intercepts every FM request to count and validate tokens.

    This middleware wraps the FMAPIInterface and provides:
    1. Pre-request token estimation and budget validation
    2. Real-time output token tracking during streaming
    3. Post-request actual usage recording and cost calculation
    4. Automatic conversation history truncation when over budget
    5. Japanese-aware token estimation

    The middleware maintains running counters in Redis for session-level
    and daily-level budget enforcement.

    Usage:
        middleware = TokenCounterMiddleware(redis_client=redis)

        # Pre-request
        estimate = middleware.estimate_request_tokens(request)
        middleware.validate_budget(request, estimate)

        # During streaming (called per Bedrock event)
        middleware.track_output_token(request, chunk_data)

        # Post-request
        middleware.record_actual_usage(request, actual_input, actual_output)
    """

    # ─── Pricing Constants ─────────────────────────────────────────
    PRICING = {
        "anthropic.claude-3-sonnet-20240229-v1:0": {
            "input_per_1m": 3.00,
            "output_per_1m": 15.00,
        },
        "anthropic.claude-3-haiku-20240307-v1:0": {
            "input_per_1m": 0.25,
            "output_per_1m": 1.25,
        },
    }

    # ─── Budget Limits ─────────────────────────────────────────────
    PER_REQUEST_INPUT_LIMIT = 4_000
    PER_REQUEST_OUTPUT_LIMIT = 1_024
    SESSION_INPUT_LIMIT = 50_000
    SESSION_OUTPUT_LIMIT = 25_000
    DAILY_COST_LIMIT_USD = 5.00
    CONTEXT_WINDOW_SIZE = 200_000
    PROMPT_OVERHEAD = 300
    SAFETY_MARGIN = 500

    def __init__(self, redis_client):
        self.redis = redis_client
        # Track estimation accuracy over time for calibration
        self._estimation_errors: list[float] = []

    def estimate_request_tokens(
        self,
        system_prompt: str,
        conversation_history: list[dict],
        rag_context: Optional[str],
        user_message: str,
        max_output_tokens: int = 1024,
    ) -> TokenEstimate:
        """
        Estimate token counts for each component of the request.

        Uses language detection to choose the right estimation ratio:
        - Japanese text: ~1 token per 1.4 characters
        - English text: ~1 token per 4 characters
        - Mixed: weighted average based on character set ratios
        """
        # Detect language mix
        all_text = f"{system_prompt} {user_message}"
        for turn in conversation_history:
            all_text += f" {turn.get('content', '')}"
        if rag_context:
            all_text += f" {rag_context}"

        jp_ratio = self._detect_japanese_ratio(all_text)
        lang = "ja" if jp_ratio > 0.3 else "en" if jp_ratio < 0.1 else "mixed"

        # Estimate each component
        sys_tokens = self._count_tokens(system_prompt, jp_ratio)
        history_tokens = sum(
            self._count_tokens(turn.get("content", ""), jp_ratio)
            for turn in conversation_history
        )
        rag_tokens = self._count_tokens(rag_context or "", jp_ratio)
        user_tokens = self._count_tokens(user_message, jp_ratio)

        total_input = sys_tokens + history_tokens + rag_tokens + user_tokens

        estimate = TokenEstimate(
            system_prompt_tokens=sys_tokens,
            conversation_history_tokens=history_tokens,
            rag_context_tokens=rag_tokens,
            user_message_tokens=user_tokens,
            total_input_tokens=total_input,
            max_output_tokens=max_output_tokens,
            total_budget=total_input + max_output_tokens,
            language_detected=lang,
            estimation_method="character_ratio_weighted",
            confidence=0.85 if lang == "en" else 0.75,  # JP is less precise
        )

        logger.debug(
            "Token estimate",
            extra=estimate.to_dict(),
        )

        return estimate

    def validate_budget(
        self,
        session_id: str,
        estimate: TokenEstimate,
        model_id: str,
    ) -> dict:
        """
        Validate the estimated token usage against all budget levels.

        Returns a dict with validation result and any adjustments needed.
        Raises TokenBudgetExceeded for hard limit violations.
        """
        result = {
            "valid": True,
            "warnings": [],
            "adjustments": [],
        }

        # Per-request input limit
        if estimate.total_input_tokens > self.PER_REQUEST_INPUT_LIMIT:
            # Calculate how much history to trim
            excess = estimate.total_input_tokens - self.PER_REQUEST_INPUT_LIMIT
            result["warnings"].append(
                f"Input tokens ({estimate.total_input_tokens}) exceed "
                f"per-request limit ({self.PER_REQUEST_INPUT_LIMIT}). "
                f"Need to trim {excess} tokens from history."
            )
            result["adjustments"].append({
                "action": "truncate_history",
                "tokens_to_trim": excess,
            })

        # Context window check
        available = (
            self.CONTEXT_WINDOW_SIZE
            - self.PROMPT_OVERHEAD
            - self.SAFETY_MARGIN
            - estimate.max_output_tokens
        )
        if estimate.total_input_tokens > available:
            result["valid"] = False
            result["warnings"].append(
                f"Input tokens ({estimate.total_input_tokens}) exceed "
                f"context window capacity ({available})."
            )

        # Session budget check
        session_usage = self._get_session_usage(session_id)
        projected_session_input = (
            session_usage.get("input", 0) + estimate.total_input_tokens
        )
        if projected_session_input > self.SESSION_INPUT_LIMIT:
            result["warnings"].append(
                f"Session would reach {projected_session_input} input tokens "
                f"(limit: {self.SESSION_INPUT_LIMIT})."
            )

        # Daily cost check
        daily_cost = self._get_daily_cost(session_id)
        estimated_cost = self._estimate_cost(
            model_id,
            estimate.total_input_tokens,
            estimate.max_output_tokens,
        )
        if daily_cost + estimated_cost > self.DAILY_COST_LIMIT_USD:
            result["valid"] = False
            result["warnings"].append(
                f"Daily cost would reach ${daily_cost + estimated_cost:.4f} "
                f"(limit: ${self.DAILY_COST_LIMIT_USD:.2f})."
            )

        return result

    def truncate_history_to_budget(
        self,
        conversation_history: list[dict],
        target_tokens: int,
    ) -> list[dict]:
        """
        Trim conversation history to fit within a token budget.

        Strategy: Remove oldest turns first, always keeping the most recent
        user message and assistant response pair.
        """
        if not conversation_history:
            return []

        current_tokens = sum(
            self._count_tokens(turn.get("content", ""), 0.3)
            for turn in conversation_history
        )

        if current_tokens <= target_tokens:
            return conversation_history

        trimmed = list(conversation_history)
        while current_tokens > target_tokens and len(trimmed) > 2:
            # Remove oldest turn
            removed = trimmed.pop(0)
            removed_tokens = self._count_tokens(
                removed.get("content", ""), 0.3
            )
            current_tokens -= removed_tokens

        logger.info(
            f"Trimmed history from {len(conversation_history)} to "
            f"{len(trimmed)} turns ({current_tokens} tokens)"
        )

        return trimmed

    def record_actual_usage(
        self,
        session_id: str,
        request_id: str,
        model_id: str,
        actual_input_tokens: int,
        actual_output_tokens: int,
        estimated_input_tokens: int,
    ) -> TokenUsageRecord:
        """
        Record actual token usage after a completed request.
        Updates session and daily counters. Calculates cost and estimation error.
        """
        # Calculate cost
        pricing = self.PRICING.get(model_id, self.PRICING[
            "anthropic.claude-3-sonnet-20240229-v1:0"
        ])
        cost = (
            (actual_input_tokens / 1_000_000) * pricing["input_per_1m"]
            + (actual_output_tokens / 1_000_000) * pricing["output_per_1m"]
        )

        # Calculate estimation error
        if estimated_input_tokens > 0:
            error_pct = (
                (actual_input_tokens - estimated_input_tokens)
                / estimated_input_tokens
            ) * 100
        else:
            error_pct = 0.0

        self._estimation_errors.append(error_pct)
        if len(self._estimation_errors) > 1000:
            self._estimation_errors = self._estimation_errors[-500:]

        record = TokenUsageRecord(
            input_tokens=actual_input_tokens,
            output_tokens=actual_output_tokens,
            model_id=model_id,
            cost_usd=cost,
            estimation_error_pct=error_pct,
            session_id=session_id,
            request_id=request_id,
        )

        # Update counters
        self._update_session_usage(
            session_id, actual_input_tokens, actual_output_tokens, cost
        )
        self._update_daily_usage(
            session_id, actual_input_tokens, actual_output_tokens, cost
        )

        logger.info(
            "Token usage recorded",
            extra=record.to_dict(),
        )

        return record

    def get_estimation_accuracy(self) -> dict:
        """Return statistics on token estimation accuracy."""
        if not self._estimation_errors:
            return {"samples": 0, "mean_error_pct": 0, "status": "no_data"}

        errors = self._estimation_errors
        mean_error = sum(errors) / len(errors)
        abs_errors = [abs(e) for e in errors]
        mean_abs_error = sum(abs_errors) / len(abs_errors)

        return {
            "samples": len(errors),
            "mean_error_pct": round(mean_error, 1),
            "mean_abs_error_pct": round(mean_abs_error, 1),
            "max_overestimate_pct": round(min(errors), 1),
            "max_underestimate_pct": round(max(errors), 1),
            "status": "calibrated" if mean_abs_error < 20 else "needs_calibration",
        }

    # ─── Private Methods ───────────────────────────────────────────

    def _count_tokens(self, text: str, jp_ratio: float) -> int:
        """
        Estimate token count for text with a given Japanese character ratio.
        """
        if not text:
            return 0

        jp_chars = sum(1 for c in text if ord(c) > 0x3000)
        en_chars = len(text) - jp_chars

        # JP chars: ~1 token per 1.4 chars (conservative)
        # EN chars: ~1 token per 4 chars
        jp_tokens = int(jp_chars * 0.71)  # 1/1.4
        en_tokens = en_chars // 4

        return jp_tokens + en_tokens + 1  # +1 to avoid zero

    def _detect_japanese_ratio(self, text: str) -> float:
        """Detect the ratio of Japanese characters in text."""
        if not text:
            return 0.0
        jp_chars = sum(1 for c in text if ord(c) > 0x3000)
        return jp_chars / len(text) if len(text) > 0 else 0.0

    def _estimate_cost(
        self, model_id: str, input_tokens: int, output_tokens: int
    ) -> float:
        """Estimate cost for a request."""
        pricing = self.PRICING.get(model_id, self.PRICING[
            "anthropic.claude-3-sonnet-20240229-v1:0"
        ])
        return (
            (input_tokens / 1_000_000) * pricing["input_per_1m"]
            + (output_tokens / 1_000_000) * pricing["output_per_1m"]
        )

    def _get_session_usage(self, session_id: str) -> dict:
        """Get current session token usage from Redis."""
        try:
            raw = self.redis.get(f"session_tokens:{session_id}")
            if raw:
                return json.loads(raw)
        except Exception as e:
            logger.warning(f"Redis session read failed: {e}")
        return {"input": 0, "output": 0, "cost_usd": 0.0, "requests": 0}

    def _get_daily_cost(self, session_id: str) -> float:
        """Get today's cost from Redis."""
        try:
            day = time.strftime("%Y-%m-%d")
            raw = self.redis.get(f"daily_tokens:{session_id}:{day}")
            if raw:
                return json.loads(raw).get("cost_usd", 0.0)
        except Exception as e:
            logger.warning(f"Redis daily read failed: {e}")
        return 0.0

    def _update_session_usage(
        self, session_id: str, input_t: int, output_t: int, cost: float
    ) -> None:
        """Update session counters in Redis."""
        try:
            key = f"session_tokens:{session_id}"
            raw = self.redis.get(key)
            data = json.loads(raw) if raw else {
                "input": 0, "output": 0, "cost_usd": 0.0, "requests": 0
            }
            data["input"] += input_t
            data["output"] += output_t
            data["cost_usd"] += cost
            data["requests"] += 1
            self.redis.setex(key, 86400, json.dumps(data))
        except Exception as e:
            logger.error(f"Session usage update failed: {e}")

    def _update_daily_usage(
        self, session_id: str, input_t: int, output_t: int, cost: float
    ) -> None:
        """Update daily counters in Redis."""
        try:
            day = time.strftime("%Y-%m-%d")
            key = f"daily_tokens:{session_id}:{day}"
            raw = self.redis.get(key)
            data = json.loads(raw) if raw else {
                "input_tokens": 0, "output_tokens": 0, "cost_usd": 0.0, "requests": 0
            }
            data["input_tokens"] += input_t
            data["output_tokens"] += output_t
            data["cost_usd"] += cost
            data["requests"] += 1
            self.redis.setex(key, 86400, json.dumps(data))
        except Exception as e:
            logger.error(f"Daily usage update failed: {e}")

Part 3: Adaptive Retry Timing and Connection Pooling

Why Adaptive Retries for GenAI APIs

Standard fixed-interval retries are dangerous for GenAI APIs because:

  1. Bedrock throttling is load-dependent — A 429 at 9:00 AM might need 2s backoff; the same 429 at peak (lunch hour in Japan) might need 15s.
  2. Model timeouts vary by input size — A 500-token request times out differently than a 4,000-token request. Retry timing should account for expected completion time.
  3. Cost of retries is non-zero — Each retry that gets far enough to invoke the model incurs token charges. Retrying a $0.01 request 3 times costs $0.04 total.
  4. Thundering herd on recovery — When Bedrock recovers from a brief outage, all clients retrying simultaneously can re-trigger throttling.

Adaptive retry timing adjusts backoff parameters based on observed conditions: error type, current throughput, time of day, and recent success rates.

Adaptive Retry State Machine

stateDiagram-v2
    [*] --> Normal: Start

    Normal --> Normal: Success (track latency)
    Normal --> SlightPressure: Single throttle or timeout

    SlightPressure --> Normal: 3 consecutive successes
    SlightPressure --> SlightPressure: Sporadic errors (< 20% rate)
    SlightPressure --> HighPressure: Error rate > 20%

    HighPressure --> SlightPressure: Error rate drops below 20%
    HighPressure --> HighPressure: Sustained errors
    HighPressure --> CircuitOpen: Error rate > 50% for 60s

    CircuitOpen --> Recovery: 30s cooldown
    Recovery --> Normal: 2 consecutive successes
    Recovery --> CircuitOpen: Probe fails

    note right of Normal
        Base delay: 1s
        Jitter: 0-1s
        Max retries: 3
    end note

    note right of SlightPressure
        Base delay: 2s
        Jitter: 0-2s
        Max retries: 2
    end note

    note right of HighPressure
        Base delay: 5s
        Jitter: 0-5s
        Max retries: 1
    end note

    note right of CircuitOpen
        Reject all requests
        Return fallback response
    end note

AdaptiveRetryClient Implementation

"""
MangaAssist Adaptive Retry Client
Dynamically adjusts retry behavior based on observed Bedrock API conditions.
"""

import json
import time
import random
import logging
from dataclasses import dataclass, field
from enum import Enum
from typing import Callable, Any, Optional
from collections import deque

from botocore.exceptions import ClientError

logger = logging.getLogger(__name__)


class PressureLevel(Enum):
    """Current pressure level based on observed API behavior."""
    NORMAL = "normal"
    SLIGHT = "slight_pressure"
    HIGH = "high_pressure"
    CIRCUIT_OPEN = "circuit_open"
    RECOVERY = "recovery"


@dataclass
class PressureConfig:
    """Retry parameters for each pressure level."""
    base_delay: float
    jitter_range: float
    max_retries: int
    max_delay: float


PRESSURE_CONFIGS = {
    PressureLevel.NORMAL: PressureConfig(
        base_delay=1.0, jitter_range=1.0, max_retries=3, max_delay=10.0
    ),
    PressureLevel.SLIGHT: PressureConfig(
        base_delay=2.0, jitter_range=2.0, max_retries=2, max_delay=15.0
    ),
    PressureLevel.HIGH: PressureConfig(
        base_delay=5.0, jitter_range=5.0, max_retries=1, max_delay=30.0
    ),
    PressureLevel.RECOVERY: PressureConfig(
        base_delay=3.0, jitter_range=2.0, max_retries=1, max_delay=20.0
    ),
}


@dataclass
class RequestOutcome:
    """Records the outcome of a single API call."""
    timestamp: float
    success: bool
    latency_ms: float
    error_code: Optional[str] = None
    model_id: Optional[str] = None


class AdaptiveRetryClient:
    """
    Retry client that adapts its behavior based on observed API conditions.

    Unlike static retry configurations, this client:
    - Tracks recent success/failure rates in a sliding window
    - Adjusts backoff delays based on current pressure level
    - Maintains separate pressure tracking per model (Sonnet vs Haiku)
    - Reduces retry count under high pressure to avoid worsening conditions
    - Uses Redis to share pressure state across ECS Fargate tasks

    Pressure Level Transitions:
    - NORMAL → SLIGHT: First throttle/timeout error
    - SLIGHT → NORMAL: 3 consecutive successes
    - SLIGHT → HIGH: Error rate > 20% in observation window
    - HIGH → SLIGHT: Error rate drops below 20%
    - HIGH → CIRCUIT_OPEN: Error rate > 50% for 60s
    - CIRCUIT_OPEN → RECOVERY: After 30s cooldown
    - RECOVERY → NORMAL: 2 consecutive successes
    - RECOVERY → CIRCUIT_OPEN: Probe fails
    """

    OBSERVATION_WINDOW_S = 60    # Look at last 60s of outcomes
    MAX_OUTCOMES = 200           # Keep last 200 outcomes max
    CIRCUIT_OPEN_DURATION_S = 30  # Time in CIRCUIT_OPEN before recovery
    SLIGHT_THRESHOLD = 0.05      # > 5% error rate = slight pressure
    HIGH_THRESHOLD = 0.20        # > 20% error rate = high pressure
    CIRCUIT_THRESHOLD = 0.50     # > 50% error rate = circuit open
    RECOVERY_SUCCESSES = 2       # Successes needed to exit RECOVERY
    NORMAL_SUCCESSES = 3         # Successes to go from SLIGHT → NORMAL

    def __init__(self, redis_client):
        self.redis = redis_client
        # Per-model outcome tracking
        self._outcomes: dict[str, deque] = {}
        self._pressure_levels: dict[str, PressureLevel] = {}
        self._circuit_open_time: dict[str, float] = {}
        self._consecutive_successes: dict[str, int] = {}

    def execute(
        self,
        operation: Callable[[], Any],
        model_id: str,
        operation_name: str = "bedrock_call",
    ) -> Any:
        """
        Execute an operation with adaptive retry behavior.

        Checks current pressure level, selects retry config,
        then executes with appropriate backoff.
        """
        pressure = self._get_pressure_level(model_id)

        # Circuit breaker check
        if pressure == PressureLevel.CIRCUIT_OPEN:
            elapsed = time.time() - self._circuit_open_time.get(model_id, 0)
            if elapsed < self.CIRCUIT_OPEN_DURATION_S:
                raise CircuitOpenError(
                    f"Circuit open for {model_id}. "
                    f"Retry in {self.CIRCUIT_OPEN_DURATION_S - elapsed:.0f}s."
                )
            # Transition to RECOVERY
            self._set_pressure(model_id, PressureLevel.RECOVERY)
            pressure = PressureLevel.RECOVERY

        config = PRESSURE_CONFIGS.get(pressure, PRESSURE_CONFIGS[PressureLevel.NORMAL])

        last_error = None
        for attempt in range(config.max_retries + 1):
            start = time.time()
            try:
                result = operation()
                latency = (time.time() - start) * 1000

                # Record success
                self._record_outcome(
                    model_id, True, latency
                )

                if attempt > 0:
                    logger.info(
                        f"Adaptive retry succeeded: {operation_name} "
                        f"attempt {attempt + 1}, pressure={pressure.value}"
                    )

                return result

            except ClientError as e:
                error_code = e.response["Error"]["Code"]
                latency = (time.time() - start) * 1000
                last_error = e

                # Non-retryable errors
                non_retryable = {
                    "ValidationException",
                    "AccessDeniedException",
                    "ResourceNotFoundException",
                }
                if error_code in non_retryable:
                    self._record_outcome(model_id, False, latency, error_code)
                    raise

                # Record failure
                self._record_outcome(model_id, False, latency, error_code)

                if attempt < config.max_retries:
                    delay = self._calculate_adaptive_backoff(
                        attempt, config, pressure
                    )
                    logger.warning(
                        f"Adaptive retry: {operation_name} {error_code}, "
                        f"pressure={pressure.value}, "
                        f"attempt {attempt + 1}/{config.max_retries + 1}, "
                        f"backoff={delay:.2f}s"
                    )
                    time.sleep(delay)

                    # Re-check pressure after backoff
                    pressure = self._get_pressure_level(model_id)
                    config = PRESSURE_CONFIGS.get(
                        pressure, PRESSURE_CONFIGS[PressureLevel.NORMAL]
                    )

            except Exception as e:
                latency = (time.time() - start) * 1000
                last_error = e
                self._record_outcome(
                    model_id, False, latency, type(e).__name__
                )

                if attempt < config.max_retries:
                    delay = self._calculate_adaptive_backoff(
                        attempt, config, pressure
                    )
                    time.sleep(delay)

        raise MaxRetriesExhausted(
            f"All retries exhausted for {operation_name} under "
            f"pressure={pressure.value}: {last_error}"
        )

    def _calculate_adaptive_backoff(
        self,
        attempt: int,
        config: PressureConfig,
        pressure: PressureLevel,
    ) -> float:
        """
        Calculate backoff with adaptive timing.

        Under higher pressure, the base delay increases and jitter range
        widens to spread out retry attempts across the fleet.
        """
        # Exponential component
        exponential = config.base_delay * (2 ** attempt)

        # Full jitter
        jitter = random.uniform(0, config.jitter_range)

        # Pressure multiplier — under HIGH pressure, add extra delay
        pressure_multiplier = {
            PressureLevel.NORMAL: 1.0,
            PressureLevel.SLIGHT: 1.2,
            PressureLevel.HIGH: 2.0,
            PressureLevel.RECOVERY: 1.5,
        }.get(pressure, 1.0)

        delay = min(
            config.max_delay,
            (exponential + jitter) * pressure_multiplier,
        )

        return delay

    def _record_outcome(
        self,
        model_id: str,
        success: bool,
        latency_ms: float,
        error_code: Optional[str] = None,
    ) -> None:
        """Record an outcome and update pressure level."""
        if model_id not in self._outcomes:
            self._outcomes[model_id] = deque(maxlen=self.MAX_OUTCOMES)
            self._consecutive_successes[model_id] = 0

        outcome = RequestOutcome(
            timestamp=time.time(),
            success=success,
            latency_ms=latency_ms,
            error_code=error_code,
            model_id=model_id,
        )
        self._outcomes[model_id].append(outcome)

        # Track consecutive successes
        if success:
            self._consecutive_successes[model_id] = (
                self._consecutive_successes.get(model_id, 0) + 1
            )
        else:
            self._consecutive_successes[model_id] = 0

        # Update pressure level
        self._update_pressure(model_id)

        # Sync to Redis for cross-task visibility
        self._sync_pressure_to_redis(model_id)

    def _update_pressure(self, model_id: str) -> None:
        """Update pressure level based on recent outcomes."""
        current = self._pressure_levels.get(model_id, PressureLevel.NORMAL)
        error_rate = self._get_recent_error_rate(model_id)
        consecutive = self._consecutive_successes.get(model_id, 0)

        new_level = current

        if current == PressureLevel.NORMAL:
            if error_rate > self.SLIGHT_THRESHOLD:
                new_level = PressureLevel.SLIGHT

        elif current == PressureLevel.SLIGHT:
            if consecutive >= self.NORMAL_SUCCESSES:
                new_level = PressureLevel.NORMAL
            elif error_rate > self.HIGH_THRESHOLD:
                new_level = PressureLevel.HIGH

        elif current == PressureLevel.HIGH:
            if error_rate < self.HIGH_THRESHOLD:
                new_level = PressureLevel.SLIGHT
            elif error_rate > self.CIRCUIT_THRESHOLD:
                new_level = PressureLevel.CIRCUIT_OPEN
                self._circuit_open_time[model_id] = time.time()

        elif current == PressureLevel.RECOVERY:
            if consecutive >= self.RECOVERY_SUCCESSES:
                new_level = PressureLevel.NORMAL
            elif error_rate > self.SLIGHT_THRESHOLD:
                new_level = PressureLevel.CIRCUIT_OPEN
                self._circuit_open_time[model_id] = time.time()

        if new_level != current:
            logger.info(
                f"Pressure transition for {model_id}: "
                f"{current.value} -> {new_level.value} "
                f"(error_rate={error_rate:.1%}, consecutive_ok={consecutive})"
            )
            self._set_pressure(model_id, new_level)

    def _get_recent_error_rate(self, model_id: str) -> float:
        """Calculate error rate in the observation window."""
        outcomes = self._outcomes.get(model_id, deque())
        if not outcomes:
            return 0.0

        cutoff = time.time() - self.OBSERVATION_WINDOW_S
        recent = [o for o in outcomes if o.timestamp >= cutoff]

        if not recent:
            return 0.0

        failures = sum(1 for o in recent if not o.success)
        return failures / len(recent)

    def _get_pressure_level(self, model_id: str) -> PressureLevel:
        """Get current pressure level, checking Redis for cross-task state."""
        # Check local state first
        local = self._pressure_levels.get(model_id, PressureLevel.NORMAL)

        # Check Redis for fleet-wide state
        try:
            redis_raw = self.redis.get(f"pressure:{model_id}")
            if redis_raw:
                redis_level = PressureLevel(redis_raw.decode() if isinstance(redis_raw, bytes) else redis_raw)
                # Use the more conservative (higher pressure) of the two
                levels = [PressureLevel.NORMAL, PressureLevel.SLIGHT,
                          PressureLevel.HIGH, PressureLevel.CIRCUIT_OPEN]
                local_idx = levels.index(local) if local in levels else 0
                redis_idx = levels.index(redis_level) if redis_level in levels else 0
                return levels[max(local_idx, redis_idx)]
        except Exception:
            pass

        return local

    def _set_pressure(self, model_id: str, level: PressureLevel) -> None:
        """Set pressure level locally."""
        self._pressure_levels[model_id] = level

    def _sync_pressure_to_redis(self, model_id: str) -> None:
        """Sync local pressure state to Redis for other Fargate tasks."""
        try:
            level = self._pressure_levels.get(model_id, PressureLevel.NORMAL)
            self.redis.setex(
                f"pressure:{model_id}",
                self.OBSERVATION_WINDOW_S,
                level.value,
            )
        except Exception:
            pass

    def get_status(self, model_id: str) -> dict:
        """Return current adaptive retry status for monitoring."""
        pressure = self._get_pressure_level(model_id)
        error_rate = self._get_recent_error_rate(model_id)
        config = PRESSURE_CONFIGS.get(pressure, PRESSURE_CONFIGS[PressureLevel.NORMAL])
        outcomes = self._outcomes.get(model_id, deque())

        recent_latencies = [
            o.latency_ms for o in outcomes
            if o.success and o.timestamp >= time.time() - 60
        ]
        avg_latency = (
            sum(recent_latencies) / len(recent_latencies)
            if recent_latencies else 0
        )

        return {
            "model_id": model_id,
            "pressure_level": pressure.value,
            "error_rate": round(error_rate, 3),
            "consecutive_successes": self._consecutive_successes.get(model_id, 0),
            "base_delay": config.base_delay,
            "max_retries": config.max_retries,
            "avg_latency_ms": round(avg_latency),
            "recent_outcomes": len(outcomes),
        }


class CircuitOpenError(Exception):
    """Raised when circuit breaker is open."""
    pass


class MaxRetriesExhausted(Exception):
    """Raised when all adaptive retries are exhausted."""
    pass

Connection Pooling for FM APIs

Why Connection Pooling Matters

Each Bedrock API call from an ECS Fargate task requires a TLS handshake (~100-200ms in ap-northeast-1). At 12 requests/second average (1M messages/day), establishing a new connection per request wastes significant time and CPU.

Connection Pool Architecture

graph TB
    subgraph Fargate["ECS Fargate Task (1 of N)"]
        POOL[Connection Pool<br/>max_pool_connections=25<br/>max_retry=3]
        TASK1[Request Handler 1]
        TASK2[Request Handler 2]
        TASK3[Request Handler 3]
        TASKN[Request Handler N]
    end

    subgraph Bedrock["Amazon Bedrock Endpoint"]
        EP[bedrock-runtime.ap-northeast-1.amazonaws.com<br/>TLS 1.3 / HTTP/1.1]
    end

    TASK1 --> POOL
    TASK2 --> POOL
    TASK3 --> POOL
    TASKN --> POOL
    POOL -->|Reuse connections| EP

    style POOL fill:#28a745,color:#fff

Connection Pool Configuration

"""
MangaAssist Bedrock Connection Pool Configuration
Optimizes connection reuse for high-throughput FM API calls.
"""

from botocore.config import Config
import boto3
import urllib3


def create_bedrock_client_with_pool(
    region: str = "ap-northeast-1",
    max_pool_connections: int = 25,
    max_retry_attempts: int = 3,
    connect_timeout: int = 5,
    read_timeout: int = 60,
) -> "botocore.client.BedrockRuntime":
    """
    Create a Bedrock Runtime client with optimized connection pooling.

    Connection pool sizing rationale for MangaAssist:
    - Average load: ~12 req/s across all Fargate tasks
    - With 4 tasks: ~3 req/s per task
    - Average request duration: ~3s (streaming)
    - Concurrent connections per task: ~9 (3 * 3s)
    - Pool size 25 gives headroom for bursts (3x peak)

    Pool connections are reused across requests, avoiding TLS
    handshake overhead (~150ms per new connection).
    """
    bedrock_config = Config(
        region_name=region,
        retries={
            "max_attempts": max_retry_attempts,
            "mode": "adaptive",
        },
        max_pool_connections=max_pool_connections,
        connect_timeout=connect_timeout,
        read_timeout=read_timeout,
        # Enable TCP keep-alive to prevent idle connection drops
        tcp_keepalive=True,
    )

    return boto3.client("bedrock-runtime", config=bedrock_config)


# Singleton pool for the Fargate task
_bedrock_client = None


def get_bedrock_client() -> "botocore.client.BedrockRuntime":
    """
    Get or create the shared Bedrock client with connection pool.
    Thread-safe — boto3 clients are safe to share across threads.
    """
    global _bedrock_client
    if _bedrock_client is None:
        _bedrock_client = create_bedrock_client_with_pool()
    return _bedrock_client

Request Queuing Under Load

Queue Architecture

When incoming request rate exceeds Bedrock's throughput (throttling scenario), a request queue prevents dropped messages while maintaining fairness.

graph TB
    subgraph Incoming["Incoming Requests"]
        R1[Request 1<br/>Priority: Normal]
        R2[Request 2<br/>Priority: High]
        R3[Request 3<br/>Priority: Normal]
    end

    subgraph Queue["Request Queue (Redis Sorted Set)"]
        PQ[Priority Queue<br/>score = priority * 1000 + timestamp]
        CAPACITY[Capacity: 1000 requests<br/>TTL: 30s per entry]
    end

    subgraph Workers["Bedrock Workers"]
        W1[Worker 1<br/>Sonnet]
        W2[Worker 2<br/>Sonnet]
        W3[Worker 3<br/>Haiku]
        W4[Worker 4<br/>Haiku]
    end

    R1 --> PQ
    R2 --> PQ
    R3 --> PQ
    PQ --> W1
    PQ --> W2
    PQ --> W3
    PQ --> W4

Request Queue Implementation

"""
MangaAssist Request Queue
Buffers FM requests during load spikes to prevent dropped messages.
"""

import json
import time
import logging
from dataclasses import dataclass
from typing import Optional
from enum import IntEnum

logger = logging.getLogger(__name__)


class RequestPriority(IntEnum):
    """Request priority levels. Lower number = higher priority."""
    CRITICAL = 1    # System health checks
    HIGH = 2        # Paying/premium users
    NORMAL = 3      # Standard requests
    LOW = 4         # Batch/background tasks
    BULK = 5        # Analytics, non-interactive


@dataclass
class QueuedRequest:
    """A request waiting in the queue."""
    request_id: str
    session_id: str
    priority: RequestPriority
    payload: dict
    enqueued_at: float
    ttl_seconds: float = 30.0  # Drop if not processed within 30s
    model_preference: str = "auto"  # "sonnet", "haiku", or "auto"

    @property
    def is_expired(self) -> bool:
        return (time.time() - self.enqueued_at) > self.ttl_seconds

    @property
    def wait_time_ms(self) -> float:
        return (time.time() - self.enqueued_at) * 1000

    def to_json(self) -> str:
        return json.dumps({
            "request_id": self.request_id,
            "session_id": self.session_id,
            "priority": self.priority.value,
            "payload": self.payload,
            "enqueued_at": self.enqueued_at,
            "ttl_seconds": self.ttl_seconds,
            "model_preference": self.model_preference,
        }, ensure_ascii=False)

    @classmethod
    def from_json(cls, data: str) -> "QueuedRequest":
        d = json.loads(data)
        return cls(
            request_id=d["request_id"],
            session_id=d["session_id"],
            priority=RequestPriority(d["priority"]),
            payload=d["payload"],
            enqueued_at=d["enqueued_at"],
            ttl_seconds=d.get("ttl_seconds", 30.0),
            model_preference=d.get("model_preference", "auto"),
        )


class RequestQueue:
    """
    Redis-backed priority queue for FM API requests.

    When Bedrock is throttling or all workers are busy, requests are
    queued rather than rejected. The queue ensures:
    - Priority ordering (premium users first)
    - TTL-based expiry (stale requests are dropped)
    - Fair scheduling across sessions
    - Backpressure signaling to API Gateway

    The queue uses a Redis sorted set where:
    - score = priority * 10_000_000_000 + timestamp
    - This ensures priority ordering with FIFO within the same priority

    Capacity is capped at 1000 requests. Beyond that, new requests
    receive a 429 response with a Retry-After header.
    """

    QUEUE_KEY = "fm_request_queue"
    MAX_QUEUE_SIZE = 1000
    CLEANUP_INTERVAL_S = 5.0  # Clean expired entries every 5s

    def __init__(self, redis_client):
        self.redis = redis_client
        self._last_cleanup = 0.0

    def enqueue(self, request: QueuedRequest) -> bool:
        """
        Add a request to the queue.
        Returns True if enqueued, False if queue is full.
        """
        # Check queue capacity
        current_size = self.redis.zcard(self.QUEUE_KEY)
        if current_size >= self.MAX_QUEUE_SIZE:
            logger.warning(
                f"Queue full ({current_size}), rejecting request "
                f"{request.request_id}"
            )
            return False

        # Score: priority * large_constant + timestamp for FIFO within priority
        score = request.priority.value * 10_000_000_000 + request.enqueued_at

        self.redis.zadd(
            self.QUEUE_KEY,
            {request.to_json(): score},
        )

        logger.info(
            f"Enqueued request {request.request_id} "
            f"(priority={request.priority.name}, "
            f"queue_size={current_size + 1})"
        )

        # Periodic cleanup
        self._maybe_cleanup()

        return True

    def dequeue(self) -> Optional[QueuedRequest]:
        """
        Pop the highest-priority, oldest request from the queue.
        Skips expired requests.
        """
        while True:
            # Pop lowest score (highest priority, oldest)
            results = self.redis.zpopmin(self.QUEUE_KEY, count=1)
            if not results:
                return None

            data, score = results[0]
            if isinstance(data, bytes):
                data = data.decode("utf-8")

            request = QueuedRequest.from_json(data)

            # Skip expired
            if request.is_expired:
                logger.debug(
                    f"Skipping expired request {request.request_id} "
                    f"(waited {request.wait_time_ms:.0f}ms)"
                )
                continue

            logger.info(
                f"Dequeued request {request.request_id} "
                f"(wait={request.wait_time_ms:.0f}ms, "
                f"priority={request.priority.name})"
            )

            return request

    def get_queue_stats(self) -> dict:
        """Return queue statistics for monitoring."""
        size = self.redis.zcard(self.QUEUE_KEY)

        # Peek at oldest entry
        oldest_entries = self.redis.zrange(self.QUEUE_KEY, 0, 0, withscores=True)
        oldest_wait_ms = 0
        if oldest_entries:
            data, score = oldest_entries[0]
            if isinstance(data, bytes):
                data = data.decode("utf-8")
            req = QueuedRequest.from_json(data)
            oldest_wait_ms = req.wait_time_ms

        return {
            "queue_size": size,
            "max_capacity": self.MAX_QUEUE_SIZE,
            "utilization_pct": round((size / self.MAX_QUEUE_SIZE) * 100, 1),
            "oldest_wait_ms": round(oldest_wait_ms),
        }

    def _maybe_cleanup(self) -> None:
        """Periodically remove expired entries."""
        now = time.time()
        if now - self._last_cleanup < self.CLEANUP_INTERVAL_S:
            return

        self._last_cleanup = now
        cutoff_score = RequestPriority.BULK.value * 10_000_000_000 + (now - 30)
        removed = self.redis.zremrangebyscore(
            self.QUEUE_KEY, "-inf", cutoff_score
        )
        if removed:
            logger.info(f"Cleaned {removed} expired queue entries")

Key Takeaways

# Takeaway MangaAssist Application
1 Chunk aggregation prevents WebSocket frame storms — Buffering Bedrock deltas for 100ms before sending reduces client frame rate from 40-80/s to 5-10/s, staying well within API Gateway limits and enabling smooth progressive rendering. MangaAssist uses a ChunkBuffer with 100ms flush interval and 4KB size threshold. Compression ratio averages 8:1 (8 Bedrock deltas per client frame).
2 Token estimation must be language-aware — Japanese text consumes ~1 token per 1.4 characters versus ~1 per 4 for English. Using English ratios for JP text would underestimate by 3x, causing context window overflows. TokenCounterMiddleware detects JP character ratio and uses weighted estimation. Accuracy tracking shows mean absolute error of ~15% for JP text.
3 Four-level token budgets prevent cost explosions at every scale — Per-request, per-session, daily per-user, and context window checks must all pass before any Bedrock invocation to prevent $45K/day runaway costs. TokenCounterMiddleware validates all four levels synchronously before invoking Bedrock. Redis failures degrade gracefully (allow request, log warning).
4 Adaptive retry timing outperforms static backoff under variable load — Bedrock throttling severity changes throughout the day. Static 1s/2s/4s backoff is either too aggressive (peak) or too conservative (off-peak). AdaptiveRetryClient tracks per-model pressure levels (NORMAL/SLIGHT/HIGH/CIRCUIT_OPEN) and adjusts base delay from 1s to 5s with proportional jitter.
5 Separate circuit breakers per model prevent cascading failures — A Sonnet timeout should not prevent Haiku queries from succeeding. Per-model pressure tracking isolates failures. AdaptiveRetryClient maintains independent outcome histories and pressure levels for each model_id, synced to Redis for fleet-wide coordination.
6 Connection pooling saves 150ms per request — Reusing TLS connections to Bedrock endpoints avoids per-request handshake overhead. At 1M messages/day, this saves 41 hours of cumulative handshake time. Singleton Bedrock client with max_pool_connections=25 and TCP keepalive. Each Fargate task maintains its own pool.
7 Request queuing preserves messages during throttling — Rather than rejecting requests when Bedrock is saturated, a priority queue buffers them with 30s TTL, ensuring premium users and critical requests are served first. Redis sorted set queue with priority scoring (CRITICAL through BULK). Maximum 1000 entries with automatic expiry cleanup.