LOCAL PREVIEW View on GitHub

Ensemble & Aggregation Strategies for Multi-Model FM Coordination

MangaAssist context: JP Manga store chatbot on AWS — Bedrock Claude 3 (Sonnet at $3/$15 per 1M tokens input/output, Haiku at $0.25/$1.25), OpenSearch Serverless (vector store), DynamoDB (sessions/products), ECS Fargate (orchestrator), API Gateway WebSocket, ElastiCache Redis. Target: useful answer in under 3 seconds, 1M messages/day scale.


Skill Mapping

Certification Domain Task Skill This File
AWS AIP-C01 Domain 2 — Implementation & Integration Task 2.1 — Select and implement FM integration approaches Skill 2.1.4 — Design model coordination systems Ensemble methods for FM outputs, voting strategies, confidence weighting, response quality scoring

Skill scope: Deep-dive into ensemble aggregation — how MangaAssist combines outputs from multiple foundation models (Sonnet, Haiku, Titan Embeddings) to deliver higher-quality, more reliable chatbot responses at production scale. Covers voting, confidence weighting, quality scoring, and cost-aware aggregation with full production Python implementations.


Mind Map: Ensemble Aggregation Strategies

mindmap
  root((Ensemble Aggregation<br/>Strategies))
    Voting Strategies
      Majority Voting
        N models produce answers
        Most-common wins
        Best for classification
      Weighted Voting
        Assign weight per model
        Weight by past accuracy
        Weight by domain expertise
      Plurality Voting
        Top vote even without majority
        Tie-breaking heuristics
        Fallback to highest-confidence
    Confidence Weighting
      Per-Model Confidence
        Token log-probabilities
        Self-reported confidence
        Calibrated probability
      Dynamic Weight Adjustment
        Exponential moving average
        Performance-decay weighting
        Intent-specific weights
      Threshold Gating
        Minimum confidence floor
        Auto-escalation on low confidence
        Abstention when all models uncertain
    Response Quality Scoring
      Coherence Metrics
        Perplexity scoring
        Sentence flow analysis
        Logical consistency check
      Factual Grounding
        RAG citation verification
        Entity cross-referencing
        Hallucination probability
      Relevance Scoring
        Query-response similarity
        Intent alignment check
        Topic drift detection
    Cost-Aware Aggregation
      Token Budget Allocation
        Per-query cost ceiling
        Model cost ranking
        Budget-proportional routing
      Cascade Economics
        Haiku-first strategy
        Sonnet escalation triggers
        Cost-per-quality-unit tracking
      Batch Optimization
        Parallel fan-out budgeting
        Partial result acceptance
        Early termination on consensus

1. Core Ensemble Architecture for MangaAssist

Multi-Model Fan-Out and Aggregation Flow

graph TD
    subgraph "User Request"
        U[Customer Query<br/>'Recommend manga like Attack on Titan']
    end

    subgraph "Router Layer"
        R[Query Classifier<br/>Haiku — intent + complexity]
        R --> |complexity: high| FO[Fan-Out Controller]
        R --> |complexity: low| HAIKU_ONLY[Haiku Direct Response]
    end

    subgraph "Parallel Invocation"
        FO --> |async| S[Claude 3 Sonnet<br/>Deep recommendation reasoning]
        FO --> |async| H[Claude 3 Haiku<br/>Quick pattern-match recommendations]
        FO --> |async| RAG[OpenSearch RAG<br/>Similar-title vector search]
    end

    subgraph "Aggregation Layer"
        S --> AGG[Ensemble Aggregator]
        H --> AGG
        RAG --> AGG
        AGG --> QS[Quality Scorer]
        QS --> MERGE[Response Merger]
    end

    subgraph "Output"
        MERGE --> RESP[Final Response<br/>Combined high-quality answer]
        HAIKU_ONLY --> RESP
    end

    U --> R

    style AGG fill:#264653,stroke:#2a9d8f,color:#fff
    style QS fill:#2a9d8f,stroke:#264653,color:#fff
    style MERGE fill:#e9c46a,stroke:#f4a261,color:#000

When To Ensemble vs. Single-Model

Query Type Strategy Rationale Cost Impact
Simple FAQ ("What are your shipping rates?") Haiku only Deterministic answer, no ambiguity $0.0003
Product recommendation Sonnet + RAG ensemble Creative reasoning + factual grounding $0.012
Japanese content query Haiku + Sonnet weighted Haiku for translation, Sonnet for nuance $0.008
Ambiguous/complex query Full 3-model ensemble Maximum coverage, aggregated confidence $0.018
Classification (intent) Haiku majority vote (3x) Fast, cheap, high agreement rate $0.0009

2. Voting Strategies

2.1 Majority Voting Implementation

Majority voting is the simplest ensemble method: query N models, take the answer most of them agree on. For MangaAssist, this is primarily used for intent classification where Haiku is invoked 3 times with temperature > 0 to get diverse classification attempts.

"""
MangaAssist Majority Voting Ensemble for Intent Classification.

Invokes Claude 3 Haiku N times with temperature variation to classify
user intent. Majority vote determines the final classification.
"""

import asyncio
import json
import logging
import time
from collections import Counter
from dataclasses import dataclass, field
from typing import Optional

import boto3

logger = logging.getLogger(__name__)

VALID_INTENTS = [
    "product_search", "recommendation", "order_status",
    "shipping_info", "content_question", "greeting",
    "complaint", "return_request", "price_inquiry",
]


@dataclass
class VoteResult:
    """Result of a single model vote."""
    intent: str
    confidence: float
    latency_ms: float
    model_id: str
    raw_response: str


@dataclass
class MajorityVoteOutcome:
    """Outcome of majority voting across N models."""
    winning_intent: str
    vote_count: int
    total_votes: int
    agreement_ratio: float
    all_votes: list[VoteResult] = field(default_factory=list)
    tie_broken: bool = False
    latency_ms: float = 0.0


class MajorityVotingEnsemble:
    """
    Ensemble classifier using majority voting across multiple
    Haiku invocations with temperature variation.

    Production config for MangaAssist:
    - 3 voters (odd number avoids ties)
    - Temperatures: [0.1, 0.3, 0.5] for diversity
    - Agreement threshold: 2/3 minimum
    - Fallback: highest-confidence single vote if no majority
    """

    def __init__(
        self,
        bedrock_client=None,
        num_voters: int = 3,
        temperatures: Optional[list[float]] = None,
        model_id: str = "anthropic.claude-3-haiku-20240307-v1:0",
        agreement_threshold: float = 0.67,
    ):
        self.bedrock = bedrock_client or boto3.client(
            "bedrock-runtime", region_name="ap-northeast-1"
        )
        self.num_voters = num_voters
        self.temperatures = temperatures or [0.1, 0.3, 0.5]
        self.model_id = model_id
        self.agreement_threshold = agreement_threshold

    async def _invoke_voter(
        self, query: str, temperature: float, voter_id: int
    ) -> VoteResult:
        """Invoke a single Haiku voter with specified temperature."""
        start = time.monotonic()
        prompt = (
            f"Classify the following customer query into exactly one intent.\n"
            f"Valid intents: {', '.join(VALID_INTENTS)}\n\n"
            f"Query: {query}\n\n"
            f"Respond with JSON: {{\"intent\": \"<intent>\", \"confidence\": <0.0-1.0>}}"
        )

        body = json.dumps({
            "anthropic_version": "bedrock-2023-05-31",
            "max_tokens": 100,
            "temperature": temperature,
            "messages": [{"role": "user", "content": prompt}],
        })

        loop = asyncio.get_event_loop()
        response = await loop.run_in_executor(
            None,
            lambda: self.bedrock.invoke_model(
                modelId=self.model_id,
                body=body,
                contentType="application/json",
                accept="application/json",
            ),
        )

        result = json.loads(response["body"].read())
        text = result["content"][0]["text"]
        elapsed_ms = (time.monotonic() - start) * 1000

        try:
            parsed = json.loads(text)
            intent = parsed.get("intent", "unknown")
            confidence = float(parsed.get("confidence", 0.0))
        except (json.JSONDecodeError, ValueError):
            intent = "unknown"
            confidence = 0.0

        if intent not in VALID_INTENTS:
            intent = "unknown"
            confidence = 0.0

        return VoteResult(
            intent=intent,
            confidence=confidence,
            latency_ms=elapsed_ms,
            model_id=f"{self.model_id}:voter-{voter_id}",
            raw_response=text,
        )

    async def vote(self, query: str) -> MajorityVoteOutcome:
        """Run majority voting across all voters in parallel."""
        start = time.monotonic()

        tasks = [
            self._invoke_voter(query, temp, idx)
            for idx, temp in enumerate(self.temperatures[:self.num_voters])
        ]
        votes: list[VoteResult] = await asyncio.gather(*tasks)

        valid_votes = [v for v in votes if v.intent != "unknown"]
        if not valid_votes:
            logger.warning("All voters returned unknown intent for: %s", query)
            best = max(votes, key=lambda v: v.confidence)
            return MajorityVoteOutcome(
                winning_intent="unknown",
                vote_count=0,
                total_votes=len(votes),
                agreement_ratio=0.0,
                all_votes=votes,
                tie_broken=False,
                latency_ms=(time.monotonic() - start) * 1000,
            )

        # Count votes
        intent_counts = Counter(v.intent for v in valid_votes)
        most_common = intent_counts.most_common()

        # Check for tie
        tie_broken = False
        if len(most_common) > 1 and most_common[0][1] == most_common[1][1]:
            # Tie-break: pick intent with highest average confidence
            tie_broken = True
            tied_intents = [
                ic[0] for ic in most_common if ic[1] == most_common[0][1]
            ]
            avg_conf = {}
            for intent in tied_intents:
                intent_votes = [v for v in valid_votes if v.intent == intent]
                avg_conf[intent] = sum(v.confidence for v in intent_votes) / len(
                    intent_votes
                )
            winning_intent = max(avg_conf, key=avg_conf.get)
            vote_count = intent_counts[winning_intent]
        else:
            winning_intent = most_common[0][0]
            vote_count = most_common[0][1]

        agreement_ratio = vote_count / len(valid_votes)
        elapsed_ms = (time.monotonic() - start) * 1000

        logger.info(
            "Majority vote: intent=%s agreement=%.2f votes=%d/%d latency=%.0fms",
            winning_intent, agreement_ratio, vote_count,
            len(valid_votes), elapsed_ms,
        )

        return MajorityVoteOutcome(
            winning_intent=winning_intent,
            vote_count=vote_count,
            total_votes=len(valid_votes),
            agreement_ratio=agreement_ratio,
            all_votes=votes,
            tie_broken=tie_broken,
            latency_ms=elapsed_ms,
        )

2.2 Weighted Voting

Weighted voting extends majority voting by assigning each model a weight based on its historical accuracy for a given intent category. MangaAssist dynamically updates weights using an exponential moving average of per-intent accuracy scores.

graph LR
    subgraph "Weight Assignment"
        HA[Haiku<br/>w=0.3 general<br/>w=0.8 classification] --> WV
        SO[Sonnet<br/>w=0.9 recommendation<br/>w=0.7 classification] --> WV
        RA[RAG<br/>w=0.5 factual<br/>w=0.2 creative] --> WV
    end

    WV[Weighted<br/>Vote Aggregator] --> OUT[Final Decision]

    style WV fill:#264653,stroke:#2a9d8f,color:#fff
Model Intent: product_search Intent: recommendation Intent: content_question Intent: order_status
Haiku 0.85 0.60 0.70 0.90
Sonnet 0.80 0.95 0.90 0.75
RAG (OpenSearch) 0.70 0.50 0.85 0.10
"""
MangaAssist Weighted Voting with Dynamic Weight Updates.

Weights are per-model, per-intent and updated using exponential
moving average based on user feedback and evaluation scores.
"""

import json
import time
import logging
from dataclasses import dataclass, field

import boto3

logger = logging.getLogger(__name__)


@dataclass
class WeightedVoteResult:
    """Result with model weight applied."""
    intent: str
    raw_confidence: float
    model_weight: float
    weighted_score: float
    model_id: str


@dataclass
class WeightedVoteOutcome:
    """Aggregated weighted voting result."""
    winning_intent: str
    weighted_score: float
    all_scores: dict[str, float] = field(default_factory=dict)
    individual_votes: list[WeightedVoteResult] = field(default_factory=list)


class DynamicWeightManager:
    """
    Manages per-model, per-intent weights in DynamoDB.
    Updates using exponential moving average (EMA).

    DynamoDB table: model_weights
    PK: model_id
    SK: intent
    Attributes: weight (float), sample_count (int), last_updated (str)
    """

    def __init__(
        self,
        table_name: str = "mangaassist_model_weights",
        ema_alpha: float = 0.1,
        dynamodb_resource=None,
    ):
        self.dynamodb = dynamodb_resource or boto3.resource(
            "dynamodb", region_name="ap-northeast-1"
        )
        self.table = self.dynamodb.Table(table_name)
        self.ema_alpha = ema_alpha
        self._cache: dict[str, float] = {}

    def get_weight(self, model_id: str, intent: str) -> float:
        """Retrieve current weight for model-intent pair."""
        cache_key = f"{model_id}:{intent}"
        if cache_key in self._cache:
            return self._cache[cache_key]

        try:
            resp = self.table.get_item(
                Key={"model_id": model_id, "intent": intent}
            )
            weight = float(resp.get("Item", {}).get("weight", 0.5))
        except Exception:
            weight = 0.5  # Default weight for unknown pairs

        self._cache[cache_key] = weight
        return weight

    def update_weight(
        self, model_id: str, intent: str, was_correct: bool
    ) -> float:
        """
        Update weight using EMA after feedback.

        new_weight = alpha * observation + (1 - alpha) * old_weight
        observation = 1.0 if correct, 0.0 if incorrect
        """
        current = self.get_weight(model_id, intent)
        observation = 1.0 if was_correct else 0.0
        new_weight = self.ema_alpha * observation + (1 - self.ema_alpha) * current

        # Clamp to [0.05, 0.99] to prevent dead/dominant weights
        new_weight = max(0.05, min(0.99, new_weight))

        self.table.update_item(
            Key={"model_id": model_id, "intent": intent},
            UpdateExpression=(
                "SET weight = :w, sample_count = if_not_exists(sample_count, :zero) + :one, "
                "last_updated = :ts"
            ),
            ExpressionAttributeValues={
                ":w": str(new_weight),
                ":zero": 0,
                ":one": 1,
                ":ts": str(int(time.time())),
            },
        )

        cache_key = f"{model_id}:{intent}"
        self._cache[cache_key] = new_weight

        logger.info(
            "Weight update: model=%s intent=%s correct=%s %.3f -> %.3f",
            model_id, intent, was_correct, current, new_weight,
        )
        return new_weight


class WeightedVotingEnsemble:
    """
    Aggregates model responses using per-intent weights.
    For recommendation queries: Sonnet weight dominates.
    For factual queries: RAG weight dominates.
    """

    def __init__(self, weight_manager: DynamicWeightManager):
        self.weights = weight_manager

    def aggregate(
        self,
        votes: list[dict],
        detected_intent: str,
    ) -> WeightedVoteOutcome:
        """
        Aggregate votes with dynamic weights.

        Each vote dict: {
            "model_id": str,
            "response": str,
            "confidence": float,
            "intent_guess": str,
        }
        """
        weighted_results = []
        intent_scores: dict[str, float] = {}

        for vote in votes:
            model_id = vote["model_id"]
            intent = vote.get("intent_guess", detected_intent)
            confidence = vote["confidence"]
            weight = self.weights.get_weight(model_id, detected_intent)
            weighted_score = confidence * weight

            result = WeightedVoteResult(
                intent=intent,
                raw_confidence=confidence,
                model_weight=weight,
                weighted_score=weighted_score,
                model_id=model_id,
            )
            weighted_results.append(result)

            intent_scores[intent] = intent_scores.get(intent, 0.0) + weighted_score

        winning_intent = max(intent_scores, key=intent_scores.get)

        return WeightedVoteOutcome(
            winning_intent=winning_intent,
            weighted_score=intent_scores[winning_intent],
            all_scores=intent_scores,
            individual_votes=weighted_results,
        )

3. Confidence Weighting and Calibration

3.1 Confidence Extraction from Bedrock Models

Claude 3 models on Bedrock do not directly expose token-level log-probabilities through the standard API. MangaAssist uses a self-reported confidence approach: the prompt explicitly asks the model to include a confidence score, which is then calibrated against historical accuracy.

graph TD
    subgraph "Confidence Pipeline"
        RAW[Raw Model Response<br/>with self-reported confidence] --> EXT[Confidence Extractor<br/>Parse JSON confidence field]
        EXT --> CAL[Calibration Layer<br/>Map raw to calibrated probability]
        CAL --> THR[Threshold Gate<br/>Accept / Escalate / Reject]
    end

    subgraph "Calibration Data"
        FB[User Feedback<br/>thumbs up/down] --> CAL_DB[(Calibration<br/>DynamoDB Table)]
        EVAL[Offline Evaluation<br/>labeled test sets] --> CAL_DB
        CAL_DB --> CAL
    end

    THR --> |confidence >= 0.85| ACCEPT[Accept Response]
    THR --> |0.50 <= confidence < 0.85| ESC[Escalate to Sonnet]
    THR --> |confidence < 0.50| REJ[Reject + Human Fallback]

    style CAL fill:#264653,stroke:#2a9d8f,color:#fff
    style THR fill:#e76f51,stroke:#f4a261,color:#fff

3.2 Calibration Table

Raw self-reported confidence from Claude models tends to be overconfident. MangaAssist maintains a calibration mapping updated weekly from evaluation data.

Raw Confidence Range Calibrated Probability Observed Accuracy Action
0.95 - 1.00 0.88 - 0.92 89% Auto-accept
0.85 - 0.94 0.75 - 0.87 80% Accept with logging
0.70 - 0.84 0.55 - 0.74 62% Escalate to Sonnet
0.50 - 0.69 0.35 - 0.54 41% Dual-model ensemble
0.00 - 0.49 0.10 - 0.34 22% Reject, human fallback
"""
MangaAssist Confidence Calibration System.

Maps raw self-reported model confidence to calibrated probabilities
using isotonic regression fitted on historical evaluation data.
"""

import bisect
import logging
from dataclasses import dataclass

import boto3

logger = logging.getLogger(__name__)


@dataclass
class CalibratedConfidence:
    """Calibrated confidence result."""
    raw_confidence: float
    calibrated_probability: float
    action: str  # "accept", "escalate", "ensemble", "reject"
    calibration_version: str


class ConfidenceCalibrator:
    """
    Isotonic regression-based calibration for FM confidence scores.

    Calibration curve is stored in DynamoDB and refreshed weekly
    by the evaluation pipeline. Uses piecewise linear interpolation
    between calibration breakpoints.
    """

    # Default breakpoints (raw -> calibrated) from initial evaluation
    DEFAULT_BREAKPOINTS = [
        (0.00, 0.10), (0.30, 0.18), (0.50, 0.35),
        (0.70, 0.55), (0.80, 0.68), (0.85, 0.75),
        (0.90, 0.82), (0.95, 0.88), (1.00, 0.92),
    ]

    ACTION_THRESHOLDS = {
        "accept": 0.75,
        "escalate": 0.55,
        "ensemble": 0.35,
        "reject": 0.0,
    }

    def __init__(
        self,
        breakpoints: list[tuple[float, float]] | None = None,
        version: str = "v1.0",
    ):
        self.breakpoints = breakpoints or self.DEFAULT_BREAKPOINTS
        self.breakpoints.sort(key=lambda x: x[0])
        self._raw_vals = [b[0] for b in self.breakpoints]
        self._cal_vals = [b[1] for b in self.breakpoints]
        self.version = version

    def calibrate(self, raw_confidence: float) -> CalibratedConfidence:
        """Map raw confidence to calibrated probability via interpolation."""
        raw = max(0.0, min(1.0, raw_confidence))

        # Find interpolation segment
        idx = bisect.bisect_right(self._raw_vals, raw)
        if idx == 0:
            calibrated = self._cal_vals[0]
        elif idx >= len(self._raw_vals):
            calibrated = self._cal_vals[-1]
        else:
            # Linear interpolation
            x0, x1 = self._raw_vals[idx - 1], self._raw_vals[idx]
            y0, y1 = self._cal_vals[idx - 1], self._cal_vals[idx]
            t = (raw - x0) / (x1 - x0) if x1 != x0 else 0.0
            calibrated = y0 + t * (y1 - y0)

        # Determine action
        if calibrated >= self.ACTION_THRESHOLDS["accept"]:
            action = "accept"
        elif calibrated >= self.ACTION_THRESHOLDS["escalate"]:
            action = "escalate"
        elif calibrated >= self.ACTION_THRESHOLDS["ensemble"]:
            action = "ensemble"
        else:
            action = "reject"

        return CalibratedConfidence(
            raw_confidence=raw,
            calibrated_probability=round(calibrated, 4),
            action=action,
            calibration_version=self.version,
        )

    def recalibrate_from_evaluations(
        self,
        evaluation_pairs: list[tuple[float, bool]],
        num_bins: int = 10,
    ) -> list[tuple[float, float]]:
        """
        Rebuild calibration curve from evaluation data.

        evaluation_pairs: [(raw_confidence, was_correct), ...]
        Uses binned frequency estimation (simplified isotonic regression).
        """
        bins: dict[int, list[bool]] = {i: [] for i in range(num_bins)}

        for raw, correct in evaluation_pairs:
            bin_idx = min(int(raw * num_bins), num_bins - 1)
            bins[bin_idx].append(correct)

        new_breakpoints = []
        for bin_idx in range(num_bins):
            raw_midpoint = (bin_idx + 0.5) / num_bins
            if bins[bin_idx]:
                calibrated = sum(bins[bin_idx]) / len(bins[bin_idx])
            else:
                # Interpolate from neighbors
                calibrated = raw_midpoint * 0.8  # Conservative default
            new_breakpoints.append((raw_midpoint, calibrated))

        # Enforce monotonicity (isotonic constraint)
        for i in range(1, len(new_breakpoints)):
            if new_breakpoints[i][1] < new_breakpoints[i - 1][1]:
                avg = (new_breakpoints[i][1] + new_breakpoints[i - 1][1]) / 2
                new_breakpoints[i] = (new_breakpoints[i][0], avg)
                new_breakpoints[i - 1] = (new_breakpoints[i - 1][0], avg)

        self.breakpoints = new_breakpoints
        self._raw_vals = [b[0] for b in self.breakpoints]
        self._cal_vals = [b[1] for b in self.breakpoints]

        logger.info("Recalibrated with %d data points, %d bins", len(evaluation_pairs), num_bins)
        return new_breakpoints

4. Response Quality Scoring

4.1 Multi-Dimensional Quality Framework

MangaAssist scores every ensemble candidate response across four dimensions before selecting the final output.

graph TD
    subgraph "Quality Dimensions"
        COH[Coherence<br/>Is the response logically structured?]
        REL[Relevance<br/>Does it address the query?]
        GRD[Grounding<br/>Is it factually supported by RAG?]
        TON[Tone<br/>Appropriate for manga store context?]
    end

    subgraph "Scoring"
        COH --> |0.0-1.0| SC[Quality Score<br/>Weighted Average]
        REL --> |0.0-1.0| SC
        GRD --> |0.0-1.0| SC
        TON --> |0.0-1.0| SC
    end

    SC --> |score >= 0.7| PASS[Accept Response]
    SC --> |score < 0.7| FAIL[Reject / Re-generate]

    style SC fill:#264653,stroke:#2a9d8f,color:#fff

4.2 Quality Scoring Weights by Query Type

Query Type Coherence Weight Relevance Weight Grounding Weight Tone Weight
Recommendation 0.20 0.35 0.30 0.15
Factual (release date, price) 0.10 0.25 0.55 0.10
Creative (describe manga) 0.30 0.25 0.15 0.30
Complaint handling 0.15 0.30 0.10 0.45
Order status 0.10 0.30 0.50 0.10
"""
MangaAssist Response Quality Scorer.

Evaluates FM responses across coherence, relevance, grounding,
and tone dimensions. Uses Haiku as a lightweight judge model.
"""

import asyncio
import json
import logging
import time
from dataclasses import dataclass
from enum import Enum

import boto3

logger = logging.getLogger(__name__)


class QueryType(str, Enum):
    RECOMMENDATION = "recommendation"
    FACTUAL = "factual"
    CREATIVE = "creative"
    COMPLAINT = "complaint"
    ORDER_STATUS = "order_status"
    GENERAL = "general"


@dataclass
class QualityDimension:
    """Score for a single quality dimension."""
    name: str
    score: float
    weight: float
    explanation: str


@dataclass
class QualityScore:
    """Complete quality assessment of a response."""
    overall_score: float
    dimensions: list[QualityDimension]
    passes_threshold: bool
    query_type: QueryType
    evaluation_latency_ms: float


# Weights per query type
QUALITY_WEIGHTS: dict[QueryType, dict[str, float]] = {
    QueryType.RECOMMENDATION: {
        "coherence": 0.20, "relevance": 0.35,
        "grounding": 0.30, "tone": 0.15,
    },
    QueryType.FACTUAL: {
        "coherence": 0.10, "relevance": 0.25,
        "grounding": 0.55, "tone": 0.10,
    },
    QueryType.CREATIVE: {
        "coherence": 0.30, "relevance": 0.25,
        "grounding": 0.15, "tone": 0.30,
    },
    QueryType.COMPLAINT: {
        "coherence": 0.15, "relevance": 0.30,
        "grounding": 0.10, "tone": 0.45,
    },
    QueryType.ORDER_STATUS: {
        "coherence": 0.10, "relevance": 0.30,
        "grounding": 0.50, "tone": 0.10,
    },
    QueryType.GENERAL: {
        "coherence": 0.25, "relevance": 0.30,
        "grounding": 0.25, "tone": 0.20,
    },
}

QUALITY_THRESHOLD = 0.70


class ResponseQualityScorer:
    """
    Uses Claude 3 Haiku as a lightweight judge to score responses.

    The judge prompt evaluates each dimension on a 0-10 scale,
    which is normalized to 0.0-1.0. Haiku judge adds ~150ms
    overhead but catches 73% of low-quality responses before
    they reach the customer.
    """

    def __init__(
        self,
        bedrock_client=None,
        judge_model_id: str = "anthropic.claude-3-haiku-20240307-v1:0",
        threshold: float = QUALITY_THRESHOLD,
    ):
        self.bedrock = bedrock_client or boto3.client(
            "bedrock-runtime", region_name="ap-northeast-1"
        )
        self.judge_model_id = judge_model_id
        self.threshold = threshold

    async def score_response(
        self,
        query: str,
        response: str,
        query_type: QueryType,
        rag_context: str = "",
    ) -> QualityScore:
        """Score a single response across all quality dimensions."""
        start = time.monotonic()

        prompt = self._build_judge_prompt(query, response, rag_context)

        body = json.dumps({
            "anthropic_version": "bedrock-2023-05-31",
            "max_tokens": 500,
            "temperature": 0.0,
            "messages": [{"role": "user", "content": prompt}],
        })

        loop = asyncio.get_event_loop()
        api_response = await loop.run_in_executor(
            None,
            lambda: self.bedrock.invoke_model(
                modelId=self.judge_model_id,
                body=body,
                contentType="application/json",
                accept="application/json",
            ),
        )

        result = json.loads(api_response["body"].read())
        text = result["content"][0]["text"]

        dimensions = self._parse_judge_response(text, query_type)
        overall = sum(d.score * d.weight for d in dimensions)
        elapsed_ms = (time.monotonic() - start) * 1000

        return QualityScore(
            overall_score=round(overall, 4),
            dimensions=dimensions,
            passes_threshold=overall >= self.threshold,
            query_type=query_type,
            evaluation_latency_ms=elapsed_ms,
        )

    def _build_judge_prompt(
        self, query: str, response: str, rag_context: str
    ) -> str:
        """Build the judge evaluation prompt."""
        return (
            "You are a quality evaluator for a Japanese manga store chatbot.\n"
            "Score the following response on four dimensions (0-10 each).\n\n"
            f"CUSTOMER QUERY: {query}\n\n"
            f"CHATBOT RESPONSE: {response}\n\n"
            f"RAG CONTEXT (if available): {rag_context or 'None'}\n\n"
            "Score each dimension:\n"
            "1. coherence: Is the response logically structured and readable?\n"
            "2. relevance: Does it directly address the customer's question?\n"
            "3. grounding: Is the information factually supported by the RAG context?\n"
            "4. tone: Is the tone appropriate for a friendly manga store?\n\n"
            "Respond with JSON only:\n"
            '{"coherence": {"score": N, "reason": "..."}, '
            '"relevance": {"score": N, "reason": "..."}, '
            '"grounding": {"score": N, "reason": "..."}, '
            '"tone": {"score": N, "reason": "..."}}'
        )

    def _parse_judge_response(
        self, text: str, query_type: QueryType
    ) -> list[QualityDimension]:
        """Parse judge response into scored dimensions."""
        weights = QUALITY_WEIGHTS.get(query_type, QUALITY_WEIGHTS[QueryType.GENERAL])

        try:
            parsed = json.loads(text)
        except json.JSONDecodeError:
            logger.warning("Judge returned non-JSON: %s", text[:200])
            return [
                QualityDimension(
                    name=dim, score=0.5, weight=w,
                    explanation="Judge parse failure — default score"
                )
                for dim, w in weights.items()
            ]

        dimensions = []
        for dim_name, weight in weights.items():
            dim_data = parsed.get(dim_name, {})
            raw_score = dim_data.get("score", 5)
            normalized = max(0.0, min(1.0, raw_score / 10.0))
            reason = dim_data.get("reason", "No explanation provided")

            dimensions.append(QualityDimension(
                name=dim_name,
                score=normalized,
                weight=weight,
                explanation=reason,
            ))

        return dimensions

    async def rank_responses(
        self,
        query: str,
        responses: list[dict],
        query_type: QueryType,
        rag_context: str = "",
    ) -> list[tuple[dict, QualityScore]]:
        """Score multiple responses and return ranked by quality."""
        tasks = [
            self.score_response(query, r["response"], query_type, rag_context)
            for r in responses
        ]
        scores = await asyncio.gather(*tasks)

        ranked = sorted(
            zip(responses, scores),
            key=lambda x: x[1].overall_score,
            reverse=True,
        )
        return ranked

5. Cost-Aware Ensemble Aggregation

5.1 Budget-Constrained Model Selection

MangaAssist processes 1M messages/day. At full Sonnet ensemble for every query, daily cost would be approximately $18,000. The cost-aware aggregator enforces per-query and daily budget ceilings.

graph TD
    subgraph "Cost Decision Tree"
        Q[Incoming Query] --> CC{Check Daily<br/>Budget Remaining}
        CC -->|< 10% remaining| HAIKU[Force Haiku Only<br/>Cheapest path]
        CC -->|>= 10% remaining| CX{Query Complexity?}
        CX -->|Simple| H2[Haiku Only<br/>~$0.0003/query]
        CX -->|Medium| CASC[Cascade<br/>Haiku first, Sonnet if needed<br/>~$0.003 avg]
        CX -->|Complex| ENS{Ensemble Budget<br/>Check}
        ENS -->|Per-query budget OK| FULL[Full Ensemble<br/>Haiku + Sonnet + RAG<br/>~$0.018/query]
        ENS -->|Per-query over budget| PARTIAL[Partial Ensemble<br/>Sonnet + RAG only<br/>~$0.015/query]
    end

    style CC fill:#e76f51,stroke:#f4a261,color:#fff
    style HAIKU fill:#2a9d8f,stroke:#264653,color:#fff

5.2 Cost Comparison: Ensemble Strategies

Strategy Avg Cost/Query Daily Cost (1M msgs) Quality Score Latency (p50) Best For
Haiku Only $0.0003 $300 0.72 180ms Simple FAQ, classification
Sonnet Only $0.012 $12,000 0.91 650ms Complex reasoning
Cascade (Haiku -> Sonnet) $0.003 $3,000 0.86 280ms Cost-optimized general use
Full Ensemble (3-model) $0.018 $18,000 0.94 800ms Maximum quality
Budget Ensemble (adaptive) $0.005 $5,000 0.89 350ms Production default
"""
MangaAssist Cost-Aware Ensemble Aggregator.

Dynamically selects ensemble strategy based on query complexity,
remaining daily budget, and per-query cost ceiling.
"""

import asyncio
import json
import logging
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional

import boto3

logger = logging.getLogger(__name__)


class EnsembleStrategy(str, Enum):
    HAIKU_ONLY = "haiku_only"
    SONNET_ONLY = "sonnet_only"
    CASCADE = "cascade"
    FULL_ENSEMBLE = "full_ensemble"
    BUDGET_ENSEMBLE = "budget_ensemble"


@dataclass
class CostEnvelope:
    """Cost tracking for a single query."""
    strategy: EnsembleStrategy
    haiku_input_tokens: int = 0
    haiku_output_tokens: int = 0
    sonnet_input_tokens: int = 0
    sonnet_output_tokens: int = 0
    embedding_tokens: int = 0
    total_cost_usd: float = 0.0

    def calculate_cost(self) -> float:
        """Calculate total cost based on Bedrock pricing."""
        haiku_cost = (
            self.haiku_input_tokens * 0.25 / 1_000_000
            + self.haiku_output_tokens * 1.25 / 1_000_000
        )
        sonnet_cost = (
            self.sonnet_input_tokens * 3.0 / 1_000_000
            + self.sonnet_output_tokens * 15.0 / 1_000_000
        )
        embed_cost = self.embedding_tokens * 0.02 / 1_000_000

        self.total_cost_usd = haiku_cost + sonnet_cost + embed_cost
        return self.total_cost_usd


@dataclass
class BudgetState:
    """Tracks daily budget consumption."""
    daily_budget_usd: float = 5000.0
    spent_today_usd: float = 0.0
    queries_today: int = 0
    per_query_ceiling_usd: float = 0.02

    @property
    def remaining_usd(self) -> float:
        return max(0.0, self.daily_budget_usd - self.spent_today_usd)

    @property
    def remaining_pct(self) -> float:
        return self.remaining_usd / self.daily_budget_usd

    def record_spend(self, cost: float) -> None:
        self.spent_today_usd += cost
        self.queries_today += 1


class CostAwareAggregator:
    """
    Selects and executes ensemble strategy based on budget constraints.

    Decision logic:
    1. If budget < 10% remaining -> Haiku only (preserve budget)
    2. If query is simple -> Haiku only (no need for ensemble)
    3. If query is medium -> Cascade (Haiku first, escalate if uncertain)
    4. If query is complex and budget allows -> Full ensemble
    5. If query is complex but budget tight -> Budget ensemble (2 models)
    """

    # Complexity thresholds (from Haiku classifier)
    COMPLEXITY_SIMPLE = 0.3
    COMPLEXITY_MEDIUM = 0.7

    def __init__(
        self,
        budget: BudgetState,
        bedrock_client=None,
        redis_client=None,
    ):
        self.budget = budget
        self.bedrock = bedrock_client or boto3.client(
            "bedrock-runtime", region_name="ap-northeast-1"
        )
        self.redis = redis_client

    def select_strategy(
        self, complexity_score: float, query_type: str
    ) -> EnsembleStrategy:
        """Select ensemble strategy based on complexity and budget."""
        # Budget emergency mode
        if self.budget.remaining_pct < 0.10:
            logger.warning(
                "Budget emergency: %.1f%% remaining, forcing Haiku only",
                self.budget.remaining_pct * 100,
            )
            return EnsembleStrategy.HAIKU_ONLY

        # Simple queries never need ensemble
        if complexity_score < self.COMPLEXITY_SIMPLE:
            return EnsembleStrategy.HAIKU_ONLY

        # Medium complexity: cascade
        if complexity_score < self.COMPLEXITY_MEDIUM:
            return EnsembleStrategy.CASCADE

        # Complex queries: check per-query budget
        estimated_full_cost = 0.018
        estimated_budget_cost = 0.008

        if estimated_full_cost <= self.budget.per_query_ceiling_usd:
            return EnsembleStrategy.FULL_ENSEMBLE

        if estimated_budget_cost <= self.budget.per_query_ceiling_usd:
            return EnsembleStrategy.BUDGET_ENSEMBLE

        # Budget too tight for any ensemble
        return EnsembleStrategy.CASCADE

    async def execute_cascade(
        self, query: str, system_prompt: str
    ) -> tuple[str, CostEnvelope]:
        """
        Cascade strategy: try Haiku first, escalate to Sonnet
        if Haiku confidence is below threshold.
        """
        cost = CostEnvelope(strategy=EnsembleStrategy.CASCADE)

        # Step 1: Haiku attempt
        haiku_resp, haiku_conf, haiku_tokens = await self._invoke_model(
            query, system_prompt, "anthropic.claude-3-haiku-20240307-v1:0"
        )
        cost.haiku_input_tokens = haiku_tokens["input"]
        cost.haiku_output_tokens = haiku_tokens["output"]

        # If Haiku is confident enough, return immediately
        if haiku_conf >= 0.85:
            cost.calculate_cost()
            self.budget.record_spend(cost.total_cost_usd)
            return haiku_resp, cost

        # Step 2: Escalate to Sonnet
        sonnet_resp, sonnet_conf, sonnet_tokens = await self._invoke_model(
            query, system_prompt, "anthropic.claude-3-sonnet-20240229-v1:0"
        )
        cost.sonnet_input_tokens = sonnet_tokens["input"]
        cost.sonnet_output_tokens = sonnet_tokens["output"]

        cost.calculate_cost()
        self.budget.record_spend(cost.total_cost_usd)
        return sonnet_resp, cost

    async def _invoke_model(
        self, query: str, system_prompt: str, model_id: str
    ) -> tuple[str, float, dict]:
        """Invoke a single model and extract response + confidence."""
        body = json.dumps({
            "anthropic_version": "bedrock-2023-05-31",
            "max_tokens": 1024,
            "temperature": 0.1,
            "system": system_prompt,
            "messages": [{"role": "user", "content": query}],
        })

        loop = asyncio.get_event_loop()
        response = await loop.run_in_executor(
            None,
            lambda: self.bedrock.invoke_model(
                modelId=model_id,
                body=body,
                contentType="application/json",
                accept="application/json",
            ),
        )

        result = json.loads(response["body"].read())
        text = result["content"][0]["text"]
        usage = result.get("usage", {})
        tokens = {
            "input": usage.get("input_tokens", 0),
            "output": usage.get("output_tokens", 0),
        }

        # Extract confidence from response if present
        confidence = 0.5  # Default
        try:
            if '"confidence"' in text:
                import re
                match = re.search(r'"confidence"\s*:\s*([\d.]+)', text)
                if match:
                    confidence = float(match.group(1))
        except (ValueError, AttributeError):
            pass

        return text, confidence, tokens

6. Ensemble Response Merging

6.1 Response Merge Strategies

When multiple models produce valid but complementary responses, MangaAssist merges them rather than discarding the runner-up.

graph TD
    subgraph "Merge Strategies"
        SEL[Selection<br/>Pick best single response] --> |"Simplest, fast"| OUT[Final Response]
        UNI[Union Merge<br/>Combine unique facts] --> |"More complete"| OUT
        REF[Refinement<br/>Use one response to refine another] --> |"Highest quality"| OUT
        SUM[Summarization<br/>Synthesize all into new response] --> |"Most expensive"| OUT
    end

    style OUT fill:#264653,stroke:#2a9d8f,color:#fff
Merge Strategy Latency Impact Quality Impact Cost Use Case
Selection +0ms Baseline None Clear winner by quality score
Union Merge +50ms +15% completeness ~$0.001 (Haiku dedup) Factual queries with complementary info
Refinement +200ms +25% quality ~$0.005 (Sonnet refine) Recommendation queries
Summarization +400ms +30% quality ~$0.010 (Sonnet synthesize) Complex multi-faceted queries
"""
MangaAssist Response Merger.

Combines multiple model responses using configurable strategies:
selection, union merge, refinement, or summarization.
"""

import asyncio
import json
import logging
from dataclasses import dataclass
from enum import Enum

import boto3

logger = logging.getLogger(__name__)


class MergeStrategy(str, Enum):
    SELECTION = "selection"
    UNION = "union"
    REFINEMENT = "refinement"
    SUMMARIZATION = "summarization"


@dataclass
class MergedResponse:
    """Result of response merging."""
    final_text: str
    strategy_used: MergeStrategy
    source_count: int
    merge_latency_ms: float
    merge_cost_usd: float


class ResponseMerger:
    """
    Merges multiple FM responses into a single high-quality output.
    Strategy selection is based on query type and response similarity.
    """

    def __init__(self, bedrock_client=None):
        self.bedrock = bedrock_client or boto3.client(
            "bedrock-runtime", region_name="ap-northeast-1"
        )

    async def merge(
        self,
        query: str,
        ranked_responses: list[tuple[str, float]],
        strategy: MergeStrategy = MergeStrategy.SELECTION,
    ) -> MergedResponse:
        """
        Merge responses based on strategy.

        ranked_responses: [(response_text, quality_score), ...] sorted desc
        """
        import time
        start = time.monotonic()

        if strategy == MergeStrategy.SELECTION:
            result = self._select_best(ranked_responses)
            cost = 0.0
        elif strategy == MergeStrategy.UNION:
            result, cost = await self._union_merge(query, ranked_responses)
        elif strategy == MergeStrategy.REFINEMENT:
            result, cost = await self._refine(query, ranked_responses)
        elif strategy == MergeStrategy.SUMMARIZATION:
            result, cost = await self._summarize(query, ranked_responses)
        else:
            result = self._select_best(ranked_responses)
            cost = 0.0

        elapsed = (time.monotonic() - start) * 1000

        return MergedResponse(
            final_text=result,
            strategy_used=strategy,
            source_count=len(ranked_responses),
            merge_latency_ms=elapsed,
            merge_cost_usd=cost,
        )

    def _select_best(
        self, ranked_responses: list[tuple[str, float]]
    ) -> str:
        """Simply return the highest-scored response."""
        if not ranked_responses:
            return "I'm sorry, I couldn't generate a response."
        return ranked_responses[0][0]

    async def _union_merge(
        self, query: str, ranked_responses: list[tuple[str, float]]
    ) -> tuple[str, float]:
        """Combine unique facts from all responses using Haiku."""
        all_responses = "\n---\n".join(
            f"Response {i+1} (score={s:.2f}): {r}"
            for i, (r, s) in enumerate(ranked_responses)
        )

        prompt = (
            "You are merging multiple chatbot responses for a manga store customer.\n"
            f"Customer query: {query}\n\n"
            f"Responses to merge:\n{all_responses}\n\n"
            "Create a single response that includes all unique, relevant facts "
            "from the responses above. Remove duplicates. Keep the friendly tone. "
            "Do not mention that multiple sources were used."
        )

        body = json.dumps({
            "anthropic_version": "bedrock-2023-05-31",
            "max_tokens": 1024,
            "temperature": 0.1,
            "messages": [{"role": "user", "content": prompt}],
        })

        loop = asyncio.get_event_loop()
        response = await loop.run_in_executor(
            None,
            lambda: self.bedrock.invoke_model(
                modelId="anthropic.claude-3-haiku-20240307-v1:0",
                body=body,
                contentType="application/json",
                accept="application/json",
            ),
        )

        result = json.loads(response["body"].read())
        text = result["content"][0]["text"]
        usage = result.get("usage", {})
        cost = (
            usage.get("input_tokens", 0) * 0.25 / 1_000_000
            + usage.get("output_tokens", 0) * 1.25 / 1_000_000
        )

        return text, cost

    async def _refine(
        self, query: str, ranked_responses: list[tuple[str, float]]
    ) -> tuple[str, float]:
        """Use Sonnet to refine the best response with context from others."""
        best_response = ranked_responses[0][0] if ranked_responses else ""
        supporting = "\n".join(
            f"- {r}" for r, _ in ranked_responses[1:3]
        )

        prompt = (
            "You are improving a chatbot response for a Japanese manga store.\n"
            f"Customer query: {query}\n\n"
            f"Primary response:\n{best_response}\n\n"
            f"Additional context from other models:\n{supporting}\n\n"
            "Refine the primary response by incorporating any useful details "
            "from the additional context. Keep it concise and friendly."
        )

        body = json.dumps({
            "anthropic_version": "bedrock-2023-05-31",
            "max_tokens": 1024,
            "temperature": 0.2,
            "messages": [{"role": "user", "content": prompt}],
        })

        loop = asyncio.get_event_loop()
        response = await loop.run_in_executor(
            None,
            lambda: self.bedrock.invoke_model(
                modelId="anthropic.claude-3-sonnet-20240229-v1:0",
                body=body,
                contentType="application/json",
                accept="application/json",
            ),
        )

        result = json.loads(response["body"].read())
        text = result["content"][0]["text"]
        usage = result.get("usage", {})
        cost = (
            usage.get("input_tokens", 0) * 3.0 / 1_000_000
            + usage.get("output_tokens", 0) * 15.0 / 1_000_000
        )

        return text, cost

    async def _summarize(
        self, query: str, ranked_responses: list[tuple[str, float]]
    ) -> tuple[str, float]:
        """Synthesize all responses into a new comprehensive answer."""
        all_text = "\n\n".join(
            f"Source {i+1}: {r}" for i, (r, _) in enumerate(ranked_responses)
        )

        prompt = (
            "You are a manga store chatbot assistant. Synthesize the following "
            "responses into one comprehensive, accurate, and friendly answer.\n\n"
            f"Customer asked: {query}\n\n"
            f"Source responses:\n{all_text}\n\n"
            "Write a single synthesized response. Do not reference the sources."
        )

        body = json.dumps({
            "anthropic_version": "bedrock-2023-05-31",
            "max_tokens": 1024,
            "temperature": 0.3,
            "messages": [{"role": "user", "content": prompt}],
        })

        loop = asyncio.get_event_loop()
        response = await loop.run_in_executor(
            None,
            lambda: self.bedrock.invoke_model(
                modelId="anthropic.claude-3-sonnet-20240229-v1:0",
                body=body,
                contentType="application/json",
                accept="application/json",
            ),
        )

        result = json.loads(response["body"].read())
        text = result["content"][0]["text"]
        usage = result.get("usage", {})
        cost = (
            usage.get("input_tokens", 0) * 3.0 / 1_000_000
            + usage.get("output_tokens", 0) * 15.0 / 1_000_000
        )

        return text, cost

7. End-to-End Ensemble Pipeline

7.1 Orchestration Sequence

sequenceDiagram
    participant C as Customer
    participant GW as API Gateway
    participant ECS as ECS Fargate
    participant RC as Router (Haiku)
    participant AGG as Aggregator
    participant H as Haiku
    participant S as Sonnet
    participant OS as OpenSearch RAG
    participant QS as Quality Scorer
    participant CACHE as Redis Cache

    C->>GW: "Recommend dark fantasy manga"
    GW->>ECS: WebSocket message
    ECS->>CACHE: Check semantic cache
    CACHE-->>ECS: Cache miss

    ECS->>RC: Classify complexity
    RC-->>ECS: complexity=0.82 (high)

    ECS->>AGG: Select strategy
    Note over AGG: Budget OK, complexity high<br/>Strategy: FULL_ENSEMBLE

    par Parallel Invocation
        AGG->>H: Generate recommendation
        AGG->>S: Generate recommendation
        AGG->>OS: Vector search similar titles
    end

    H-->>AGG: Haiku response (180ms)
    OS-->>AGG: RAG results (120ms)
    S-->>AGG: Sonnet response (650ms)

    AGG->>QS: Score all 3 responses
    QS-->>AGG: Ranked by quality

    Note over AGG: Strategy: REFINEMENT merge<br/>Sonnet best + RAG facts

    AGG-->>ECS: Merged response
    ECS->>CACHE: Store in semantic cache
    ECS-->>GW: Final response
    GW-->>C: "For dark fantasy manga, I'd recommend..."

7.2 Latency Breakdown

Stage Duration Cumulative Notes
Cache check 5ms 5ms Redis HNSW vector search
Complexity classification 150ms 155ms Haiku single invocation
Parallel model invocation 650ms 805ms Bounded by slowest (Sonnet)
Quality scoring 160ms 965ms Haiku judge, all 3 candidates
Response merging 220ms 1,185ms Refinement via Sonnet
Cache store 8ms 1,193ms Async, non-blocking
Total ~1,200ms Well under 3s target

8. Embedding Drift Detection for Ensemble Stability

8.1 Why Embedding Drift Matters

When Titan Embeddings produces vectors that shift over time (model updates, data distribution changes), the semantic cache and RAG component of the ensemble begin to disagree with the voting models. MangaAssist monitors embedding drift to keep the ensemble synchronized.

graph TD
    subgraph "Drift Detection Pipeline"
        REF[Reference Embeddings<br/>1000 canonical queries] --> COMP[Weekly Comparison]
        CURR[Current Embeddings<br/>Same 1000 queries] --> COMP
        COMP --> DRIFT{Cosine Drift<br/>> 0.05?}
        DRIFT -->|No| OK[Embeddings Stable<br/>No action needed]
        DRIFT -->|Yes| ALARM[CloudWatch Alarm<br/>EmbeddingDriftHigh]
        ALARM --> REBUILD[Trigger Cache<br/>Rebuild + Re-index]
    end

    style ALARM fill:#e76f51,stroke:#f4a261,color:#fff
    style OK fill:#2a9d8f,stroke:#264653,color:#fff
"""
MangaAssist Embedding Drift Monitor.

Detects when the Titan Embeddings model produces vectors that
have shifted from the reference baseline, indicating the need
for cache rebuild and OpenSearch re-indexing.
"""

import json
import logging
import math
from dataclasses import dataclass

import boto3

logger = logging.getLogger(__name__)


@dataclass
class DriftReport:
    """Embedding drift analysis result."""
    mean_cosine_drift: float
    max_cosine_drift: float
    pct_above_threshold: float
    drift_detected: bool
    sample_size: int
    threshold: float


class EmbeddingDriftMonitor:
    """
    Compares current embeddings against a stored reference set.
    Uses cosine distance as the drift metric.
    """

    def __init__(
        self,
        bedrock_client=None,
        s3_client=None,
        reference_bucket: str = "mangaassist-ml-artifacts",
        reference_key: str = "embeddings/reference_vectors.json",
        drift_threshold: float = 0.05,
        model_id: str = "amazon.titan-embed-text-v2:0",
    ):
        self.bedrock = bedrock_client or boto3.client(
            "bedrock-runtime", region_name="ap-northeast-1"
        )
        self.s3 = s3_client or boto3.client("s3")
        self.reference_bucket = reference_bucket
        self.reference_key = reference_key
        self.drift_threshold = drift_threshold
        self.model_id = model_id

    def _cosine_distance(self, vec_a: list[float], vec_b: list[float]) -> float:
        """Compute cosine distance (1 - cosine similarity)."""
        dot = sum(a * b for a, b in zip(vec_a, vec_b))
        norm_a = math.sqrt(sum(a * a for a in vec_a))
        norm_b = math.sqrt(sum(b * b for b in vec_b))
        if norm_a == 0 or norm_b == 0:
            return 1.0
        similarity = dot / (norm_a * norm_b)
        return 1.0 - similarity

    def _embed_text(self, text: str) -> list[float]:
        """Generate embedding for a single text."""
        body = json.dumps({
            "inputText": text,
            "dimensions": 1024,
            "normalize": True,
        })
        response = self.bedrock.invoke_model(
            modelId=self.model_id,
            body=body,
            contentType="application/json",
            accept="application/json",
        )
        result = json.loads(response["body"].read())
        return result["embedding"]

    def check_drift(
        self,
        reference_pairs: list[dict],
    ) -> DriftReport:
        """
        Compare current embeddings with reference vectors.

        reference_pairs: [{"text": "...", "reference_vector": [...]}, ...]
        """
        drifts = []
        for pair in reference_pairs:
            current_vec = self._embed_text(pair["text"])
            ref_vec = pair["reference_vector"]
            drift = self._cosine_distance(current_vec, ref_vec)
            drifts.append(drift)

        mean_drift = sum(drifts) / len(drifts) if drifts else 0.0
        max_drift = max(drifts) if drifts else 0.0
        above_threshold = sum(
            1 for d in drifts if d > self.drift_threshold
        )
        pct_above = above_threshold / len(drifts) if drifts else 0.0

        detected = mean_drift > self.drift_threshold

        report = DriftReport(
            mean_cosine_drift=round(mean_drift, 6),
            max_cosine_drift=round(max_drift, 6),
            pct_above_threshold=round(pct_above, 4),
            drift_detected=detected,
            sample_size=len(drifts),
            threshold=self.drift_threshold,
        )

        if detected:
            logger.warning(
                "Embedding drift detected: mean=%.4f max=%.4f pct_above=%.1f%%",
                mean_drift, max_drift, pct_above * 100,
            )
        else:
            logger.info(
                "Embedding stable: mean_drift=%.4f threshold=%.4f",
                mean_drift, self.drift_threshold,
            )

        return report

9. Comparison Table: Aggregation Methods

Method Accuracy Gain Latency Impact Cost Impact Complexity When to Use
Majority Voting +5-8% Nx single model Nx single model Low Classification, yes/no decisions
Weighted Voting +8-12% Nx + weight lookup Nx + DynamoDB read Medium Intent routing, multi-class
Confidence Weighting +10-15% +50ms calibration +$0.001 Medium Cascade escalation decisions
Quality Scoring +15-20% +150ms judge call +$0.003 Haiku judge High Customer-facing responses
Response Merging +20-30% +200-400ms merge +$0.005-0.010 High Complex queries, recommendations
Cost-Aware Selection Cost neutral +10ms decision Saves 40-60% vs full ensemble Medium Budget-constrained production
Embedding Drift Detection Prevents degradation Offline (weekly) ~$0.50 per check Low Maintaining ensemble consistency

10. Anti-Patterns and Pitfalls

10.1 Common Ensemble Mistakes

Anti-Pattern Problem MangaAssist Solution
Ensembling identical models No diversity = no benefit Use Haiku + Sonnet + RAG (different architectures)
Equal weights for all models Ignores domain expertise Dynamic per-intent weights from DynamoDB
No cost ceiling Budget blowout at scale BudgetState with per-query and daily caps
Synchronous fan-out Latency = sum of all models asyncio.gather for parallel invocation
Ignoring calibration Raw confidence is misleading Isotonic regression calibration layer
Always ensembling Overkill for simple queries Complexity-based strategy selection
No drift monitoring Silent quality degradation Weekly embedding drift checks

10.2 Ensemble Decision Flowchart

graph TD
    START[New Query] --> CACHE{Semantic Cache<br/>Hit?}
    CACHE -->|Hit| SERVE[Serve Cached Response]
    CACHE -->|Miss| CLASS[Classify Complexity<br/>via Haiku]
    CLASS --> SIMPLE{Simple?}
    SIMPLE -->|Yes| HAIKU[Haiku Only]
    SIMPLE -->|No| MED{Medium?}
    MED -->|Yes| CASC[Cascade<br/>Haiku -> Sonnet if needed]
    MED -->|No| BUDGET{Budget<br/>Available?}
    BUDGET -->|Yes| FULL[Full Ensemble<br/>3-model parallel]
    BUDGET -->|No| PARTIAL[Budget Ensemble<br/>2-model]

    HAIKU --> QS[Quality Score Check]
    CASC --> QS
    FULL --> MERGE[Response Merge]
    PARTIAL --> MERGE
    MERGE --> QS
    QS --> |>= 0.7| STORE[Cache + Return]
    QS --> |< 0.7| REGEN[Regenerate or Escalate]

    style FULL fill:#264653,stroke:#2a9d8f,color:#fff
    style CASC fill:#2a9d8f,stroke:#264653,color:#fff
    style BUDGET fill:#e76f51,stroke:#f4a261,color:#fff

Key Takeaways

  1. Ensemble strategies must match query complexity — simple FAQ queries should never trigger a 3-model ensemble; use the complexity classifier to route appropriately.

  2. Confidence calibration is non-negotiable — raw self-reported confidence from Claude models is systematically overconfident; the isotonic regression calibration layer converts raw scores into actionable probabilities.

  3. Cost-aware aggregation is essential at scale — at 1M messages/day, the difference between Haiku-only ($300/day) and full ensemble ($18,000/day) is 60x; the budget-aware aggregator dynamically selects strategies to hit the $5,000/day target.

  4. Response merging adds measurable quality — refinement merging (Sonnet refines best response using supporting evidence) adds 200ms but improves quality scores by 25%, which is worth it for high-value recommendation queries.

  5. Weighted voting with EMA updates adapts to drift — static weights become stale as model behavior changes; exponential moving average updates from user feedback keep weights calibrated to current model performance.

  6. Embedding drift silently degrades ensemble quality — when Titan Embeddings vectors shift, the RAG component starts disagreeing with the classification models; weekly drift monitoring catches this before customers notice.

  7. Parallel invocation bounds latency — asyncio.gather ensures the ensemble latency is bounded by the slowest model (Sonnet at 650ms), not the sum of all models, keeping MangaAssist well under the 3-second target.