Dynamic Routing and Metric-Based Model Selection
MangaAssist context: JP Manga store chatbot on AWS — Bedrock Claude 3 (Sonnet at $3/$15 per 1M tokens input/output, Haiku at $0.25/$1.25), OpenSearch Serverless (vector store), DynamoDB (sessions/products), ECS Fargate (orchestrator), API Gateway WebSocket, ElastiCache Redis. Target: useful answer in under 3 seconds, 1M messages/day scale.
Skill Mapping
| Field | Value |
|---|---|
| Certification | AWS AI Practitioner (AIP-C01) |
| Domain | 2 — Implementation and Integration of Foundation Models |
| Task | 2.4 — Design model deployment and inference strategies |
| Skill | 2.4.4 — Develop intelligent model routing systems to optimize model selection |
| Focus | Deep-dive into content complexity scoring, real-time metric collection, routing table management, A/B routing, shadow routing |
1. Content Complexity Scoring — Deep Dive
Content complexity scoring is the foundation of dynamic routing. A well-calibrated scorer is the difference between routing 60% of queries to cheap Haiku (saving thousands per day) and accidentally sending everything to expensive Sonnet.
1.1 Complexity Dimensions
mindmap
root((Complexity Scoring))
Lexical Features
Average Word Length
Vocabulary Richness TTR
Rare Word Frequency
Technical Terminology Count
Japanese Character Ratio
Structural Features
Sentence Count
Clause Depth
Question Count
Conditional Statements
Enumeration Lists
Semantic Features
Entity Density
Comparison Operators
Negation Chains
Temporal References
Causal Relationships
Contextual Features
Multi-Turn Depth
Topic Switches
Reference Resolution
Anaphora Chains
Implicit Knowledge Needs
Domain Features
Manga-Specific Terms
Cultural Context Requirements
Cross-Reference Needs
Creative vs Factual Balance
Spoiler Sensitivity
1.2 Multi-Dimensional Complexity Scorer
"""
complexity_scorer.py — MangaAssist Multi-Dimensional Complexity Scorer
Computes a fine-grained complexity score across lexical, structural,
semantic, contextual, and domain-specific dimensions. The composite
score drives model selection in the dynamic routing pipeline.
"""
import re
import math
import logging
from typing import Dict, Any, List, Optional, Tuple
from dataclasses import dataclass, field
from collections import Counter
logger = logging.getLogger(__name__)
@dataclass
class DimensionScore:
"""Score for a single complexity dimension."""
dimension: str
raw_score: float # Dimension-specific scale
normalized: float # 0-1 normalized
weight: float # Contribution weight
weighted_score: float # normalized * weight
features: Dict[str, Any] = field(default_factory=dict)
@dataclass
class ComplexityReport:
"""Full complexity scoring report."""
composite_score: float # 0-10 final score
confidence: float # 0-1 confidence in the score
dimensions: List[DimensionScore]
recommended_model: str
reasoning: str
scoring_time_ms: float
feature_vector: Dict[str, float] = field(default_factory=dict)
class ComplexityScorer:
"""
Multi-dimensional complexity scorer for MangaAssist queries.
Analyzes queries across 5 dimensions (lexical, structural, semantic,
contextual, domain) to produce a composite 0-10 complexity score.
Each dimension is independently scored, weighted, and normalized.
Calibration targets (from MangaAssist production analysis):
- "What time do you close?" → 1.2 (Trivial → Haiku)
- "Do you have One Piece volume 99?" → 2.8 (Simple → Haiku)
- "Compare Demon Slayer and JJK art styles" → 6.5 (Complex → Sonnet)
- "Analyze the thematic evolution of Berserk across the Golden Age arc
and how Miura's art style reflects Guts's psychological state" → 8.9 (Expert → Sonnet)
Usage:
scorer = ComplexityScorer()
report = scorer.score(query, conversation_history)
if report.composite_score >= 6.5:
use_sonnet()
else:
use_haiku()
"""
# Dimension weights (must sum to 1.0)
WEIGHTS = {
"lexical": 0.20,
"structural": 0.20,
"semantic": 0.25,
"contextual": 0.15,
"domain": 0.20,
}
# Japanese character pattern
JP_PATTERN = re.compile(r"[\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\uFF00-\uFFEF]")
# Technical manga/anime terms that signal domain complexity
DOMAIN_TERMS_BASIC = {
"manga", "anime", "volume", "chapter", "series", "author",
"publisher", "price", "edition", "paperback", "hardcover",
}
DOMAIN_TERMS_INTERMEDIATE = {
"shounen", "shoujo", "seinen", "josei", "isekai", "mecha",
"slice of life", "tankoubon", "mangaka", "light novel",
"visual novel", "doujinshi", "oneshot", "serialization",
}
DOMAIN_TERMS_ADVANCED = {
"paneling technique", "screentone", "decompression",
"narrative structure", "visual metaphor", "symbolic imagery",
"character arc", "thematic analysis", "art evolution",
"cultural significance", "genre deconstruction", "metanarrative",
}
# Complex linguistic patterns
COMPARISON_PATTERNS = [
r"\b(compare|contrast|versus|vs\.?|difference between|similarities)\b",
r"\b(better than|worse than|more .+ than|less .+ than)\b",
r"\b(pros and cons|advantages|disadvantages|trade-?offs)\b",
]
ANALYTICAL_PATTERNS = [
r"\b(analyze|explain why|what caused|how does .+ affect)\b",
r"\b(significance of|meaning behind|symbolism|represents)\b",
r"\b(evolution of|development of|progression|trajectory)\b",
]
CONDITIONAL_PATTERNS = [
r"\b(if .+ then|what if|assuming|suppose|given that)\b",
r"\b(would .+ if|could .+ when|should .+ because)\b",
]
CAUSAL_PATTERNS = [
r"\b(because|therefore|consequently|as a result|leads to)\b",
r"\b(caused by|due to|reason for|explains why|contributes to)\b",
]
def __init__(
self,
sonnet_threshold: float = 6.5,
weights: Optional[Dict[str, float]] = None,
):
self.sonnet_threshold = sonnet_threshold
self.weights = weights or self.WEIGHTS.copy()
# Validate weights sum to 1.0
total = sum(self.weights.values())
if abs(total - 1.0) > 0.01:
raise ValueError(f"Dimension weights must sum to 1.0, got {total}")
# Compile regex patterns once
self._comparison_re = [re.compile(p, re.IGNORECASE) for p in self.COMPARISON_PATTERNS]
self._analytical_re = [re.compile(p, re.IGNORECASE) for p in self.ANALYTICAL_PATTERNS]
self._conditional_re = [re.compile(p, re.IGNORECASE) for p in self.CONDITIONAL_PATTERNS]
self._causal_re = [re.compile(p, re.IGNORECASE) for p in self.CAUSAL_PATTERNS]
logger.info(
"ComplexityScorer initialized | threshold=%.1f | weights=%s",
sonnet_threshold,
self.weights,
)
def score(
self,
query: str,
conversation_history: Optional[List[Dict[str, str]]] = None,
) -> ComplexityReport:
"""
Compute a full complexity report for the given query.
Args:
query: The user's current message
conversation_history: Previous conversation turns
Returns:
ComplexityReport with composite score and per-dimension breakdown
"""
import time
start = time.monotonic()
history = conversation_history or []
# Score each dimension independently
lexical = self._score_lexical(query)
structural = self._score_structural(query)
semantic = self._score_semantic(query)
contextual = self._score_contextual(query, history)
domain = self._score_domain(query)
dimensions = [lexical, structural, semantic, contextual, domain]
# Compute composite score (weighted sum, scaled to 0-10)
composite = sum(d.weighted_score for d in dimensions) * 10.0
composite = max(0.0, min(10.0, composite))
# Compute confidence based on feature coverage
non_zero_dims = sum(1 for d in dimensions if d.raw_score > 0)
confidence = non_zero_dims / len(dimensions)
# Build feature vector for logging/analysis
feature_vector = {}
for dim in dimensions:
for key, value in dim.features.items():
feature_vector[f"{dim.dimension}_{key}"] = value
# Determine recommended model
if composite >= self.sonnet_threshold:
model = "anthropic.claude-3-sonnet-20240229-v1:0"
reasoning = f"Complexity {composite:.1f} >= threshold {self.sonnet_threshold} — Sonnet recommended"
else:
model = "anthropic.claude-3-haiku-20240307-v1:0"
reasoning = f"Complexity {composite:.1f} < threshold {self.sonnet_threshold} — Haiku sufficient"
elapsed = (time.monotonic() - start) * 1000
report = ComplexityReport(
composite_score=composite,
confidence=confidence,
dimensions=dimensions,
recommended_model=model,
reasoning=reasoning,
scoring_time_ms=elapsed,
feature_vector=feature_vector,
)
logger.info(
"Complexity scored | composite=%.1f | model=%s | confidence=%.2f | %.2fms",
composite,
model.split(".")[-1][:15],
confidence,
elapsed,
)
return report
def _score_lexical(self, query: str) -> DimensionScore:
"""
Score lexical complexity: vocabulary richness, word length, rare terms.
Features:
- avg_word_length: Average characters per word
- type_token_ratio: Unique words / total words
- long_word_ratio: Words > 8 chars / total words
- jp_char_ratio: Japanese characters / total characters
"""
words = re.findall(r"\b\w+\b", query.lower())
if not words:
return self._empty_dimension("lexical")
word_count = len(words)
unique_words = len(set(words))
avg_word_length = sum(len(w) for w in words) / word_count
type_token_ratio = unique_words / word_count if word_count > 0 else 0
long_word_ratio = sum(1 for w in words if len(w) > 8) / word_count
jp_chars = len(self.JP_PATTERN.findall(query))
total_chars = max(len(query), 1)
jp_char_ratio = jp_chars / total_chars
# Raw score: 0-5 scale
raw = 0.0
raw += min(avg_word_length / 8.0, 1.0) * 1.5 # Word length
raw += min(type_token_ratio, 1.0) * 1.0 # Vocabulary richness
raw += min(long_word_ratio * 5, 1.0) * 1.0 # Long words
raw += min(jp_char_ratio * 2, 1.0) * 0.5 # Japanese content
normalized = min(raw / 4.0, 1.0)
weight = self.weights["lexical"]
return DimensionScore(
dimension="lexical",
raw_score=raw,
normalized=normalized,
weight=weight,
weighted_score=normalized * weight,
features={
"avg_word_length": round(avg_word_length, 2),
"type_token_ratio": round(type_token_ratio, 3),
"long_word_ratio": round(long_word_ratio, 3),
"jp_char_ratio": round(jp_char_ratio, 3),
"word_count": word_count,
},
)
def _score_structural(self, query: str) -> DimensionScore:
"""
Score structural complexity: sentence count, question depth, clauses.
Features:
- sentence_count: Number of sentences
- question_count: Number of question marks
- clause_indicators: Subordinate clause markers
- enumeration: Lists or numbered items
"""
sentences = re.split(r"[.!?]+", query)
sentences = [s.strip() for s in sentences if s.strip()]
sentence_count = len(sentences)
question_count = query.count("?")
clause_markers = len(re.findall(
r"\b(which|that|who|whom|whose|where|when|while|although|because|since|unless|if)\b",
query, re.IGNORECASE,
))
enum_markers = len(re.findall(r"(\d+\.|[-*]|\bfirst\b|\bsecond\b|\bthird\b)", query, re.IGNORECASE))
comma_count = query.count(",")
raw = 0.0
raw += min(sentence_count / 4.0, 1.0) * 1.0 # Multiple sentences
raw += min(question_count / 3.0, 1.0) * 1.0 # Multiple questions
raw += min(clause_markers / 3.0, 1.0) * 1.0 # Clause depth
raw += min(enum_markers / 3.0, 1.0) * 0.5 # Enumerations
raw += min(comma_count / 5.0, 1.0) * 0.5 # Structural complexity proxy
normalized = min(raw / 4.0, 1.0)
weight = self.weights["structural"]
return DimensionScore(
dimension="structural",
raw_score=raw,
normalized=normalized,
weight=weight,
weighted_score=normalized * weight,
features={
"sentence_count": sentence_count,
"question_count": question_count,
"clause_markers": clause_markers,
"enum_markers": enum_markers,
"comma_count": comma_count,
},
)
def _score_semantic(self, query: str) -> DimensionScore:
"""
Score semantic complexity: analytical depth, causal reasoning, comparisons.
Features:
- comparison_count: Comparative language patterns
- analytical_count: Analytical request patterns
- conditional_count: Hypothetical/conditional structures
- causal_count: Causal relationship indicators
- entity_density: Named entities per word
"""
comparison_count = sum(1 for p in self._comparison_re if p.search(query))
analytical_count = sum(1 for p in self._analytical_re if p.search(query))
conditional_count = sum(1 for p in self._conditional_re if p.search(query))
causal_count = sum(1 for p in self._causal_re if p.search(query))
# Entity density (capitalized words and quoted terms)
entities = re.findall(r"\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b", query)
quoted = re.findall(r'"[^"]+"|\'[^\']+\'', query)
entity_count = len(entities) + len(quoted)
word_count = max(len(query.split()), 1)
entity_density = entity_count / word_count
raw = 0.0
raw += min(comparison_count, 2) * 1.0 # Comparisons
raw += min(analytical_count, 2) * 1.5 # Analytical depth
raw += min(conditional_count, 2) * 0.8 # Conditionals
raw += min(causal_count, 2) * 0.7 # Causal reasoning
raw += min(entity_density * 10, 1.0) * 0.5 # Entity density
normalized = min(raw / 4.5, 1.0)
weight = self.weights["semantic"]
return DimensionScore(
dimension="semantic",
raw_score=raw,
normalized=normalized,
weight=weight,
weighted_score=normalized * weight,
features={
"comparison_count": comparison_count,
"analytical_count": analytical_count,
"conditional_count": conditional_count,
"causal_count": causal_count,
"entity_count": entity_count,
"entity_density": round(entity_density, 3),
},
)
def _score_contextual(
self, query: str, history: List[Dict[str, str]]
) -> DimensionScore:
"""
Score contextual complexity: multi-turn depth, topic continuity, references.
Features:
- turn_depth: Number of previous conversation turns
- has_pronouns: References requiring context resolution
- topic_shift: Whether the query shifts topic from history
- reference_count: Explicit back-references ("as I mentioned", "the one you said")
"""
turn_depth = len(history)
# Pronoun-based reference resolution
pronouns = len(re.findall(
r"\b(it|they|them|this|that|these|those|its|their|the one)\b",
query, re.IGNORECASE,
))
has_pronouns = pronouns > 0
# Explicit back-references
back_refs = len(re.findall(
r"\b(you (mentioned|said|told)|as I (said|asked)|earlier|previously|the same|like before)\b",
query, re.IGNORECASE,
))
# Topic shift detection (simple: check if query shares keywords with last turn)
topic_shift = False
if history:
last_msg = history[-1].get("content", "").lower()
query_words = set(query.lower().split())
last_words = set(last_msg.split())
overlap = len(query_words & last_words)
total = max(len(query_words), 1)
if overlap / total < 0.1:
topic_shift = True
raw = 0.0
raw += min(turn_depth / 5.0, 1.0) * 1.5 # Multi-turn depth
raw += (1.0 if has_pronouns else 0.0) * 0.5 # Pronoun resolution
raw += min(back_refs, 2) * 0.8 # Explicit references
raw += (0.5 if topic_shift else 0.0) # Topic continuity
normalized = min(raw / 3.3, 1.0)
weight = self.weights["contextual"]
return DimensionScore(
dimension="contextual",
raw_score=raw,
normalized=normalized,
weight=weight,
weighted_score=normalized * weight,
features={
"turn_depth": turn_depth,
"pronoun_count": pronouns,
"has_pronouns": has_pronouns,
"back_reference_count": back_refs,
"topic_shift": topic_shift,
},
)
def _score_domain(self, query: str) -> DimensionScore:
"""
Score domain-specific complexity: manga/anime knowledge requirements.
Features:
- basic_term_count: Common manga terms found
- intermediate_term_count: Niche genre/format terms
- advanced_term_count: Critical analysis terminology
- cultural_context_needed: Whether cultural knowledge is required
"""
q_lower = query.lower()
basic_count = sum(1 for t in self.DOMAIN_TERMS_BASIC if t in q_lower)
intermediate_count = sum(1 for t in self.DOMAIN_TERMS_INTERMEDIATE if t in q_lower)
advanced_count = sum(1 for t in self.DOMAIN_TERMS_ADVANCED if t in q_lower)
# Cultural context heuristic
cultural_indicators = [
"japanese culture", "cultural significance", "honorifics",
"japanese humor", "cultural reference", "tradition",
"ceremony", "social hierarchy", "senpai", "sensei",
]
cultural_needed = any(ind in q_lower for ind in cultural_indicators)
raw = 0.0
raw += min(basic_count / 3.0, 1.0) * 0.5 # Basic domain terms
raw += min(intermediate_count / 2.0, 1.0) * 1.0 # Intermediate terms
raw += min(advanced_count, 2) * 1.5 # Advanced analysis terms
raw += (1.0 if cultural_needed else 0.0) * 1.0 # Cultural context
normalized = min(raw / 4.0, 1.0)
weight = self.weights["domain"]
return DimensionScore(
dimension="domain",
raw_score=raw,
normalized=normalized,
weight=weight,
weighted_score=normalized * weight,
features={
"basic_term_count": basic_count,
"intermediate_term_count": intermediate_count,
"advanced_term_count": advanced_count,
"cultural_context_needed": cultural_needed,
},
)
def _empty_dimension(self, name: str) -> DimensionScore:
"""Return a zero-score dimension for empty input."""
return DimensionScore(
dimension=name,
raw_score=0.0,
normalized=0.0,
weight=self.weights.get(name, 0.0),
weighted_score=0.0,
)
2. Real-Time Metric Collection
2.1 Metric Collection Pipeline Architecture
flowchart TB
subgraph Sources["Metric Emission Points"]
direction TB
INV[Bedrock InvokeModel<br/>Latency + Token Count]
RESP[Response Evaluator<br/>Quality Score]
USER[User Feedback<br/>Thumbs Up/Down]
ERR[Error Handler<br/>Error Type + Rate]
COST[Cost Calculator<br/>Per-Query USD]
end
subgraph Pipeline["Collection Pipeline"]
direction TB
EMT[Metric Emitter<br/>Async Non-Blocking]
KDS[Kinesis Data Stream<br/>2 Shards]
AGG[Lambda Aggregator<br/>1-min Windows]
end
subgraph RealTime["Real-Time Layer"]
direction TB
RED[ElastiCache Redis<br/>Sorted Sets + Hashes]
PUB[Redis Pub/Sub<br/>Threshold Alerts]
end
subgraph Durable["Durable Layer"]
direction TB
DDB[DynamoDB<br/>Time-Series Metrics]
S3[S3 Parquet<br/>Historical Archive]
end
subgraph Consumers["Metric Consumers"]
direction TB
SEL[MetricBasedSelector<br/>Model Ranking]
DASH[CloudWatch Dashboard<br/>Monitoring]
ALERT[SNS Alerts<br/>Anomaly Detection]
end
INV --> EMT
RESP --> EMT
USER --> EMT
ERR --> EMT
COST --> EMT
EMT --> KDS
KDS --> AGG
AGG --> RED
AGG --> DDB
AGG --> S3
RED --> SEL
RED --> PUB
PUB --> ALERT
DDB --> DASH
2.2 Metric Data Points
| Metric | Source | Granularity | Storage | TTL |
|---|---|---|---|---|
invocation_latency_ms |
Bedrock response | Per-request | Redis sorted set | 1 hour |
input_token_count |
Bedrock response | Per-request | Redis counter | 1 hour |
output_token_count |
Bedrock response | Per-request | Redis counter | 1 hour |
error_count |
Error handler | Per-request | Redis counter | 1 hour |
throttle_count |
Bedrock 429 response | Per-request | Redis counter | 1 hour |
quality_score |
Response evaluator | Per-request | DynamoDB | 30 days |
user_satisfaction |
Feedback button | Per-session | DynamoDB | 90 days |
cost_usd |
Cost calculator | Per-request | DynamoDB | 90 days |
p50_latency_ms |
Lambda aggregator | 1-minute window | Redis hash | 1 hour |
p95_latency_ms |
Lambda aggregator | 1-minute window | Redis hash | 1 hour |
p99_latency_ms |
Lambda aggregator | 1-minute window | Redis hash | 1 hour |
queries_per_minute |
Lambda aggregator | 1-minute window | Redis hash | 1 hour |
2.3 Metric Emitter Implementation
"""
metric_emitter.py — MangaAssist Async Metric Emitter
Non-blocking metric emission for the routing pipeline.
Pushes raw metric data points to Kinesis for aggregation
and directly to Redis for real-time consumption.
"""
import json
import time
import logging
import asyncio
from typing import Dict, Any, Optional, List
from dataclasses import dataclass, asdict
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
import boto3
logger = logging.getLogger(__name__)
@dataclass
class MetricDataPoint:
"""A single metric observation."""
metric_name: str
value: float
model_id: str
timestamp_ms: int
dimensions: Dict[str, str]
unit: str = "None" # "Milliseconds", "Count", "USD", "Percent"
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
@dataclass
class InvocationMetrics:
"""Complete metrics from a single model invocation."""
model_id: str
latency_ms: float
input_tokens: int
output_tokens: int
status: str # "success", "error", "throttled", "timeout"
error_type: Optional[str] = None
quality_score: Optional[float] = None
cost_usd: Optional[float] = None
route_path: str = "unknown"
query_type: str = "unknown"
complexity_score: float = 0.0
class MetricEmitter:
"""
Asynchronous metric emitter for MangaAssist routing.
Emits metrics in a non-blocking fashion to avoid adding latency
to the hot path. Uses fire-and-forget pattern with local buffering
and batch flushing.
Architecture:
- Metrics are buffered in memory
- Background thread flushes to Kinesis every 5 seconds
- Critical metrics (errors, throttles) also go directly to Redis
- CloudWatch embedded metric format for automatic dashboard creation
Usage:
emitter = MetricEmitter(
kinesis_stream="MangaAssist-Metrics",
redis_client=redis_conn,
)
emitter.emit_invocation(InvocationMetrics(
model_id="anthropic.claude-3-haiku-20240307-v1:0",
latency_ms=450.0,
input_tokens=200,
output_tokens=150,
status="success",
))
"""
BATCH_SIZE = 50
FLUSH_INTERVAL_SECONDS = 5.0
# Cost constants
COSTS = {
"anthropic.claude-3-haiku-20240307-v1:0": {"input": 0.25, "output": 1.25},
"anthropic.claude-3-sonnet-20240229-v1:0": {"input": 3.00, "output": 15.00},
}
def __init__(
self,
kinesis_stream: str = "MangaAssist-Metrics",
redis_client=None,
region: str = "ap-northeast-1",
enable_cloudwatch: bool = True,
):
self.stream_name = kinesis_stream
self.redis = redis_client
self.enable_cw = enable_cloudwatch
self.kinesis = boto3.client("kinesis", region_name=region)
if enable_cloudwatch:
self.cloudwatch = boto3.client("cloudwatch", region_name=region)
self._buffer: List[MetricDataPoint] = []
self._buffer_lock = asyncio.Lock() if asyncio.get_event_loop().is_running() else None
self._executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="metric-emit")
self._total_emitted = 0
self._total_errors = 0
logger.info(
"MetricEmitter initialized | stream=%s | cw=%s",
kinesis_stream,
enable_cloudwatch,
)
def emit_invocation(self, metrics: InvocationMetrics) -> None:
"""
Emit all metrics from a single model invocation.
Non-blocking: returns immediately and processes asynchronously.
"""
now_ms = int(time.time() * 1000)
dimensions = {
"model_id": metrics.model_id,
"route_path": metrics.route_path,
"query_type": metrics.query_type,
"status": metrics.status,
}
points = [
MetricDataPoint("invocation_latency_ms", metrics.latency_ms, metrics.model_id, now_ms, dimensions, "Milliseconds"),
MetricDataPoint("input_token_count", float(metrics.input_tokens), metrics.model_id, now_ms, dimensions, "Count"),
MetricDataPoint("output_token_count", float(metrics.output_tokens), metrics.model_id, now_ms, dimensions, "Count"),
]
# Compute cost if not provided
cost = metrics.cost_usd
if cost is None:
model_costs = self.COSTS.get(metrics.model_id)
if model_costs:
cost = (
metrics.input_tokens * model_costs["input"]
+ metrics.output_tokens * model_costs["output"]
) / 1_000_000
if cost is not None:
points.append(MetricDataPoint("cost_usd", cost, metrics.model_id, now_ms, dimensions, "USD"))
if metrics.quality_score is not None:
points.append(MetricDataPoint("quality_score", metrics.quality_score, metrics.model_id, now_ms, dimensions, "None"))
if metrics.status == "error":
points.append(MetricDataPoint("error_count", 1.0, metrics.model_id, now_ms, dimensions, "Count"))
if metrics.status == "throttled":
points.append(MetricDataPoint("throttle_count", 1.0, metrics.model_id, now_ms, dimensions, "Count"))
# Fire and forget
self._executor.submit(self._process_points, points)
# Critical metrics go directly to Redis for real-time routing
if self.redis and metrics.status in ("error", "throttled"):
self._executor.submit(self._emit_critical_to_redis, metrics)
def _process_points(self, points: List[MetricDataPoint]) -> None:
"""Process metric data points (runs in background thread)."""
try:
# Send to Kinesis
records = [
{
"Data": json.dumps(p.to_dict()).encode("utf-8"),
"PartitionKey": p.model_id,
}
for p in points
]
if records:
self.kinesis.put_records(StreamName=self.stream_name, Records=records)
self._total_emitted += len(records)
# Update Redis real-time counters
if self.redis:
self._update_redis_counters(points)
except Exception as e:
self._total_errors += 1
logger.warning("Metric emission failed | error=%s | points=%d", e, len(points))
def _update_redis_counters(self, points: List[MetricDataPoint]) -> None:
"""Update Redis real-time counters and sorted sets."""
try:
pipe = self.redis.pipeline()
for p in points:
minute_key = f"metrics_min:{p.model_id}:{int(p.timestamp_ms / 60000)}"
if p.metric_name == "invocation_latency_ms":
# Sorted set for percentile calculation
pipe.zadd(
f"latency_samples:{p.model_id}",
{str(p.timestamp_ms): p.value},
)
pipe.expire(f"latency_samples:{p.model_id}", 3600)
# Increment per-minute counters
pipe.hincrby(minute_key, p.metric_name + "_count", 1)
pipe.hincrbyfloat(minute_key, p.metric_name + "_sum", p.value)
pipe.expire(minute_key, 3600)
pipe.execute()
except Exception as e:
logger.warning("Redis counter update failed | error=%s", e)
def _emit_critical_to_redis(self, metrics: InvocationMetrics) -> None:
"""Push critical events directly to Redis for immediate routing impact."""
try:
event = {
"model_id": metrics.model_id,
"status": metrics.status,
"error_type": metrics.error_type,
"timestamp": int(time.time() * 1000),
}
self.redis.publish("routing:critical_events", json.dumps(event))
# Increment error/throttle counters
error_key = f"errors:{metrics.model_id}:{int(time.time()) // 60}"
self.redis.incr(error_key)
self.redis.expire(error_key, 300)
except Exception as e:
logger.warning("Critical metric emission failed | error=%s", e)
def get_stats(self) -> Dict[str, int]:
"""Return emitter statistics for monitoring."""
return {
"total_emitted": self._total_emitted,
"total_errors": self._total_errors,
"buffer_size": len(self._buffer),
}
3. Routing Table Management
3.1 Routing Table Lifecycle
stateDiagram-v2
[*] --> Proposed: New route submitted
Proposed --> Validated: Schema + policy check
Validated --> Staged: Written to staging table
Staged --> Testing: Canary traffic enabled
Testing --> Active: Metrics pass thresholds
Testing --> Rolled_Back: Metrics fail
Active --> Deprecated: Newer route promoted
Deprecated --> Archived: TTL expired
Rolled_Back --> Proposed: Fix and resubmit
Archived --> [*]
note right of Testing
Shadow route with 5% traffic
Monitor for 30 minutes
Compare against baseline
end note
3.2 Routing Table Manager
"""
routing_table_manager.py — MangaAssist Routing Table Lifecycle Manager
Manages the full lifecycle of routing configurations: creation, validation,
staged rollout, testing, activation, deprecation, and archival.
"""
import json
import time
import logging
import hashlib
from typing import Dict, Any, List, Optional, Tuple
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
import boto3
from boto3.dynamodb.conditions import Key, Attr
logger = logging.getLogger(__name__)
class RouteStatus(Enum):
"""Lifecycle status of a routing configuration."""
PROPOSED = "proposed"
VALIDATED = "validated"
STAGED = "staged"
TESTING = "testing"
ACTIVE = "active"
DEPRECATED = "deprecated"
ROLLED_BACK = "rolled_back"
ARCHIVED = "archived"
@dataclass
class RouteEntry:
"""A single routing table entry."""
route_id: str
intent_category: str
sub_intent: str
model_id: str
fallback_model_id: str
max_tokens: int
temperature: float
status: RouteStatus
version: int
created_at: str
updated_at: str
created_by: str
config_hash: str = ""
test_traffic_pct: float = 0.0
test_start_time: Optional[str] = None
test_metrics: Dict[str, Any] = field(default_factory=dict)
metadata: Dict[str, Any] = field(default_factory=dict)
def compute_hash(self) -> str:
"""Compute a content hash for change detection."""
content = f"{self.intent_category}:{self.sub_intent}:{self.model_id}:{self.max_tokens}:{self.temperature}"
return hashlib.sha256(content.encode()).hexdigest()[:16]
@dataclass
class ValidationResult:
"""Result of route validation."""
valid: bool
errors: List[str]
warnings: List[str]
class RoutingTableManager:
"""
Routing table lifecycle manager for MangaAssist.
Manages the full lifecycle of routing configurations with safe
rollout procedures, validation, and automatic rollback support.
Features:
- Schema validation for new route entries
- Staged rollout with configurable canary traffic
- Automatic metric-based promotion or rollback
- Version history with audit trail
- Bulk import/export for infrastructure-as-code
Usage:
manager = RoutingTableManager(
config_table="MangaAssist-RoutingConfig",
history_table="MangaAssist-RouteHistory",
)
entry = RouteEntry(...)
result = manager.propose_route(entry)
if result.valid:
manager.stage_route(entry.route_id)
manager.start_test(entry.route_id, traffic_pct=5.0)
"""
VALID_MODELS = [
"anthropic.claude-3-haiku-20240307-v1:0",
"anthropic.claude-3-sonnet-20240229-v1:0",
]
MAX_TOKENS_LIMITS = {
"anthropic.claude-3-haiku-20240307-v1:0": 4096,
"anthropic.claude-3-sonnet-20240229-v1:0": 4096,
}
def __init__(
self,
config_table: str = "MangaAssist-RoutingConfig",
history_table: str = "MangaAssist-RouteHistory",
region: str = "ap-northeast-1",
redis_client=None,
):
self.dynamodb = boto3.resource("dynamodb", region_name=region)
self.config_table = self.dynamodb.Table(config_table)
self.history_table = self.dynamodb.Table(history_table)
self.redis = redis_client
logger.info(
"RoutingTableManager initialized | config=%s | history=%s",
config_table,
history_table,
)
def propose_route(self, entry: RouteEntry) -> ValidationResult:
"""
Propose a new routing entry. Validates and stores as PROPOSED.
Args:
entry: The route entry to propose
Returns:
ValidationResult indicating if proposal was accepted
"""
# Validate
result = self._validate_entry(entry)
if not result.valid:
logger.warning(
"Route proposal rejected | id=%s | errors=%s",
entry.route_id,
result.errors,
)
return result
# Store as proposed
entry.status = RouteStatus.PROPOSED
entry.config_hash = entry.compute_hash()
entry.created_at = datetime.utcnow().isoformat()
entry.updated_at = entry.created_at
self._store_entry(entry)
self._record_history(entry, "proposed", f"New route proposed by {entry.created_by}")
logger.info("Route proposed | id=%s | intent=%s:%s", entry.route_id, entry.intent_category, entry.sub_intent)
return result
def validate_route(self, route_id: str) -> ValidationResult:
"""Transition a proposed route to validated status."""
entry = self._get_entry(route_id)
if not entry:
return ValidationResult(valid=False, errors=["Route not found"], warnings=[])
if entry.status != RouteStatus.PROPOSED:
return ValidationResult(valid=False, errors=[f"Route in {entry.status.value} state, expected proposed"], warnings=[])
result = self._validate_entry(entry)
if result.valid:
entry.status = RouteStatus.VALIDATED
entry.updated_at = datetime.utcnow().isoformat()
self._store_entry(entry)
self._record_history(entry, "validated", "Passed validation checks")
return result
def stage_route(self, route_id: str) -> bool:
"""Write a validated route to the staging configuration."""
entry = self._get_entry(route_id)
if not entry or entry.status != RouteStatus.VALIDATED:
logger.error("Cannot stage route | id=%s | status=%s", route_id, entry.status.value if entry else "not_found")
return False
entry.status = RouteStatus.STAGED
entry.updated_at = datetime.utcnow().isoformat()
self._store_entry(entry)
self._record_history(entry, "staged", "Route staged for testing")
logger.info("Route staged | id=%s", route_id)
return True
def start_test(
self,
route_id: str,
traffic_pct: float = 5.0,
test_duration_minutes: int = 30,
) -> bool:
"""
Begin canary testing for a staged route.
Routes a percentage of matching traffic through the new configuration
while monitoring key metrics against the existing active route.
Args:
route_id: The staged route to test
traffic_pct: Percentage of traffic for the canary (1-50%)
test_duration_minutes: How long to run the test
Returns:
True if test was started successfully
"""
entry = self._get_entry(route_id)
if not entry or entry.status != RouteStatus.STAGED:
logger.error("Cannot start test | id=%s", route_id)
return False
if traffic_pct < 1 or traffic_pct > 50:
logger.error("Traffic percentage must be between 1-50%%, got %.1f%%", traffic_pct)
return False
entry.status = RouteStatus.TESTING
entry.test_traffic_pct = traffic_pct
entry.test_start_time = datetime.utcnow().isoformat()
entry.updated_at = datetime.utcnow().isoformat()
entry.metadata["test_duration_minutes"] = test_duration_minutes
entry.metadata["test_end_time"] = (
datetime.utcnow() + timedelta(minutes=test_duration_minutes)
).isoformat()
self._store_entry(entry)
self._record_history(
entry,
"testing_started",
f"Canary test started: {traffic_pct}% traffic for {test_duration_minutes}min",
)
# Publish test start event to Redis for real-time routing awareness
if self.redis:
self.redis.publish("routing:test_events", json.dumps({
"event": "test_started",
"route_id": route_id,
"intent": entry.intent_category,
"sub_intent": entry.sub_intent,
"traffic_pct": traffic_pct,
"model_id": entry.model_id,
}))
logger.info(
"Route test started | id=%s | traffic=%.1f%% | duration=%dmin",
route_id,
traffic_pct,
test_duration_minutes,
)
return True
def evaluate_test(self, route_id: str) -> Dict[str, Any]:
"""
Evaluate test results and determine whether to promote or rollback.
Criteria:
- P95 latency within 20% of baseline
- Error rate below 2%
- Quality score within 10% of baseline
- Cost per query within expected bounds
Returns:
Evaluation result with recommendation
"""
entry = self._get_entry(route_id)
if not entry or entry.status != RouteStatus.TESTING:
return {"status": "error", "reason": "Route not in testing state"}
# Collect test metrics from Redis
test_metrics = self._collect_test_metrics(entry)
baseline_metrics = self._collect_baseline_metrics(entry)
evaluation = {
"route_id": route_id,
"test_duration_minutes": entry.metadata.get("test_duration_minutes", 0),
"test_queries": test_metrics.get("query_count", 0),
"checks": [],
}
passed = True
# P95 latency check
if baseline_metrics.get("p95_latency_ms", 0) > 0:
latency_ratio = test_metrics.get("p95_latency_ms", 0) / baseline_metrics["p95_latency_ms"]
check = {
"name": "p95_latency",
"test_value": test_metrics.get("p95_latency_ms", 0),
"baseline_value": baseline_metrics["p95_latency_ms"],
"threshold": "within 20%",
"passed": latency_ratio <= 1.2,
}
evaluation["checks"].append(check)
if not check["passed"]:
passed = False
# Error rate check
error_rate = test_metrics.get("error_rate_pct", 0)
check = {
"name": "error_rate",
"test_value": error_rate,
"threshold": "< 2%",
"passed": error_rate < 2.0,
}
evaluation["checks"].append(check)
if not check["passed"]:
passed = False
# Quality score check
if baseline_metrics.get("quality_score", 0) > 0:
quality_ratio = test_metrics.get("quality_score", 0) / baseline_metrics["quality_score"]
check = {
"name": "quality_score",
"test_value": test_metrics.get("quality_score", 0),
"baseline_value": baseline_metrics["quality_score"],
"threshold": "within 10%",
"passed": quality_ratio >= 0.9,
}
evaluation["checks"].append(check)
if not check["passed"]:
passed = False
evaluation["passed"] = passed
evaluation["recommendation"] = "promote" if passed else "rollback"
entry.test_metrics = evaluation
entry.updated_at = datetime.utcnow().isoformat()
self._store_entry(entry)
logger.info(
"Test evaluated | id=%s | passed=%s | recommendation=%s",
route_id,
passed,
evaluation["recommendation"],
)
return evaluation
def promote_route(self, route_id: str) -> bool:
"""Promote a tested route to active status, deprecating the current active route."""
entry = self._get_entry(route_id)
if not entry or entry.status != RouteStatus.TESTING:
return False
# Deprecate current active route for this intent
self._deprecate_active_route(entry.intent_category, entry.sub_intent)
entry.status = RouteStatus.ACTIVE
entry.test_traffic_pct = 0.0
entry.updated_at = datetime.utcnow().isoformat()
self._store_entry(entry)
self._record_history(entry, "promoted", "Route promoted to active after successful test")
# Invalidate caches
self._invalidate_caches(entry)
logger.info("Route promoted | id=%s | intent=%s:%s", route_id, entry.intent_category, entry.sub_intent)
return True
def rollback_route(self, route_id: str, reason: str = "Test failed") -> bool:
"""Rollback a testing route."""
entry = self._get_entry(route_id)
if not entry:
return False
entry.status = RouteStatus.ROLLED_BACK
entry.test_traffic_pct = 0.0
entry.updated_at = datetime.utcnow().isoformat()
self._store_entry(entry)
self._record_history(entry, "rolled_back", reason)
logger.info("Route rolled back | id=%s | reason=%s", route_id, reason)
return True
def get_active_routes(self) -> List[RouteEntry]:
"""Retrieve all currently active routes."""
try:
response = self.config_table.scan(
FilterExpression=Attr("status").eq("active"),
)
entries = []
for item in response.get("Items", []):
entries.append(self._item_to_entry(item))
return entries
except Exception as e:
logger.error("Failed to get active routes | error=%s", e)
return []
def export_config(self) -> Dict[str, Any]:
"""Export all active routes as JSON for infrastructure-as-code."""
routes = self.get_active_routes()
return {
"version": "1.0",
"exported_at": datetime.utcnow().isoformat(),
"routes": [
{
"intent_category": r.intent_category,
"sub_intent": r.sub_intent,
"model_id": r.model_id,
"fallback_model_id": r.fallback_model_id,
"max_tokens": r.max_tokens,
"temperature": r.temperature,
}
for r in routes
],
}
# --- Private helpers ---
def _validate_entry(self, entry: RouteEntry) -> ValidationResult:
"""Validate a route entry against schema and policy rules."""
errors: List[str] = []
warnings: List[str] = []
if not entry.intent_category:
errors.append("intent_category is required")
if entry.model_id not in self.VALID_MODELS:
errors.append(f"Invalid model_id: {entry.model_id}")
if entry.fallback_model_id not in self.VALID_MODELS:
errors.append(f"Invalid fallback_model_id: {entry.fallback_model_id}")
if entry.model_id == entry.fallback_model_id:
warnings.append("model_id and fallback_model_id are the same — no failover diversity")
max_limit = self.MAX_TOKENS_LIMITS.get(entry.model_id, 4096)
if entry.max_tokens > max_limit:
errors.append(f"max_tokens {entry.max_tokens} exceeds limit {max_limit} for {entry.model_id}")
if entry.temperature < 0 or entry.temperature > 1:
errors.append(f"temperature must be 0-1, got {entry.temperature}")
return ValidationResult(valid=len(errors) == 0, errors=errors, warnings=warnings)
def _store_entry(self, entry: RouteEntry) -> None:
"""Store a route entry in DynamoDB."""
self.config_table.put_item(Item={
"PK": f"ROUTE#{entry.intent_category}",
"SK": f"VERSION#{entry.version}#{entry.sub_intent}",
"route_id": entry.route_id,
"intent_category": entry.intent_category,
"sub_intent": entry.sub_intent,
"model_id": entry.model_id,
"fallback_model_id": entry.fallback_model_id,
"max_tokens": entry.max_tokens,
"temperature": str(entry.temperature),
"status": entry.status.value,
"version": entry.version,
"config_hash": entry.config_hash,
"test_traffic_pct": str(entry.test_traffic_pct),
"created_at": entry.created_at,
"updated_at": entry.updated_at,
"created_by": entry.created_by,
"metadata": json.dumps(entry.metadata),
"test_metrics": json.dumps(entry.test_metrics),
})
def _get_entry(self, route_id: str) -> Optional[RouteEntry]:
"""Retrieve a route entry by ID."""
try:
response = self.config_table.scan(
FilterExpression=Attr("route_id").eq(route_id),
Limit=1,
)
items = response.get("Items", [])
if items:
return self._item_to_entry(items[0])
except Exception as e:
logger.error("Failed to get route | id=%s | error=%s", route_id, e)
return None
def _item_to_entry(self, item: Dict[str, Any]) -> RouteEntry:
"""Convert a DynamoDB item to a RouteEntry."""
return RouteEntry(
route_id=item.get("route_id", ""),
intent_category=item.get("intent_category", ""),
sub_intent=item.get("sub_intent", "default"),
model_id=item.get("model_id", ""),
fallback_model_id=item.get("fallback_model_id", ""),
max_tokens=int(item.get("max_tokens", 1024)),
temperature=float(item.get("temperature", 0.3)),
status=RouteStatus(item.get("status", "proposed")),
version=int(item.get("version", 1)),
created_at=item.get("created_at", ""),
updated_at=item.get("updated_at", ""),
created_by=item.get("created_by", "system"),
config_hash=item.get("config_hash", ""),
test_traffic_pct=float(item.get("test_traffic_pct", 0)),
metadata=json.loads(item.get("metadata", "{}")),
test_metrics=json.loads(item.get("test_metrics", "{}")),
)
def _record_history(self, entry: RouteEntry, event: str, detail: str) -> None:
"""Record a history event for audit trail."""
try:
self.history_table.put_item(Item={
"PK": f"HISTORY#{entry.route_id}",
"SK": f"EVENT#{datetime.utcnow().isoformat()}",
"route_id": entry.route_id,
"event": event,
"detail": detail,
"status": entry.status.value,
"model_id": entry.model_id,
"timestamp": datetime.utcnow().isoformat(),
"ttl": int((datetime.utcnow() + timedelta(days=90)).timestamp()),
})
except Exception as e:
logger.warning("History recording failed | id=%s | error=%s", entry.route_id, e)
def _deprecate_active_route(self, intent: str, sub_intent: str) -> None:
"""Deprecate the current active route for an intent."""
try:
response = self.config_table.scan(
FilterExpression=(
Attr("intent_category").eq(intent)
& Attr("sub_intent").eq(sub_intent)
& Attr("status").eq("active")
),
)
for item in response.get("Items", []):
self.config_table.update_item(
Key={"PK": item["PK"], "SK": item["SK"]},
UpdateExpression="SET #s = :s, updated_at = :t",
ExpressionAttributeNames={"#s": "status"},
ExpressionAttributeValues={
":s": "deprecated",
":t": datetime.utcnow().isoformat(),
},
)
except Exception as e:
logger.error("Failed to deprecate active route | error=%s", e)
def _invalidate_caches(self, entry: RouteEntry) -> None:
"""Invalidate route caches after promotion."""
if self.redis:
try:
cache_key = f"route:{entry.intent_category}:{entry.sub_intent}"
self.redis.delete(cache_key)
self.redis.publish("routing:cache_invalidation", json.dumps({
"intent": entry.intent_category,
"sub_intent": entry.sub_intent,
}))
except Exception as e:
logger.warning("Cache invalidation failed | error=%s", e)
def _collect_test_metrics(self, entry: RouteEntry) -> Dict[str, Any]:
"""Collect metrics for the test route from Redis."""
# Placeholder — in production, aggregate from Redis sorted sets
return {
"query_count": 0,
"p95_latency_ms": 0,
"error_rate_pct": 0,
"quality_score": 0,
}
def _collect_baseline_metrics(self, entry: RouteEntry) -> Dict[str, Any]:
"""Collect baseline metrics from the active route."""
return {
"p95_latency_ms": 0,
"quality_score": 0,
}
4. A/B Routing for Model Comparison
A/B routing sends a controlled split of traffic to different models for the same intent, allowing direct comparison of model performance on identical query distributions.
4.1 A/B Routing Flow
flowchart TB
QUERY([Incoming Query]) --> HASH[Hash Session ID]
HASH --> BUCKET{Bucket Assignment}
BUCKET -->|0-79| CONTROL[Control Group: Haiku<br/>80% traffic]
BUCKET -->|80-94| TREATMENT_A[Treatment A: Sonnet<br/>15% traffic]
BUCKET -->|95-99| TREATMENT_B[Treatment B: Sonnet + Prompt V2<br/>5% traffic]
CONTROL --> INVOKE_C[Invoke Haiku]
TREATMENT_A --> INVOKE_A[Invoke Sonnet]
TREATMENT_B --> INVOKE_B[Invoke Sonnet + New Prompt]
INVOKE_C --> RESPOND([Response to User])
INVOKE_A --> RESPOND
INVOKE_B --> RESPOND
INVOKE_C --> LOG_C[Log: group=control]
INVOKE_A --> LOG_A[Log: group=treatment_a]
INVOKE_B --> LOG_B[Log: group=treatment_b]
LOG_C --> ANALYSIS[Statistical Analysis]
LOG_A --> ANALYSIS
LOG_B --> ANALYSIS
style CONTROL fill:#4CAF50,color:#fff
style TREATMENT_A fill:#2196F3,color:#fff
style TREATMENT_B fill:#FF9800,color:#fff
4.2 A/B Router Implementation
"""
ab_router.py — MangaAssist A/B Model Routing
Deterministic session-based traffic splitting for controlled model comparison.
Uses consistent hashing to ensure the same user always sees the same variant.
"""
import hashlib
import logging
import time
from typing import Dict, Any, List, Optional
from dataclasses import dataclass, field
logger = logging.getLogger(__name__)
@dataclass
class ABVariant:
"""A single variant in an A/B experiment."""
name: str
model_id: str
traffic_pct: float # 0-100
bucket_start: int = 0 # Computed: inclusive
bucket_end: int = 0 # Computed: exclusive
prompt_template: Optional[str] = None
max_tokens: int = 1024
temperature: float = 0.3
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class ABExperiment:
"""An active A/B experiment configuration."""
experiment_id: str
name: str
description: str
intent_filter: str # Which intents this experiment applies to ("*" for all)
variants: List[ABVariant]
start_time: str
end_time: Optional[str] = None
enabled: bool = True
min_sample_size: int = 1000
confidence_level: float = 0.95
@dataclass
class ABAssignment:
"""Result of assigning a session to an A/B variant."""
experiment_id: str
variant_name: str
model_id: str
bucket: int
max_tokens: int
temperature: float
prompt_template: Optional[str]
metadata: Dict[str, Any] = field(default_factory=dict)
class ABRouter:
"""
A/B model routing for MangaAssist experiments.
Deterministically assigns sessions to experiment variants using
consistent hashing. Ensures:
- Same session always gets same variant (no flickering)
- Traffic splits are precise at scale
- Multiple concurrent experiments supported
- Graceful degradation if experiment config is invalid
Usage:
router = ABRouter()
router.add_experiment(ABExperiment(
experiment_id="exp_001",
name="Sonnet Quality Test",
variants=[
ABVariant(name="control", model_id="haiku", traffic_pct=80),
ABVariant(name="treatment", model_id="sonnet", traffic_pct=20),
],
...
))
assignment = router.assign("session_123", intent="product_detail")
# assignment.variant_name → "control" or "treatment"
"""
BUCKET_COUNT = 1000 # 0.1% granularity
def __init__(self, redis_client=None):
self.experiments: Dict[str, ABExperiment] = {}
self.redis = redis_client
self._assignment_cache: Dict[str, ABAssignment] = {}
logger.info("ABRouter initialized | buckets=%d", self.BUCKET_COUNT)
def add_experiment(self, experiment: ABExperiment) -> bool:
"""
Register an A/B experiment.
Validates traffic splits sum to 100% and computes bucket ranges.
"""
# Validate traffic splits
total_traffic = sum(v.traffic_pct for v in experiment.variants)
if abs(total_traffic - 100.0) > 0.1:
logger.error(
"Invalid traffic split | experiment=%s | total=%.1f%%",
experiment.experiment_id,
total_traffic,
)
return False
# Compute bucket ranges
current_bucket = 0
for variant in experiment.variants:
bucket_count = int(variant.traffic_pct / 100.0 * self.BUCKET_COUNT)
variant.bucket_start = current_bucket
variant.bucket_end = current_bucket + bucket_count
current_bucket = variant.bucket_end
# Handle rounding — last variant gets remaining buckets
if experiment.variants:
experiment.variants[-1].bucket_end = self.BUCKET_COUNT
self.experiments[experiment.experiment_id] = experiment
logger.info(
"Experiment registered | id=%s | variants=%d | intents=%s",
experiment.experiment_id,
len(experiment.variants),
experiment.intent_filter,
)
return True
def assign(
self,
session_id: str,
intent: str = "*",
experiment_id: Optional[str] = None,
) -> Optional[ABAssignment]:
"""
Assign a session to an A/B variant.
Uses consistent hashing to deterministically map sessions
to buckets, ensuring stable assignments across requests.
Args:
session_id: Unique session identifier
intent: Current query intent for experiment filtering
experiment_id: Specific experiment (or None for auto-match)
Returns:
ABAssignment if an experiment matches, None otherwise
"""
# Check cache for existing assignment
cache_key = f"{session_id}:{intent}:{experiment_id or 'auto'}"
if cache_key in self._assignment_cache:
return self._assignment_cache[cache_key]
# Find applicable experiment
experiment = self._find_experiment(intent, experiment_id)
if not experiment:
return None
# Compute deterministic bucket
bucket = self._compute_bucket(session_id, experiment.experiment_id)
# Find variant for bucket
variant = self._find_variant(experiment, bucket)
if not variant:
logger.warning("No variant found for bucket %d | experiment=%s", bucket, experiment.experiment_id)
return None
assignment = ABAssignment(
experiment_id=experiment.experiment_id,
variant_name=variant.name,
model_id=variant.model_id,
bucket=bucket,
max_tokens=variant.max_tokens,
temperature=variant.temperature,
prompt_template=variant.prompt_template,
metadata={
"experiment_name": experiment.name,
"traffic_pct": variant.traffic_pct,
},
)
self._assignment_cache[cache_key] = assignment
# Record assignment in Redis for analytics
if self.redis:
try:
self.redis.hincrby(
f"ab:{experiment.experiment_id}:counts",
variant.name,
1,
)
except Exception:
pass
logger.debug(
"A/B assignment | session=%s | experiment=%s | variant=%s | bucket=%d",
session_id[:8],
experiment.experiment_id,
variant.name,
bucket,
)
return assignment
def record_outcome(
self,
experiment_id: str,
variant_name: str,
session_id: str,
metrics: Dict[str, float],
) -> None:
"""
Record outcome metrics for a variant assignment.
Args:
experiment_id: The experiment ID
variant_name: The variant name
session_id: The session ID
metrics: Outcome metrics (latency_ms, quality_score, user_satisfied)
"""
if not self.redis:
return
try:
key_prefix = f"ab:{experiment_id}:{variant_name}"
pipe = self.redis.pipeline()
for metric_name, value in metrics.items():
pipe.rpush(f"{key_prefix}:{metric_name}", str(value))
pipe.expire(f"{key_prefix}:{metric_name}", 86400 * 7) # 7 days
pipe.execute()
except Exception as e:
logger.warning("Failed to record A/B outcome | error=%s", e)
def _compute_bucket(self, session_id: str, experiment_id: str) -> int:
"""Compute a deterministic bucket from session ID and experiment."""
hash_input = f"{session_id}:{experiment_id}"
hash_bytes = hashlib.sha256(hash_input.encode()).digest()
hash_int = int.from_bytes(hash_bytes[:4], byteorder="big")
return hash_int % self.BUCKET_COUNT
def _find_experiment(
self, intent: str, experiment_id: Optional[str]
) -> Optional[ABExperiment]:
"""Find an applicable experiment for the given intent."""
if experiment_id:
exp = self.experiments.get(experiment_id)
return exp if exp and exp.enabled else None
for exp in self.experiments.values():
if not exp.enabled:
continue
if exp.intent_filter == "*" or intent in exp.intent_filter.split(","):
return exp
return None
def _find_variant(
self, experiment: ABExperiment, bucket: int
) -> Optional[ABVariant]:
"""Find the variant that owns the given bucket."""
for variant in experiment.variants:
if variant.bucket_start <= bucket < variant.bucket_end:
return variant
return None
5. Shadow Routing for Safe Testing
Shadow routing duplicates a query to a secondary model without affecting the user response. It allows testing new models or configurations against production traffic with zero user impact.
5.1 Shadow Routing Architecture
flowchart TB
QUERY([Incoming Query]) --> PRIMARY[Primary Route<br/>Haiku — serves user]
QUERY --> SHADOW[Shadow Route<br/>Sonnet — discarded]
PRIMARY --> RESPONSE([User Response<br/>From Primary Only])
PRIMARY --> LOG_P[Log Primary Metrics]
SHADOW --> LOG_S[Log Shadow Metrics]
LOG_P --> COMPARE[Metric Comparison<br/>Lambda Aggregator]
LOG_S --> COMPARE
COMPARE --> REPORT[Shadow Report<br/>Quality Delta + Cost Impact]
style PRIMARY fill:#4CAF50,color:#fff
style SHADOW fill:#9E9E9E,color:#fff
style RESPONSE fill:#4CAF50,color:#fff
5.2 Shadow Router Implementation
"""
shadow_router.py — MangaAssist Shadow Model Routing
Sends queries to a secondary (shadow) model in parallel with the primary
model. The shadow response is discarded — only metrics are collected.
Enables safe, zero-impact production testing of new models.
"""
import asyncio
import time
import json
import logging
from typing import Dict, Any, Optional, List, Tuple
from dataclasses import dataclass, field
from concurrent.futures import ThreadPoolExecutor
import boto3
logger = logging.getLogger(__name__)
@dataclass
class ShadowConfig:
"""Configuration for shadow routing."""
shadow_model_id: str
shadow_traffic_pct: float = 10.0 # % of traffic to shadow
shadow_max_tokens: int = 1024
shadow_temperature: float = 0.3
shadow_timeout_ms: int = 5000 # Timeout for shadow invocation
collect_response: bool = True # Whether to store shadow responses
compare_quality: bool = True # Whether to run quality comparison
enabled: bool = True
@dataclass
class ShadowResult:
"""Result of a shadow routing invocation."""
primary_model_id: str
shadow_model_id: str
primary_latency_ms: float
shadow_latency_ms: float
primary_tokens: Dict[str, int] # {"input": N, "output": N}
shadow_tokens: Dict[str, int]
primary_cost_usd: float
shadow_cost_usd: float
quality_comparison: Optional[Dict[str, Any]] = None
shadow_error: Optional[str] = None
class ShadowRouter:
"""
Shadow model routing for MangaAssist.
Duplicates queries to a shadow model alongside the primary model.
The primary response is served to the user immediately; the shadow
response is used only for metric collection and comparison.
Key Properties:
- Zero user impact: shadow failures never affect responses
- Async execution: shadow invocation doesn't add latency to primary
- Configurable sampling: only shadow a percentage of traffic
- Quality comparison: automated scoring of primary vs shadow responses
- Cost tracking: full cost accounting including shadow overhead
WARNING: Shadow routing costs real money. At 10% shadow rate with Sonnet
as shadow model on 1M queries/day: ~$114/day additional cost.
Usage:
shadow = ShadowRouter(
config=ShadowConfig(
shadow_model_id="anthropic.claude-3-sonnet-20240229-v1:0",
shadow_traffic_pct=10.0,
),
)
# In the main routing path
primary_response = invoke_primary(query)
shadow.maybe_invoke_shadow(query, primary_response)
"""
COSTS = {
"anthropic.claude-3-haiku-20240307-v1:0": {"input": 0.25, "output": 1.25},
"anthropic.claude-3-sonnet-20240229-v1:0": {"input": 3.00, "output": 15.00},
}
def __init__(
self,
config: ShadowConfig,
bedrock_client=None,
redis_client=None,
region: str = "ap-northeast-1",
):
self.config = config
self.bedrock = bedrock_client or boto3.client(
"bedrock-runtime", region_name=region
)
self.redis = redis_client
self._executor = ThreadPoolExecutor(
max_workers=4,
thread_name_prefix="shadow-route",
)
# Sampling state
self._sample_counter = 0
self._sample_interval = max(1, int(100 / config.shadow_traffic_pct))
# Stats
self._total_shadowed = 0
self._total_shadow_errors = 0
self._total_shadow_cost_usd = 0.0
logger.info(
"ShadowRouter initialized | shadow_model=%s | traffic=%.1f%% | interval=%d",
config.shadow_model_id,
config.shadow_traffic_pct,
self._sample_interval,
)
def should_shadow(self) -> bool:
"""Determine if this request should be shadowed (sampling)."""
if not self.config.enabled:
return False
self._sample_counter += 1
return self._sample_counter % self._sample_interval == 0
def invoke_shadow_async(
self,
query: str,
primary_model_id: str,
primary_response: str,
primary_latency_ms: float,
primary_tokens: Dict[str, int],
conversation_history: Optional[List[Dict[str, str]]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""
Fire-and-forget shadow invocation.
This method returns immediately. The shadow invocation runs
in a background thread. Results are pushed to Redis for
offline analysis.
"""
if not self.should_shadow():
return
self._executor.submit(
self._execute_shadow,
query,
primary_model_id,
primary_response,
primary_latency_ms,
primary_tokens,
conversation_history or [],
metadata or {},
)
def _execute_shadow(
self,
query: str,
primary_model_id: str,
primary_response: str,
primary_latency_ms: float,
primary_tokens: Dict[str, int],
history: List[Dict[str, str]],
metadata: Dict[str, Any],
) -> None:
"""Execute shadow invocation and collect metrics (runs in background)."""
start = time.monotonic()
shadow_response = None
shadow_tokens = {"input": 0, "output": 0}
shadow_error = None
try:
# Build messages
messages = []
for msg in history[-5:]: # Last 5 turns for context
messages.append({
"role": msg.get("role", "user"),
"content": msg.get("content", ""),
})
messages.append({"role": "user", "content": query})
# Invoke shadow model
body = json.dumps({
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": self.config.shadow_max_tokens,
"messages": messages,
"temperature": self.config.shadow_temperature,
})
response = self.bedrock.invoke_model(
modelId=self.config.shadow_model_id,
body=body,
contentType="application/json",
accept="application/json",
)
result = json.loads(response["body"].read())
shadow_response = result.get("content", [{}])[0].get("text", "")
usage = result.get("usage", {})
shadow_tokens = {
"input": usage.get("input_tokens", 0),
"output": usage.get("output_tokens", 0),
}
except Exception as e:
shadow_error = str(e)
self._total_shadow_errors += 1
logger.debug("Shadow invocation failed | error=%s", e)
shadow_latency = (time.monotonic() - start) * 1000
# Calculate costs
primary_cost = self._calculate_cost(primary_model_id, primary_tokens)
shadow_cost = self._calculate_cost(self.config.shadow_model_id, shadow_tokens)
self._total_shadow_cost_usd += shadow_cost
self._total_shadowed += 1
# Build result
result = ShadowResult(
primary_model_id=primary_model_id,
shadow_model_id=self.config.shadow_model_id,
primary_latency_ms=primary_latency_ms,
shadow_latency_ms=shadow_latency,
primary_tokens=primary_tokens,
shadow_tokens=shadow_tokens,
primary_cost_usd=primary_cost,
shadow_cost_usd=shadow_cost,
shadow_error=shadow_error,
)
# Quality comparison if both responses available
if shadow_response and self.config.compare_quality:
result.quality_comparison = self._compare_quality(
query, primary_response, shadow_response
)
# Store results
self._store_shadow_result(result, metadata)
logger.debug(
"Shadow complete | primary=%.0fms | shadow=%.0fms | cost_delta=$%.6f",
primary_latency_ms,
shadow_latency,
shadow_cost - primary_cost,
)
def _compare_quality(
self,
query: str,
primary_response: str,
shadow_response: str,
) -> Dict[str, Any]:
"""
Compare quality between primary and shadow responses.
Simple heuristic comparison (no LLM-as-judge to avoid cost spiral):
- Response length ratio
- Keyword overlap
- Entity coverage
"""
primary_words = set(primary_response.lower().split())
shadow_words = set(shadow_response.lower().split())
overlap = len(primary_words & shadow_words)
total = max(len(primary_words | shadow_words), 1)
similarity = overlap / total
return {
"response_length_primary": len(primary_response),
"response_length_shadow": len(shadow_response),
"length_ratio": len(shadow_response) / max(len(primary_response), 1),
"word_overlap_pct": round(similarity * 100, 1),
"shadow_longer": len(shadow_response) > len(primary_response),
}
def _calculate_cost(self, model_id: str, tokens: Dict[str, int]) -> float:
"""Calculate cost in USD for a model invocation."""
costs = self.COSTS.get(model_id, {"input": 0, "output": 0})
return (
tokens.get("input", 0) * costs["input"]
+ tokens.get("output", 0) * costs["output"]
) / 1_000_000
def _store_shadow_result(
self, result: ShadowResult, metadata: Dict[str, Any]
) -> None:
"""Store shadow comparison result in Redis for analysis."""
if not self.redis:
return
try:
record = {
"primary_model": result.primary_model_id,
"shadow_model": result.shadow_model_id,
"primary_latency_ms": result.primary_latency_ms,
"shadow_latency_ms": result.shadow_latency_ms,
"primary_cost_usd": result.primary_cost_usd,
"shadow_cost_usd": result.shadow_cost_usd,
"shadow_error": result.shadow_error,
"quality": result.quality_comparison,
"timestamp": int(time.time()),
}
key = f"shadow:results:{int(time.time())}"
self.redis.setex(key, 86400, json.dumps(record))
# Increment aggregate counters
self.redis.incr("shadow:total_count")
if result.shadow_error:
self.redis.incr("shadow:error_count")
self.redis.incrbyfloat("shadow:total_shadow_cost", result.shadow_cost_usd)
except Exception as e:
logger.warning("Failed to store shadow result | error=%s", e)
def get_shadow_stats(self) -> Dict[str, Any]:
"""Return shadow routing statistics."""
return {
"total_shadowed": self._total_shadowed,
"total_errors": self._total_shadow_errors,
"total_shadow_cost_usd": round(self._total_shadow_cost_usd, 4),
"error_rate_pct": (
(self._total_shadow_errors / max(self._total_shadowed, 1)) * 100
),
"config": {
"shadow_model": self.config.shadow_model_id,
"traffic_pct": self.config.shadow_traffic_pct,
"enabled": self.config.enabled,
},
}
6. Shadow Routing Cost Impact Analysis
At MangaAssist scale (1M messages/day), shadow routing cost must be carefully managed:
| Shadow Config | Shadow Rate | Shadow Model | Shadow Queries/Day | Additional Daily Cost | Monthly Overhead |
|---|---|---|---|---|---|
| Conservative | 5% | Haiku | 50,000 | $31.25 | $937 |
| Standard | 10% | Haiku | 100,000 | $62.50 | $1,875 |
| Standard | 10% | Sonnet | 100,000 | $1,140.00 | $34,200 |
| Aggressive | 25% | Sonnet | 250,000 | $2,850.00 | $85,500 |
| Full Mirror | 100% | Sonnet | 1,000,000 | $11,400.00 | $342,000 |
Recommendation: Shadow at 10% with Haiku as shadow when testing Haiku config changes. Shadow at 5% with Sonnet only during time-boxed quality comparison studies (maximum 48 hours).
7. Comparison: A/B vs Shadow Routing
| Dimension | A/B Routing | Shadow Routing |
|---|---|---|
| User impact | Users see the variant response | Zero — shadow response is discarded |
| Metrics quality | Real user satisfaction data | Synthetic comparison only |
| Cost | Same as production (variant replaces primary) | Additional cost (both models invoked) |
| Latency impact | None (user gets variant directly) | None (shadow is async fire-and-forget) |
| Risk | Medium — bad variant affects real users | Zero — shadow failures are invisible |
| Use case | Comparing final model performance | Pre-launch validation of new models |
| Duration | Days to weeks (need statistical significance) | Hours to days (no user impact) |
| Sample size | Limited by traffic split | Can shadow 100% if budget allows |
| Rollback | Change traffic split | Disable shadow config |
8. Integration Points Summary
flowchart TB
subgraph Core["Core Routing"]
SR[StaticRouter]
DR[DynamicModelRouter]
MS[MetricBasedSelector]
end
subgraph Analysis["Analysis"]
CS[ComplexityScorer]
ME[MetricEmitter]
end
subgraph Testing["Testing & Experimentation"]
AB[ABRouter]
SH[ShadowRouter]
RTM[RoutingTableManager]
end
subgraph Edge["API Gateway"]
AGT[APIGatewayRouteTransformer]
end
AGT -->|Fast-path intents| SR
AGT -->|Complex/ambiguous| DR
DR --> CS
DR --> MS
MS --> ME
RTM -->|Config updates| SR
RTM -->|Canary traffic| AB
AB -->|Experiment traffic| DR
SH -->|Parallel shadow| ME
style SR fill:#4CAF50,color:#fff
style DR fill:#2196F3,color:#fff
style MS fill:#FF9800,color:#fff
style CS fill:#9C27B0,color:#fff
9. References
| Resource | Link |
|---|---|
| Amazon Bedrock Inference APIs | https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InvokeModel.html |
| Claude 3 Model Pricing | https://aws.amazon.com/bedrock/pricing/ |
| ElastiCache Redis Sorted Sets | https://docs.aws.amazon.com/AmazonElastiCache/latest/red-ug/Sorted-Sets.html |
| DynamoDB Streams | https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Streams.html |
| Kinesis Data Streams | https://docs.aws.amazon.com/streams/latest/dev/introduction.html |
| A/B Testing Statistical Significance | https://docs.aws.amazon.com/wellarchitected/latest/machine-learning-lens/experimentation.html |
| CloudWatch Embedded Metric Format | https://docs.aws.amazon.com/AmazonCloudWatch/latest/monitoring/CloudWatch_Embedded_Metric_Format.html |