Streaming Response Handling, Token Management, and Retry Patterns
MangaAssist context: JP Manga store chatbot on AWS — Bedrock Claude 3 (Sonnet at $3/$15 per 1M tokens input/output, Haiku at $0.25/$1.25), OpenSearch Serverless (vector store), DynamoDB (sessions/products), ECS Fargate (orchestrator), API Gateway WebSocket, ElastiCache Redis. Target: useful answer in under 3 seconds, 1M messages/day scale.
Skill Mapping
| Field | Value |
|---|---|
| Certification | AWS AIP-C01 |
| Domain | 2 — Implementation & Integration |
| Task | 2.5 — Application Integration Patterns |
| Skill | 2.5.1 — Create FM API interfaces to address the specific requirements of GenAI workloads |
| Deep-Dive Focus | Chunked streaming via WebSocket, token counting middleware, adaptive retry timing, connection pooling, request queuing |
Overview
This document is a deep-dive companion to 01-fm-api-interface-architecture.md. Where the architecture document defines the components (FMAPIInterface, TokenLimitEnforcer, RetryStrategyManager, StreamingAPIHandler), this document examines the three hardest operational problems in GenAI API interfaces:
- Streaming — How chunks flow from Bedrock through ECS Fargate through API Gateway WebSocket to the browser, handling backpressure, reconnection, and partial failures
- Token management — How middleware intercepts every request to count tokens, enforce budgets, and track costs in real time
- Retry patterns — How adaptive retry timing avoids thundering herds while maximizing request success rate under Bedrock throttling
Part 1: Chunked Streaming via WebSocket
The Streaming Pipeline — End to End
graph LR
subgraph Bedrock["Amazon Bedrock"]
MODEL[Claude 3 Sonnet/Haiku]
STREAM_OUT[InvokeModelWithResponseStream<br/>yields content_block_delta events]
end
subgraph Fargate["ECS Fargate Orchestrator"]
CONSUMER[Stream Consumer<br/>reads event stream]
TOKENIZER[Token Counter<br/>real-time output tracking]
CHUNKER[Chunk Aggregator<br/>buffers small deltas]
METRICS[Latency Tracker<br/>TTFT + TPS metrics]
BACKPRESSURE[Backpressure Monitor<br/>pause if client slow]
end
subgraph APIGW["API Gateway"]
MGMT_API[Management API<br/>postToConnection]
FRAME[Frame Encoder<br/>JSON + UTF-8 <= 32KB]
end
subgraph Client["Browser / Mobile"]
WS_CLIENT[WebSocket Client<br/>onmessage handler]
RENDERER[Progressive Renderer<br/>append text as it arrives]
RECONNECT[Reconnect Handler<br/>exponential backoff]
end
MODEL --> STREAM_OUT
STREAM_OUT --> CONSUMER
CONSUMER --> TOKENIZER
CONSUMER --> CHUNKER
CONSUMER --> METRICS
CHUNKER --> BACKPRESSURE
BACKPRESSURE --> FRAME
FRAME --> MGMT_API
MGMT_API --> WS_CLIENT
WS_CLIENT --> RENDERER
RECONNECT -.->|on disconnect| MGMT_API
style MODEL fill:#ff9900,color:#000
style MGMT_API fill:#527fff,color:#fff
Why Aggregate Chunks Before Sending
Bedrock's InvokeModelWithResponseStream emits very small deltas — often a single word or even a partial word. Sending each delta as a separate WebSocket frame would:
- Overwhelm the client with hundreds of tiny messages per second
- Hit API Gateway's per-connection rate limit (documented as "several hundred" messages/second)
- Increase overhead since each frame has JSON wrapper + network overhead
- Fragment Japanese text mid-character across frame boundaries
The solution is a chunk aggregator that buffers deltas and flushes based on time intervals (50-100ms) or size thresholds (8KB), whichever comes first.
Chunk Aggregation Timing
Bedrock output rate: ~40-80 tokens/second for Sonnet
Average token size: ~4 bytes English, ~3 bytes Japanese (UTF-8)
Bytes per second: ~120-240 bytes/second text content
Aggregation strategy:
- Flush every 50ms → ~6-12 bytes per frame (too small)
- Flush every 100ms → ~12-24 bytes per frame (still small but acceptable)
- Flush every 200ms → ~24-48 bytes per frame (good balance)
- Flush when buffer > 8KB → handles burst output
MangaAssist setting: flush every 100ms OR when buffer > 4KB
→ Typical frame: 12-48 bytes of text + ~150 bytes JSON wrapper
→ ~5-10 frames per second to client
→ Well within API Gateway rate limits
→ Smooth progressive rendering in browser
StreamingMiddleware Implementation
"""
MangaAssist Streaming Middleware
Sits between the Bedrock stream consumer and the WebSocket relay layer.
Handles chunk aggregation, backpressure detection, heartbeats, and stream lifecycle.
"""
import json
import time
import asyncio
import logging
from dataclasses import dataclass, field
from typing import Optional, AsyncIterator, Callable
from enum import Enum
import boto3
from botocore.exceptions import ClientError
logger = logging.getLogger(__name__)
class StreamState(Enum):
"""States of a streaming session."""
INITIALIZING = "initializing"
STREAMING = "streaming"
PAUSED = "paused" # Backpressure — waiting for client to catch up
COMPLETING = "completing" # Final flush + done message
COMPLETED = "completed"
CANCELLED = "cancelled" # Client disconnected
ERROR = "error"
@dataclass
class ChunkBuffer:
"""
Accumulates small Bedrock deltas and flushes them as aggregated chunks.
Prevents overwhelming the WebSocket with hundreds of tiny frames.
"""
texts: list = field(default_factory=list)
byte_count: int = 0
last_flush_time: float = 0.0
flush_interval_ms: float = 100.0 # Flush every 100ms
max_buffer_bytes: int = 4096 # Flush when buffer exceeds 4KB
total_flushed_chunks: int = 0
total_flushed_bytes: int = 0
def add(self, text: str) -> None:
"""Add a text delta to the buffer."""
self.texts.append(text)
self.byte_count += len(text.encode("utf-8"))
def should_flush(self) -> bool:
"""Determine if the buffer should be flushed now."""
if not self.texts:
return False
# Size-based flush
if self.byte_count >= self.max_buffer_bytes:
return True
# Time-based flush
now = time.time()
elapsed_ms = (now - self.last_flush_time) * 1000
if elapsed_ms >= self.flush_interval_ms:
return True
return False
def flush(self) -> Optional[str]:
"""Flush the buffer and return aggregated text, or None if empty."""
if not self.texts:
return None
aggregated = "".join(self.texts)
aggregated_bytes = len(aggregated.encode("utf-8"))
self.total_flushed_chunks += 1
self.total_flushed_bytes += aggregated_bytes
self.texts = []
self.byte_count = 0
self.last_flush_time = time.time()
return aggregated
@property
def is_empty(self) -> bool:
return len(self.texts) == 0
@dataclass
class StreamMetricsCollector:
"""Collects real-time metrics during a streaming session."""
stream_start: float = 0.0
first_token_time: float = 0.0
last_token_time: float = 0.0
input_tokens: int = 0
output_tokens: int = 0
chunks_from_bedrock: int = 0
chunks_to_client: int = 0
bytes_to_client: int = 0
backpressure_pauses: int = 0
total_pause_duration_ms: float = 0.0
reconnections: int = 0
@property
def time_to_first_token_ms(self) -> float:
"""Time from request start to first token received from Bedrock."""
if self.first_token_time and self.stream_start:
return (self.first_token_time - self.stream_start) * 1000
return 0.0
@property
def tokens_per_second(self) -> float:
"""Output token generation rate."""
if self.first_token_time and self.last_token_time:
duration = self.last_token_time - self.first_token_time
if duration > 0:
return self.output_tokens / duration
return 0.0
@property
def total_latency_ms(self) -> float:
"""Total time from start to last token."""
if self.last_token_time and self.stream_start:
return (self.last_token_time - self.stream_start) * 1000
return 0.0
@property
def compression_ratio(self) -> float:
"""Ratio of Bedrock chunks to client frames — higher = more aggregation."""
if self.chunks_to_client > 0:
return self.chunks_from_bedrock / self.chunks_to_client
return 0.0
def to_dict(self) -> dict:
"""Export metrics as a dictionary for logging and CloudWatch."""
return {
"ttft_ms": round(self.time_to_first_token_ms),
"tps": round(self.tokens_per_second, 1),
"total_latency_ms": round(self.total_latency_ms),
"input_tokens": self.input_tokens,
"output_tokens": self.output_tokens,
"chunks_from_bedrock": self.chunks_from_bedrock,
"chunks_to_client": self.chunks_to_client,
"bytes_to_client": self.bytes_to_client,
"compression_ratio": round(self.compression_ratio, 1),
"backpressure_pauses": self.backpressure_pauses,
"pause_duration_ms": round(self.total_pause_duration_ms),
"reconnections": self.reconnections,
}
class StreamingMiddleware:
"""
Middleware that sits between the Bedrock stream and the WebSocket client.
Features:
- Chunk aggregation with time-based and size-based flushing
- Backpressure detection (monitors postToConnection failures)
- Heartbeat pings during pauses in Bedrock output
- Graceful stream cancellation on client disconnect
- Real-time metrics collection (TTFT, TPS, compression ratio)
- Japanese text boundary awareness (never split mid-character)
Architecture:
Bedrock Stream → StreamingMiddleware → API Gateway postToConnection → Client
Usage:
middleware = StreamingMiddleware(
apigw_endpoint="https://...",
redis_client=redis,
)
metrics = await middleware.process_stream(
connection_id="abc",
session_id="sess-123",
request_id="req-456",
bedrock_stream=response["body"],
)
"""
HEARTBEAT_INTERVAL_S = 5.0 # Send heartbeat if no data for 5s
MAX_STREAM_DURATION_S = 120.0 # Hard kill after 2 minutes
BACKPRESSURE_PAUSE_S = 0.5 # Pause duration on send failure
MAX_BACKPRESSURE_PAUSES = 10 # Cancel stream after this many pauses
MAX_FRAME_BYTES = 32_768 # API Gateway WebSocket frame limit
def __init__(
self,
apigw_endpoint: str,
redis_client,
chunk_buffer_config: Optional[dict] = None,
):
self.redis = redis_client
self.apigw_mgmt = boto3.client(
"apigatewaymanagementapi",
endpoint_url=apigw_endpoint,
)
self._buffer_config = chunk_buffer_config or {}
def process_stream(
self,
connection_id: str,
session_id: str,
request_id: str,
bedrock_stream,
on_complete: Optional[Callable] = None,
) -> StreamMetricsCollector:
"""
Process a Bedrock stream and relay to WebSocket client.
Args:
connection_id: API Gateway WebSocket connection ID
session_id: MangaAssist session ID
request_id: Unique request identifier for idempotency
bedrock_stream: The Bedrock response stream object
on_complete: Optional callback when stream finishes
Returns:
StreamMetricsCollector with performance data
"""
metrics = StreamMetricsCollector(stream_start=time.time())
buffer = ChunkBuffer(**self._buffer_config)
state = StreamState.INITIALIZING
full_text_parts = []
try:
state = StreamState.STREAMING
for event in bedrock_stream:
# Check for client disconnect
if self._is_client_gone(connection_id):
state = StreamState.CANCELLED
logger.info(f"Stream cancelled — client gone: {connection_id}")
break
# Check stream duration limit
elapsed = time.time() - metrics.stream_start
if elapsed > self.MAX_STREAM_DURATION_S:
logger.warning(f"Stream duration limit: {elapsed:.1f}s")
break
chunk = event.get("chunk")
if not chunk:
continue
chunk_data = json.loads(chunk.get("bytes", b"{}"))
chunk_type = chunk_data.get("type")
# ─── Handle content_block_delta (text chunks) ──────────
if chunk_type == "content_block_delta":
delta = chunk_data.get("delta", {})
text = delta.get("text", "")
if text:
now = time.time()
metrics.chunks_from_bedrock += 1
if metrics.first_token_time == 0:
metrics.first_token_time = now
metrics.last_token_time = now
buffer.add(text)
full_text_parts.append(text)
# Flush if buffer is ready
if buffer.should_flush():
aggregated = buffer.flush()
if aggregated:
success = self._send_text_chunk(
connection_id,
request_id,
aggregated,
metrics.chunks_to_client,
)
if success:
metrics.chunks_to_client += 1
metrics.bytes_to_client += len(
aggregated.encode("utf-8")
)
else:
# Backpressure detected
metrics.backpressure_pauses += 1
if metrics.backpressure_pauses > self.MAX_BACKPRESSURE_PAUSES:
state = StreamState.CANCELLED
logger.warning(
f"Stream cancelled — too many "
f"backpressure pauses: {connection_id}"
)
break
pause_start = time.time()
time.sleep(self.BACKPRESSURE_PAUSE_S)
metrics.total_pause_duration_ms += (
(time.time() - pause_start) * 1000
)
# ─── Handle message_start (input token count) ──────────
elif chunk_type == "message_start":
msg = chunk_data.get("message", {})
usage = msg.get("usage", {})
metrics.input_tokens = usage.get("input_tokens", 0)
# ─── Handle message_delta (output token count) ─────────
elif chunk_type == "message_delta":
usage = chunk_data.get("usage", {})
metrics.output_tokens = usage.get("output_tokens", 0)
# ─── Handle message_stop (stream complete) ─────────────
elif chunk_type == "message_stop":
state = StreamState.COMPLETING
# Final buffer flush
if not buffer.is_empty:
remaining = buffer.flush()
if remaining:
self._send_text_chunk(
connection_id, request_id,
remaining, metrics.chunks_to_client,
)
metrics.chunks_to_client += 1
metrics.bytes_to_client += len(remaining.encode("utf-8"))
# Send completion message
if state != StreamState.CANCELLED:
self._send_done_message(
connection_id, request_id, metrics,
"".join(full_text_parts),
)
state = StreamState.COMPLETED
if on_complete:
on_complete(metrics, "".join(full_text_parts))
except ClientError as e:
if e.response["Error"]["Code"] == "GoneException":
state = StreamState.CANCELLED
logger.info(f"Client gone during stream: {connection_id}")
else:
state = StreamState.ERROR
logger.error(f"Stream error: {e}")
except Exception as e:
state = StreamState.ERROR
logger.error(f"Unexpected stream error: {e}", exc_info=True)
self._send_error_message(connection_id, "Stream processing error")
# Log final metrics
logger.info(
f"Stream finished: state={state.value}",
extra={
"connection_id": connection_id,
"request_id": request_id,
"state": state.value,
**metrics.to_dict(),
},
)
return metrics
def _send_text_chunk(
self,
connection_id: str,
request_id: str,
text: str,
chunk_index: int,
) -> bool:
"""Send a text chunk to the client. Returns True on success."""
payload = {
"type": "chunk",
"text": text,
"index": chunk_index,
"requestId": request_id,
}
encoded = json.dumps(payload, ensure_ascii=False).encode("utf-8")
# Split if exceeds frame limit
if len(encoded) > self.MAX_FRAME_BYTES:
return self._send_split_frames(
connection_id, request_id, text, chunk_index
)
try:
self.apigw_mgmt.post_to_connection(
ConnectionId=connection_id,
Data=encoded,
)
return True
except ClientError as e:
if e.response["Error"]["Code"] == "GoneException":
self._mark_client_gone(connection_id)
return False
logger.warning(f"Send failed: {e}")
return False
def _send_split_frames(
self,
connection_id: str,
request_id: str,
text: str,
chunk_index: int,
) -> bool:
"""Split oversized text across multiple frames."""
max_text_bytes = self.MAX_FRAME_BYTES - 256 # Reserve for JSON wrapper
text_bytes = text.encode("utf-8")
offset = 0
sub_index = 0
all_sent = True
while offset < len(text_bytes):
segment = text_bytes[offset:offset + max_text_bytes]
# Avoid splitting multi-byte UTF-8 characters
segment_text = segment.decode("utf-8", errors="ignore")
payload = json.dumps({
"type": "chunk",
"text": segment_text,
"index": chunk_index,
"subIndex": sub_index,
"requestId": request_id,
}, ensure_ascii=False).encode("utf-8")
try:
self.apigw_mgmt.post_to_connection(
ConnectionId=connection_id,
Data=payload,
)
except ClientError:
all_sent = False
break
offset += max_text_bytes
sub_index += 1
return all_sent
def _send_done_message(
self,
connection_id: str,
request_id: str,
metrics: StreamMetricsCollector,
full_text: str,
) -> None:
"""Send stream completion message with metrics."""
payload = {
"type": "done",
"requestId": request_id,
"tokens": {
"input": metrics.input_tokens,
"output": metrics.output_tokens,
},
"metrics": {
"ttft_ms": round(metrics.time_to_first_token_ms),
"total_ms": round(metrics.total_latency_ms),
"tps": round(metrics.tokens_per_second, 1),
"chunks": metrics.chunks_to_client,
},
"textLength": len(full_text),
}
try:
self.apigw_mgmt.post_to_connection(
ConnectionId=connection_id,
Data=json.dumps(payload, ensure_ascii=False).encode("utf-8"),
)
except ClientError:
logger.warning(f"Failed to send done message: {connection_id}")
def _send_error_message(self, connection_id: str, message: str) -> None:
"""Send an error message to the client."""
try:
self.apigw_mgmt.post_to_connection(
ConnectionId=connection_id,
Data=json.dumps({
"type": "error",
"message": message,
}).encode("utf-8"),
)
except Exception:
pass
def _is_client_gone(self, connection_id: str) -> bool:
"""Check Redis for disconnection marker."""
try:
return self.redis.get(f"disconnected:{connection_id}") is not None
except Exception:
return False
def _mark_client_gone(self, connection_id: str) -> None:
"""Mark a connection as gone in Redis."""
try:
self.redis.setex(f"disconnected:{connection_id}", 120, "1")
except Exception:
pass
Part 2: Token Counting Middleware
The Token Counting Problem
Token counting in GenAI applications is harder than it appears:
- Pre-request estimation — You must know approximately how many input tokens a request will consume before sending it to Bedrock, because Bedrock charges for all input tokens regardless of output length.
- Japanese text complexity — Japanese characters consume 1-3 tokens each depending on the character type (hiragana, katakana, kanji, romaji), making estimation harder than English.
- Real-time output tracking — During streaming, you need to track output tokens as they arrive to enforce per-request and per-session budgets.
- Middleware position — Token counting must happen at the middleware layer, intercepting every request before it reaches Bedrock and every response chunk as it streams back.
Token Estimation Accuracy
| Text Type | Chars per Token | Example | Estimation Method |
|---|---|---|---|
| English prose | ~4.0 | "manga recommendation" | len(text) // 4 |
| Japanese hiragana | ~1.5 | "おすすめのマンガ" | len(text) * 0.7 |
| Japanese kanji | ~1.0 | "漫画推薦" | len(text) * 1.0 |
| Mixed JP/EN | ~2.0 | "One Pieceのおすすめ" | Weighted average |
| JSON/structured | ~3.5 | {"title":"..."} |
len(text) // 3.5 |
| URLs/paths | ~2.0 | https://manga.store/vol/1 |
len(text) // 2 |
For MangaAssist with its mixed JP/EN content, the conservative estimate of 1 token per 2 characters provides a safe upper bound that prevents context window overflows.
Token Budget Flow
graph TB
subgraph PreRequest["Pre-Request Token Check"]
ESTIMATE[Estimate Input Tokens<br/>system + history + RAG + user msg]
CHECK_REQ[Check Per-Request Budget<br/>max 4,000 input tokens]
CHECK_SESSION[Check Session Budget<br/>max 50,000 cumulative input]
CHECK_DAILY[Check Daily Budget<br/>max $5.00/day per user]
CHECK_WINDOW[Check Context Window<br/>200K - overhead - output - safety]
end
subgraph DuringStream["During Streaming"]
COUNT_OUT[Count Output Tokens<br/>from message_delta events]
ENFORCE_OUT[Enforce Output Budget<br/>cancel if exceeds max_tokens]
TRACK_COST[Track Running Cost<br/>input + output so far]
end
subgraph PostRequest["Post-Request Recording"]
RECORD_ACTUAL[Record Actual Usage<br/>exact tokens from Bedrock]
UPDATE_SESSION[Update Session Counters<br/>Redis + DynamoDB]
UPDATE_DAILY[Update Daily Counters<br/>Redis with TTL]
EMIT_METRIC[Emit CloudWatch Metric<br/>token usage + cost]
end
ESTIMATE --> CHECK_REQ
CHECK_REQ -->|pass| CHECK_SESSION
CHECK_REQ -->|fail| REJECT[Reject: Token Budget Exceeded]
CHECK_SESSION -->|pass| CHECK_DAILY
CHECK_SESSION -->|fail| REJECT
CHECK_DAILY -->|pass| CHECK_WINDOW
CHECK_DAILY -->|fail| REJECT
CHECK_WINDOW -->|pass| INVOKE[Invoke Bedrock]
CHECK_WINDOW -->|fail| TRUNCATE[Truncate History + Retry]
INVOKE --> COUNT_OUT
COUNT_OUT --> ENFORCE_OUT
ENFORCE_OUT --> TRACK_COST
TRACK_COST --> RECORD_ACTUAL
RECORD_ACTUAL --> UPDATE_SESSION
RECORD_ACTUAL --> UPDATE_DAILY
RECORD_ACTUAL --> EMIT_METRIC
style REJECT fill:#dc3545,color:#fff
style TRUNCATE fill:#ffc107,color:#000
style INVOKE fill:#28a745,color:#fff
TokenCounterMiddleware Implementation
"""
MangaAssist Token Counter Middleware
Intercepts every FM request to estimate, validate, and track token usage.
"""
import json
import time
import logging
import re
from dataclasses import dataclass, field
from typing import Optional
logger = logging.getLogger(__name__)
@dataclass
class TokenEstimate:
"""Detailed token estimation breakdown for a request."""
system_prompt_tokens: int = 0
conversation_history_tokens: int = 0
rag_context_tokens: int = 0
user_message_tokens: int = 0
total_input_tokens: int = 0
max_output_tokens: int = 0
total_budget: int = 0
# Estimation metadata
language_detected: str = "mixed"
estimation_method: str = "character_ratio"
confidence: float = 0.8 # 0.0 to 1.0
def to_dict(self) -> dict:
return {
"system_prompt": self.system_prompt_tokens,
"history": self.conversation_history_tokens,
"rag_context": self.rag_context_tokens,
"user_message": self.user_message_tokens,
"total_input": self.total_input_tokens,
"max_output": self.max_output_tokens,
"total_budget": self.total_budget,
"language": self.language_detected,
"method": self.estimation_method,
"confidence": self.confidence,
}
@dataclass
class TokenUsageRecord:
"""Actual token usage after a completed request."""
input_tokens: int
output_tokens: int
model_id: str
cost_usd: float
estimation_error_pct: float # How far off the estimate was
session_id: str
request_id: str
timestamp: float = field(default_factory=time.time)
def to_dict(self) -> dict:
return {
"input_tokens": self.input_tokens,
"output_tokens": self.output_tokens,
"model_id": self.model_id,
"cost_usd": round(self.cost_usd, 6),
"estimation_error_pct": round(self.estimation_error_pct, 1),
"session_id": self.session_id,
"request_id": self.request_id,
"timestamp": self.timestamp,
}
class TokenCounterMiddleware:
"""
Middleware that intercepts every FM request to count and validate tokens.
This middleware wraps the FMAPIInterface and provides:
1. Pre-request token estimation and budget validation
2. Real-time output token tracking during streaming
3. Post-request actual usage recording and cost calculation
4. Automatic conversation history truncation when over budget
5. Japanese-aware token estimation
The middleware maintains running counters in Redis for session-level
and daily-level budget enforcement.
Usage:
middleware = TokenCounterMiddleware(redis_client=redis)
# Pre-request
estimate = middleware.estimate_request_tokens(request)
middleware.validate_budget(request, estimate)
# During streaming (called per Bedrock event)
middleware.track_output_token(request, chunk_data)
# Post-request
middleware.record_actual_usage(request, actual_input, actual_output)
"""
# ─── Pricing Constants ─────────────────────────────────────────
PRICING = {
"anthropic.claude-3-sonnet-20240229-v1:0": {
"input_per_1m": 3.00,
"output_per_1m": 15.00,
},
"anthropic.claude-3-haiku-20240307-v1:0": {
"input_per_1m": 0.25,
"output_per_1m": 1.25,
},
}
# ─── Budget Limits ─────────────────────────────────────────────
PER_REQUEST_INPUT_LIMIT = 4_000
PER_REQUEST_OUTPUT_LIMIT = 1_024
SESSION_INPUT_LIMIT = 50_000
SESSION_OUTPUT_LIMIT = 25_000
DAILY_COST_LIMIT_USD = 5.00
CONTEXT_WINDOW_SIZE = 200_000
PROMPT_OVERHEAD = 300
SAFETY_MARGIN = 500
def __init__(self, redis_client):
self.redis = redis_client
# Track estimation accuracy over time for calibration
self._estimation_errors: list[float] = []
def estimate_request_tokens(
self,
system_prompt: str,
conversation_history: list[dict],
rag_context: Optional[str],
user_message: str,
max_output_tokens: int = 1024,
) -> TokenEstimate:
"""
Estimate token counts for each component of the request.
Uses language detection to choose the right estimation ratio:
- Japanese text: ~1 token per 1.4 characters
- English text: ~1 token per 4 characters
- Mixed: weighted average based on character set ratios
"""
# Detect language mix
all_text = f"{system_prompt} {user_message}"
for turn in conversation_history:
all_text += f" {turn.get('content', '')}"
if rag_context:
all_text += f" {rag_context}"
jp_ratio = self._detect_japanese_ratio(all_text)
lang = "ja" if jp_ratio > 0.3 else "en" if jp_ratio < 0.1 else "mixed"
# Estimate each component
sys_tokens = self._count_tokens(system_prompt, jp_ratio)
history_tokens = sum(
self._count_tokens(turn.get("content", ""), jp_ratio)
for turn in conversation_history
)
rag_tokens = self._count_tokens(rag_context or "", jp_ratio)
user_tokens = self._count_tokens(user_message, jp_ratio)
total_input = sys_tokens + history_tokens + rag_tokens + user_tokens
estimate = TokenEstimate(
system_prompt_tokens=sys_tokens,
conversation_history_tokens=history_tokens,
rag_context_tokens=rag_tokens,
user_message_tokens=user_tokens,
total_input_tokens=total_input,
max_output_tokens=max_output_tokens,
total_budget=total_input + max_output_tokens,
language_detected=lang,
estimation_method="character_ratio_weighted",
confidence=0.85 if lang == "en" else 0.75, # JP is less precise
)
logger.debug(
"Token estimate",
extra=estimate.to_dict(),
)
return estimate
def validate_budget(
self,
session_id: str,
estimate: TokenEstimate,
model_id: str,
) -> dict:
"""
Validate the estimated token usage against all budget levels.
Returns a dict with validation result and any adjustments needed.
Raises TokenBudgetExceeded for hard limit violations.
"""
result = {
"valid": True,
"warnings": [],
"adjustments": [],
}
# Per-request input limit
if estimate.total_input_tokens > self.PER_REQUEST_INPUT_LIMIT:
# Calculate how much history to trim
excess = estimate.total_input_tokens - self.PER_REQUEST_INPUT_LIMIT
result["warnings"].append(
f"Input tokens ({estimate.total_input_tokens}) exceed "
f"per-request limit ({self.PER_REQUEST_INPUT_LIMIT}). "
f"Need to trim {excess} tokens from history."
)
result["adjustments"].append({
"action": "truncate_history",
"tokens_to_trim": excess,
})
# Context window check
available = (
self.CONTEXT_WINDOW_SIZE
- self.PROMPT_OVERHEAD
- self.SAFETY_MARGIN
- estimate.max_output_tokens
)
if estimate.total_input_tokens > available:
result["valid"] = False
result["warnings"].append(
f"Input tokens ({estimate.total_input_tokens}) exceed "
f"context window capacity ({available})."
)
# Session budget check
session_usage = self._get_session_usage(session_id)
projected_session_input = (
session_usage.get("input", 0) + estimate.total_input_tokens
)
if projected_session_input > self.SESSION_INPUT_LIMIT:
result["warnings"].append(
f"Session would reach {projected_session_input} input tokens "
f"(limit: {self.SESSION_INPUT_LIMIT})."
)
# Daily cost check
daily_cost = self._get_daily_cost(session_id)
estimated_cost = self._estimate_cost(
model_id,
estimate.total_input_tokens,
estimate.max_output_tokens,
)
if daily_cost + estimated_cost > self.DAILY_COST_LIMIT_USD:
result["valid"] = False
result["warnings"].append(
f"Daily cost would reach ${daily_cost + estimated_cost:.4f} "
f"(limit: ${self.DAILY_COST_LIMIT_USD:.2f})."
)
return result
def truncate_history_to_budget(
self,
conversation_history: list[dict],
target_tokens: int,
) -> list[dict]:
"""
Trim conversation history to fit within a token budget.
Strategy: Remove oldest turns first, always keeping the most recent
user message and assistant response pair.
"""
if not conversation_history:
return []
current_tokens = sum(
self._count_tokens(turn.get("content", ""), 0.3)
for turn in conversation_history
)
if current_tokens <= target_tokens:
return conversation_history
trimmed = list(conversation_history)
while current_tokens > target_tokens and len(trimmed) > 2:
# Remove oldest turn
removed = trimmed.pop(0)
removed_tokens = self._count_tokens(
removed.get("content", ""), 0.3
)
current_tokens -= removed_tokens
logger.info(
f"Trimmed history from {len(conversation_history)} to "
f"{len(trimmed)} turns ({current_tokens} tokens)"
)
return trimmed
def record_actual_usage(
self,
session_id: str,
request_id: str,
model_id: str,
actual_input_tokens: int,
actual_output_tokens: int,
estimated_input_tokens: int,
) -> TokenUsageRecord:
"""
Record actual token usage after a completed request.
Updates session and daily counters. Calculates cost and estimation error.
"""
# Calculate cost
pricing = self.PRICING.get(model_id, self.PRICING[
"anthropic.claude-3-sonnet-20240229-v1:0"
])
cost = (
(actual_input_tokens / 1_000_000) * pricing["input_per_1m"]
+ (actual_output_tokens / 1_000_000) * pricing["output_per_1m"]
)
# Calculate estimation error
if estimated_input_tokens > 0:
error_pct = (
(actual_input_tokens - estimated_input_tokens)
/ estimated_input_tokens
) * 100
else:
error_pct = 0.0
self._estimation_errors.append(error_pct)
if len(self._estimation_errors) > 1000:
self._estimation_errors = self._estimation_errors[-500:]
record = TokenUsageRecord(
input_tokens=actual_input_tokens,
output_tokens=actual_output_tokens,
model_id=model_id,
cost_usd=cost,
estimation_error_pct=error_pct,
session_id=session_id,
request_id=request_id,
)
# Update counters
self._update_session_usage(
session_id, actual_input_tokens, actual_output_tokens, cost
)
self._update_daily_usage(
session_id, actual_input_tokens, actual_output_tokens, cost
)
logger.info(
"Token usage recorded",
extra=record.to_dict(),
)
return record
def get_estimation_accuracy(self) -> dict:
"""Return statistics on token estimation accuracy."""
if not self._estimation_errors:
return {"samples": 0, "mean_error_pct": 0, "status": "no_data"}
errors = self._estimation_errors
mean_error = sum(errors) / len(errors)
abs_errors = [abs(e) for e in errors]
mean_abs_error = sum(abs_errors) / len(abs_errors)
return {
"samples": len(errors),
"mean_error_pct": round(mean_error, 1),
"mean_abs_error_pct": round(mean_abs_error, 1),
"max_overestimate_pct": round(min(errors), 1),
"max_underestimate_pct": round(max(errors), 1),
"status": "calibrated" if mean_abs_error < 20 else "needs_calibration",
}
# ─── Private Methods ───────────────────────────────────────────
def _count_tokens(self, text: str, jp_ratio: float) -> int:
"""
Estimate token count for text with a given Japanese character ratio.
"""
if not text:
return 0
jp_chars = sum(1 for c in text if ord(c) > 0x3000)
en_chars = len(text) - jp_chars
# JP chars: ~1 token per 1.4 chars (conservative)
# EN chars: ~1 token per 4 chars
jp_tokens = int(jp_chars * 0.71) # 1/1.4
en_tokens = en_chars // 4
return jp_tokens + en_tokens + 1 # +1 to avoid zero
def _detect_japanese_ratio(self, text: str) -> float:
"""Detect the ratio of Japanese characters in text."""
if not text:
return 0.0
jp_chars = sum(1 for c in text if ord(c) > 0x3000)
return jp_chars / len(text) if len(text) > 0 else 0.0
def _estimate_cost(
self, model_id: str, input_tokens: int, output_tokens: int
) -> float:
"""Estimate cost for a request."""
pricing = self.PRICING.get(model_id, self.PRICING[
"anthropic.claude-3-sonnet-20240229-v1:0"
])
return (
(input_tokens / 1_000_000) * pricing["input_per_1m"]
+ (output_tokens / 1_000_000) * pricing["output_per_1m"]
)
def _get_session_usage(self, session_id: str) -> dict:
"""Get current session token usage from Redis."""
try:
raw = self.redis.get(f"session_tokens:{session_id}")
if raw:
return json.loads(raw)
except Exception as e:
logger.warning(f"Redis session read failed: {e}")
return {"input": 0, "output": 0, "cost_usd": 0.0, "requests": 0}
def _get_daily_cost(self, session_id: str) -> float:
"""Get today's cost from Redis."""
try:
day = time.strftime("%Y-%m-%d")
raw = self.redis.get(f"daily_tokens:{session_id}:{day}")
if raw:
return json.loads(raw).get("cost_usd", 0.0)
except Exception as e:
logger.warning(f"Redis daily read failed: {e}")
return 0.0
def _update_session_usage(
self, session_id: str, input_t: int, output_t: int, cost: float
) -> None:
"""Update session counters in Redis."""
try:
key = f"session_tokens:{session_id}"
raw = self.redis.get(key)
data = json.loads(raw) if raw else {
"input": 0, "output": 0, "cost_usd": 0.0, "requests": 0
}
data["input"] += input_t
data["output"] += output_t
data["cost_usd"] += cost
data["requests"] += 1
self.redis.setex(key, 86400, json.dumps(data))
except Exception as e:
logger.error(f"Session usage update failed: {e}")
def _update_daily_usage(
self, session_id: str, input_t: int, output_t: int, cost: float
) -> None:
"""Update daily counters in Redis."""
try:
day = time.strftime("%Y-%m-%d")
key = f"daily_tokens:{session_id}:{day}"
raw = self.redis.get(key)
data = json.loads(raw) if raw else {
"input_tokens": 0, "output_tokens": 0, "cost_usd": 0.0, "requests": 0
}
data["input_tokens"] += input_t
data["output_tokens"] += output_t
data["cost_usd"] += cost
data["requests"] += 1
self.redis.setex(key, 86400, json.dumps(data))
except Exception as e:
logger.error(f"Daily usage update failed: {e}")
Part 3: Adaptive Retry Timing and Connection Pooling
Why Adaptive Retries for GenAI APIs
Standard fixed-interval retries are dangerous for GenAI APIs because:
- Bedrock throttling is load-dependent — A 429 at 9:00 AM might need 2s backoff; the same 429 at peak (lunch hour in Japan) might need 15s.
- Model timeouts vary by input size — A 500-token request times out differently than a 4,000-token request. Retry timing should account for expected completion time.
- Cost of retries is non-zero — Each retry that gets far enough to invoke the model incurs token charges. Retrying a $0.01 request 3 times costs $0.04 total.
- Thundering herd on recovery — When Bedrock recovers from a brief outage, all clients retrying simultaneously can re-trigger throttling.
Adaptive retry timing adjusts backoff parameters based on observed conditions: error type, current throughput, time of day, and recent success rates.
Adaptive Retry State Machine
stateDiagram-v2
[*] --> Normal: Start
Normal --> Normal: Success (track latency)
Normal --> SlightPressure: Single throttle or timeout
SlightPressure --> Normal: 3 consecutive successes
SlightPressure --> SlightPressure: Sporadic errors (< 20% rate)
SlightPressure --> HighPressure: Error rate > 20%
HighPressure --> SlightPressure: Error rate drops below 20%
HighPressure --> HighPressure: Sustained errors
HighPressure --> CircuitOpen: Error rate > 50% for 60s
CircuitOpen --> Recovery: 30s cooldown
Recovery --> Normal: 2 consecutive successes
Recovery --> CircuitOpen: Probe fails
note right of Normal
Base delay: 1s
Jitter: 0-1s
Max retries: 3
end note
note right of SlightPressure
Base delay: 2s
Jitter: 0-2s
Max retries: 2
end note
note right of HighPressure
Base delay: 5s
Jitter: 0-5s
Max retries: 1
end note
note right of CircuitOpen
Reject all requests
Return fallback response
end note
AdaptiveRetryClient Implementation
"""
MangaAssist Adaptive Retry Client
Dynamically adjusts retry behavior based on observed Bedrock API conditions.
"""
import json
import time
import random
import logging
from dataclasses import dataclass, field
from enum import Enum
from typing import Callable, Any, Optional
from collections import deque
from botocore.exceptions import ClientError
logger = logging.getLogger(__name__)
class PressureLevel(Enum):
"""Current pressure level based on observed API behavior."""
NORMAL = "normal"
SLIGHT = "slight_pressure"
HIGH = "high_pressure"
CIRCUIT_OPEN = "circuit_open"
RECOVERY = "recovery"
@dataclass
class PressureConfig:
"""Retry parameters for each pressure level."""
base_delay: float
jitter_range: float
max_retries: int
max_delay: float
PRESSURE_CONFIGS = {
PressureLevel.NORMAL: PressureConfig(
base_delay=1.0, jitter_range=1.0, max_retries=3, max_delay=10.0
),
PressureLevel.SLIGHT: PressureConfig(
base_delay=2.0, jitter_range=2.0, max_retries=2, max_delay=15.0
),
PressureLevel.HIGH: PressureConfig(
base_delay=5.0, jitter_range=5.0, max_retries=1, max_delay=30.0
),
PressureLevel.RECOVERY: PressureConfig(
base_delay=3.0, jitter_range=2.0, max_retries=1, max_delay=20.0
),
}
@dataclass
class RequestOutcome:
"""Records the outcome of a single API call."""
timestamp: float
success: bool
latency_ms: float
error_code: Optional[str] = None
model_id: Optional[str] = None
class AdaptiveRetryClient:
"""
Retry client that adapts its behavior based on observed API conditions.
Unlike static retry configurations, this client:
- Tracks recent success/failure rates in a sliding window
- Adjusts backoff delays based on current pressure level
- Maintains separate pressure tracking per model (Sonnet vs Haiku)
- Reduces retry count under high pressure to avoid worsening conditions
- Uses Redis to share pressure state across ECS Fargate tasks
Pressure Level Transitions:
- NORMAL → SLIGHT: First throttle/timeout error
- SLIGHT → NORMAL: 3 consecutive successes
- SLIGHT → HIGH: Error rate > 20% in observation window
- HIGH → SLIGHT: Error rate drops below 20%
- HIGH → CIRCUIT_OPEN: Error rate > 50% for 60s
- CIRCUIT_OPEN → RECOVERY: After 30s cooldown
- RECOVERY → NORMAL: 2 consecutive successes
- RECOVERY → CIRCUIT_OPEN: Probe fails
"""
OBSERVATION_WINDOW_S = 60 # Look at last 60s of outcomes
MAX_OUTCOMES = 200 # Keep last 200 outcomes max
CIRCUIT_OPEN_DURATION_S = 30 # Time in CIRCUIT_OPEN before recovery
SLIGHT_THRESHOLD = 0.05 # > 5% error rate = slight pressure
HIGH_THRESHOLD = 0.20 # > 20% error rate = high pressure
CIRCUIT_THRESHOLD = 0.50 # > 50% error rate = circuit open
RECOVERY_SUCCESSES = 2 # Successes needed to exit RECOVERY
NORMAL_SUCCESSES = 3 # Successes to go from SLIGHT → NORMAL
def __init__(self, redis_client):
self.redis = redis_client
# Per-model outcome tracking
self._outcomes: dict[str, deque] = {}
self._pressure_levels: dict[str, PressureLevel] = {}
self._circuit_open_time: dict[str, float] = {}
self._consecutive_successes: dict[str, int] = {}
def execute(
self,
operation: Callable[[], Any],
model_id: str,
operation_name: str = "bedrock_call",
) -> Any:
"""
Execute an operation with adaptive retry behavior.
Checks current pressure level, selects retry config,
then executes with appropriate backoff.
"""
pressure = self._get_pressure_level(model_id)
# Circuit breaker check
if pressure == PressureLevel.CIRCUIT_OPEN:
elapsed = time.time() - self._circuit_open_time.get(model_id, 0)
if elapsed < self.CIRCUIT_OPEN_DURATION_S:
raise CircuitOpenError(
f"Circuit open for {model_id}. "
f"Retry in {self.CIRCUIT_OPEN_DURATION_S - elapsed:.0f}s."
)
# Transition to RECOVERY
self._set_pressure(model_id, PressureLevel.RECOVERY)
pressure = PressureLevel.RECOVERY
config = PRESSURE_CONFIGS.get(pressure, PRESSURE_CONFIGS[PressureLevel.NORMAL])
last_error = None
for attempt in range(config.max_retries + 1):
start = time.time()
try:
result = operation()
latency = (time.time() - start) * 1000
# Record success
self._record_outcome(
model_id, True, latency
)
if attempt > 0:
logger.info(
f"Adaptive retry succeeded: {operation_name} "
f"attempt {attempt + 1}, pressure={pressure.value}"
)
return result
except ClientError as e:
error_code = e.response["Error"]["Code"]
latency = (time.time() - start) * 1000
last_error = e
# Non-retryable errors
non_retryable = {
"ValidationException",
"AccessDeniedException",
"ResourceNotFoundException",
}
if error_code in non_retryable:
self._record_outcome(model_id, False, latency, error_code)
raise
# Record failure
self._record_outcome(model_id, False, latency, error_code)
if attempt < config.max_retries:
delay = self._calculate_adaptive_backoff(
attempt, config, pressure
)
logger.warning(
f"Adaptive retry: {operation_name} {error_code}, "
f"pressure={pressure.value}, "
f"attempt {attempt + 1}/{config.max_retries + 1}, "
f"backoff={delay:.2f}s"
)
time.sleep(delay)
# Re-check pressure after backoff
pressure = self._get_pressure_level(model_id)
config = PRESSURE_CONFIGS.get(
pressure, PRESSURE_CONFIGS[PressureLevel.NORMAL]
)
except Exception as e:
latency = (time.time() - start) * 1000
last_error = e
self._record_outcome(
model_id, False, latency, type(e).__name__
)
if attempt < config.max_retries:
delay = self._calculate_adaptive_backoff(
attempt, config, pressure
)
time.sleep(delay)
raise MaxRetriesExhausted(
f"All retries exhausted for {operation_name} under "
f"pressure={pressure.value}: {last_error}"
)
def _calculate_adaptive_backoff(
self,
attempt: int,
config: PressureConfig,
pressure: PressureLevel,
) -> float:
"""
Calculate backoff with adaptive timing.
Under higher pressure, the base delay increases and jitter range
widens to spread out retry attempts across the fleet.
"""
# Exponential component
exponential = config.base_delay * (2 ** attempt)
# Full jitter
jitter = random.uniform(0, config.jitter_range)
# Pressure multiplier — under HIGH pressure, add extra delay
pressure_multiplier = {
PressureLevel.NORMAL: 1.0,
PressureLevel.SLIGHT: 1.2,
PressureLevel.HIGH: 2.0,
PressureLevel.RECOVERY: 1.5,
}.get(pressure, 1.0)
delay = min(
config.max_delay,
(exponential + jitter) * pressure_multiplier,
)
return delay
def _record_outcome(
self,
model_id: str,
success: bool,
latency_ms: float,
error_code: Optional[str] = None,
) -> None:
"""Record an outcome and update pressure level."""
if model_id not in self._outcomes:
self._outcomes[model_id] = deque(maxlen=self.MAX_OUTCOMES)
self._consecutive_successes[model_id] = 0
outcome = RequestOutcome(
timestamp=time.time(),
success=success,
latency_ms=latency_ms,
error_code=error_code,
model_id=model_id,
)
self._outcomes[model_id].append(outcome)
# Track consecutive successes
if success:
self._consecutive_successes[model_id] = (
self._consecutive_successes.get(model_id, 0) + 1
)
else:
self._consecutive_successes[model_id] = 0
# Update pressure level
self._update_pressure(model_id)
# Sync to Redis for cross-task visibility
self._sync_pressure_to_redis(model_id)
def _update_pressure(self, model_id: str) -> None:
"""Update pressure level based on recent outcomes."""
current = self._pressure_levels.get(model_id, PressureLevel.NORMAL)
error_rate = self._get_recent_error_rate(model_id)
consecutive = self._consecutive_successes.get(model_id, 0)
new_level = current
if current == PressureLevel.NORMAL:
if error_rate > self.SLIGHT_THRESHOLD:
new_level = PressureLevel.SLIGHT
elif current == PressureLevel.SLIGHT:
if consecutive >= self.NORMAL_SUCCESSES:
new_level = PressureLevel.NORMAL
elif error_rate > self.HIGH_THRESHOLD:
new_level = PressureLevel.HIGH
elif current == PressureLevel.HIGH:
if error_rate < self.HIGH_THRESHOLD:
new_level = PressureLevel.SLIGHT
elif error_rate > self.CIRCUIT_THRESHOLD:
new_level = PressureLevel.CIRCUIT_OPEN
self._circuit_open_time[model_id] = time.time()
elif current == PressureLevel.RECOVERY:
if consecutive >= self.RECOVERY_SUCCESSES:
new_level = PressureLevel.NORMAL
elif error_rate > self.SLIGHT_THRESHOLD:
new_level = PressureLevel.CIRCUIT_OPEN
self._circuit_open_time[model_id] = time.time()
if new_level != current:
logger.info(
f"Pressure transition for {model_id}: "
f"{current.value} -> {new_level.value} "
f"(error_rate={error_rate:.1%}, consecutive_ok={consecutive})"
)
self._set_pressure(model_id, new_level)
def _get_recent_error_rate(self, model_id: str) -> float:
"""Calculate error rate in the observation window."""
outcomes = self._outcomes.get(model_id, deque())
if not outcomes:
return 0.0
cutoff = time.time() - self.OBSERVATION_WINDOW_S
recent = [o for o in outcomes if o.timestamp >= cutoff]
if not recent:
return 0.0
failures = sum(1 for o in recent if not o.success)
return failures / len(recent)
def _get_pressure_level(self, model_id: str) -> PressureLevel:
"""Get current pressure level, checking Redis for cross-task state."""
# Check local state first
local = self._pressure_levels.get(model_id, PressureLevel.NORMAL)
# Check Redis for fleet-wide state
try:
redis_raw = self.redis.get(f"pressure:{model_id}")
if redis_raw:
redis_level = PressureLevel(redis_raw.decode() if isinstance(redis_raw, bytes) else redis_raw)
# Use the more conservative (higher pressure) of the two
levels = [PressureLevel.NORMAL, PressureLevel.SLIGHT,
PressureLevel.HIGH, PressureLevel.CIRCUIT_OPEN]
local_idx = levels.index(local) if local in levels else 0
redis_idx = levels.index(redis_level) if redis_level in levels else 0
return levels[max(local_idx, redis_idx)]
except Exception:
pass
return local
def _set_pressure(self, model_id: str, level: PressureLevel) -> None:
"""Set pressure level locally."""
self._pressure_levels[model_id] = level
def _sync_pressure_to_redis(self, model_id: str) -> None:
"""Sync local pressure state to Redis for other Fargate tasks."""
try:
level = self._pressure_levels.get(model_id, PressureLevel.NORMAL)
self.redis.setex(
f"pressure:{model_id}",
self.OBSERVATION_WINDOW_S,
level.value,
)
except Exception:
pass
def get_status(self, model_id: str) -> dict:
"""Return current adaptive retry status for monitoring."""
pressure = self._get_pressure_level(model_id)
error_rate = self._get_recent_error_rate(model_id)
config = PRESSURE_CONFIGS.get(pressure, PRESSURE_CONFIGS[PressureLevel.NORMAL])
outcomes = self._outcomes.get(model_id, deque())
recent_latencies = [
o.latency_ms for o in outcomes
if o.success and o.timestamp >= time.time() - 60
]
avg_latency = (
sum(recent_latencies) / len(recent_latencies)
if recent_latencies else 0
)
return {
"model_id": model_id,
"pressure_level": pressure.value,
"error_rate": round(error_rate, 3),
"consecutive_successes": self._consecutive_successes.get(model_id, 0),
"base_delay": config.base_delay,
"max_retries": config.max_retries,
"avg_latency_ms": round(avg_latency),
"recent_outcomes": len(outcomes),
}
class CircuitOpenError(Exception):
"""Raised when circuit breaker is open."""
pass
class MaxRetriesExhausted(Exception):
"""Raised when all adaptive retries are exhausted."""
pass
Connection Pooling for FM APIs
Why Connection Pooling Matters
Each Bedrock API call from an ECS Fargate task requires a TLS handshake (~100-200ms in ap-northeast-1). At 12 requests/second average (1M messages/day), establishing a new connection per request wastes significant time and CPU.
Connection Pool Architecture
graph TB
subgraph Fargate["ECS Fargate Task (1 of N)"]
POOL[Connection Pool<br/>max_pool_connections=25<br/>max_retry=3]
TASK1[Request Handler 1]
TASK2[Request Handler 2]
TASK3[Request Handler 3]
TASKN[Request Handler N]
end
subgraph Bedrock["Amazon Bedrock Endpoint"]
EP[bedrock-runtime.ap-northeast-1.amazonaws.com<br/>TLS 1.3 / HTTP/1.1]
end
TASK1 --> POOL
TASK2 --> POOL
TASK3 --> POOL
TASKN --> POOL
POOL -->|Reuse connections| EP
style POOL fill:#28a745,color:#fff
Connection Pool Configuration
"""
MangaAssist Bedrock Connection Pool Configuration
Optimizes connection reuse for high-throughput FM API calls.
"""
from botocore.config import Config
import boto3
import urllib3
def create_bedrock_client_with_pool(
region: str = "ap-northeast-1",
max_pool_connections: int = 25,
max_retry_attempts: int = 3,
connect_timeout: int = 5,
read_timeout: int = 60,
) -> "botocore.client.BedrockRuntime":
"""
Create a Bedrock Runtime client with optimized connection pooling.
Connection pool sizing rationale for MangaAssist:
- Average load: ~12 req/s across all Fargate tasks
- With 4 tasks: ~3 req/s per task
- Average request duration: ~3s (streaming)
- Concurrent connections per task: ~9 (3 * 3s)
- Pool size 25 gives headroom for bursts (3x peak)
Pool connections are reused across requests, avoiding TLS
handshake overhead (~150ms per new connection).
"""
bedrock_config = Config(
region_name=region,
retries={
"max_attempts": max_retry_attempts,
"mode": "adaptive",
},
max_pool_connections=max_pool_connections,
connect_timeout=connect_timeout,
read_timeout=read_timeout,
# Enable TCP keep-alive to prevent idle connection drops
tcp_keepalive=True,
)
return boto3.client("bedrock-runtime", config=bedrock_config)
# Singleton pool for the Fargate task
_bedrock_client = None
def get_bedrock_client() -> "botocore.client.BedrockRuntime":
"""
Get or create the shared Bedrock client with connection pool.
Thread-safe — boto3 clients are safe to share across threads.
"""
global _bedrock_client
if _bedrock_client is None:
_bedrock_client = create_bedrock_client_with_pool()
return _bedrock_client
Request Queuing Under Load
Queue Architecture
When incoming request rate exceeds Bedrock's throughput (throttling scenario), a request queue prevents dropped messages while maintaining fairness.
graph TB
subgraph Incoming["Incoming Requests"]
R1[Request 1<br/>Priority: Normal]
R2[Request 2<br/>Priority: High]
R3[Request 3<br/>Priority: Normal]
end
subgraph Queue["Request Queue (Redis Sorted Set)"]
PQ[Priority Queue<br/>score = priority * 1000 + timestamp]
CAPACITY[Capacity: 1000 requests<br/>TTL: 30s per entry]
end
subgraph Workers["Bedrock Workers"]
W1[Worker 1<br/>Sonnet]
W2[Worker 2<br/>Sonnet]
W3[Worker 3<br/>Haiku]
W4[Worker 4<br/>Haiku]
end
R1 --> PQ
R2 --> PQ
R3 --> PQ
PQ --> W1
PQ --> W2
PQ --> W3
PQ --> W4
Request Queue Implementation
"""
MangaAssist Request Queue
Buffers FM requests during load spikes to prevent dropped messages.
"""
import json
import time
import logging
from dataclasses import dataclass
from typing import Optional
from enum import IntEnum
logger = logging.getLogger(__name__)
class RequestPriority(IntEnum):
"""Request priority levels. Lower number = higher priority."""
CRITICAL = 1 # System health checks
HIGH = 2 # Paying/premium users
NORMAL = 3 # Standard requests
LOW = 4 # Batch/background tasks
BULK = 5 # Analytics, non-interactive
@dataclass
class QueuedRequest:
"""A request waiting in the queue."""
request_id: str
session_id: str
priority: RequestPriority
payload: dict
enqueued_at: float
ttl_seconds: float = 30.0 # Drop if not processed within 30s
model_preference: str = "auto" # "sonnet", "haiku", or "auto"
@property
def is_expired(self) -> bool:
return (time.time() - self.enqueued_at) > self.ttl_seconds
@property
def wait_time_ms(self) -> float:
return (time.time() - self.enqueued_at) * 1000
def to_json(self) -> str:
return json.dumps({
"request_id": self.request_id,
"session_id": self.session_id,
"priority": self.priority.value,
"payload": self.payload,
"enqueued_at": self.enqueued_at,
"ttl_seconds": self.ttl_seconds,
"model_preference": self.model_preference,
}, ensure_ascii=False)
@classmethod
def from_json(cls, data: str) -> "QueuedRequest":
d = json.loads(data)
return cls(
request_id=d["request_id"],
session_id=d["session_id"],
priority=RequestPriority(d["priority"]),
payload=d["payload"],
enqueued_at=d["enqueued_at"],
ttl_seconds=d.get("ttl_seconds", 30.0),
model_preference=d.get("model_preference", "auto"),
)
class RequestQueue:
"""
Redis-backed priority queue for FM API requests.
When Bedrock is throttling or all workers are busy, requests are
queued rather than rejected. The queue ensures:
- Priority ordering (premium users first)
- TTL-based expiry (stale requests are dropped)
- Fair scheduling across sessions
- Backpressure signaling to API Gateway
The queue uses a Redis sorted set where:
- score = priority * 10_000_000_000 + timestamp
- This ensures priority ordering with FIFO within the same priority
Capacity is capped at 1000 requests. Beyond that, new requests
receive a 429 response with a Retry-After header.
"""
QUEUE_KEY = "fm_request_queue"
MAX_QUEUE_SIZE = 1000
CLEANUP_INTERVAL_S = 5.0 # Clean expired entries every 5s
def __init__(self, redis_client):
self.redis = redis_client
self._last_cleanup = 0.0
def enqueue(self, request: QueuedRequest) -> bool:
"""
Add a request to the queue.
Returns True if enqueued, False if queue is full.
"""
# Check queue capacity
current_size = self.redis.zcard(self.QUEUE_KEY)
if current_size >= self.MAX_QUEUE_SIZE:
logger.warning(
f"Queue full ({current_size}), rejecting request "
f"{request.request_id}"
)
return False
# Score: priority * large_constant + timestamp for FIFO within priority
score = request.priority.value * 10_000_000_000 + request.enqueued_at
self.redis.zadd(
self.QUEUE_KEY,
{request.to_json(): score},
)
logger.info(
f"Enqueued request {request.request_id} "
f"(priority={request.priority.name}, "
f"queue_size={current_size + 1})"
)
# Periodic cleanup
self._maybe_cleanup()
return True
def dequeue(self) -> Optional[QueuedRequest]:
"""
Pop the highest-priority, oldest request from the queue.
Skips expired requests.
"""
while True:
# Pop lowest score (highest priority, oldest)
results = self.redis.zpopmin(self.QUEUE_KEY, count=1)
if not results:
return None
data, score = results[0]
if isinstance(data, bytes):
data = data.decode("utf-8")
request = QueuedRequest.from_json(data)
# Skip expired
if request.is_expired:
logger.debug(
f"Skipping expired request {request.request_id} "
f"(waited {request.wait_time_ms:.0f}ms)"
)
continue
logger.info(
f"Dequeued request {request.request_id} "
f"(wait={request.wait_time_ms:.0f}ms, "
f"priority={request.priority.name})"
)
return request
def get_queue_stats(self) -> dict:
"""Return queue statistics for monitoring."""
size = self.redis.zcard(self.QUEUE_KEY)
# Peek at oldest entry
oldest_entries = self.redis.zrange(self.QUEUE_KEY, 0, 0, withscores=True)
oldest_wait_ms = 0
if oldest_entries:
data, score = oldest_entries[0]
if isinstance(data, bytes):
data = data.decode("utf-8")
req = QueuedRequest.from_json(data)
oldest_wait_ms = req.wait_time_ms
return {
"queue_size": size,
"max_capacity": self.MAX_QUEUE_SIZE,
"utilization_pct": round((size / self.MAX_QUEUE_SIZE) * 100, 1),
"oldest_wait_ms": round(oldest_wait_ms),
}
def _maybe_cleanup(self) -> None:
"""Periodically remove expired entries."""
now = time.time()
if now - self._last_cleanup < self.CLEANUP_INTERVAL_S:
return
self._last_cleanup = now
cutoff_score = RequestPriority.BULK.value * 10_000_000_000 + (now - 30)
removed = self.redis.zremrangebyscore(
self.QUEUE_KEY, "-inf", cutoff_score
)
if removed:
logger.info(f"Cleaned {removed} expired queue entries")
Key Takeaways
| # | Takeaway | MangaAssist Application |
|---|---|---|
| 1 | Chunk aggregation prevents WebSocket frame storms — Buffering Bedrock deltas for 100ms before sending reduces client frame rate from 40-80/s to 5-10/s, staying well within API Gateway limits and enabling smooth progressive rendering. | MangaAssist uses a ChunkBuffer with 100ms flush interval and 4KB size threshold. Compression ratio averages 8:1 (8 Bedrock deltas per client frame). |
| 2 | Token estimation must be language-aware — Japanese text consumes ~1 token per 1.4 characters versus ~1 per 4 for English. Using English ratios for JP text would underestimate by 3x, causing context window overflows. | TokenCounterMiddleware detects JP character ratio and uses weighted estimation. Accuracy tracking shows mean absolute error of ~15% for JP text. |
| 3 | Four-level token budgets prevent cost explosions at every scale — Per-request, per-session, daily per-user, and context window checks must all pass before any Bedrock invocation to prevent $45K/day runaway costs. | TokenCounterMiddleware validates all four levels synchronously before invoking Bedrock. Redis failures degrade gracefully (allow request, log warning). |
| 4 | Adaptive retry timing outperforms static backoff under variable load — Bedrock throttling severity changes throughout the day. Static 1s/2s/4s backoff is either too aggressive (peak) or too conservative (off-peak). | AdaptiveRetryClient tracks per-model pressure levels (NORMAL/SLIGHT/HIGH/CIRCUIT_OPEN) and adjusts base delay from 1s to 5s with proportional jitter. |
| 5 | Separate circuit breakers per model prevent cascading failures — A Sonnet timeout should not prevent Haiku queries from succeeding. Per-model pressure tracking isolates failures. | AdaptiveRetryClient maintains independent outcome histories and pressure levels for each model_id, synced to Redis for fleet-wide coordination. |
| 6 | Connection pooling saves 150ms per request — Reusing TLS connections to Bedrock endpoints avoids per-request handshake overhead. At 1M messages/day, this saves 41 hours of cumulative handshake time. | Singleton Bedrock client with max_pool_connections=25 and TCP keepalive. Each Fargate task maintains its own pool. |
| 7 | Request queuing preserves messages during throttling — Rather than rejecting requests when Bedrock is saturated, a priority queue buffers them with 30s TTL, ensuring premium users and critical requests are served first. | Redis sorted set queue with priority scoring (CRITICAL through BULK). Maximum 1000 entries with automatic expiry cleanup. |