LOCAL PREVIEW View on GitHub

Dynamic Routing and Metric-Based Model Selection

MangaAssist context: JP Manga store chatbot on AWS — Bedrock Claude 3 (Sonnet at $3/$15 per 1M tokens input/output, Haiku at $0.25/$1.25), OpenSearch Serverless (vector store), DynamoDB (sessions/products), ECS Fargate (orchestrator), API Gateway WebSocket, ElastiCache Redis. Target: useful answer in under 3 seconds, 1M messages/day scale.


Skill Mapping

Field Value
Certification AWS AI Practitioner (AIP-C01)
Domain 2 — Implementation and Integration of Foundation Models
Task 2.4 — Design model deployment and inference strategies
Skill 2.4.4 — Develop intelligent model routing systems to optimize model selection
Focus Deep-dive into content complexity scoring, real-time metric collection, routing table management, A/B routing, shadow routing

1. Content Complexity Scoring — Deep Dive

Content complexity scoring is the foundation of dynamic routing. A well-calibrated scorer is the difference between routing 60% of queries to cheap Haiku (saving thousands per day) and accidentally sending everything to expensive Sonnet.

1.1 Complexity Dimensions

mindmap
  root((Complexity Scoring))
    Lexical Features
      Average Word Length
      Vocabulary Richness TTR
      Rare Word Frequency
      Technical Terminology Count
      Japanese Character Ratio
    Structural Features
      Sentence Count
      Clause Depth
      Question Count
      Conditional Statements
      Enumeration Lists
    Semantic Features
      Entity Density
      Comparison Operators
      Negation Chains
      Temporal References
      Causal Relationships
    Contextual Features
      Multi-Turn Depth
      Topic Switches
      Reference Resolution
      Anaphora Chains
      Implicit Knowledge Needs
    Domain Features
      Manga-Specific Terms
      Cultural Context Requirements
      Cross-Reference Needs
      Creative vs Factual Balance
      Spoiler Sensitivity

1.2 Multi-Dimensional Complexity Scorer

"""
complexity_scorer.py — MangaAssist Multi-Dimensional Complexity Scorer

Computes a fine-grained complexity score across lexical, structural,
semantic, contextual, and domain-specific dimensions. The composite
score drives model selection in the dynamic routing pipeline.
"""

import re
import math
import logging
from typing import Dict, Any, List, Optional, Tuple
from dataclasses import dataclass, field
from collections import Counter

logger = logging.getLogger(__name__)


@dataclass
class DimensionScore:
    """Score for a single complexity dimension."""
    dimension: str
    raw_score: float       # Dimension-specific scale
    normalized: float      # 0-1 normalized
    weight: float          # Contribution weight
    weighted_score: float  # normalized * weight
    features: Dict[str, Any] = field(default_factory=dict)


@dataclass
class ComplexityReport:
    """Full complexity scoring report."""
    composite_score: float           # 0-10 final score
    confidence: float                # 0-1 confidence in the score
    dimensions: List[DimensionScore]
    recommended_model: str
    reasoning: str
    scoring_time_ms: float
    feature_vector: Dict[str, float] = field(default_factory=dict)


class ComplexityScorer:
    """
    Multi-dimensional complexity scorer for MangaAssist queries.

    Analyzes queries across 5 dimensions (lexical, structural, semantic,
    contextual, domain) to produce a composite 0-10 complexity score.
    Each dimension is independently scored, weighted, and normalized.

    Calibration targets (from MangaAssist production analysis):
        - "What time do you close?" → 1.2 (Trivial → Haiku)
        - "Do you have One Piece volume 99?" → 2.8 (Simple → Haiku)
        - "Compare Demon Slayer and JJK art styles" → 6.5 (Complex → Sonnet)
        - "Analyze the thematic evolution of Berserk across the Golden Age arc
           and how Miura's art style reflects Guts's psychological state" → 8.9 (Expert → Sonnet)

    Usage:
        scorer = ComplexityScorer()
        report = scorer.score(query, conversation_history)
        if report.composite_score >= 6.5:
            use_sonnet()
        else:
            use_haiku()
    """

    # Dimension weights (must sum to 1.0)
    WEIGHTS = {
        "lexical": 0.20,
        "structural": 0.20,
        "semantic": 0.25,
        "contextual": 0.15,
        "domain": 0.20,
    }

    # Japanese character pattern
    JP_PATTERN = re.compile(r"[\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\uFF00-\uFFEF]")

    # Technical manga/anime terms that signal domain complexity
    DOMAIN_TERMS_BASIC = {
        "manga", "anime", "volume", "chapter", "series", "author",
        "publisher", "price", "edition", "paperback", "hardcover",
    }
    DOMAIN_TERMS_INTERMEDIATE = {
        "shounen", "shoujo", "seinen", "josei", "isekai", "mecha",
        "slice of life", "tankoubon", "mangaka", "light novel",
        "visual novel", "doujinshi", "oneshot", "serialization",
    }
    DOMAIN_TERMS_ADVANCED = {
        "paneling technique", "screentone", "decompression",
        "narrative structure", "visual metaphor", "symbolic imagery",
        "character arc", "thematic analysis", "art evolution",
        "cultural significance", "genre deconstruction", "metanarrative",
    }

    # Complex linguistic patterns
    COMPARISON_PATTERNS = [
        r"\b(compare|contrast|versus|vs\.?|difference between|similarities)\b",
        r"\b(better than|worse than|more .+ than|less .+ than)\b",
        r"\b(pros and cons|advantages|disadvantages|trade-?offs)\b",
    ]
    ANALYTICAL_PATTERNS = [
        r"\b(analyze|explain why|what caused|how does .+ affect)\b",
        r"\b(significance of|meaning behind|symbolism|represents)\b",
        r"\b(evolution of|development of|progression|trajectory)\b",
    ]
    CONDITIONAL_PATTERNS = [
        r"\b(if .+ then|what if|assuming|suppose|given that)\b",
        r"\b(would .+ if|could .+ when|should .+ because)\b",
    ]
    CAUSAL_PATTERNS = [
        r"\b(because|therefore|consequently|as a result|leads to)\b",
        r"\b(caused by|due to|reason for|explains why|contributes to)\b",
    ]

    def __init__(
        self,
        sonnet_threshold: float = 6.5,
        weights: Optional[Dict[str, float]] = None,
    ):
        self.sonnet_threshold = sonnet_threshold
        self.weights = weights or self.WEIGHTS.copy()

        # Validate weights sum to 1.0
        total = sum(self.weights.values())
        if abs(total - 1.0) > 0.01:
            raise ValueError(f"Dimension weights must sum to 1.0, got {total}")

        # Compile regex patterns once
        self._comparison_re = [re.compile(p, re.IGNORECASE) for p in self.COMPARISON_PATTERNS]
        self._analytical_re = [re.compile(p, re.IGNORECASE) for p in self.ANALYTICAL_PATTERNS]
        self._conditional_re = [re.compile(p, re.IGNORECASE) for p in self.CONDITIONAL_PATTERNS]
        self._causal_re = [re.compile(p, re.IGNORECASE) for p in self.CAUSAL_PATTERNS]

        logger.info(
            "ComplexityScorer initialized | threshold=%.1f | weights=%s",
            sonnet_threshold,
            self.weights,
        )

    def score(
        self,
        query: str,
        conversation_history: Optional[List[Dict[str, str]]] = None,
    ) -> ComplexityReport:
        """
        Compute a full complexity report for the given query.

        Args:
            query: The user's current message
            conversation_history: Previous conversation turns

        Returns:
            ComplexityReport with composite score and per-dimension breakdown
        """
        import time
        start = time.monotonic()

        history = conversation_history or []

        # Score each dimension independently
        lexical = self._score_lexical(query)
        structural = self._score_structural(query)
        semantic = self._score_semantic(query)
        contextual = self._score_contextual(query, history)
        domain = self._score_domain(query)

        dimensions = [lexical, structural, semantic, contextual, domain]

        # Compute composite score (weighted sum, scaled to 0-10)
        composite = sum(d.weighted_score for d in dimensions) * 10.0
        composite = max(0.0, min(10.0, composite))

        # Compute confidence based on feature coverage
        non_zero_dims = sum(1 for d in dimensions if d.raw_score > 0)
        confidence = non_zero_dims / len(dimensions)

        # Build feature vector for logging/analysis
        feature_vector = {}
        for dim in dimensions:
            for key, value in dim.features.items():
                feature_vector[f"{dim.dimension}_{key}"] = value

        # Determine recommended model
        if composite >= self.sonnet_threshold:
            model = "anthropic.claude-3-sonnet-20240229-v1:0"
            reasoning = f"Complexity {composite:.1f} >= threshold {self.sonnet_threshold} — Sonnet recommended"
        else:
            model = "anthropic.claude-3-haiku-20240307-v1:0"
            reasoning = f"Complexity {composite:.1f} < threshold {self.sonnet_threshold} — Haiku sufficient"

        elapsed = (time.monotonic() - start) * 1000

        report = ComplexityReport(
            composite_score=composite,
            confidence=confidence,
            dimensions=dimensions,
            recommended_model=model,
            reasoning=reasoning,
            scoring_time_ms=elapsed,
            feature_vector=feature_vector,
        )

        logger.info(
            "Complexity scored | composite=%.1f | model=%s | confidence=%.2f | %.2fms",
            composite,
            model.split(".")[-1][:15],
            confidence,
            elapsed,
        )

        return report

    def _score_lexical(self, query: str) -> DimensionScore:
        """
        Score lexical complexity: vocabulary richness, word length, rare terms.

        Features:
            - avg_word_length: Average characters per word
            - type_token_ratio: Unique words / total words
            - long_word_ratio: Words > 8 chars / total words
            - jp_char_ratio: Japanese characters / total characters
        """
        words = re.findall(r"\b\w+\b", query.lower())
        if not words:
            return self._empty_dimension("lexical")

        word_count = len(words)
        unique_words = len(set(words))

        avg_word_length = sum(len(w) for w in words) / word_count
        type_token_ratio = unique_words / word_count if word_count > 0 else 0
        long_word_ratio = sum(1 for w in words if len(w) > 8) / word_count
        jp_chars = len(self.JP_PATTERN.findall(query))
        total_chars = max(len(query), 1)
        jp_char_ratio = jp_chars / total_chars

        # Raw score: 0-5 scale
        raw = 0.0
        raw += min(avg_word_length / 8.0, 1.0) * 1.5       # Word length
        raw += min(type_token_ratio, 1.0) * 1.0             # Vocabulary richness
        raw += min(long_word_ratio * 5, 1.0) * 1.0          # Long words
        raw += min(jp_char_ratio * 2, 1.0) * 0.5            # Japanese content

        normalized = min(raw / 4.0, 1.0)
        weight = self.weights["lexical"]

        return DimensionScore(
            dimension="lexical",
            raw_score=raw,
            normalized=normalized,
            weight=weight,
            weighted_score=normalized * weight,
            features={
                "avg_word_length": round(avg_word_length, 2),
                "type_token_ratio": round(type_token_ratio, 3),
                "long_word_ratio": round(long_word_ratio, 3),
                "jp_char_ratio": round(jp_char_ratio, 3),
                "word_count": word_count,
            },
        )

    def _score_structural(self, query: str) -> DimensionScore:
        """
        Score structural complexity: sentence count, question depth, clauses.

        Features:
            - sentence_count: Number of sentences
            - question_count: Number of question marks
            - clause_indicators: Subordinate clause markers
            - enumeration: Lists or numbered items
        """
        sentences = re.split(r"[.!?]+", query)
        sentences = [s.strip() for s in sentences if s.strip()]
        sentence_count = len(sentences)

        question_count = query.count("?")
        clause_markers = len(re.findall(
            r"\b(which|that|who|whom|whose|where|when|while|although|because|since|unless|if)\b",
            query, re.IGNORECASE,
        ))
        enum_markers = len(re.findall(r"(\d+\.|[-*]|\bfirst\b|\bsecond\b|\bthird\b)", query, re.IGNORECASE))
        comma_count = query.count(",")

        raw = 0.0
        raw += min(sentence_count / 4.0, 1.0) * 1.0        # Multiple sentences
        raw += min(question_count / 3.0, 1.0) * 1.0         # Multiple questions
        raw += min(clause_markers / 3.0, 1.0) * 1.0         # Clause depth
        raw += min(enum_markers / 3.0, 1.0) * 0.5           # Enumerations
        raw += min(comma_count / 5.0, 1.0) * 0.5            # Structural complexity proxy

        normalized = min(raw / 4.0, 1.0)
        weight = self.weights["structural"]

        return DimensionScore(
            dimension="structural",
            raw_score=raw,
            normalized=normalized,
            weight=weight,
            weighted_score=normalized * weight,
            features={
                "sentence_count": sentence_count,
                "question_count": question_count,
                "clause_markers": clause_markers,
                "enum_markers": enum_markers,
                "comma_count": comma_count,
            },
        )

    def _score_semantic(self, query: str) -> DimensionScore:
        """
        Score semantic complexity: analytical depth, causal reasoning, comparisons.

        Features:
            - comparison_count: Comparative language patterns
            - analytical_count: Analytical request patterns
            - conditional_count: Hypothetical/conditional structures
            - causal_count: Causal relationship indicators
            - entity_density: Named entities per word
        """
        comparison_count = sum(1 for p in self._comparison_re if p.search(query))
        analytical_count = sum(1 for p in self._analytical_re if p.search(query))
        conditional_count = sum(1 for p in self._conditional_re if p.search(query))
        causal_count = sum(1 for p in self._causal_re if p.search(query))

        # Entity density (capitalized words and quoted terms)
        entities = re.findall(r"\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b", query)
        quoted = re.findall(r'"[^"]+"|\'[^\']+\'', query)
        entity_count = len(entities) + len(quoted)
        word_count = max(len(query.split()), 1)
        entity_density = entity_count / word_count

        raw = 0.0
        raw += min(comparison_count, 2) * 1.0                # Comparisons
        raw += min(analytical_count, 2) * 1.5                # Analytical depth
        raw += min(conditional_count, 2) * 0.8               # Conditionals
        raw += min(causal_count, 2) * 0.7                    # Causal reasoning
        raw += min(entity_density * 10, 1.0) * 0.5           # Entity density

        normalized = min(raw / 4.5, 1.0)
        weight = self.weights["semantic"]

        return DimensionScore(
            dimension="semantic",
            raw_score=raw,
            normalized=normalized,
            weight=weight,
            weighted_score=normalized * weight,
            features={
                "comparison_count": comparison_count,
                "analytical_count": analytical_count,
                "conditional_count": conditional_count,
                "causal_count": causal_count,
                "entity_count": entity_count,
                "entity_density": round(entity_density, 3),
            },
        )

    def _score_contextual(
        self, query: str, history: List[Dict[str, str]]
    ) -> DimensionScore:
        """
        Score contextual complexity: multi-turn depth, topic continuity, references.

        Features:
            - turn_depth: Number of previous conversation turns
            - has_pronouns: References requiring context resolution
            - topic_shift: Whether the query shifts topic from history
            - reference_count: Explicit back-references ("as I mentioned", "the one you said")
        """
        turn_depth = len(history)

        # Pronoun-based reference resolution
        pronouns = len(re.findall(
            r"\b(it|they|them|this|that|these|those|its|their|the one)\b",
            query, re.IGNORECASE,
        ))
        has_pronouns = pronouns > 0

        # Explicit back-references
        back_refs = len(re.findall(
            r"\b(you (mentioned|said|told)|as I (said|asked)|earlier|previously|the same|like before)\b",
            query, re.IGNORECASE,
        ))

        # Topic shift detection (simple: check if query shares keywords with last turn)
        topic_shift = False
        if history:
            last_msg = history[-1].get("content", "").lower()
            query_words = set(query.lower().split())
            last_words = set(last_msg.split())
            overlap = len(query_words & last_words)
            total = max(len(query_words), 1)
            if overlap / total < 0.1:
                topic_shift = True

        raw = 0.0
        raw += min(turn_depth / 5.0, 1.0) * 1.5             # Multi-turn depth
        raw += (1.0 if has_pronouns else 0.0) * 0.5         # Pronoun resolution
        raw += min(back_refs, 2) * 0.8                       # Explicit references
        raw += (0.5 if topic_shift else 0.0)                 # Topic continuity

        normalized = min(raw / 3.3, 1.0)
        weight = self.weights["contextual"]

        return DimensionScore(
            dimension="contextual",
            raw_score=raw,
            normalized=normalized,
            weight=weight,
            weighted_score=normalized * weight,
            features={
                "turn_depth": turn_depth,
                "pronoun_count": pronouns,
                "has_pronouns": has_pronouns,
                "back_reference_count": back_refs,
                "topic_shift": topic_shift,
            },
        )

    def _score_domain(self, query: str) -> DimensionScore:
        """
        Score domain-specific complexity: manga/anime knowledge requirements.

        Features:
            - basic_term_count: Common manga terms found
            - intermediate_term_count: Niche genre/format terms
            - advanced_term_count: Critical analysis terminology
            - cultural_context_needed: Whether cultural knowledge is required
        """
        q_lower = query.lower()

        basic_count = sum(1 for t in self.DOMAIN_TERMS_BASIC if t in q_lower)
        intermediate_count = sum(1 for t in self.DOMAIN_TERMS_INTERMEDIATE if t in q_lower)
        advanced_count = sum(1 for t in self.DOMAIN_TERMS_ADVANCED if t in q_lower)

        # Cultural context heuristic
        cultural_indicators = [
            "japanese culture", "cultural significance", "honorifics",
            "japanese humor", "cultural reference", "tradition",
            "ceremony", "social hierarchy", "senpai", "sensei",
        ]
        cultural_needed = any(ind in q_lower for ind in cultural_indicators)

        raw = 0.0
        raw += min(basic_count / 3.0, 1.0) * 0.5            # Basic domain terms
        raw += min(intermediate_count / 2.0, 1.0) * 1.0     # Intermediate terms
        raw += min(advanced_count, 2) * 1.5                  # Advanced analysis terms
        raw += (1.0 if cultural_needed else 0.0) * 1.0      # Cultural context

        normalized = min(raw / 4.0, 1.0)
        weight = self.weights["domain"]

        return DimensionScore(
            dimension="domain",
            raw_score=raw,
            normalized=normalized,
            weight=weight,
            weighted_score=normalized * weight,
            features={
                "basic_term_count": basic_count,
                "intermediate_term_count": intermediate_count,
                "advanced_term_count": advanced_count,
                "cultural_context_needed": cultural_needed,
            },
        )

    def _empty_dimension(self, name: str) -> DimensionScore:
        """Return a zero-score dimension for empty input."""
        return DimensionScore(
            dimension=name,
            raw_score=0.0,
            normalized=0.0,
            weight=self.weights.get(name, 0.0),
            weighted_score=0.0,
        )

2. Real-Time Metric Collection

2.1 Metric Collection Pipeline Architecture

flowchart TB
    subgraph Sources["Metric Emission Points"]
        direction TB
        INV[Bedrock InvokeModel<br/>Latency + Token Count]
        RESP[Response Evaluator<br/>Quality Score]
        USER[User Feedback<br/>Thumbs Up/Down]
        ERR[Error Handler<br/>Error Type + Rate]
        COST[Cost Calculator<br/>Per-Query USD]
    end

    subgraph Pipeline["Collection Pipeline"]
        direction TB
        EMT[Metric Emitter<br/>Async Non-Blocking]
        KDS[Kinesis Data Stream<br/>2 Shards]
        AGG[Lambda Aggregator<br/>1-min Windows]
    end

    subgraph RealTime["Real-Time Layer"]
        direction TB
        RED[ElastiCache Redis<br/>Sorted Sets + Hashes]
        PUB[Redis Pub/Sub<br/>Threshold Alerts]
    end

    subgraph Durable["Durable Layer"]
        direction TB
        DDB[DynamoDB<br/>Time-Series Metrics]
        S3[S3 Parquet<br/>Historical Archive]
    end

    subgraph Consumers["Metric Consumers"]
        direction TB
        SEL[MetricBasedSelector<br/>Model Ranking]
        DASH[CloudWatch Dashboard<br/>Monitoring]
        ALERT[SNS Alerts<br/>Anomaly Detection]
    end

    INV --> EMT
    RESP --> EMT
    USER --> EMT
    ERR --> EMT
    COST --> EMT

    EMT --> KDS
    KDS --> AGG
    AGG --> RED
    AGG --> DDB
    AGG --> S3

    RED --> SEL
    RED --> PUB
    PUB --> ALERT
    DDB --> DASH

2.2 Metric Data Points

Metric Source Granularity Storage TTL
invocation_latency_ms Bedrock response Per-request Redis sorted set 1 hour
input_token_count Bedrock response Per-request Redis counter 1 hour
output_token_count Bedrock response Per-request Redis counter 1 hour
error_count Error handler Per-request Redis counter 1 hour
throttle_count Bedrock 429 response Per-request Redis counter 1 hour
quality_score Response evaluator Per-request DynamoDB 30 days
user_satisfaction Feedback button Per-session DynamoDB 90 days
cost_usd Cost calculator Per-request DynamoDB 90 days
p50_latency_ms Lambda aggregator 1-minute window Redis hash 1 hour
p95_latency_ms Lambda aggregator 1-minute window Redis hash 1 hour
p99_latency_ms Lambda aggregator 1-minute window Redis hash 1 hour
queries_per_minute Lambda aggregator 1-minute window Redis hash 1 hour

2.3 Metric Emitter Implementation

"""
metric_emitter.py — MangaAssist Async Metric Emitter

Non-blocking metric emission for the routing pipeline.
Pushes raw metric data points to Kinesis for aggregation
and directly to Redis for real-time consumption.
"""

import json
import time
import logging
import asyncio
from typing import Dict, Any, Optional, List
from dataclasses import dataclass, asdict
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor

import boto3

logger = logging.getLogger(__name__)


@dataclass
class MetricDataPoint:
    """A single metric observation."""
    metric_name: str
    value: float
    model_id: str
    timestamp_ms: int
    dimensions: Dict[str, str]
    unit: str = "None"  # "Milliseconds", "Count", "USD", "Percent"

    def to_dict(self) -> Dict[str, Any]:
        return asdict(self)


@dataclass
class InvocationMetrics:
    """Complete metrics from a single model invocation."""
    model_id: str
    latency_ms: float
    input_tokens: int
    output_tokens: int
    status: str            # "success", "error", "throttled", "timeout"
    error_type: Optional[str] = None
    quality_score: Optional[float] = None
    cost_usd: Optional[float] = None
    route_path: str = "unknown"
    query_type: str = "unknown"
    complexity_score: float = 0.0


class MetricEmitter:
    """
    Asynchronous metric emitter for MangaAssist routing.

    Emits metrics in a non-blocking fashion to avoid adding latency
    to the hot path. Uses fire-and-forget pattern with local buffering
    and batch flushing.

    Architecture:
        - Metrics are buffered in memory
        - Background thread flushes to Kinesis every 5 seconds
        - Critical metrics (errors, throttles) also go directly to Redis
        - CloudWatch embedded metric format for automatic dashboard creation

    Usage:
        emitter = MetricEmitter(
            kinesis_stream="MangaAssist-Metrics",
            redis_client=redis_conn,
        )
        emitter.emit_invocation(InvocationMetrics(
            model_id="anthropic.claude-3-haiku-20240307-v1:0",
            latency_ms=450.0,
            input_tokens=200,
            output_tokens=150,
            status="success",
        ))
    """

    BATCH_SIZE = 50
    FLUSH_INTERVAL_SECONDS = 5.0

    # Cost constants
    COSTS = {
        "anthropic.claude-3-haiku-20240307-v1:0": {"input": 0.25, "output": 1.25},
        "anthropic.claude-3-sonnet-20240229-v1:0": {"input": 3.00, "output": 15.00},
    }

    def __init__(
        self,
        kinesis_stream: str = "MangaAssist-Metrics",
        redis_client=None,
        region: str = "ap-northeast-1",
        enable_cloudwatch: bool = True,
    ):
        self.stream_name = kinesis_stream
        self.redis = redis_client
        self.enable_cw = enable_cloudwatch

        self.kinesis = boto3.client("kinesis", region_name=region)
        if enable_cloudwatch:
            self.cloudwatch = boto3.client("cloudwatch", region_name=region)

        self._buffer: List[MetricDataPoint] = []
        self._buffer_lock = asyncio.Lock() if asyncio.get_event_loop().is_running() else None
        self._executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="metric-emit")

        self._total_emitted = 0
        self._total_errors = 0

        logger.info(
            "MetricEmitter initialized | stream=%s | cw=%s",
            kinesis_stream,
            enable_cloudwatch,
        )

    def emit_invocation(self, metrics: InvocationMetrics) -> None:
        """
        Emit all metrics from a single model invocation.
        Non-blocking: returns immediately and processes asynchronously.
        """
        now_ms = int(time.time() * 1000)
        dimensions = {
            "model_id": metrics.model_id,
            "route_path": metrics.route_path,
            "query_type": metrics.query_type,
            "status": metrics.status,
        }

        points = [
            MetricDataPoint("invocation_latency_ms", metrics.latency_ms, metrics.model_id, now_ms, dimensions, "Milliseconds"),
            MetricDataPoint("input_token_count", float(metrics.input_tokens), metrics.model_id, now_ms, dimensions, "Count"),
            MetricDataPoint("output_token_count", float(metrics.output_tokens), metrics.model_id, now_ms, dimensions, "Count"),
        ]

        # Compute cost if not provided
        cost = metrics.cost_usd
        if cost is None:
            model_costs = self.COSTS.get(metrics.model_id)
            if model_costs:
                cost = (
                    metrics.input_tokens * model_costs["input"]
                    + metrics.output_tokens * model_costs["output"]
                ) / 1_000_000
        if cost is not None:
            points.append(MetricDataPoint("cost_usd", cost, metrics.model_id, now_ms, dimensions, "USD"))

        if metrics.quality_score is not None:
            points.append(MetricDataPoint("quality_score", metrics.quality_score, metrics.model_id, now_ms, dimensions, "None"))

        if metrics.status == "error":
            points.append(MetricDataPoint("error_count", 1.0, metrics.model_id, now_ms, dimensions, "Count"))
        if metrics.status == "throttled":
            points.append(MetricDataPoint("throttle_count", 1.0, metrics.model_id, now_ms, dimensions, "Count"))

        # Fire and forget
        self._executor.submit(self._process_points, points)

        # Critical metrics go directly to Redis for real-time routing
        if self.redis and metrics.status in ("error", "throttled"):
            self._executor.submit(self._emit_critical_to_redis, metrics)

    def _process_points(self, points: List[MetricDataPoint]) -> None:
        """Process metric data points (runs in background thread)."""
        try:
            # Send to Kinesis
            records = [
                {
                    "Data": json.dumps(p.to_dict()).encode("utf-8"),
                    "PartitionKey": p.model_id,
                }
                for p in points
            ]
            if records:
                self.kinesis.put_records(StreamName=self.stream_name, Records=records)
                self._total_emitted += len(records)

            # Update Redis real-time counters
            if self.redis:
                self._update_redis_counters(points)

        except Exception as e:
            self._total_errors += 1
            logger.warning("Metric emission failed | error=%s | points=%d", e, len(points))

    def _update_redis_counters(self, points: List[MetricDataPoint]) -> None:
        """Update Redis real-time counters and sorted sets."""
        try:
            pipe = self.redis.pipeline()
            for p in points:
                minute_key = f"metrics_min:{p.model_id}:{int(p.timestamp_ms / 60000)}"

                if p.metric_name == "invocation_latency_ms":
                    # Sorted set for percentile calculation
                    pipe.zadd(
                        f"latency_samples:{p.model_id}",
                        {str(p.timestamp_ms): p.value},
                    )
                    pipe.expire(f"latency_samples:{p.model_id}", 3600)

                # Increment per-minute counters
                pipe.hincrby(minute_key, p.metric_name + "_count", 1)
                pipe.hincrbyfloat(minute_key, p.metric_name + "_sum", p.value)
                pipe.expire(minute_key, 3600)

            pipe.execute()
        except Exception as e:
            logger.warning("Redis counter update failed | error=%s", e)

    def _emit_critical_to_redis(self, metrics: InvocationMetrics) -> None:
        """Push critical events directly to Redis for immediate routing impact."""
        try:
            event = {
                "model_id": metrics.model_id,
                "status": metrics.status,
                "error_type": metrics.error_type,
                "timestamp": int(time.time() * 1000),
            }
            self.redis.publish("routing:critical_events", json.dumps(event))

            # Increment error/throttle counters
            error_key = f"errors:{metrics.model_id}:{int(time.time()) // 60}"
            self.redis.incr(error_key)
            self.redis.expire(error_key, 300)

        except Exception as e:
            logger.warning("Critical metric emission failed | error=%s", e)

    def get_stats(self) -> Dict[str, int]:
        """Return emitter statistics for monitoring."""
        return {
            "total_emitted": self._total_emitted,
            "total_errors": self._total_errors,
            "buffer_size": len(self._buffer),
        }

3. Routing Table Management

3.1 Routing Table Lifecycle

stateDiagram-v2
    [*] --> Proposed: New route submitted
    Proposed --> Validated: Schema + policy check
    Validated --> Staged: Written to staging table
    Staged --> Testing: Canary traffic enabled
    Testing --> Active: Metrics pass thresholds
    Testing --> Rolled_Back: Metrics fail
    Active --> Deprecated: Newer route promoted
    Deprecated --> Archived: TTL expired
    Rolled_Back --> Proposed: Fix and resubmit
    Archived --> [*]

    note right of Testing
        Shadow route with 5% traffic
        Monitor for 30 minutes
        Compare against baseline
    end note

3.2 Routing Table Manager

"""
routing_table_manager.py — MangaAssist Routing Table Lifecycle Manager

Manages the full lifecycle of routing configurations: creation, validation,
staged rollout, testing, activation, deprecation, and archival.
"""

import json
import time
import logging
import hashlib
from typing import Dict, Any, List, Optional, Tuple
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum

import boto3
from boto3.dynamodb.conditions import Key, Attr

logger = logging.getLogger(__name__)


class RouteStatus(Enum):
    """Lifecycle status of a routing configuration."""
    PROPOSED = "proposed"
    VALIDATED = "validated"
    STAGED = "staged"
    TESTING = "testing"
    ACTIVE = "active"
    DEPRECATED = "deprecated"
    ROLLED_BACK = "rolled_back"
    ARCHIVED = "archived"


@dataclass
class RouteEntry:
    """A single routing table entry."""
    route_id: str
    intent_category: str
    sub_intent: str
    model_id: str
    fallback_model_id: str
    max_tokens: int
    temperature: float
    status: RouteStatus
    version: int
    created_at: str
    updated_at: str
    created_by: str
    config_hash: str = ""
    test_traffic_pct: float = 0.0
    test_start_time: Optional[str] = None
    test_metrics: Dict[str, Any] = field(default_factory=dict)
    metadata: Dict[str, Any] = field(default_factory=dict)

    def compute_hash(self) -> str:
        """Compute a content hash for change detection."""
        content = f"{self.intent_category}:{self.sub_intent}:{self.model_id}:{self.max_tokens}:{self.temperature}"
        return hashlib.sha256(content.encode()).hexdigest()[:16]


@dataclass
class ValidationResult:
    """Result of route validation."""
    valid: bool
    errors: List[str]
    warnings: List[str]


class RoutingTableManager:
    """
    Routing table lifecycle manager for MangaAssist.

    Manages the full lifecycle of routing configurations with safe
    rollout procedures, validation, and automatic rollback support.

    Features:
        - Schema validation for new route entries
        - Staged rollout with configurable canary traffic
        - Automatic metric-based promotion or rollback
        - Version history with audit trail
        - Bulk import/export for infrastructure-as-code

    Usage:
        manager = RoutingTableManager(
            config_table="MangaAssist-RoutingConfig",
            history_table="MangaAssist-RouteHistory",
        )
        entry = RouteEntry(...)
        result = manager.propose_route(entry)
        if result.valid:
            manager.stage_route(entry.route_id)
            manager.start_test(entry.route_id, traffic_pct=5.0)
    """

    VALID_MODELS = [
        "anthropic.claude-3-haiku-20240307-v1:0",
        "anthropic.claude-3-sonnet-20240229-v1:0",
    ]

    MAX_TOKENS_LIMITS = {
        "anthropic.claude-3-haiku-20240307-v1:0": 4096,
        "anthropic.claude-3-sonnet-20240229-v1:0": 4096,
    }

    def __init__(
        self,
        config_table: str = "MangaAssist-RoutingConfig",
        history_table: str = "MangaAssist-RouteHistory",
        region: str = "ap-northeast-1",
        redis_client=None,
    ):
        self.dynamodb = boto3.resource("dynamodb", region_name=region)
        self.config_table = self.dynamodb.Table(config_table)
        self.history_table = self.dynamodb.Table(history_table)
        self.redis = redis_client

        logger.info(
            "RoutingTableManager initialized | config=%s | history=%s",
            config_table,
            history_table,
        )

    def propose_route(self, entry: RouteEntry) -> ValidationResult:
        """
        Propose a new routing entry. Validates and stores as PROPOSED.

        Args:
            entry: The route entry to propose

        Returns:
            ValidationResult indicating if proposal was accepted
        """
        # Validate
        result = self._validate_entry(entry)
        if not result.valid:
            logger.warning(
                "Route proposal rejected | id=%s | errors=%s",
                entry.route_id,
                result.errors,
            )
            return result

        # Store as proposed
        entry.status = RouteStatus.PROPOSED
        entry.config_hash = entry.compute_hash()
        entry.created_at = datetime.utcnow().isoformat()
        entry.updated_at = entry.created_at

        self._store_entry(entry)
        self._record_history(entry, "proposed", f"New route proposed by {entry.created_by}")

        logger.info("Route proposed | id=%s | intent=%s:%s", entry.route_id, entry.intent_category, entry.sub_intent)
        return result

    def validate_route(self, route_id: str) -> ValidationResult:
        """Transition a proposed route to validated status."""
        entry = self._get_entry(route_id)
        if not entry:
            return ValidationResult(valid=False, errors=["Route not found"], warnings=[])
        if entry.status != RouteStatus.PROPOSED:
            return ValidationResult(valid=False, errors=[f"Route in {entry.status.value} state, expected proposed"], warnings=[])

        result = self._validate_entry(entry)
        if result.valid:
            entry.status = RouteStatus.VALIDATED
            entry.updated_at = datetime.utcnow().isoformat()
            self._store_entry(entry)
            self._record_history(entry, "validated", "Passed validation checks")

        return result

    def stage_route(self, route_id: str) -> bool:
        """Write a validated route to the staging configuration."""
        entry = self._get_entry(route_id)
        if not entry or entry.status != RouteStatus.VALIDATED:
            logger.error("Cannot stage route | id=%s | status=%s", route_id, entry.status.value if entry else "not_found")
            return False

        entry.status = RouteStatus.STAGED
        entry.updated_at = datetime.utcnow().isoformat()
        self._store_entry(entry)
        self._record_history(entry, "staged", "Route staged for testing")

        logger.info("Route staged | id=%s", route_id)
        return True

    def start_test(
        self,
        route_id: str,
        traffic_pct: float = 5.0,
        test_duration_minutes: int = 30,
    ) -> bool:
        """
        Begin canary testing for a staged route.

        Routes a percentage of matching traffic through the new configuration
        while monitoring key metrics against the existing active route.

        Args:
            route_id: The staged route to test
            traffic_pct: Percentage of traffic for the canary (1-50%)
            test_duration_minutes: How long to run the test

        Returns:
            True if test was started successfully
        """
        entry = self._get_entry(route_id)
        if not entry or entry.status != RouteStatus.STAGED:
            logger.error("Cannot start test | id=%s", route_id)
            return False

        if traffic_pct < 1 or traffic_pct > 50:
            logger.error("Traffic percentage must be between 1-50%%, got %.1f%%", traffic_pct)
            return False

        entry.status = RouteStatus.TESTING
        entry.test_traffic_pct = traffic_pct
        entry.test_start_time = datetime.utcnow().isoformat()
        entry.updated_at = datetime.utcnow().isoformat()
        entry.metadata["test_duration_minutes"] = test_duration_minutes
        entry.metadata["test_end_time"] = (
            datetime.utcnow() + timedelta(minutes=test_duration_minutes)
        ).isoformat()

        self._store_entry(entry)
        self._record_history(
            entry,
            "testing_started",
            f"Canary test started: {traffic_pct}% traffic for {test_duration_minutes}min",
        )

        # Publish test start event to Redis for real-time routing awareness
        if self.redis:
            self.redis.publish("routing:test_events", json.dumps({
                "event": "test_started",
                "route_id": route_id,
                "intent": entry.intent_category,
                "sub_intent": entry.sub_intent,
                "traffic_pct": traffic_pct,
                "model_id": entry.model_id,
            }))

        logger.info(
            "Route test started | id=%s | traffic=%.1f%% | duration=%dmin",
            route_id,
            traffic_pct,
            test_duration_minutes,
        )
        return True

    def evaluate_test(self, route_id: str) -> Dict[str, Any]:
        """
        Evaluate test results and determine whether to promote or rollback.

        Criteria:
            - P95 latency within 20% of baseline
            - Error rate below 2%
            - Quality score within 10% of baseline
            - Cost per query within expected bounds

        Returns:
            Evaluation result with recommendation
        """
        entry = self._get_entry(route_id)
        if not entry or entry.status != RouteStatus.TESTING:
            return {"status": "error", "reason": "Route not in testing state"}

        # Collect test metrics from Redis
        test_metrics = self._collect_test_metrics(entry)
        baseline_metrics = self._collect_baseline_metrics(entry)

        evaluation = {
            "route_id": route_id,
            "test_duration_minutes": entry.metadata.get("test_duration_minutes", 0),
            "test_queries": test_metrics.get("query_count", 0),
            "checks": [],
        }

        passed = True

        # P95 latency check
        if baseline_metrics.get("p95_latency_ms", 0) > 0:
            latency_ratio = test_metrics.get("p95_latency_ms", 0) / baseline_metrics["p95_latency_ms"]
            check = {
                "name": "p95_latency",
                "test_value": test_metrics.get("p95_latency_ms", 0),
                "baseline_value": baseline_metrics["p95_latency_ms"],
                "threshold": "within 20%",
                "passed": latency_ratio <= 1.2,
            }
            evaluation["checks"].append(check)
            if not check["passed"]:
                passed = False

        # Error rate check
        error_rate = test_metrics.get("error_rate_pct", 0)
        check = {
            "name": "error_rate",
            "test_value": error_rate,
            "threshold": "< 2%",
            "passed": error_rate < 2.0,
        }
        evaluation["checks"].append(check)
        if not check["passed"]:
            passed = False

        # Quality score check
        if baseline_metrics.get("quality_score", 0) > 0:
            quality_ratio = test_metrics.get("quality_score", 0) / baseline_metrics["quality_score"]
            check = {
                "name": "quality_score",
                "test_value": test_metrics.get("quality_score", 0),
                "baseline_value": baseline_metrics["quality_score"],
                "threshold": "within 10%",
                "passed": quality_ratio >= 0.9,
            }
            evaluation["checks"].append(check)
            if not check["passed"]:
                passed = False

        evaluation["passed"] = passed
        evaluation["recommendation"] = "promote" if passed else "rollback"

        entry.test_metrics = evaluation
        entry.updated_at = datetime.utcnow().isoformat()
        self._store_entry(entry)

        logger.info(
            "Test evaluated | id=%s | passed=%s | recommendation=%s",
            route_id,
            passed,
            evaluation["recommendation"],
        )

        return evaluation

    def promote_route(self, route_id: str) -> bool:
        """Promote a tested route to active status, deprecating the current active route."""
        entry = self._get_entry(route_id)
        if not entry or entry.status != RouteStatus.TESTING:
            return False

        # Deprecate current active route for this intent
        self._deprecate_active_route(entry.intent_category, entry.sub_intent)

        entry.status = RouteStatus.ACTIVE
        entry.test_traffic_pct = 0.0
        entry.updated_at = datetime.utcnow().isoformat()
        self._store_entry(entry)
        self._record_history(entry, "promoted", "Route promoted to active after successful test")

        # Invalidate caches
        self._invalidate_caches(entry)

        logger.info("Route promoted | id=%s | intent=%s:%s", route_id, entry.intent_category, entry.sub_intent)
        return True

    def rollback_route(self, route_id: str, reason: str = "Test failed") -> bool:
        """Rollback a testing route."""
        entry = self._get_entry(route_id)
        if not entry:
            return False

        entry.status = RouteStatus.ROLLED_BACK
        entry.test_traffic_pct = 0.0
        entry.updated_at = datetime.utcnow().isoformat()
        self._store_entry(entry)
        self._record_history(entry, "rolled_back", reason)

        logger.info("Route rolled back | id=%s | reason=%s", route_id, reason)
        return True

    def get_active_routes(self) -> List[RouteEntry]:
        """Retrieve all currently active routes."""
        try:
            response = self.config_table.scan(
                FilterExpression=Attr("status").eq("active"),
            )
            entries = []
            for item in response.get("Items", []):
                entries.append(self._item_to_entry(item))
            return entries
        except Exception as e:
            logger.error("Failed to get active routes | error=%s", e)
            return []

    def export_config(self) -> Dict[str, Any]:
        """Export all active routes as JSON for infrastructure-as-code."""
        routes = self.get_active_routes()
        return {
            "version": "1.0",
            "exported_at": datetime.utcnow().isoformat(),
            "routes": [
                {
                    "intent_category": r.intent_category,
                    "sub_intent": r.sub_intent,
                    "model_id": r.model_id,
                    "fallback_model_id": r.fallback_model_id,
                    "max_tokens": r.max_tokens,
                    "temperature": r.temperature,
                }
                for r in routes
            ],
        }

    # --- Private helpers ---

    def _validate_entry(self, entry: RouteEntry) -> ValidationResult:
        """Validate a route entry against schema and policy rules."""
        errors: List[str] = []
        warnings: List[str] = []

        if not entry.intent_category:
            errors.append("intent_category is required")
        if entry.model_id not in self.VALID_MODELS:
            errors.append(f"Invalid model_id: {entry.model_id}")
        if entry.fallback_model_id not in self.VALID_MODELS:
            errors.append(f"Invalid fallback_model_id: {entry.fallback_model_id}")
        if entry.model_id == entry.fallback_model_id:
            warnings.append("model_id and fallback_model_id are the same — no failover diversity")

        max_limit = self.MAX_TOKENS_LIMITS.get(entry.model_id, 4096)
        if entry.max_tokens > max_limit:
            errors.append(f"max_tokens {entry.max_tokens} exceeds limit {max_limit} for {entry.model_id}")
        if entry.temperature < 0 or entry.temperature > 1:
            errors.append(f"temperature must be 0-1, got {entry.temperature}")

        return ValidationResult(valid=len(errors) == 0, errors=errors, warnings=warnings)

    def _store_entry(self, entry: RouteEntry) -> None:
        """Store a route entry in DynamoDB."""
        self.config_table.put_item(Item={
            "PK": f"ROUTE#{entry.intent_category}",
            "SK": f"VERSION#{entry.version}#{entry.sub_intent}",
            "route_id": entry.route_id,
            "intent_category": entry.intent_category,
            "sub_intent": entry.sub_intent,
            "model_id": entry.model_id,
            "fallback_model_id": entry.fallback_model_id,
            "max_tokens": entry.max_tokens,
            "temperature": str(entry.temperature),
            "status": entry.status.value,
            "version": entry.version,
            "config_hash": entry.config_hash,
            "test_traffic_pct": str(entry.test_traffic_pct),
            "created_at": entry.created_at,
            "updated_at": entry.updated_at,
            "created_by": entry.created_by,
            "metadata": json.dumps(entry.metadata),
            "test_metrics": json.dumps(entry.test_metrics),
        })

    def _get_entry(self, route_id: str) -> Optional[RouteEntry]:
        """Retrieve a route entry by ID."""
        try:
            response = self.config_table.scan(
                FilterExpression=Attr("route_id").eq(route_id),
                Limit=1,
            )
            items = response.get("Items", [])
            if items:
                return self._item_to_entry(items[0])
        except Exception as e:
            logger.error("Failed to get route | id=%s | error=%s", route_id, e)
        return None

    def _item_to_entry(self, item: Dict[str, Any]) -> RouteEntry:
        """Convert a DynamoDB item to a RouteEntry."""
        return RouteEntry(
            route_id=item.get("route_id", ""),
            intent_category=item.get("intent_category", ""),
            sub_intent=item.get("sub_intent", "default"),
            model_id=item.get("model_id", ""),
            fallback_model_id=item.get("fallback_model_id", ""),
            max_tokens=int(item.get("max_tokens", 1024)),
            temperature=float(item.get("temperature", 0.3)),
            status=RouteStatus(item.get("status", "proposed")),
            version=int(item.get("version", 1)),
            created_at=item.get("created_at", ""),
            updated_at=item.get("updated_at", ""),
            created_by=item.get("created_by", "system"),
            config_hash=item.get("config_hash", ""),
            test_traffic_pct=float(item.get("test_traffic_pct", 0)),
            metadata=json.loads(item.get("metadata", "{}")),
            test_metrics=json.loads(item.get("test_metrics", "{}")),
        )

    def _record_history(self, entry: RouteEntry, event: str, detail: str) -> None:
        """Record a history event for audit trail."""
        try:
            self.history_table.put_item(Item={
                "PK": f"HISTORY#{entry.route_id}",
                "SK": f"EVENT#{datetime.utcnow().isoformat()}",
                "route_id": entry.route_id,
                "event": event,
                "detail": detail,
                "status": entry.status.value,
                "model_id": entry.model_id,
                "timestamp": datetime.utcnow().isoformat(),
                "ttl": int((datetime.utcnow() + timedelta(days=90)).timestamp()),
            })
        except Exception as e:
            logger.warning("History recording failed | id=%s | error=%s", entry.route_id, e)

    def _deprecate_active_route(self, intent: str, sub_intent: str) -> None:
        """Deprecate the current active route for an intent."""
        try:
            response = self.config_table.scan(
                FilterExpression=(
                    Attr("intent_category").eq(intent)
                    & Attr("sub_intent").eq(sub_intent)
                    & Attr("status").eq("active")
                ),
            )
            for item in response.get("Items", []):
                self.config_table.update_item(
                    Key={"PK": item["PK"], "SK": item["SK"]},
                    UpdateExpression="SET #s = :s, updated_at = :t",
                    ExpressionAttributeNames={"#s": "status"},
                    ExpressionAttributeValues={
                        ":s": "deprecated",
                        ":t": datetime.utcnow().isoformat(),
                    },
                )
        except Exception as e:
            logger.error("Failed to deprecate active route | error=%s", e)

    def _invalidate_caches(self, entry: RouteEntry) -> None:
        """Invalidate route caches after promotion."""
        if self.redis:
            try:
                cache_key = f"route:{entry.intent_category}:{entry.sub_intent}"
                self.redis.delete(cache_key)
                self.redis.publish("routing:cache_invalidation", json.dumps({
                    "intent": entry.intent_category,
                    "sub_intent": entry.sub_intent,
                }))
            except Exception as e:
                logger.warning("Cache invalidation failed | error=%s", e)

    def _collect_test_metrics(self, entry: RouteEntry) -> Dict[str, Any]:
        """Collect metrics for the test route from Redis."""
        # Placeholder — in production, aggregate from Redis sorted sets
        return {
            "query_count": 0,
            "p95_latency_ms": 0,
            "error_rate_pct": 0,
            "quality_score": 0,
        }

    def _collect_baseline_metrics(self, entry: RouteEntry) -> Dict[str, Any]:
        """Collect baseline metrics from the active route."""
        return {
            "p95_latency_ms": 0,
            "quality_score": 0,
        }

4. A/B Routing for Model Comparison

A/B routing sends a controlled split of traffic to different models for the same intent, allowing direct comparison of model performance on identical query distributions.

4.1 A/B Routing Flow

flowchart TB
    QUERY([Incoming Query]) --> HASH[Hash Session ID]
    HASH --> BUCKET{Bucket Assignment}

    BUCKET -->|0-79| CONTROL[Control Group: Haiku<br/>80% traffic]
    BUCKET -->|80-94| TREATMENT_A[Treatment A: Sonnet<br/>15% traffic]
    BUCKET -->|95-99| TREATMENT_B[Treatment B: Sonnet + Prompt V2<br/>5% traffic]

    CONTROL --> INVOKE_C[Invoke Haiku]
    TREATMENT_A --> INVOKE_A[Invoke Sonnet]
    TREATMENT_B --> INVOKE_B[Invoke Sonnet + New Prompt]

    INVOKE_C --> RESPOND([Response to User])
    INVOKE_A --> RESPOND
    INVOKE_B --> RESPOND

    INVOKE_C --> LOG_C[Log: group=control]
    INVOKE_A --> LOG_A[Log: group=treatment_a]
    INVOKE_B --> LOG_B[Log: group=treatment_b]

    LOG_C --> ANALYSIS[Statistical Analysis]
    LOG_A --> ANALYSIS
    LOG_B --> ANALYSIS

    style CONTROL fill:#4CAF50,color:#fff
    style TREATMENT_A fill:#2196F3,color:#fff
    style TREATMENT_B fill:#FF9800,color:#fff

4.2 A/B Router Implementation

"""
ab_router.py — MangaAssist A/B Model Routing

Deterministic session-based traffic splitting for controlled model comparison.
Uses consistent hashing to ensure the same user always sees the same variant.
"""

import hashlib
import logging
import time
from typing import Dict, Any, List, Optional
from dataclasses import dataclass, field

logger = logging.getLogger(__name__)


@dataclass
class ABVariant:
    """A single variant in an A/B experiment."""
    name: str
    model_id: str
    traffic_pct: float           # 0-100
    bucket_start: int = 0        # Computed: inclusive
    bucket_end: int = 0          # Computed: exclusive
    prompt_template: Optional[str] = None
    max_tokens: int = 1024
    temperature: float = 0.3
    metadata: Dict[str, Any] = field(default_factory=dict)


@dataclass
class ABExperiment:
    """An active A/B experiment configuration."""
    experiment_id: str
    name: str
    description: str
    intent_filter: str           # Which intents this experiment applies to ("*" for all)
    variants: List[ABVariant]
    start_time: str
    end_time: Optional[str] = None
    enabled: bool = True
    min_sample_size: int = 1000
    confidence_level: float = 0.95


@dataclass
class ABAssignment:
    """Result of assigning a session to an A/B variant."""
    experiment_id: str
    variant_name: str
    model_id: str
    bucket: int
    max_tokens: int
    temperature: float
    prompt_template: Optional[str]
    metadata: Dict[str, Any] = field(default_factory=dict)


class ABRouter:
    """
    A/B model routing for MangaAssist experiments.

    Deterministically assigns sessions to experiment variants using
    consistent hashing. Ensures:
        - Same session always gets same variant (no flickering)
        - Traffic splits are precise at scale
        - Multiple concurrent experiments supported
        - Graceful degradation if experiment config is invalid

    Usage:
        router = ABRouter()
        router.add_experiment(ABExperiment(
            experiment_id="exp_001",
            name="Sonnet Quality Test",
            variants=[
                ABVariant(name="control", model_id="haiku", traffic_pct=80),
                ABVariant(name="treatment", model_id="sonnet", traffic_pct=20),
            ],
            ...
        ))

        assignment = router.assign("session_123", intent="product_detail")
        # assignment.variant_name → "control" or "treatment"
    """

    BUCKET_COUNT = 1000  # 0.1% granularity

    def __init__(self, redis_client=None):
        self.experiments: Dict[str, ABExperiment] = {}
        self.redis = redis_client
        self._assignment_cache: Dict[str, ABAssignment] = {}
        logger.info("ABRouter initialized | buckets=%d", self.BUCKET_COUNT)

    def add_experiment(self, experiment: ABExperiment) -> bool:
        """
        Register an A/B experiment.

        Validates traffic splits sum to 100% and computes bucket ranges.
        """
        # Validate traffic splits
        total_traffic = sum(v.traffic_pct for v in experiment.variants)
        if abs(total_traffic - 100.0) > 0.1:
            logger.error(
                "Invalid traffic split | experiment=%s | total=%.1f%%",
                experiment.experiment_id,
                total_traffic,
            )
            return False

        # Compute bucket ranges
        current_bucket = 0
        for variant in experiment.variants:
            bucket_count = int(variant.traffic_pct / 100.0 * self.BUCKET_COUNT)
            variant.bucket_start = current_bucket
            variant.bucket_end = current_bucket + bucket_count
            current_bucket = variant.bucket_end

        # Handle rounding — last variant gets remaining buckets
        if experiment.variants:
            experiment.variants[-1].bucket_end = self.BUCKET_COUNT

        self.experiments[experiment.experiment_id] = experiment
        logger.info(
            "Experiment registered | id=%s | variants=%d | intents=%s",
            experiment.experiment_id,
            len(experiment.variants),
            experiment.intent_filter,
        )
        return True

    def assign(
        self,
        session_id: str,
        intent: str = "*",
        experiment_id: Optional[str] = None,
    ) -> Optional[ABAssignment]:
        """
        Assign a session to an A/B variant.

        Uses consistent hashing to deterministically map sessions
        to buckets, ensuring stable assignments across requests.

        Args:
            session_id: Unique session identifier
            intent: Current query intent for experiment filtering
            experiment_id: Specific experiment (or None for auto-match)

        Returns:
            ABAssignment if an experiment matches, None otherwise
        """
        # Check cache for existing assignment
        cache_key = f"{session_id}:{intent}:{experiment_id or 'auto'}"
        if cache_key in self._assignment_cache:
            return self._assignment_cache[cache_key]

        # Find applicable experiment
        experiment = self._find_experiment(intent, experiment_id)
        if not experiment:
            return None

        # Compute deterministic bucket
        bucket = self._compute_bucket(session_id, experiment.experiment_id)

        # Find variant for bucket
        variant = self._find_variant(experiment, bucket)
        if not variant:
            logger.warning("No variant found for bucket %d | experiment=%s", bucket, experiment.experiment_id)
            return None

        assignment = ABAssignment(
            experiment_id=experiment.experiment_id,
            variant_name=variant.name,
            model_id=variant.model_id,
            bucket=bucket,
            max_tokens=variant.max_tokens,
            temperature=variant.temperature,
            prompt_template=variant.prompt_template,
            metadata={
                "experiment_name": experiment.name,
                "traffic_pct": variant.traffic_pct,
            },
        )

        self._assignment_cache[cache_key] = assignment

        # Record assignment in Redis for analytics
        if self.redis:
            try:
                self.redis.hincrby(
                    f"ab:{experiment.experiment_id}:counts",
                    variant.name,
                    1,
                )
            except Exception:
                pass

        logger.debug(
            "A/B assignment | session=%s | experiment=%s | variant=%s | bucket=%d",
            session_id[:8],
            experiment.experiment_id,
            variant.name,
            bucket,
        )

        return assignment

    def record_outcome(
        self,
        experiment_id: str,
        variant_name: str,
        session_id: str,
        metrics: Dict[str, float],
    ) -> None:
        """
        Record outcome metrics for a variant assignment.

        Args:
            experiment_id: The experiment ID
            variant_name: The variant name
            session_id: The session ID
            metrics: Outcome metrics (latency_ms, quality_score, user_satisfied)
        """
        if not self.redis:
            return

        try:
            key_prefix = f"ab:{experiment_id}:{variant_name}"

            pipe = self.redis.pipeline()
            for metric_name, value in metrics.items():
                pipe.rpush(f"{key_prefix}:{metric_name}", str(value))
                pipe.expire(f"{key_prefix}:{metric_name}", 86400 * 7)  # 7 days
            pipe.execute()

        except Exception as e:
            logger.warning("Failed to record A/B outcome | error=%s", e)

    def _compute_bucket(self, session_id: str, experiment_id: str) -> int:
        """Compute a deterministic bucket from session ID and experiment."""
        hash_input = f"{session_id}:{experiment_id}"
        hash_bytes = hashlib.sha256(hash_input.encode()).digest()
        hash_int = int.from_bytes(hash_bytes[:4], byteorder="big")
        return hash_int % self.BUCKET_COUNT

    def _find_experiment(
        self, intent: str, experiment_id: Optional[str]
    ) -> Optional[ABExperiment]:
        """Find an applicable experiment for the given intent."""
        if experiment_id:
            exp = self.experiments.get(experiment_id)
            return exp if exp and exp.enabled else None

        for exp in self.experiments.values():
            if not exp.enabled:
                continue
            if exp.intent_filter == "*" or intent in exp.intent_filter.split(","):
                return exp
        return None

    def _find_variant(
        self, experiment: ABExperiment, bucket: int
    ) -> Optional[ABVariant]:
        """Find the variant that owns the given bucket."""
        for variant in experiment.variants:
            if variant.bucket_start <= bucket < variant.bucket_end:
                return variant
        return None

5. Shadow Routing for Safe Testing

Shadow routing duplicates a query to a secondary model without affecting the user response. It allows testing new models or configurations against production traffic with zero user impact.

5.1 Shadow Routing Architecture

flowchart TB
    QUERY([Incoming Query]) --> PRIMARY[Primary Route<br/>Haiku — serves user]
    QUERY --> SHADOW[Shadow Route<br/>Sonnet — discarded]

    PRIMARY --> RESPONSE([User Response<br/>From Primary Only])

    PRIMARY --> LOG_P[Log Primary Metrics]
    SHADOW --> LOG_S[Log Shadow Metrics]

    LOG_P --> COMPARE[Metric Comparison<br/>Lambda Aggregator]
    LOG_S --> COMPARE

    COMPARE --> REPORT[Shadow Report<br/>Quality Delta + Cost Impact]

    style PRIMARY fill:#4CAF50,color:#fff
    style SHADOW fill:#9E9E9E,color:#fff
    style RESPONSE fill:#4CAF50,color:#fff

5.2 Shadow Router Implementation

"""
shadow_router.py — MangaAssist Shadow Model Routing

Sends queries to a secondary (shadow) model in parallel with the primary
model. The shadow response is discarded — only metrics are collected.
Enables safe, zero-impact production testing of new models.
"""

import asyncio
import time
import json
import logging
from typing import Dict, Any, Optional, List, Tuple
from dataclasses import dataclass, field
from concurrent.futures import ThreadPoolExecutor

import boto3

logger = logging.getLogger(__name__)


@dataclass
class ShadowConfig:
    """Configuration for shadow routing."""
    shadow_model_id: str
    shadow_traffic_pct: float = 10.0     # % of traffic to shadow
    shadow_max_tokens: int = 1024
    shadow_temperature: float = 0.3
    shadow_timeout_ms: int = 5000        # Timeout for shadow invocation
    collect_response: bool = True         # Whether to store shadow responses
    compare_quality: bool = True          # Whether to run quality comparison
    enabled: bool = True


@dataclass
class ShadowResult:
    """Result of a shadow routing invocation."""
    primary_model_id: str
    shadow_model_id: str
    primary_latency_ms: float
    shadow_latency_ms: float
    primary_tokens: Dict[str, int]       # {"input": N, "output": N}
    shadow_tokens: Dict[str, int]
    primary_cost_usd: float
    shadow_cost_usd: float
    quality_comparison: Optional[Dict[str, Any]] = None
    shadow_error: Optional[str] = None


class ShadowRouter:
    """
    Shadow model routing for MangaAssist.

    Duplicates queries to a shadow model alongside the primary model.
    The primary response is served to the user immediately; the shadow
    response is used only for metric collection and comparison.

    Key Properties:
        - Zero user impact: shadow failures never affect responses
        - Async execution: shadow invocation doesn't add latency to primary
        - Configurable sampling: only shadow a percentage of traffic
        - Quality comparison: automated scoring of primary vs shadow responses
        - Cost tracking: full cost accounting including shadow overhead

    WARNING: Shadow routing costs real money. At 10% shadow rate with Sonnet
    as shadow model on 1M queries/day: ~$114/day additional cost.

    Usage:
        shadow = ShadowRouter(
            config=ShadowConfig(
                shadow_model_id="anthropic.claude-3-sonnet-20240229-v1:0",
                shadow_traffic_pct=10.0,
            ),
        )

        # In the main routing path
        primary_response = invoke_primary(query)
        shadow.maybe_invoke_shadow(query, primary_response)
    """

    COSTS = {
        "anthropic.claude-3-haiku-20240307-v1:0": {"input": 0.25, "output": 1.25},
        "anthropic.claude-3-sonnet-20240229-v1:0": {"input": 3.00, "output": 15.00},
    }

    def __init__(
        self,
        config: ShadowConfig,
        bedrock_client=None,
        redis_client=None,
        region: str = "ap-northeast-1",
    ):
        self.config = config
        self.bedrock = bedrock_client or boto3.client(
            "bedrock-runtime", region_name=region
        )
        self.redis = redis_client
        self._executor = ThreadPoolExecutor(
            max_workers=4,
            thread_name_prefix="shadow-route",
        )

        # Sampling state
        self._sample_counter = 0
        self._sample_interval = max(1, int(100 / config.shadow_traffic_pct))

        # Stats
        self._total_shadowed = 0
        self._total_shadow_errors = 0
        self._total_shadow_cost_usd = 0.0

        logger.info(
            "ShadowRouter initialized | shadow_model=%s | traffic=%.1f%% | interval=%d",
            config.shadow_model_id,
            config.shadow_traffic_pct,
            self._sample_interval,
        )

    def should_shadow(self) -> bool:
        """Determine if this request should be shadowed (sampling)."""
        if not self.config.enabled:
            return False
        self._sample_counter += 1
        return self._sample_counter % self._sample_interval == 0

    def invoke_shadow_async(
        self,
        query: str,
        primary_model_id: str,
        primary_response: str,
        primary_latency_ms: float,
        primary_tokens: Dict[str, int],
        conversation_history: Optional[List[Dict[str, str]]] = None,
        metadata: Optional[Dict[str, Any]] = None,
    ) -> None:
        """
        Fire-and-forget shadow invocation.

        This method returns immediately. The shadow invocation runs
        in a background thread. Results are pushed to Redis for
        offline analysis.
        """
        if not self.should_shadow():
            return

        self._executor.submit(
            self._execute_shadow,
            query,
            primary_model_id,
            primary_response,
            primary_latency_ms,
            primary_tokens,
            conversation_history or [],
            metadata or {},
        )

    def _execute_shadow(
        self,
        query: str,
        primary_model_id: str,
        primary_response: str,
        primary_latency_ms: float,
        primary_tokens: Dict[str, int],
        history: List[Dict[str, str]],
        metadata: Dict[str, Any],
    ) -> None:
        """Execute shadow invocation and collect metrics (runs in background)."""
        start = time.monotonic()
        shadow_response = None
        shadow_tokens = {"input": 0, "output": 0}
        shadow_error = None

        try:
            # Build messages
            messages = []
            for msg in history[-5:]:  # Last 5 turns for context
                messages.append({
                    "role": msg.get("role", "user"),
                    "content": msg.get("content", ""),
                })
            messages.append({"role": "user", "content": query})

            # Invoke shadow model
            body = json.dumps({
                "anthropic_version": "bedrock-2023-05-31",
                "max_tokens": self.config.shadow_max_tokens,
                "messages": messages,
                "temperature": self.config.shadow_temperature,
            })

            response = self.bedrock.invoke_model(
                modelId=self.config.shadow_model_id,
                body=body,
                contentType="application/json",
                accept="application/json",
            )

            result = json.loads(response["body"].read())
            shadow_response = result.get("content", [{}])[0].get("text", "")
            usage = result.get("usage", {})
            shadow_tokens = {
                "input": usage.get("input_tokens", 0),
                "output": usage.get("output_tokens", 0),
            }

        except Exception as e:
            shadow_error = str(e)
            self._total_shadow_errors += 1
            logger.debug("Shadow invocation failed | error=%s", e)

        shadow_latency = (time.monotonic() - start) * 1000

        # Calculate costs
        primary_cost = self._calculate_cost(primary_model_id, primary_tokens)
        shadow_cost = self._calculate_cost(self.config.shadow_model_id, shadow_tokens)
        self._total_shadow_cost_usd += shadow_cost
        self._total_shadowed += 1

        # Build result
        result = ShadowResult(
            primary_model_id=primary_model_id,
            shadow_model_id=self.config.shadow_model_id,
            primary_latency_ms=primary_latency_ms,
            shadow_latency_ms=shadow_latency,
            primary_tokens=primary_tokens,
            shadow_tokens=shadow_tokens,
            primary_cost_usd=primary_cost,
            shadow_cost_usd=shadow_cost,
            shadow_error=shadow_error,
        )

        # Quality comparison if both responses available
        if shadow_response and self.config.compare_quality:
            result.quality_comparison = self._compare_quality(
                query, primary_response, shadow_response
            )

        # Store results
        self._store_shadow_result(result, metadata)

        logger.debug(
            "Shadow complete | primary=%.0fms | shadow=%.0fms | cost_delta=$%.6f",
            primary_latency_ms,
            shadow_latency,
            shadow_cost - primary_cost,
        )

    def _compare_quality(
        self,
        query: str,
        primary_response: str,
        shadow_response: str,
    ) -> Dict[str, Any]:
        """
        Compare quality between primary and shadow responses.

        Simple heuristic comparison (no LLM-as-judge to avoid cost spiral):
            - Response length ratio
            - Keyword overlap
            - Entity coverage
        """
        primary_words = set(primary_response.lower().split())
        shadow_words = set(shadow_response.lower().split())

        overlap = len(primary_words & shadow_words)
        total = max(len(primary_words | shadow_words), 1)
        similarity = overlap / total

        return {
            "response_length_primary": len(primary_response),
            "response_length_shadow": len(shadow_response),
            "length_ratio": len(shadow_response) / max(len(primary_response), 1),
            "word_overlap_pct": round(similarity * 100, 1),
            "shadow_longer": len(shadow_response) > len(primary_response),
        }

    def _calculate_cost(self, model_id: str, tokens: Dict[str, int]) -> float:
        """Calculate cost in USD for a model invocation."""
        costs = self.COSTS.get(model_id, {"input": 0, "output": 0})
        return (
            tokens.get("input", 0) * costs["input"]
            + tokens.get("output", 0) * costs["output"]
        ) / 1_000_000

    def _store_shadow_result(
        self, result: ShadowResult, metadata: Dict[str, Any]
    ) -> None:
        """Store shadow comparison result in Redis for analysis."""
        if not self.redis:
            return

        try:
            record = {
                "primary_model": result.primary_model_id,
                "shadow_model": result.shadow_model_id,
                "primary_latency_ms": result.primary_latency_ms,
                "shadow_latency_ms": result.shadow_latency_ms,
                "primary_cost_usd": result.primary_cost_usd,
                "shadow_cost_usd": result.shadow_cost_usd,
                "shadow_error": result.shadow_error,
                "quality": result.quality_comparison,
                "timestamp": int(time.time()),
            }

            key = f"shadow:results:{int(time.time())}"
            self.redis.setex(key, 86400, json.dumps(record))

            # Increment aggregate counters
            self.redis.incr("shadow:total_count")
            if result.shadow_error:
                self.redis.incr("shadow:error_count")
            self.redis.incrbyfloat("shadow:total_shadow_cost", result.shadow_cost_usd)

        except Exception as e:
            logger.warning("Failed to store shadow result | error=%s", e)

    def get_shadow_stats(self) -> Dict[str, Any]:
        """Return shadow routing statistics."""
        return {
            "total_shadowed": self._total_shadowed,
            "total_errors": self._total_shadow_errors,
            "total_shadow_cost_usd": round(self._total_shadow_cost_usd, 4),
            "error_rate_pct": (
                (self._total_shadow_errors / max(self._total_shadowed, 1)) * 100
            ),
            "config": {
                "shadow_model": self.config.shadow_model_id,
                "traffic_pct": self.config.shadow_traffic_pct,
                "enabled": self.config.enabled,
            },
        }

6. Shadow Routing Cost Impact Analysis

At MangaAssist scale (1M messages/day), shadow routing cost must be carefully managed:

Shadow Config Shadow Rate Shadow Model Shadow Queries/Day Additional Daily Cost Monthly Overhead
Conservative 5% Haiku 50,000 $31.25 $937
Standard 10% Haiku 100,000 $62.50 $1,875
Standard 10% Sonnet 100,000 $1,140.00 $34,200
Aggressive 25% Sonnet 250,000 $2,850.00 $85,500
Full Mirror 100% Sonnet 1,000,000 $11,400.00 $342,000

Recommendation: Shadow at 10% with Haiku as shadow when testing Haiku config changes. Shadow at 5% with Sonnet only during time-boxed quality comparison studies (maximum 48 hours).


7. Comparison: A/B vs Shadow Routing

Dimension A/B Routing Shadow Routing
User impact Users see the variant response Zero — shadow response is discarded
Metrics quality Real user satisfaction data Synthetic comparison only
Cost Same as production (variant replaces primary) Additional cost (both models invoked)
Latency impact None (user gets variant directly) None (shadow is async fire-and-forget)
Risk Medium — bad variant affects real users Zero — shadow failures are invisible
Use case Comparing final model performance Pre-launch validation of new models
Duration Days to weeks (need statistical significance) Hours to days (no user impact)
Sample size Limited by traffic split Can shadow 100% if budget allows
Rollback Change traffic split Disable shadow config

8. Integration Points Summary

flowchart TB
    subgraph Core["Core Routing"]
        SR[StaticRouter]
        DR[DynamicModelRouter]
        MS[MetricBasedSelector]
    end

    subgraph Analysis["Analysis"]
        CS[ComplexityScorer]
        ME[MetricEmitter]
    end

    subgraph Testing["Testing & Experimentation"]
        AB[ABRouter]
        SH[ShadowRouter]
        RTM[RoutingTableManager]
    end

    subgraph Edge["API Gateway"]
        AGT[APIGatewayRouteTransformer]
    end

    AGT -->|Fast-path intents| SR
    AGT -->|Complex/ambiguous| DR
    DR --> CS
    DR --> MS
    MS --> ME
    RTM -->|Config updates| SR
    RTM -->|Canary traffic| AB
    AB -->|Experiment traffic| DR
    SH -->|Parallel shadow| ME

    style SR fill:#4CAF50,color:#fff
    style DR fill:#2196F3,color:#fff
    style MS fill:#FF9800,color:#fff
    style CS fill:#9C27B0,color:#fff

9. References

Resource Link
Amazon Bedrock Inference APIs https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InvokeModel.html
Claude 3 Model Pricing https://aws.amazon.com/bedrock/pricing/
ElastiCache Redis Sorted Sets https://docs.aws.amazon.com/AmazonElastiCache/latest/red-ug/Sorted-Sets.html
DynamoDB Streams https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Streams.html
Kinesis Data Streams https://docs.aws.amazon.com/streams/latest/dev/introduction.html
A/B Testing Statistical Significance https://docs.aws.amazon.com/wellarchitected/latest/machine-learning-lens/experimentation.html
CloudWatch Embedded Metric Format https://docs.aws.amazon.com/AmazonCloudWatch/latest/monitoring/CloudWatch_Embedded_Metric_Format.html