LOCAL PREVIEW View on GitHub

Skill 2.1.4 — Model Coordination Architecture

MangaAssist context: JP Manga store chatbot on AWS — Bedrock Claude 3 (Sonnet at $3/$15 per 1M tokens input/output, Haiku at $0.25/$1.25), OpenSearch Serverless (vector store), DynamoDB (sessions/products), ECS Fargate (orchestrator), API Gateway WebSocket, ElastiCache Redis. Target: useful answer in under 3 seconds, 1M messages/day scale.


Table of Contents

  1. Mind Map: Model Coordination Systems
  2. Core Concepts
  3. Architecture Diagram: Multi-Model Orchestration
  4. Specialized FM Roles in MangaAssist
  5. Production Code: ModelRouter
  6. Production Code: EnsembleCoordinator
  7. Production Code: ModelSelectionFramework
  8. Comparison Table: Coordination Strategies
  9. Parallel vs Sequential Coordination
  10. Cost-Performance Tradeoff Analysis
  11. Interview Q&A
  12. Key Takeaways

Mind Map: Model Coordination Systems

Model Coordination Systems (Skill 2.1.4)
│
├── Specialized FMs
│   ├── Claude 3 Sonnet → Complex reasoning, recommendations, creative responses
│   ├── Claude 3 Haiku  → Classification, intent detection, simple Q&A
│   ├── Titan Embeddings → Semantic search, vector generation, similarity
│   ├── Role assignment   → Each model handles what it does best
│   └── Capability matrix → Map task types to model strengths
│
├── Model Ensembles
│   ├── Voting (majority/plurality)
│   │   ├── Multiple models answer same question
│   │   ├── Select most common answer
│   │   └── Good for classification tasks
│   ├── Weighted Averaging
│   │   ├── Assign confidence weights per model
│   │   ├── Combine scores proportionally
│   │   └── Good for scoring/ranking tasks
│   ├── Stacking (meta-learner)
│   │   ├── Base models generate features
│   │   ├── Meta-model combines predictions
│   │   └── Good for complex decision tasks
│   └── Cascade (escalation)
│       ├── Start with cheapest model
│       ├── Escalate if confidence is low
│       └── Good for cost optimization
│
├── Custom Aggregation
│   ├── Confidence-based selection
│   │   ├── Each model reports confidence score
│   │   ├── Select highest-confidence response
│   │   └── Fallback to ensemble if all low
│   ├── Response merging
│   │   ├── Combine complementary outputs
│   │   ├── Merge facts from multiple models
│   │   └── Deduplicate and reconcile conflicts
│   ├── Quality scoring
│   │   ├── Evaluate response coherence
│   │   ├── Check factual consistency
│   │   └── Score relevance to query
│   └── Cost-aware aggregation
│       ├── Factor in token costs per model
│       ├── Budget-constrained selection
│       └── Dynamic cost ceiling
│
├── Selection Frameworks
│   ├── Complexity-based routing
│   │   ├── Classify query difficulty (simple/medium/complex)
│   │   ├── Route simple → Haiku, complex → Sonnet
│   │   └── Use lightweight classifier for routing
│   ├── Cost-performance tradeoff
│   │   ├── Define SLAs per query type
│   │   ├── Optimize model choice for budget
│   │   └── Monitor cost-per-quality-unit
│   ├── Latency-aware selection
│   │   ├── Track p50/p95/p99 per model
│   │   ├── Route time-sensitive queries to faster models
│   │   └── Parallel dispatch for latency hedging
│   └── Domain-specific routing
│       ├── Manga recommendations → Sonnet
│       ├── Order status → Haiku
│       └── Product search → Titan embeddings + Haiku
│
└── Parallel vs Sequential Coordination
    ├── Parallel fan-out
    │   ├── Send query to all models simultaneously
    │   ├── Aggregate responses
    │   └── Tradeoff: higher cost, lower latency
    ├── Sequential pipeline
    │   ├── Output of model A feeds model B
    │   ├── Each stage refines the answer
    │   └── Tradeoff: lower cost, higher latency
    ├── Conditional branching
    │   ├── Route based on intermediate results
    │   ├── Only invoke expensive models when needed
    │   └── Tradeoff: balanced cost/latency
    └── Hybrid approaches
        ├── Parallel classification + sequential reasoning
        ├── Cached fast path + on-demand deep path
        └── Best for production systems like MangaAssist

Core Concepts

What Is Model Coordination?

Model coordination is the practice of orchestrating multiple foundation models to work together on a task rather than relying on a single model for everything. In MangaAssist, we use three distinct models — each with different strengths, speeds, and costs — and coordinate them through routing, ensembles, and aggregation to deliver the best possible answer within our 3-second SLA.

Why Not Just Use One Model?

Challenge Single Model (Sonnet Only) Coordinated System
Cost at 1M msgs/day ~$45,000/day (all Sonnet) ~$8,000/day (80% Haiku routing)
Latency p95 2.8s (Sonnet for everything) 0.9s (Haiku for simple queries)
Quality on simple queries Overkill — wastes capacity Right-sized model per task
Quality on complex queries Good but no verification Cross-model validation
Single point of failure Yes Graceful degradation

The Three Pillars of Model Coordination

┌─────────────────────────────────────────────────────────────────┐
│                  MODEL COORDINATION PILLARS                     │
├─────────────────┬─────────────────────┬─────────────────────────┤
│   ROUTING       │   ENSEMBLE          │   AGGREGATION           │
│                 │                     │                         │
│ Which model     │ How to combine      │ How to merge            │
│ handles which   │ multiple model      │ outputs into a          │
│ query?          │ predictions?        │ single response?        │
│                 │                     │                         │
│ - Complexity    │ - Voting            │ - Confidence selection  │
│ - Cost budget   │ - Weighted avg      │ - Response merging      │
│ - Latency SLA   │ - Stacking          │ - Quality scoring       │
│ - Domain type   │ - Cascade           │ - Deduplication         │
└─────────────────┴─────────────────────┴─────────────────────────┘

Architecture Diagram: Multi-Model Orchestration

                    MangaAssist Multi-Model Orchestration
                    ════════════════════════════════════

  User Message
       │
       ▼
┌──────────────────┐
│  API Gateway     │     WebSocket connection
│  WebSocket       │     Persistent session
└──────┬───────────┘
       │
       ▼
┌──────────────────┐     ┌────────────────────┐
│  ECS Fargate     │────▶│  ElastiCache Redis  │  Session cache
│  Orchestrator    │◀────│  (TTL: 30 min)      │  + routing cache
└──────┬───────────┘     └────────────────────┘
       │
       ▼
┌──────────────────────────────────────────────────────────────┐
│                    MODEL ROUTER (Step 1)                      │
│                                                              │
│  Input: raw user query + session context                     │
│                                                              │
│  ┌─────────────────────────────────┐                         │
│  │  Haiku Classifier (fast path)   │  Cost: ~$0.0001/query   │
│  │  - Intent detection             │  Latency: ~150ms        │
│  │  - Complexity scoring (1-10)    │                         │
│  │  - Domain classification        │                         │
│  └──────────┬──────────────────────┘                         │
│             │                                                │
│   ┌─────────┼──────────┬──────────────────┐                  │
│   ▼         ▼          ▼                  ▼                  │
│ Simple    Medium     Complex           Embedding             │
│ (1-3)     (4-6)      (7-10)           Required               │
└───┬─────────┬──────────┬──────────────────┬──────────────────┘
    │         │          │                  │
    ▼         ▼          ▼                  ▼
┌────────┐ ┌────────┐ ┌────────┐    ┌──────────────┐
│ Haiku  │ │ Haiku  │ │ Sonnet │    │ Titan        │
│ Direct │ │ + RAG  │ │ + RAG  │    │ Embeddings   │
│        │ │        │ │        │    │ v2           │
│ ~200ms │ │ ~600ms │ │ ~2.2s  │    │ ~100ms       │
│ $0.001 │ │ $0.003 │ │ $0.02  │    │ $0.0001      │
└───┬────┘ └───┬────┘ └───┬────┘    └──────┬───────┘
    │          │          │                │
    ▼          ▼          ▼                ▼
┌──────────────────────────────────────────────────────────────┐
│                 AGGREGATION LAYER (Step 2)                    │
│                                                              │
│  ┌─────────────────────────────────────────────────────┐     │
│  │  Response Quality Scorer                             │     │
│  │  - Confidence extraction from model output           │     │
│  │  - Relevance check against original query            │     │
│  │  - Coherence validation                              │     │
│  │  - Safety/guardrail check                            │     │
│  └─────────────────────────────────────────────────────┘     │
│                                                              │
│  ┌─────────────────────────────────────────────────────┐     │
│  │  Conflict Resolution                                 │     │
│  │  - If ensemble: merge or select best                 │     │
│  │  - If cascade: accept or escalate                    │     │
│  │  - If routing: pass through with metadata            │     │
│  └─────────────────────────────────────────────────────┘     │
└──────────────────────────┬───────────────────────────────────┘
                           │
                           ▼
                    ┌──────────────┐
                    │  Final       │
                    │  Response    │──▶  User
                    │  + Metadata  │
                    └──────────────┘

Data Flow Detail for MangaAssist

Example Query: "I loved Attack on Titan — recommend similar dark fantasy manga
                with great worldbuilding, and tell me if any are on sale"

Step 1 — Haiku Classifier (150ms, $0.0001):
   Intent: [recommendation, price_inquiry]
   Complexity: 8/10 (multi-intent, subjective criteria)
   Domain: [manga_recommendation, product_catalog]
   Route Decision: SONNET + RAG (complex reasoning needed)

Step 2 — Titan Embeddings (100ms, $0.0001, parallel with Step 1):
   Generate embedding for "dark fantasy manga great worldbuilding"
   Query OpenSearch Serverless for top-10 similar manga
   Return: [Berserk, Claymore, Vinland Saga, Made in Abyss, ...]

Step 3 — OpenSearch Retrieval (200ms):
   Semantic search: top-10 manga vectors
   Filter: in_stock=true
   Enrich: DynamoDB product details (price, sale status)

Step 4 — Sonnet Reasoning (1.8s, $0.02):
   Input: user query + retrieved manga + product details
   Output: personalized recommendation with sale information
   Confidence: 0.92

Step 5 — Aggregation (50ms):
   Quality score: 0.89 (relevant, coherent, addresses both intents)
   Safety check: PASS
   Total latency: 2.3s (within 3s SLA)
   Total cost: ~$0.021

Response: "Based on your love for Attack on Titan, here are my top picks..."

Specialized FM Roles in MangaAssist

Role Assignment Matrix

Model Role Task Types Avg Latency Cost/Query Quality Band
Haiku Classifier + Simple responder Intent detection, FAQ, order status, simple lookups 150-300ms $0.0003-0.002 Good for simple
Sonnet Reasoner + Generator Recommendations, comparisons, creative writing, complex Q&A 1.5-2.5s $0.01-0.04 Excellent for complex
Titan Embeddings v2 Vectorizer Semantic search, similarity matching, clustering 80-120ms $0.0001 N/A (embedding)

Capability Mapping

┌─────────────────────────────────────────────────────────────┐
│              MODEL CAPABILITY MATRIX                        │
├──────────────────┬────────┬────────┬──────────────────────┤
│ Task             │ Haiku  │ Sonnet │ Titan Embed          │
├──────────────────┼────────┼────────┼──────────────────────┤
│ Intent classify  │ ★★★★★  │ ★★★★★  │ N/A                  │
│ Simple FAQ       │ ★★★★☆  │ ★★★★★  │ N/A                  │
│ Order status     │ ★★★★★  │ ★★★★★  │ N/A                  │
│ Manga recommend  │ ★★☆☆☆  │ ★★★★★  │ ★★★★☆ (retrieval)    │
│ Price comparison │ ★★★☆☆  │ ★★★★★  │ N/A                  │
│ Creative writing │ ★★☆☆☆  │ ★★★★★  │ N/A                  │
│ Semantic search  │ N/A    │ N/A    │ ★★★★★                │
│ Multi-turn logic │ ★★☆☆☆  │ ★★★★★  │ N/A                  │
│ Translation      │ ★★★★☆  │ ★★★★★  │ N/A                  │
│ Summarization    │ ★★★☆☆  │ ★★★★★  │ N/A                  │
├──────────────────┼────────┼────────┼──────────────────────┤
│ Cost (input/1M)  │ $0.25  │ $3.00  │ $0.02                │
│ Cost (output/1M) │ $1.25  │ $15.00 │ N/A                  │
│ Avg latency      │ 200ms  │ 2.0s   │ 100ms                │
└──────────────────┴────────┴────────┴──────────────────────┘

Production Code: ModelRouter

"""
ModelRouter: Routes incoming queries to the optimal model based on
complexity, intent, cost budget, and latency requirements.

MangaAssist production implementation for ECS Fargate orchestrator.
"""

import asyncio
import json
import time
import hashlib
from enum import Enum
from dataclasses import dataclass, field
from typing import Optional

import boto3
import redis.asyncio as redis


# ---------------------------------------------------------------------------
# Configuration & Enums
# ---------------------------------------------------------------------------

class ModelId(Enum):
    """Available Bedrock model identifiers."""
    HAIKU = "anthropic.claude-3-haiku-20240307-v1:0"
    SONNET = "anthropic.claude-3-sonnet-20240229-v1:0"
    TITAN_EMBED = "amazon.titan-embed-text-v2:0"


class QueryComplexity(Enum):
    """Complexity tiers that determine model routing."""
    SIMPLE = "simple"       # Complexity score 1-3
    MEDIUM = "medium"       # Complexity score 4-6
    COMPLEX = "complex"     # Complexity score 7-10


class Intent(Enum):
    """Recognized user intents for MangaAssist."""
    GREETING = "greeting"
    FAQ = "faq"
    ORDER_STATUS = "order_status"
    PRODUCT_SEARCH = "product_search"
    RECOMMENDATION = "recommendation"
    PRICE_INQUIRY = "price_inquiry"
    COMPARISON = "comparison"
    CREATIVE = "creative"
    COMPLAINT = "complaint"
    UNKNOWN = "unknown"


# ---------------------------------------------------------------------------
# Data classes
# ---------------------------------------------------------------------------

@dataclass
class RoutingDecision:
    """Encapsulates the routing decision for a query."""
    primary_model: ModelId
    needs_rag: bool
    needs_embeddings: bool
    complexity: QueryComplexity
    intents: list[Intent]
    estimated_latency_ms: int
    estimated_cost_usd: float
    confidence: float
    reasoning: str
    fallback_model: Optional[ModelId] = None


@dataclass
class ClassificationResult:
    """Output of the Haiku classifier step."""
    intents: list[Intent]
    complexity_score: int          # 1-10
    domain_tags: list[str]
    requires_context: bool
    requires_product_data: bool
    language: str                  # "en", "ja", etc.
    confidence: float
    raw_response: dict = field(default_factory=dict)


@dataclass
class ModelCostConfig:
    """Per-model cost and latency parameters."""
    input_cost_per_1m: float       # USD per 1M input tokens
    output_cost_per_1m: float      # USD per 1M output tokens
    avg_latency_ms: int
    max_tokens_default: int


# ---------------------------------------------------------------------------
# Cost configurations
# ---------------------------------------------------------------------------

MODEL_COSTS: dict[ModelId, ModelCostConfig] = {
    ModelId.HAIKU: ModelCostConfig(
        input_cost_per_1m=0.25,
        output_cost_per_1m=1.25,
        avg_latency_ms=200,
        max_tokens_default=512,
    ),
    ModelId.SONNET: ModelCostConfig(
        input_cost_per_1m=3.00,
        output_cost_per_1m=15.00,
        avg_latency_ms=2000,
        max_tokens_default=1024,
    ),
    ModelId.TITAN_EMBED: ModelCostConfig(
        input_cost_per_1m=0.02,
        output_cost_per_1m=0.0,
        avg_latency_ms=100,
        max_tokens_default=0,
    ),
}


# ---------------------------------------------------------------------------
# Routing rules — intent-to-model mapping
# ---------------------------------------------------------------------------

INTENT_ROUTING_RULES: dict[Intent, dict] = {
    Intent.GREETING: {
        "model": ModelId.HAIKU,
        "needs_rag": False,
        "needs_embeddings": False,
    },
    Intent.FAQ: {
        "model": ModelId.HAIKU,
        "needs_rag": True,
        "needs_embeddings": True,
    },
    Intent.ORDER_STATUS: {
        "model": ModelId.HAIKU,
        "needs_rag": False,
        "needs_embeddings": False,
    },
    Intent.PRODUCT_SEARCH: {
        "model": ModelId.HAIKU,
        "needs_rag": True,
        "needs_embeddings": True,
    },
    Intent.RECOMMENDATION: {
        "model": ModelId.SONNET,
        "needs_rag": True,
        "needs_embeddings": True,
    },
    Intent.PRICE_INQUIRY: {
        "model": ModelId.HAIKU,
        "needs_rag": True,
        "needs_embeddings": False,
    },
    Intent.COMPARISON: {
        "model": ModelId.SONNET,
        "needs_rag": True,
        "needs_embeddings": True,
    },
    Intent.CREATIVE: {
        "model": ModelId.SONNET,
        "needs_rag": False,
        "needs_embeddings": False,
    },
    Intent.COMPLAINT: {
        "model": ModelId.SONNET,
        "needs_rag": True,
        "needs_embeddings": False,
    },
    Intent.UNKNOWN: {
        "model": ModelId.HAIKU,
        "needs_rag": False,
        "needs_embeddings": False,
    },
}


# ---------------------------------------------------------------------------
# ModelRouter — Main orchestration class
# ---------------------------------------------------------------------------

class ModelRouter:
    """
    Routes user queries to the optimal Bedrock model based on:
    1. Intent classification (via Haiku fast path)
    2. Complexity scoring
    3. Cost budget constraints
    4. Latency SLA requirements

    Usage:
        router = ModelRouter(bedrock_client, redis_client)
        decision = await router.route(query="Recommend dark fantasy manga",
                                       session_id="sess-abc123")
    """

    def __init__(
        self,
        bedrock_client: boto3.client,
        redis_client: redis.Redis,
        latency_sla_ms: int = 3000,
        cost_ceiling_usd: float = 0.05,
    ):
        self.bedrock = bedrock_client
        self.redis = redis_client
        self.latency_sla_ms = latency_sla_ms
        self.cost_ceiling_usd = cost_ceiling_usd
        self._classification_cache_ttl = 300   # 5 minutes

    # ----- Public API -----

    async def route(
        self,
        query: str,
        session_id: str,
        conversation_history: list[dict] | None = None,
        force_model: ModelId | None = None,
    ) -> RoutingDecision:
        """
        Determine the optimal model for a given query.

        Parameters
        ----------
        query : str
            The raw user message.
        session_id : str
            Current session identifier for caching.
        conversation_history : list[dict], optional
            Previous turns for context-aware routing.
        force_model : ModelId, optional
            Override routing and force a specific model.

        Returns
        -------
        RoutingDecision
            Complete routing decision with model, cost, latency estimates.
        """
        start = time.monotonic()

        # Check cache first
        cached = await self._check_routing_cache(query, session_id)
        if cached:
            return cached

        # If forced, skip classification
        if force_model:
            return self._build_forced_decision(force_model, query)

        # Step 1: Classify the query with Haiku (fast path)
        classification = await self._classify_query(query, conversation_history)

        # Step 2: Apply routing rules
        decision = self._apply_routing_rules(classification)

        # Step 3: Apply complexity override — if complexity > 6, upgrade to Sonnet
        decision = self._apply_complexity_override(decision, classification)

        # Step 4: Apply cost and latency constraints
        decision = self._apply_constraints(decision)

        # Step 5: Cache the decision
        await self._cache_routing_decision(query, session_id, decision)

        elapsed_ms = int((time.monotonic() - start) * 1000)
        decision.reasoning += f" | Router overhead: {elapsed_ms}ms"

        return decision

    # ----- Classification -----

    async def _classify_query(
        self,
        query: str,
        conversation_history: list[dict] | None = None,
    ) -> ClassificationResult:
        """Use Haiku to classify intent, complexity, and domain."""

        system_prompt = """You are a query classifier for MangaAssist, a Japanese manga
store chatbot. Analyze the user query and return a JSON object with:

{
  "intents": ["recommendation", "price_inquiry"],
  "complexity_score": 8,
  "domain_tags": ["manga_recommendation", "product_catalog"],
  "requires_context": true,
  "requires_product_data": true,
  "language": "en",
  "confidence": 0.95
}

Valid intents: greeting, faq, order_status, product_search, recommendation,
price_inquiry, comparison, creative, complaint, unknown.
Complexity score: 1 (trivial) to 10 (very complex, multi-step reasoning).
Respond ONLY with valid JSON."""

        messages = []
        if conversation_history:
            messages.extend(conversation_history[-3:])  # Last 3 turns for context
        messages.append({"role": "user", "content": query})

        body = json.dumps({
            "anthropic_version": "bedrock-2023-05-31",
            "max_tokens": 256,
            "system": system_prompt,
            "messages": messages,
        })

        response = self.bedrock.invoke_model(
            modelId=ModelId.HAIKU.value,
            contentType="application/json",
            accept="application/json",
            body=body,
        )

        result = json.loads(response["body"].read())
        content = result["content"][0]["text"]

        try:
            parsed = json.loads(content)
        except json.JSONDecodeError:
            # Fallback if Haiku returns malformed JSON
            parsed = {
                "intents": ["unknown"],
                "complexity_score": 5,
                "domain_tags": [],
                "requires_context": True,
                "requires_product_data": False,
                "language": "en",
                "confidence": 0.3,
            }

        return ClassificationResult(
            intents=[Intent(i) for i in parsed.get("intents", ["unknown"])],
            complexity_score=parsed.get("complexity_score", 5),
            domain_tags=parsed.get("domain_tags", []),
            requires_context=parsed.get("requires_context", True),
            requires_product_data=parsed.get("requires_product_data", False),
            language=parsed.get("language", "en"),
            confidence=parsed.get("confidence", 0.5),
            raw_response=parsed,
        )

    # ----- Routing Rules Engine -----

    def _apply_routing_rules(
        self, classification: ClassificationResult
    ) -> RoutingDecision:
        """Map classification results to a routing decision using intent rules."""

        primary_intent = classification.intents[0] if classification.intents else Intent.UNKNOWN
        rule = INTENT_ROUTING_RULES.get(primary_intent, INTENT_ROUTING_RULES[Intent.UNKNOWN])

        # If multiple intents, check if any require Sonnet
        needs_sonnet = any(
            INTENT_ROUTING_RULES.get(i, {}).get("model") == ModelId.SONNET
            for i in classification.intents
        )
        needs_rag = any(
            INTENT_ROUTING_RULES.get(i, {}).get("needs_rag", False)
            for i in classification.intents
        )
        needs_embeddings = any(
            INTENT_ROUTING_RULES.get(i, {}).get("needs_embeddings", False)
            for i in classification.intents
        )

        selected_model = ModelId.SONNET if needs_sonnet else rule["model"]
        cost_config = MODEL_COSTS[selected_model]

        # Estimate cost based on average token counts
        avg_input_tokens = 800 if needs_rag else 200
        avg_output_tokens = 400
        estimated_cost = (
            (avg_input_tokens / 1_000_000) * cost_config.input_cost_per_1m
            + (avg_output_tokens / 1_000_000) * cost_config.output_cost_per_1m
        )

        # Estimate latency including RAG overhead
        rag_overhead = 300 if needs_rag else 0
        embed_overhead = 100 if needs_embeddings else 0
        estimated_latency = cost_config.avg_latency_ms + rag_overhead + embed_overhead

        complexity = self._score_to_complexity(classification.complexity_score)

        return RoutingDecision(
            primary_model=selected_model,
            needs_rag=needs_rag,
            needs_embeddings=needs_embeddings,
            complexity=complexity,
            intents=classification.intents,
            estimated_latency_ms=estimated_latency,
            estimated_cost_usd=estimated_cost,
            confidence=classification.confidence,
            reasoning=f"Intent={primary_intent.value}, complexity={classification.complexity_score}",
            fallback_model=ModelId.HAIKU if selected_model == ModelId.SONNET else None,
        )

    def _apply_complexity_override(
        self, decision: RoutingDecision, classification: ClassificationResult
    ) -> RoutingDecision:
        """Override model choice if complexity demands it."""

        if classification.complexity_score >= 7 and decision.primary_model == ModelId.HAIKU:
            decision.primary_model = ModelId.SONNET
            decision.fallback_model = ModelId.HAIKU
            cost_config = MODEL_COSTS[ModelId.SONNET]
            decision.estimated_latency_ms = cost_config.avg_latency_ms + (
                300 if decision.needs_rag else 0
            )
            decision.reasoning += " | Upgraded to Sonnet (high complexity)"

        return decision

    def _apply_constraints(self, decision: RoutingDecision) -> RoutingDecision:
        """Enforce cost ceiling and latency SLA."""

        # Cost constraint: downgrade if over budget
        if decision.estimated_cost_usd > self.cost_ceiling_usd:
            if decision.primary_model == ModelId.SONNET:
                decision.primary_model = ModelId.HAIKU
                decision.reasoning += " | Downgraded to Haiku (cost ceiling)"

        # Latency constraint: if estimated > SLA, prefer faster model
        if decision.estimated_latency_ms > self.latency_sla_ms:
            if decision.primary_model == ModelId.SONNET:
                decision.primary_model = ModelId.HAIKU
                decision.needs_rag = False  # Drop RAG to meet SLA
                decision.reasoning += " | Downgraded to Haiku (latency SLA)"

        return decision

    # ----- Caching -----

    async def _check_routing_cache(
        self, query: str, session_id: str
    ) -> Optional[RoutingDecision]:
        """Check Redis for a cached routing decision."""
        cache_key = f"route:{self._query_hash(query)}:{session_id}"
        cached = await self.redis.get(cache_key)
        if cached:
            data = json.loads(cached)
            return RoutingDecision(
                primary_model=ModelId(data["primary_model"]),
                needs_rag=data["needs_rag"],
                needs_embeddings=data["needs_embeddings"],
                complexity=QueryComplexity(data["complexity"]),
                intents=[Intent(i) for i in data["intents"]],
                estimated_latency_ms=data["estimated_latency_ms"],
                estimated_cost_usd=data["estimated_cost_usd"],
                confidence=data["confidence"],
                reasoning=data["reasoning"] + " | CACHED",
                fallback_model=ModelId(data["fallback_model"]) if data.get("fallback_model") else None,
            )
        return None

    async def _cache_routing_decision(
        self, query: str, session_id: str, decision: RoutingDecision
    ) -> None:
        """Cache routing decision in Redis."""
        cache_key = f"route:{self._query_hash(query)}:{session_id}"
        data = {
            "primary_model": decision.primary_model.value,
            "needs_rag": decision.needs_rag,
            "needs_embeddings": decision.needs_embeddings,
            "complexity": decision.complexity.value,
            "intents": [i.value for i in decision.intents],
            "estimated_latency_ms": decision.estimated_latency_ms,
            "estimated_cost_usd": decision.estimated_cost_usd,
            "confidence": decision.confidence,
            "reasoning": decision.reasoning,
            "fallback_model": decision.fallback_model.value if decision.fallback_model else None,
        }
        await self.redis.setex(cache_key, self._classification_cache_ttl, json.dumps(data))

    # ----- Helpers -----

    @staticmethod
    def _query_hash(query: str) -> str:
        """Stable hash for cache keys."""
        return hashlib.sha256(query.strip().lower().encode()).hexdigest()[:16]

    @staticmethod
    def _score_to_complexity(score: int) -> QueryComplexity:
        if score <= 3:
            return QueryComplexity.SIMPLE
        elif score <= 6:
            return QueryComplexity.MEDIUM
        return QueryComplexity.COMPLEX

    def _build_forced_decision(self, model: ModelId, query: str) -> RoutingDecision:
        """Build a decision when model is force-selected."""
        cost_config = MODEL_COSTS[model]
        return RoutingDecision(
            primary_model=model,
            needs_rag=False,
            needs_embeddings=False,
            complexity=QueryComplexity.MEDIUM,
            intents=[Intent.UNKNOWN],
            estimated_latency_ms=cost_config.avg_latency_ms,
            estimated_cost_usd=0.01,
            confidence=1.0,
            reasoning=f"Force-routed to {model.value}",
        )

Production Code: EnsembleCoordinator

"""
EnsembleCoordinator: Manages parallel or sequential invocation of multiple
models and combines their outputs using configurable aggregation strategies.
"""

import asyncio
import json
import time
from dataclasses import dataclass
from typing import Callable, Optional

import boto3


@dataclass
class ModelResponse:
    """Response from a single model invocation."""
    model_id: str
    content: str
    confidence: float
    latency_ms: int
    token_count_input: int
    token_count_output: int
    cost_usd: float
    metadata: dict


@dataclass
class EnsembleResult:
    """Aggregated result from multiple model invocations."""
    final_content: str
    strategy_used: str
    model_responses: list[ModelResponse]
    total_latency_ms: int
    total_cost_usd: float
    agreement_score: float          # 0-1: how much models agreed
    selected_model: Optional[str]   # Which model was chosen (if selection-based)


class EnsembleCoordinator:
    """
    Coordinates multiple Bedrock models in parallel or sequential patterns.

    Strategies:
      - "confidence_select": Pick the response with highest confidence
      - "voting":            Ask models to classify, pick majority
      - "cascade":           Start cheap, escalate if confidence < threshold
      - "parallel_merge":    Run all in parallel, merge complementary outputs
      - "quality_score":     Score each response, pick the best

    Usage:
        coordinator = EnsembleCoordinator(bedrock_client)
        result = await coordinator.run_ensemble(
            query="Compare Berserk and Vinland Saga",
            models=[ModelId.HAIKU, ModelId.SONNET],
            strategy="confidence_select",
        )
    """

    COST_MAP = {
        "anthropic.claude-3-haiku-20240307-v1:0": (0.25, 1.25),
        "anthropic.claude-3-sonnet-20240229-v1:0": (3.00, 15.00),
    }

    def __init__(self, bedrock_client: boto3.client, timeout_ms: int = 5000):
        self.bedrock = bedrock_client
        self.timeout_ms = timeout_ms

    # ----- Public API -----

    async def run_ensemble(
        self,
        query: str,
        models: list[str],
        strategy: str = "confidence_select",
        system_prompt: str = "",
        cascade_threshold: float = 0.8,
        custom_scorer: Optional[Callable] = None,
    ) -> EnsembleResult:
        """
        Execute an ensemble of models and aggregate results.

        Parameters
        ----------
        query : str
            User query to send to all models.
        models : list[str]
            Bedrock model IDs to invoke.
        strategy : str
            Aggregation strategy name.
        system_prompt : str
            System prompt shared across all models.
        cascade_threshold : float
            Confidence threshold for cascade escalation.
        custom_scorer : callable, optional
            Custom function to score responses: (ModelResponse) -> float.
        """
        start = time.monotonic()

        if strategy == "cascade":
            responses = await self._run_cascade(
                query, models, system_prompt, cascade_threshold
            )
        else:
            # Parallel invocation for all other strategies
            responses = await self._run_parallel(query, models, system_prompt)

        # Apply aggregation strategy
        result = self._aggregate(responses, strategy, custom_scorer)

        result.total_latency_ms = int((time.monotonic() - start) * 1000)
        return result

    # ----- Invocation Patterns -----

    async def _run_parallel(
        self, query: str, models: list[str], system_prompt: str
    ) -> list[ModelResponse]:
        """Invoke all models in parallel."""
        tasks = [
            self._invoke_model(model_id, query, system_prompt) for model_id in models
        ]
        results = await asyncio.gather(*tasks, return_exceptions=True)

        responses = []
        for r in results:
            if isinstance(r, ModelResponse):
                responses.append(r)
            # Skip failures — logged elsewhere
        return responses

    async def _run_cascade(
        self,
        query: str,
        models: list[str],
        system_prompt: str,
        threshold: float,
    ) -> list[ModelResponse]:
        """
        Invoke models sequentially from cheapest to most expensive.
        Stop when confidence exceeds threshold.
        """
        responses = []
        for model_id in models:
            response = await self._invoke_model(model_id, query, system_prompt)
            responses.append(response)

            if response.confidence >= threshold:
                break  # Confident enough — no need to escalate

        return responses

    async def _invoke_model(
        self, model_id: str, query: str, system_prompt: str
    ) -> ModelResponse:
        """Invoke a single Bedrock model and return structured response."""
        start = time.monotonic()

        enhanced_system = system_prompt + """

After your response, on a new line, output your confidence as:
CONFIDENCE: 0.XX (a float between 0.0 and 1.0)"""

        body = json.dumps({
            "anthropic_version": "bedrock-2023-05-31",
            "max_tokens": 1024,
            "system": enhanced_system,
            "messages": [{"role": "user", "content": query}],
        })

        response = self.bedrock.invoke_model(
            modelId=model_id,
            contentType="application/json",
            accept="application/json",
            body=body,
        )

        result = json.loads(response["body"].read())
        content = result["content"][0]["text"]
        latency_ms = int((time.monotonic() - start) * 1000)

        # Extract confidence from response
        confidence = self._extract_confidence(content)

        # Clean content (remove CONFIDENCE line)
        clean_content = self._clean_content(content)

        # Calculate cost
        input_tokens = result.get("usage", {}).get("input_tokens", 0)
        output_tokens = result.get("usage", {}).get("output_tokens", 0)
        cost = self._calculate_cost(model_id, input_tokens, output_tokens)

        return ModelResponse(
            model_id=model_id,
            content=clean_content,
            confidence=confidence,
            latency_ms=latency_ms,
            token_count_input=input_tokens,
            token_count_output=output_tokens,
            cost_usd=cost,
            metadata={"raw_usage": result.get("usage", {})},
        )

    # ----- Aggregation Strategies -----

    def _aggregate(
        self,
        responses: list[ModelResponse],
        strategy: str,
        custom_scorer: Optional[Callable],
    ) -> EnsembleResult:
        """Route to the appropriate aggregation method."""

        if not responses:
            return EnsembleResult(
                final_content="I'm sorry, I couldn't generate a response.",
                strategy_used=strategy,
                model_responses=[],
                total_latency_ms=0,
                total_cost_usd=0.0,
                agreement_score=0.0,
                selected_model=None,
            )

        strategy_map = {
            "confidence_select": self._strategy_confidence_select,
            "voting": self._strategy_voting,
            "quality_score": self._strategy_quality_score,
            "parallel_merge": self._strategy_parallel_merge,
            "cascade": self._strategy_confidence_select,  # Cascade uses last response
        }

        aggregator = strategy_map.get(strategy, self._strategy_confidence_select)

        if strategy == "quality_score" and custom_scorer:
            return aggregator(responses, custom_scorer)
        return aggregator(responses)

    def _strategy_confidence_select(
        self, responses: list[ModelResponse], **kwargs
    ) -> EnsembleResult:
        """Select the response with the highest self-reported confidence."""
        best = max(responses, key=lambda r: r.confidence)

        agreement = self._calculate_agreement(responses)

        return EnsembleResult(
            final_content=best.content,
            strategy_used="confidence_select",
            model_responses=responses,
            total_latency_ms=max(r.latency_ms for r in responses),
            total_cost_usd=sum(r.cost_usd for r in responses),
            agreement_score=agreement,
            selected_model=best.model_id,
        )

    def _strategy_voting(
        self, responses: list[ModelResponse], **kwargs
    ) -> EnsembleResult:
        """
        For classification tasks — pick the most common answer.
        Assumes responses are short classification labels.
        """
        from collections import Counter

        votes = Counter(r.content.strip().lower() for r in responses)
        winner_text, winner_count = votes.most_common(1)[0]

        # Find the original response matching the winner
        best = next(
            (r for r in responses if r.content.strip().lower() == winner_text),
            responses[0],
        )

        agreement = winner_count / len(responses)

        return EnsembleResult(
            final_content=best.content,
            strategy_used="voting",
            model_responses=responses,
            total_latency_ms=max(r.latency_ms for r in responses),
            total_cost_usd=sum(r.cost_usd for r in responses),
            agreement_score=agreement,
            selected_model=best.model_id,
        )

    def _strategy_quality_score(
        self,
        responses: list[ModelResponse],
        scorer: Optional[Callable] = None,
    ) -> EnsembleResult:
        """Score each response with a custom scorer and pick the best."""
        if scorer is None:
            # Fallback to confidence-based
            return self._strategy_confidence_select(responses)

        scored = [(r, scorer(r)) for r in responses]
        best_response, best_score = max(scored, key=lambda x: x[1])

        return EnsembleResult(
            final_content=best_response.content,
            strategy_used="quality_score",
            model_responses=responses,
            total_latency_ms=max(r.latency_ms for r in responses),
            total_cost_usd=sum(r.cost_usd for r in responses),
            agreement_score=best_score,
            selected_model=best_response.model_id,
        )

    def _strategy_parallel_merge(
        self, responses: list[ModelResponse], **kwargs
    ) -> EnsembleResult:
        """
        Merge complementary outputs from multiple models.
        Useful when different models contribute different information.
        """
        # Simple merge: concatenate with deduplication markers
        sections = []
        for r in responses:
            model_name = r.model_id.split(".")[-1].split("-")[0].title()
            sections.append(f"[{model_name} perspective]\n{r.content}")

        merged = "\n\n---\n\n".join(sections)

        return EnsembleResult(
            final_content=merged,
            strategy_used="parallel_merge",
            model_responses=responses,
            total_latency_ms=max(r.latency_ms for r in responses),
            total_cost_usd=sum(r.cost_usd for r in responses),
            agreement_score=self._calculate_agreement(responses),
            selected_model=None,
        )

    # ----- Helpers -----

    @staticmethod
    def _extract_confidence(content: str) -> float:
        """Extract CONFIDENCE: X.XX from model output."""
        for line in reversed(content.split("\n")):
            if line.strip().startswith("CONFIDENCE:"):
                try:
                    return float(line.split(":")[1].strip())
                except (ValueError, IndexError):
                    pass
        return 0.5  # Default confidence

    @staticmethod
    def _clean_content(content: str) -> str:
        """Remove the CONFIDENCE line from content."""
        lines = content.split("\n")
        cleaned = [l for l in lines if not l.strip().startswith("CONFIDENCE:")]
        return "\n".join(cleaned).strip()

    def _calculate_cost(
        self, model_id: str, input_tokens: int, output_tokens: int
    ) -> float:
        """Calculate invocation cost in USD."""
        costs = self.COST_MAP.get(model_id, (0.25, 1.25))
        return (input_tokens / 1_000_000) * costs[0] + (output_tokens / 1_000_000) * costs[1]

    @staticmethod
    def _calculate_agreement(responses: list[ModelResponse]) -> float:
        """
        Estimate agreement between responses using a simple heuristic.
        Production systems would use semantic similarity.
        """
        if len(responses) <= 1:
            return 1.0

        # Simple word overlap heuristic (production: use embeddings)
        word_sets = [set(r.content.lower().split()) for r in responses]
        total_overlap = 0
        comparisons = 0

        for i in range(len(word_sets)):
            for j in range(i + 1, len(word_sets)):
                intersection = word_sets[i] & word_sets[j]
                union = word_sets[i] | word_sets[j]
                if union:
                    total_overlap += len(intersection) / len(union)
                comparisons += 1

        return total_overlap / comparisons if comparisons > 0 else 0.0

Production Code: ModelSelectionFramework

"""
ModelSelectionFramework: High-level framework that combines routing,
ensemble coordination, and cost management into a single entry point.
"""

import asyncio
import json
import time
from dataclasses import dataclass
from typing import Optional

import boto3
import redis.asyncio as redis
from aws_lambda_powertools import Logger, Metrics, Tracer

logger = Logger(service="mangaassist-model-selection")
metrics = Metrics(namespace="MangaAssist/ModelCoordination")
tracer = Tracer(service="mangaassist-model-selection")


@dataclass
class CoordinationResult:
    """Final output of the model coordination pipeline."""
    response_text: str
    model_used: str
    total_latency_ms: int
    total_cost_usd: float
    routing_decision: str
    confidence: float
    sla_met: bool
    metadata: dict


class ModelSelectionFramework:
    """
    End-to-end model coordination for MangaAssist.

    Combines:
    - ModelRouter for intelligent query classification and routing
    - EnsembleCoordinator for multi-model aggregation when needed
    - Cost tracking and SLA enforcement
    - Automatic fallback and degradation

    Usage:
        framework = ModelSelectionFramework(
            bedrock_client=bedrock,
            redis_client=redis,
        )
        result = await framework.process(
            query="What's the best-selling manga this month?",
            session_id="sess-abc123",
        )
        print(result.response_text)
    """

    def __init__(
        self,
        bedrock_client: boto3.client,
        redis_client: redis.Redis,
        latency_sla_ms: int = 3000,
        daily_budget_usd: float = 10000.0,
    ):
        self.bedrock = bedrock_client
        self.redis = redis_client
        self.latency_sla_ms = latency_sla_ms
        self.daily_budget_usd = daily_budget_usd

        self.router = ModelRouter(bedrock_client, redis_client, latency_sla_ms)
        self.ensemble = EnsembleCoordinator(bedrock_client, timeout_ms=latency_sla_ms)

    @tracer.capture_method
    async def process(
        self,
        query: str,
        session_id: str,
        conversation_history: list[dict] | None = None,
        use_ensemble: bool = False,
    ) -> CoordinationResult:
        """
        Process a user query through the full coordination pipeline.

        Parameters
        ----------
        query : str
            The user's message.
        session_id : str
            Session identifier.
        conversation_history : list[dict], optional
            Previous conversation turns.
        use_ensemble : bool
            Force ensemble mode for A/B testing.
        """
        pipeline_start = time.monotonic()

        try:
            # Step 1: Route the query
            routing = await self.router.route(query, session_id, conversation_history)

            logger.info("Routing decision", extra={
                "model": routing.primary_model.value,
                "complexity": routing.complexity.value,
                "intents": [i.value for i in routing.intents],
                "estimated_cost": routing.estimated_cost_usd,
            })

            # Step 2: Check daily budget
            remaining_budget = await self._check_budget(session_id)
            if remaining_budget <= 0:
                # Emergency mode: force Haiku for everything
                routing.primary_model = ModelId.HAIKU
                routing.needs_rag = False
                logger.warning("Daily budget exhausted — forcing Haiku")

            # Step 3: Execute model invocation
            if use_ensemble and routing.complexity == QueryComplexity.COMPLEX:
                result = await self._run_ensemble_path(query, routing)
            else:
                result = await self._run_single_path(query, routing, conversation_history)

            # Step 4: Track cost
            await self._track_cost(session_id, result.total_cost_usd)

            # Step 5: Enforce SLA
            total_ms = int((time.monotonic() - pipeline_start) * 1000)
            result.total_latency_ms = total_ms
            result.sla_met = total_ms <= self.latency_sla_ms

            # Emit metrics
            metrics.add_metric(name="Latency", unit="Milliseconds", value=total_ms)
            metrics.add_metric(name="Cost", unit="None", value=result.total_cost_usd)
            metrics.add_metric(name="SLAMet", unit="Count", value=1 if result.sla_met else 0)

            if not result.sla_met:
                logger.warning("SLA breach", extra={
                    "latency_ms": total_ms,
                    "sla_ms": self.latency_sla_ms,
                    "model": result.model_used,
                })

            return result

        except Exception as e:
            logger.exception("Model coordination pipeline failed")
            # Emergency fallback: return a safe default
            return CoordinationResult(
                response_text="I'm having trouble processing your request. Please try again.",
                model_used="fallback",
                total_latency_ms=int((time.monotonic() - pipeline_start) * 1000),
                total_cost_usd=0.0,
                routing_decision="error_fallback",
                confidence=0.0,
                sla_met=False,
                metadata={"error": str(e)},
            )

    async def _run_single_path(
        self,
        query: str,
        routing: RoutingDecision,
        conversation_history: list[dict] | None,
    ) -> CoordinationResult:
        """Execute single-model path with optional RAG."""

        system_prompt = self._build_system_prompt(routing)

        messages = []
        if conversation_history:
            messages.extend(conversation_history[-5:])

        # If RAG is needed, enrich the query
        enriched_query = query
        if routing.needs_rag:
            context = await self._retrieve_context(query, routing.needs_embeddings)
            enriched_query = f"""Context from our manga catalog:
{context}

User question: {query}

Provide a helpful answer based on the context above."""

        messages.append({"role": "user", "content": enriched_query})

        start = time.monotonic()
        body = json.dumps({
            "anthropic_version": "bedrock-2023-05-31",
            "max_tokens": MODEL_COSTS[routing.primary_model].max_tokens_default,
            "system": system_prompt,
            "messages": messages,
        })

        try:
            response = self.bedrock.invoke_model(
                modelId=routing.primary_model.value,
                contentType="application/json",
                accept="application/json",
                body=body,
            )

            result = json.loads(response["body"].read())
            content = result["content"][0]["text"]
            usage = result.get("usage", {})
            latency_ms = int((time.monotonic() - start) * 1000)

            cost = self._calc_cost(
                routing.primary_model,
                usage.get("input_tokens", 0),
                usage.get("output_tokens", 0),
            )

            return CoordinationResult(
                response_text=content,
                model_used=routing.primary_model.value,
                total_latency_ms=latency_ms,
                total_cost_usd=cost,
                routing_decision=routing.reasoning,
                confidence=routing.confidence,
                sla_met=True,
                metadata={"usage": usage, "routing": routing.reasoning},
            )

        except Exception as e:
            # Fallback to backup model
            if routing.fallback_model:
                logger.warning(f"Primary model failed, falling back to {routing.fallback_model}")
                fallback_routing = RoutingDecision(
                    primary_model=routing.fallback_model,
                    needs_rag=False,
                    needs_embeddings=False,
                    complexity=routing.complexity,
                    intents=routing.intents,
                    estimated_latency_ms=200,
                    estimated_cost_usd=0.001,
                    confidence=0.5,
                    reasoning="Fallback after primary failure",
                )
                return await self._run_single_path(query, fallback_routing, conversation_history)
            raise

    async def _run_ensemble_path(
        self, query: str, routing: RoutingDecision
    ) -> CoordinationResult:
        """Execute ensemble path for complex queries."""

        models = [ModelId.HAIKU.value, ModelId.SONNET.value]
        system_prompt = self._build_system_prompt(routing)

        ensemble_result = await self.ensemble.run_ensemble(
            query=query,
            models=models,
            strategy="confidence_select",
            system_prompt=system_prompt,
        )

        return CoordinationResult(
            response_text=ensemble_result.final_content,
            model_used=ensemble_result.selected_model or "ensemble",
            total_latency_ms=ensemble_result.total_latency_ms,
            total_cost_usd=ensemble_result.total_cost_usd,
            routing_decision=f"ensemble:{ensemble_result.strategy_used}",
            confidence=ensemble_result.agreement_score,
            sla_met=True,
            metadata={
                "ensemble_strategy": ensemble_result.strategy_used,
                "agreement": ensemble_result.agreement_score,
                "model_count": len(ensemble_result.model_responses),
            },
        )

    async def _retrieve_context(self, query: str, use_embeddings: bool) -> str:
        """Retrieve relevant context from OpenSearch via embeddings."""
        if use_embeddings:
            # Generate embedding with Titan
            embed_body = json.dumps({
                "inputText": query,
                "dimensions": 1024,
                "normalize": True,
            })
            embed_response = self.bedrock.invoke_model(
                modelId=ModelId.TITAN_EMBED.value,
                contentType="application/json",
                accept="application/json",
                body=embed_body,
            )
            embedding = json.loads(embed_response["body"].read())["embedding"]

            # Query OpenSearch (placeholder — actual implementation uses opensearch-py)
            # results = await self.opensearch.search(embedding, top_k=5)
            return "[Retrieved manga catalog context would appear here]"

        return "[Keyword-based retrieval context]"

    def _build_system_prompt(self, routing: RoutingDecision) -> str:
        """Build system prompt based on routing context."""
        base = """You are MangaAssist, a helpful chatbot for a Japanese manga store.
You help customers find manga, check prices, track orders, and get recommendations.
Be friendly, knowledgeable about manga, and concise in your responses."""

        if routing.complexity == QueryComplexity.COMPLEX:
            base += "\nThis is a complex query — provide detailed, well-reasoned answers."
        elif routing.complexity == QueryComplexity.SIMPLE:
            base += "\nThis is a simple query — be brief and direct."

        return base

    async def _check_budget(self, session_id: str) -> float:
        """Check remaining daily budget from Redis."""
        today_key = f"budget:{time.strftime('%Y-%m-%d')}"
        spent = await self.redis.get(today_key)
        spent_usd = float(spent) if spent else 0.0
        return self.daily_budget_usd - spent_usd

    async def _track_cost(self, session_id: str, cost_usd: float) -> None:
        """Increment daily cost counter in Redis."""
        today_key = f"budget:{time.strftime('%Y-%m-%d')}"
        await self.redis.incrbyfloat(today_key, cost_usd)
        await self.redis.expire(today_key, 86400)  # TTL: 24 hours

    @staticmethod
    def _calc_cost(model: ModelId, input_tokens: int, output_tokens: int) -> float:
        config = MODEL_COSTS[model]
        return (
            (input_tokens / 1_000_000) * config.input_cost_per_1m
            + (output_tokens / 1_000_000) * config.output_cost_per_1m
        )

Comparison Table: Coordination Strategies

Strategy How It Works Best For Latency Cost Quality MangaAssist Use Case
Single Model One model handles everything Low complexity, MVP Low (one call) Low (one call) Depends on model Early prototype
Routing Classifier sends query to best-fit model Mixed workloads with clear categories Low (classifier + one model) Medium (classifier overhead) High (right model for each task) Primary strategy — 80% of queries
Ensemble (voting) Multiple models vote on answer Classification, yes/no decisions High (parallel calls) High (N models) Very High (consensus) Content moderation decisions
Ensemble (cascade) Cheap model first, escalate if unsure Cost optimization with quality floor Variable (1-N calls) Low-Medium (often just cheap model) High (escalation when needed) Product search fallback
Ensemble (stacking) Base models feed a meta-model Complex decisions needing multiple perspectives Very High (base + meta) Very High (N+1 models) Highest (learned combination) Not used (too slow for 3s SLA)
Parallel merge Run models in parallel, merge outputs Complementary expertise High (parallel, then merge) High (N models) High (comprehensive) A/B testing new models

Decision Matrix: When to Use Each Strategy

                        Query Complexity
                    Low         Medium        High
                ┌───────────┬───────────┬───────────┐
    Strict      │  Haiku    │  Haiku    │  Haiku    │
    (<$0.005)   │  Single   │  + Cache  │  + Cache  │  ← Cost budget
                ├───────────┼───────────┼───────────┤
    Normal      │  Haiku    │  Routing  │  Sonnet   │
    (<$0.05)    │  Single   │  (Haiku/  │  Single   │  ← Cost budget
                │           │  Sonnet)  │  + RAG    │
                ├───────────┼───────────┼───────────┤
    Generous    │  Haiku    │  Routing  │  Ensemble │
    (<$0.50)    │  Single   │  + Cache  │  (Cascade │  ← Cost budget
                │           │           │  or Vote) │
                └───────────┴───────────┴───────────┘

Parallel vs Sequential Coordination

Parallel Fan-Out

         ┌─── Haiku ────────┐
         │   (200ms)         │
Query ───┼─── Sonnet ───────┼──── Aggregator ──── Response
         │   (2000ms)        │     (50ms)
         └─── Titan Embed ──┘
              (100ms)

Total latency: max(200, 2000, 100) + 50 = 2050ms
Total cost:    sum(Haiku, Sonnet, Titan) = $0.022

Pros: Lowest possible latency (parallel execution), full model coverage. Cons: Highest cost (every model runs), wastes compute on simple queries. When to use: Critical queries where quality > cost, A/B testing.

Sequential Pipeline

Query ──▶ Haiku Classifier ──▶ Route Decision ──▶ Selected Model ──▶ Response
          (150ms)               (10ms)             (200-2000ms)

Total latency: 150 + 10 + model_latency = 360-2160ms
Total cost:    Haiku classifier + selected model = $0.001-0.021

Pros: Cost-efficient (only invoke what's needed), predictable pipeline. Cons: Higher latency (serial steps), single point of failure at classifier. When to use: MangaAssist primary path — best cost/quality tradeoff.

Cascade (Sequential Escalation)

Query ──▶ Haiku ──▶ Confidence check ──▶ [If low] ──▶ Sonnet ──▶ Response
          (200ms)   (10ms)                              (2000ms)

If confidence >= 0.8: Total = 210ms,  Cost = $0.001
If confidence <  0.8: Total = 2210ms, Cost = $0.022

Pros: Optimal cost (cheap model handles most), quality guarantee. Cons: Worst-case latency is high (both models), harder to tune threshold. When to use: Price-sensitive workloads where 70%+ queries are simple.

Hybrid: MangaAssist Production Pattern

                    ┌────────────────────────────────────────────┐
                    │         HYBRID COORDINATION                │
                    │                                            │
  Query ──▶ Haiku Classifier (150ms, parallel with Titan Embed)  │
                    │                                            │
              ┌─────┴─────┐                                      │
              │  SIMPLE   │──▶ Haiku Direct (200ms)              │
              │  (60%)    │    Total: 350ms, $0.001              │
              ├───────────┤                                      │
              │  MEDIUM   │──▶ Haiku + RAG (600ms)               │
              │  (25%)    │    Total: 750ms, $0.003              │
              ├───────────┤                                      │
              │  COMPLEX  │──▶ Sonnet + RAG (2200ms)             │
              │  (15%)    │    Total: 2350ms, $0.022             │
              └───────────┘                                      │
                    │                                            │
                    │  Weighted avg cost: 0.6($0.001) +          │
                    │    0.25($0.003) + 0.15($0.022) = $0.005    │
                    │  Weighted avg latency: 0.6(350) +          │
                    │    0.25(750) + 0.15(2350) = 750ms          │
                    └────────────────────────────────────────────┘

Cost-Performance Tradeoff Analysis

Daily Cost Projection at 1M Messages/Day

Strategy Haiku % Sonnet % Daily Cost Avg Latency Quality Score
All Sonnet 0% 100% $22,000 2.0s 9.5/10
All Haiku 100% 0% $1,500 0.2s 7.0/10
Routing (60/25/15) 85% 15% $4,800 0.75s 8.8/10
Cascade (threshold=0.8) 78% 22% $6,200 0.9s 9.0/10
Ensemble (all parallel) 0% 100% (+Haiku) $23,500 2.0s 9.7/10

Cost Optimization Formula

Daily Cost = Messages_per_day * (
    pct_simple  * cost_haiku_simple  +
    pct_medium  * cost_haiku_rag     +
    pct_complex * cost_sonnet_rag    +
    1.0         * cost_classifier
)

For MangaAssist:
= 1,000,000 * (
    0.60 * $0.001  +      # Simple queries via Haiku
    0.25 * $0.003  +      # Medium queries via Haiku+RAG
    0.15 * $0.022  +      # Complex queries via Sonnet+RAG
    1.00 * $0.0001        # Haiku classifier for every query
)
= 1,000,000 * $0.005
= $5,000/day
= ~$150,000/month

Interview Q&A

Q1: Why use Haiku as a classifier rather than a rule-based router?

Haiku as a classifier provides semantic understanding that rule-based systems lack. A regex or keyword-based router would misclassify "I hate that my favorite manga ended — recommend something to fill the void" because "hate" might trigger a complaint route. Haiku understands this is a recommendation query with emotional context. At $0.0001 per classification, the cost is negligible compared to the routing quality improvement. The key tradeoff: Haiku adds ~150ms to every request, but saves money by preventing expensive Sonnet calls for simple queries.

Q2: How do you handle model disagreement in an ensemble?

We use a confidence-weighted selection strategy. Each model reports its confidence score, and we select the response with the highest confidence. If confidence scores are close (within 0.1), we fall back to the more expensive model (Sonnet) under the assumption that its reasoning is more reliable. For classification tasks, we use majority voting. The agreement score is tracked as a metric — if it consistently drops below 0.7, it triggers an alert to review the models' alignment.

Q3: What happens when the daily budget is exhausted?

The framework enters degraded mode: all queries route to Haiku regardless of complexity, RAG is disabled (reducing OpenSearch costs too), and response quality monitoring is enhanced to track the impact. This ensures the service stays up with reduced quality rather than failing entirely. An alert fires to the on-call team, and the budget can be manually increased via a Redis override key.

Q4: How does this system handle the 3-second SLA?

The SLA is enforced at multiple levels: (1) The router estimates latency before choosing a model and will downgrade to Haiku if the estimated latency exceeds the SLA. (2) The Bedrock client has a timeout set to sla_ms - router_overhead_ms. (3) If a Sonnet call is slow, the cascade pattern allows returning the Haiku result as a fallback. (4) Redis caching of routing decisions eliminates 150ms on repeat queries. The p95 latency target is 2.5s to leave headroom.

Q5: Why not just use stacking (meta-learner) for the highest quality?

Stacking requires running all base models plus a meta-model, resulting in latency of 200ms + 2000ms + 1500ms (meta) = 3700ms — exceeding the 3s SLA. Additionally, stacking requires labeled training data for the meta-learner, which we don't have at launch. Our routing + cascade approach achieves 90%+ of stacking quality at 40% of the cost and 50% of the latency. We may revisit stacking for offline batch tasks like catalog enrichment.


Key Takeaways

  1. Route, don't ensemble by default — For MangaAssist's SLA and cost constraints, intelligent routing (60% simple/25% medium/15% complex) beats running all models for every query.

  2. Haiku as classifier is the cornerstone — At $0.0001 per classification and 150ms latency, it's the cheapest way to make intelligent routing decisions with semantic understanding.

  3. Cascade > parallel for cost optimization — Start cheap (Haiku), escalate to Sonnet only when confidence is low. This naturally optimizes cost while maintaining quality.

  4. Budget enforcement prevents runaway costs — Tracking daily spend in Redis and degrading gracefully keeps the service alive even when costs spike.

  5. SLA enforcement is multi-layered — Don't rely on a single timeout. Enforce at routing, invocation, and aggregation layers for reliable latency guarantees.

  6. Measure agreement for quality signals — When models disagree, it indicates ambiguous queries or model drift. Track this as a key metric.

  7. Cache routing decisions, not responses — Caching the routing decision (which model to use) is safe and saves classifier costs. Caching actual responses requires more careful invalidation.