LOCAL PREVIEW View on GitHub

Prompt Chaining & Orchestration for MangaAssist

MangaAssist context: JP Manga store chatbot on AWS — Bedrock Claude 3 (Sonnet at $3/$15 per 1M tokens input/output, Haiku at $0.25/$1.25), OpenSearch Serverless (vector store), DynamoDB (sessions/products), ECS Fargate (orchestrator), API Gateway WebSocket, ElastiCache Redis. Target: useful answer in under 3 seconds, 1M messages/day scale.

Skill Mapping

Attribute Detail
Domain 2 — Implementation & Integration of GenAI Applications
Task 2.5 — Application Integration Patterns
Skill 2.5.5 — Advanced GenAI Applications
Focus Chain builder, parallel chains, conditional routing, map-reduce pattern
MangaAssist Scope Prompt chain orchestration for multi-step manga queries, response synthesis, and quality control

Mind Map

mindmap
  root((Prompt Chaining<br/>& Orchestration))
    Chain Types
      Sequential Chain
        Step-by-Step Processing
        Context Accumulation
        Early Exit on Failure
      Parallel Chain
        Independent Sub-Tasks
        Concurrent Execution
        Result Merging
      Conditional Chain
        Intent-Based Routing
        Confidence Thresholds
        Fallback Paths
      Map-Reduce Chain
        Split Input
        Parallel Map
        Reduce / Aggregate
    Chain Builder
      Fluent API
        Step Definition
        Connection Wiring
        Validation Rules
      Configuration
        Timeout per Step
        Retry Policy
        Model Selection
      Observability
        Step Tracing
        Token Tracking
        Latency Breakdown
    Orchestration
      State Management
        Immutable State Passing
        Checkpoint / Resume
        Error State Propagation
      Budget Management
        Time Budget Allocation
        Token Budget Tracking
        Cost Ceiling Enforcement
      Quality Control
        Output Validation
        Confidence Scoring
        Human-in-the-Loop Gates
    Production Patterns
      Cache-Augmented Chains
        Redis Prompt Cache
        Embedding Cache
        Response Cache
      Streaming Chains
        Chunk-by-Chunk Delivery
        Progressive Rendering
        Backpressure Handling
      Graceful Degradation
        Partial Response Delivery
        Model Fallback
        Static Fallback Content

Chain Architecture Overview

graph TB
    subgraph Input["Chain Input"]
        UM[User Message]
        CTX[Session Context<br/>from DynamoDB]
        PREF[User Preferences<br/>from Redis]
    end

    subgraph ChainTypes["Chain Patterns"]
        SC[Sequential Chain<br/>Step A → B → C]
        PC[Parallel Chain<br/>Step A ∥ B → Merge]
        CC[Conditional Chain<br/>If A then B else C]
        MR[Map-Reduce Chain<br/>Split → Map → Reduce]
    end

    subgraph Execution["Chain Execution Engine"]
        CE[Chain Executor]
        SM[State Manager]
        BM[Budget Manager]
        TK[Token Tracker]
    end

    subgraph Output["Chain Output"]
        RS[Response Synthesis]
        QC[Quality Check]
        WS[WebSocket Delivery]
    end

    UM --> CE
    CTX --> CE
    PREF --> CE
    CE --> SC
    CE --> PC
    CE --> CC
    CE --> MR
    SC --> SM
    PC --> SM
    CC --> SM
    MR --> SM
    SM --> BM
    BM --> TK
    SM --> RS
    RS --> QC
    QC --> WS

    style CE fill:#232f3e,color:#ff9900
    style SM fill:#339af0,color:#fff

Chain Builder — Fluent API

"""
Prompt Chain Builder for MangaAssist.
Fluent API for constructing sequential, parallel, conditional, and map-reduce chains.
"""

import asyncio
import copy
import json
import logging
import time
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Coroutine, Optional

logger = logging.getLogger(__name__)


class ChainStepType(Enum):
    PROMPT = "prompt"          # Call an FM model
    TOOL = "tool"              # Call an external tool (OpenSearch, DynamoDB)
    TRANSFORM = "transform"    # Pure function transformation
    CONDITIONAL = "conditional" # Branch based on condition
    PARALLEL = "parallel"      # Execute sub-steps concurrently
    MAP_REDUCE = "map_reduce"  # Split → parallel map → reduce


@dataclass
class ChainState:
    """Immutable state passed between chain steps."""
    chain_id: str
    step_index: int = 0
    data: dict[str, Any] = field(default_factory=dict)
    history: list[dict[str, Any]] = field(default_factory=list)
    tokens_used: int = 0
    cost_usd: float = 0.0
    elapsed_ms: float = 0.0
    errors: list[str] = field(default_factory=list)

    def with_update(self, **kwargs: Any) -> "ChainState":
        """Create a new state with updated fields (immutable pattern)."""
        new_state = copy.deepcopy(self)
        for key, value in kwargs.items():
            if key == "data" and isinstance(value, dict):
                new_state.data.update(value)
            elif hasattr(new_state, key):
                setattr(new_state, key, value)
        return new_state


@dataclass
class ChainStepConfig:
    """Configuration for a single chain step."""
    name: str
    step_type: ChainStepType
    model_id: str | None = None
    prompt_template: str | None = None
    tool_name: str | None = None
    handler: Callable | None = None
    timeout_ms: float = 2000.0
    max_retries: int = 1
    max_tokens: int = 512
    temperature: float = 0.3
    condition: Callable[[ChainState], bool] | None = None
    sub_steps: list["ChainStepConfig"] = field(default_factory=list)
    # Map-reduce specific
    splitter: Callable[[ChainState], list[Any]] | None = None
    reducer: Callable[[list[Any]], Any] | None = None


@dataclass
class ChainStepResult:
    """Result from executing a single chain step."""
    step_name: str
    output: Any
    tokens_in: int = 0
    tokens_out: int = 0
    cost_usd: float = 0.0
    latency_ms: float = 0.0
    model_id: str | None = None
    retries: int = 0
    error: str | None = None


class ChainBuilder:
    """
    Fluent API for building prompt chains in MangaAssist.

    Usage:
        chain = (
            ChainBuilder("manga_search")
            .add_prompt("classify_intent", model="haiku", template="...")
            .add_conditional("route", condition=lambda s: s.data["intent"])
            .add_parallel("enrich", sub_steps=[...])
            .add_prompt("synthesize", model="sonnet", template="...")
            .build()
        )
    """

    MODEL_ALIASES = {
        "sonnet": "anthropic.claude-3-sonnet-20240229-v1:0",
        "haiku": "anthropic.claude-3-haiku-20240307-v1:0",
    }

    def __init__(self, chain_name: str):
        self.chain_name = chain_name
        self.steps: list[ChainStepConfig] = []

    def _resolve_model(self, model: str) -> str:
        return self.MODEL_ALIASES.get(model, model)

    def add_prompt(
        self,
        name: str,
        template: str,
        model: str = "haiku",
        max_tokens: int = 512,
        temperature: float = 0.3,
        timeout_ms: float = 2000.0,
    ) -> "ChainBuilder":
        """Add an FM prompt step to the chain."""
        self.steps.append(ChainStepConfig(
            name=name,
            step_type=ChainStepType.PROMPT,
            model_id=self._resolve_model(model),
            prompt_template=template,
            max_tokens=max_tokens,
            temperature=temperature,
            timeout_ms=timeout_ms,
        ))
        return self

    def add_tool(
        self,
        name: str,
        tool_name: str,
        handler: Callable,
        timeout_ms: float = 1000.0,
    ) -> "ChainBuilder":
        """Add a tool invocation step (e.g., OpenSearch, DynamoDB)."""
        self.steps.append(ChainStepConfig(
            name=name,
            step_type=ChainStepType.TOOL,
            tool_name=tool_name,
            handler=handler,
            timeout_ms=timeout_ms,
        ))
        return self

    def add_transform(
        self,
        name: str,
        handler: Callable[[ChainState], Any],
    ) -> "ChainBuilder":
        """Add a pure transformation step (no FM call)."""
        self.steps.append(ChainStepConfig(
            name=name,
            step_type=ChainStepType.TRANSFORM,
            handler=handler,
        ))
        return self

    def add_conditional(
        self,
        name: str,
        condition: Callable[[ChainState], bool],
        if_true: ChainStepConfig | None = None,
        if_false: ChainStepConfig | None = None,
    ) -> "ChainBuilder":
        """Add a conditional branching step."""
        sub_steps = []
        if if_true:
            sub_steps.append(if_true)
        if if_false:
            sub_steps.append(if_false)
        self.steps.append(ChainStepConfig(
            name=name,
            step_type=ChainStepType.CONDITIONAL,
            condition=condition,
            sub_steps=sub_steps,
        ))
        return self

    def add_parallel(
        self,
        name: str,
        sub_steps: list[ChainStepConfig],
        timeout_ms: float = 2000.0,
    ) -> "ChainBuilder":
        """Add parallel execution of multiple steps."""
        self.steps.append(ChainStepConfig(
            name=name,
            step_type=ChainStepType.PARALLEL,
            sub_steps=sub_steps,
            timeout_ms=timeout_ms,
        ))
        return self

    def add_map_reduce(
        self,
        name: str,
        splitter: Callable[[ChainState], list[Any]],
        map_step: ChainStepConfig,
        reducer: Callable[[list[Any]], Any],
        timeout_ms: float = 3000.0,
    ) -> "ChainBuilder":
        """Add a map-reduce pattern step."""
        self.steps.append(ChainStepConfig(
            name=name,
            step_type=ChainStepType.MAP_REDUCE,
            splitter=splitter,
            reducer=reducer,
            sub_steps=[map_step],
            timeout_ms=timeout_ms,
        ))
        return self

    def build(self) -> "PromptChain":
        """Build and return the configured prompt chain."""
        return PromptChain(
            name=self.chain_name,
            steps=list(self.steps),
        )


class PromptChain:
    """A compiled prompt chain ready for execution."""

    def __init__(self, name: str, steps: list[ChainStepConfig]):
        self.name = name
        self.steps = steps
        self.chain_id = str(uuid.uuid4())[:8]

    def describe(self) -> dict[str, Any]:
        """Return a human-readable description of the chain."""
        return {
            "name": self.name,
            "chain_id": self.chain_id,
            "steps": [
                {
                    "name": s.name,
                    "type": s.step_type.value,
                    "model": s.model_id,
                    "timeout_ms": s.timeout_ms,
                }
                for s in self.steps
            ],
        }

Chain Execution Engine

graph TB
    subgraph Engine["Chain Execution Engine"]
        EX[Executor<br/>Step-by-step runner]
        SM[State Manager<br/>Immutable state updates]
        BM[Budget Manager<br/>Time & token tracking]
        ER[Error Handler<br/>Retry & fallback]
    end

    subgraph StepExecution["Step Execution"]
        PE[Prompt Executor<br/>Bedrock invoke_model]
        TE[Tool Executor<br/>OpenSearch / DynamoDB]
        XF[Transform Executor<br/>Pure functions]
        CE[Conditional Executor<br/>Branch selection]
        PX[Parallel Executor<br/>asyncio.gather]
        MX[Map-Reduce Executor<br/>Split / Map / Reduce]
    end

    EX --> PE
    EX --> TE
    EX --> XF
    EX --> CE
    EX --> PX
    EX --> MX

    PE --> SM
    TE --> SM
    XF --> SM
    CE --> SM
    PX --> SM
    MX --> SM

    SM --> BM
    BM -->|Over Budget| ER
    PE -->|Error| ER
    TE -->|Error| ER

    style EX fill:#232f3e,color:#ff9900
    style BM fill:#ff6b6b,color:#fff

Chain Executor Implementation

"""
Chain Execution Engine for MangaAssist prompt chains.
Handles all chain step types with budget management and error recovery.
"""


class ChainExecutor:
    """
    Executes prompt chains with budget management, error handling,
    and full observability.
    """

    def __init__(
        self,
        bedrock_client: Any,
        tool_registry: dict[str, Callable] | None = None,
        total_budget_ms: float = 2800.0,
        total_token_budget: int = 4096,
        cost_ceiling_usd: float = 0.05,
    ):
        self.bedrock_client = bedrock_client
        self.tool_registry = tool_registry or {}
        self.total_budget_ms = total_budget_ms
        self.total_token_budget = total_token_budget
        self.cost_ceiling_usd = cost_ceiling_usd
        self.step_results: list[ChainStepResult] = []

    async def execute(
        self, chain: PromptChain, initial_data: dict[str, Any]
    ) -> tuple[ChainState, list[ChainStepResult]]:
        """Execute a full prompt chain and return final state + results."""
        state = ChainState(
            chain_id=chain.chain_id,
            data=initial_data,
        )
        self.step_results.clear()
        chain_start = time.monotonic()

        for i, step_config in enumerate(chain.steps):
            state = state.with_update(step_index=i)
            elapsed = (time.monotonic() - chain_start) * 1000

            # Budget check
            if elapsed >= self.total_budget_ms:
                logger.warning(
                    "Chain '%s' budget exhausted at step %d/%d (%.0fms)",
                    chain.name, i, len(chain.steps), elapsed,
                )
                state.errors.append(f"Budget exhausted at step {i}")
                break

            if state.cost_usd >= self.cost_ceiling_usd:
                logger.warning(
                    "Chain '%s' cost ceiling reached: $%.4f",
                    chain.name, state.cost_usd,
                )
                state.errors.append(f"Cost ceiling ${self.cost_ceiling_usd} reached")
                break

            remaining_ms = self.total_budget_ms - elapsed
            try:
                result = await self._execute_step(
                    step_config, state, min(remaining_ms, step_config.timeout_ms)
                )
                self.step_results.append(result)

                # Update state with step output
                state = state.with_update(
                    data={step_config.name: result.output},
                    tokens_used=state.tokens_used + result.tokens_in + result.tokens_out,
                    cost_usd=state.cost_usd + result.cost_usd,
                    elapsed_ms=(time.monotonic() - chain_start) * 1000,
                    history=[*state.history, {
                        "step": step_config.name,
                        "type": step_config.step_type.value,
                        "latency_ms": result.latency_ms,
                    }],
                )

            except Exception as e:
                error_msg = f"Step '{step_config.name}' failed: {str(e)}"
                logger.error(error_msg)
                self.step_results.append(ChainStepResult(
                    step_name=step_config.name,
                    output=None,
                    error=str(e),
                ))
                state.errors.append(error_msg)
                # Continue chain — downstream steps can check for errors

        state = state.with_update(
            elapsed_ms=(time.monotonic() - chain_start) * 1000
        )
        return state, self.step_results

    async def _execute_step(
        self,
        config: ChainStepConfig,
        state: ChainState,
        timeout_ms: float,
    ) -> ChainStepResult:
        """Execute a single chain step based on its type."""
        start = time.monotonic()

        try:
            result = await asyncio.wait_for(
                self._dispatch_step(config, state),
                timeout=timeout_ms / 1000.0,
            )
        except asyncio.TimeoutError:
            return ChainStepResult(
                step_name=config.name,
                output=None,
                latency_ms=(time.monotonic() - start) * 1000,
                error=f"Step timed out after {timeout_ms}ms",
            )

        result.latency_ms = (time.monotonic() - start) * 1000
        return result

    async def _dispatch_step(
        self, config: ChainStepConfig, state: ChainState
    ) -> ChainStepResult:
        """Route step execution based on step type."""
        if config.step_type == ChainStepType.PROMPT:
            return await self._execute_prompt_step(config, state)
        elif config.step_type == ChainStepType.TOOL:
            return await self._execute_tool_step(config, state)
        elif config.step_type == ChainStepType.TRANSFORM:
            return await self._execute_transform_step(config, state)
        elif config.step_type == ChainStepType.CONDITIONAL:
            return await self._execute_conditional_step(config, state)
        elif config.step_type == ChainStepType.PARALLEL:
            return await self._execute_parallel_step(config, state)
        elif config.step_type == ChainStepType.MAP_REDUCE:
            return await self._execute_map_reduce_step(config, state)
        else:
            raise ValueError(f"Unknown step type: {config.step_type}")

    async def _execute_prompt_step(
        self, config: ChainStepConfig, state: ChainState
    ) -> ChainStepResult:
        """Execute a prompt step by calling Bedrock."""
        # Render template with current state data
        prompt = config.prompt_template or ""
        for key, value in state.data.items():
            placeholder = f"{{{{{key}}}}}"
            if placeholder in prompt:
                prompt = prompt.replace(placeholder, str(value))

        response = await self.bedrock_client.invoke(
            model_id=config.model_id,
            prompt=prompt,
            max_tokens=config.max_tokens,
        )

        usage = response.get("usage", {})
        tokens_in = usage.get("input_tokens", 0)
        tokens_out = usage.get("output_tokens", 0)

        return ChainStepResult(
            step_name=config.name,
            output=response.get("text", ""),
            tokens_in=tokens_in,
            tokens_out=tokens_out,
            cost_usd=response.get("cost_usd", 0.0),
            model_id=config.model_id,
        )

    async def _execute_tool_step(
        self, config: ChainStepConfig, state: ChainState
    ) -> ChainStepResult:
        """Execute a tool step (OpenSearch, DynamoDB, etc.)."""
        handler = config.handler or self.tool_registry.get(config.tool_name or "")
        if not handler:
            raise ValueError(f"No handler for tool: {config.tool_name}")

        result = await handler(state.data)
        return ChainStepResult(
            step_name=config.name,
            output=result,
        )

    async def _execute_transform_step(
        self, config: ChainStepConfig, state: ChainState
    ) -> ChainStepResult:
        """Execute a pure transformation step."""
        if config.handler is None:
            raise ValueError(f"No handler for transform: {config.name}")
        result = config.handler(state)
        return ChainStepResult(
            step_name=config.name,
            output=result,
        )

    async def _execute_conditional_step(
        self, config: ChainStepConfig, state: ChainState
    ) -> ChainStepResult:
        """Execute a conditional branching step."""
        if config.condition is None:
            raise ValueError(f"No condition for conditional step: {config.name}")

        branch = config.condition(state)
        if branch and len(config.sub_steps) > 0:
            return await self._dispatch_step(config.sub_steps[0], state)
        elif not branch and len(config.sub_steps) > 1:
            return await self._dispatch_step(config.sub_steps[1], state)
        else:
            return ChainStepResult(
                step_name=config.name,
                output=None,
                error="No matching branch",
            )

    async def _execute_parallel_step(
        self, config: ChainStepConfig, state: ChainState
    ) -> ChainStepResult:
        """Execute multiple sub-steps in parallel."""
        tasks = [
            self._dispatch_step(sub, state) for sub in config.sub_steps
        ]
        results = await asyncio.gather(*tasks, return_exceptions=True)

        outputs = []
        total_tokens_in = 0
        total_tokens_out = 0
        total_cost = 0.0

        for r in results:
            if isinstance(r, Exception):
                outputs.append({"error": str(r)})
            else:
                outputs.append(r.output)
                total_tokens_in += r.tokens_in
                total_tokens_out += r.tokens_out
                total_cost += r.cost_usd

        return ChainStepResult(
            step_name=config.name,
            output=outputs,
            tokens_in=total_tokens_in,
            tokens_out=total_tokens_out,
            cost_usd=total_cost,
        )

    async def _execute_map_reduce_step(
        self, config: ChainStepConfig, state: ChainState
    ) -> ChainStepResult:
        """Execute a map-reduce pattern step."""
        if not config.splitter or not config.reducer:
            raise ValueError("Map-reduce step requires splitter and reducer")

        # Split
        chunks = config.splitter(state)
        if not chunks:
            return ChainStepResult(step_name=config.name, output=None)

        # Map (parallel)
        map_step = config.sub_steps[0] if config.sub_steps else None
        if not map_step:
            raise ValueError("Map-reduce step requires a map sub-step")

        map_tasks = []
        for chunk in chunks:
            chunk_state = state.with_update(data={"chunk": chunk})
            map_tasks.append(self._dispatch_step(map_step, chunk_state))

        map_results = await asyncio.gather(*map_tasks, return_exceptions=True)
        map_outputs = [
            r.output if not isinstance(r, Exception) else None
            for r in map_results
        ]

        # Reduce
        reduced = config.reducer(map_outputs)

        total_tokens = sum(
            r.tokens_in + r.tokens_out
            for r in map_results if not isinstance(r, Exception)
        )
        total_cost = sum(
            r.cost_usd for r in map_results if not isinstance(r, Exception)
        )

        return ChainStepResult(
            step_name=config.name,
            output=reduced,
            tokens_in=total_tokens,
            cost_usd=total_cost,
        )

Production Chain Examples for MangaAssist

graph LR
    subgraph SearchChain["Manga Search Chain"]
        S1[Classify Intent<br/>Haiku 200ms]
        S2[Generate Embedding<br/>Bedrock Titan 100ms]
        S3[Vector Search<br/>OpenSearch 300ms]
        S4[Format Response<br/>Haiku 400ms]
        S1 --> S2 --> S3 --> S4
    end

    subgraph RecommendChain["Recommendation Chain"]
        R1[Fetch Preferences<br/>DynamoDB 50ms]
        R2[Fetch History<br/>DynamoDB 50ms]
        R3[Generate Recs<br/>Sonnet 1500ms]
        R4[Validate & Format<br/>Haiku 300ms]
        R1 --> R3
        R2 --> R3
        R3 --> R4
    end

    subgraph ComplexChain["Complex Multi-Intent Chain"]
        C1[Decompose Query<br/>Haiku 300ms]
        C2[Parallel: Search + Order<br/>800ms max]
        C3[Merge Results<br/>Transform 10ms]
        C4[Synthesize Response<br/>Haiku 400ms]
        C1 --> C2 --> C3 --> C4
    end

    style S1 fill:#51cf66,color:#fff
    style R3 fill:#339af0,color:#fff
    style C2 fill:#ffd43b,color:#333

MangaAssist Chain Definitions

"""
Production prompt chain definitions for MangaAssist chatbot.
Pre-built chains for common manga store operations.
"""


class MangaChains:
    """Factory for MangaAssist prompt chains."""

    @staticmethod
    def search_chain() -> PromptChain:
        """Build the manga search chain: classify → embed → search → format."""
        return (
            ChainBuilder("manga_search")
            .add_prompt(
                name="classify_intent",
                model="haiku",
                template=(
                    "Classify this manga store query. Return JSON with "
                    "'intent' (search/recommend/order/chat), 'language' (ja/en), "
                    "'search_terms' (array).\n\nQuery: {{user_message}}"
                ),
                max_tokens=150,
                timeout_ms=500,
            )
            .add_transform(
                name="extract_search_terms",
                handler=lambda state: _safe_json_parse(
                    state.data.get("classify_intent", "{}"),
                    default={"search_terms": [state.data.get("user_message", "")]},
                ),
            )
            .add_tool(
                name="vector_search",
                tool_name="opensearch",
                handler=_opensearch_search_handler,
                timeout_ms=500,
            )
            .add_prompt(
                name="format_response",
                model="haiku",
                template=(
                    "Format these manga search results into a helpful response. "
                    "Match the user's language. Include title (JP + romaji), "
                    "author, price, availability.\n\n"
                    "User query: {{user_message}}\n"
                    "Search results: {{vector_search}}\n"
                    "Respond naturally and suggest related titles."
                ),
                max_tokens=500,
                timeout_ms=800,
            )
            .build()
        )

    @staticmethod
    def recommendation_chain() -> PromptChain:
        """Build the recommendation chain with parallel preference fetching."""
        return (
            ChainBuilder("manga_recommendation")
            .add_parallel(
                name="fetch_user_data",
                sub_steps=[
                    ChainStepConfig(
                        name="fetch_preferences",
                        step_type=ChainStepType.TOOL,
                        tool_name="dynamodb_preferences",
                        handler=_fetch_preferences_handler,
                        timeout_ms=200,
                    ),
                    ChainStepConfig(
                        name="fetch_history",
                        step_type=ChainStepType.TOOL,
                        tool_name="dynamodb_history",
                        handler=_fetch_history_handler,
                        timeout_ms=200,
                    ),
                ],
                timeout_ms=300,
            )
            .add_prompt(
                name="generate_recommendations",
                model="sonnet",
                template=(
                    "You are a manga recommendation expert for a Japanese manga store. "
                    "Based on the user's preferences and history, suggest 3-5 titles.\n\n"
                    "User request: {{user_message}}\n"
                    "Preferences: {{fetch_user_data}}\n\n"
                    "For each recommendation include:\n"
                    "- Title in Japanese and English\n"
                    "- Author name\n"
                    "- Why they'd enjoy it\n"
                    "- Price and availability\n"
                    "Respond in the user's language."
                ),
                max_tokens=800,
                timeout_ms=2000,
            )
            .add_prompt(
                name="quality_check",
                model="haiku",
                template=(
                    "Check this manga recommendation response for quality. "
                    "Verify: (1) has 3+ titles, (2) includes Japanese names, "
                    "(3) no hallucinated prices, (4) matches user language.\n\n"
                    "Response: {{generate_recommendations}}\n\n"
                    "If quality is acceptable, return the response unchanged. "
                    "If not, fix the issues and return the corrected version."
                ),
                max_tokens=800,
                timeout_ms=500,
            )
            .build()
        )

    @staticmethod
    def complex_query_chain() -> PromptChain:
        """Build the complex multi-intent chain with parallel sub-tasks."""
        return (
            ChainBuilder("complex_query")
            .add_prompt(
                name="decompose",
                model="haiku",
                template=(
                    "Decompose this complex manga store query into sub-tasks. "
                    "Return JSON array of objects with 'task' and 'type' "
                    "(search/recommend/order/info).\n\n"
                    "Query: {{user_message}}"
                ),
                max_tokens=200,
                timeout_ms=500,
            )
            .add_transform(
                name="plan_parallel",
                handler=lambda state: _safe_json_parse(
                    state.data.get("decompose", "[]"),
                    default=[{"task": state.data.get("user_message", ""), "type": "chat"}],
                ),
            )
            .add_prompt(
                name="synthesize",
                model="haiku",
                template=(
                    "Synthesize a unified response from these sub-task results "
                    "for a manga store customer.\n\n"
                    "Original query: {{user_message}}\n"
                    "Sub-task results: {{plan_parallel}}\n\n"
                    "Provide a coherent, helpful response."
                ),
                max_tokens=600,
                timeout_ms=800,
            )
            .build()
        )

    @staticmethod
    def map_reduce_catalog_chain() -> PromptChain:
        """Map-reduce chain for processing large catalog queries."""
        return (
            ChainBuilder("catalog_analysis")
            .add_map_reduce(
                name="analyze_genres",
                splitter=lambda state: state.data.get("genres", []),
                map_step=ChainStepConfig(
                    name="analyze_single_genre",
                    step_type=ChainStepType.PROMPT,
                    model_id="anthropic.claude-3-haiku-20240307-v1:0",
                    prompt_template=(
                        "Summarize the top 3 manga in the '{{chunk}}' genre. "
                        "Include titles, authors, and sales rank."
                    ),
                    max_tokens=300,
                    timeout_ms=800,
                ),
                reducer=lambda results: "\n\n".join(
                    str(r) for r in results if r is not None
                ),
                timeout_ms=2000,
            )
            .add_prompt(
                name="final_summary",
                model="haiku",
                template=(
                    "Create a comprehensive catalog summary from these genre analyses:\n\n"
                    "{{analyze_genres}}\n\n"
                    "Highlight bestsellers and new releases."
                ),
                max_tokens=500,
                timeout_ms=800,
            )
            .build()
        )


# --- Helper functions ---

def _safe_json_parse(text: str, default: Any = None) -> Any:
    """Safely parse JSON from FM output, returning default on failure."""
    try:
        # Handle markdown code blocks
        if "```json" in text:
            text = text.split("```json")[1].split("```")[0]
        elif "```" in text:
            text = text.split("```")[1].split("```")[0]
        return json.loads(text.strip())
    except (json.JSONDecodeError, IndexError):
        logger.warning("Failed to parse JSON from FM output: %s...", text[:100])
        return default


async def _opensearch_search_handler(data: dict[str, Any]) -> list[dict]:
    """OpenSearch search handler (production would use real client)."""
    return [{"title": "Search result", "source": "opensearch"}]


async def _fetch_preferences_handler(data: dict[str, Any]) -> dict:
    """DynamoDB preferences handler."""
    return {"genres": ["action", "romance"], "language": "ja"}


async def _fetch_history_handler(data: dict[str, Any]) -> list[dict]:
    """DynamoDB purchase history handler."""
    return [{"title": "鬼滅の刃", "volume": 1}]

Chain Observability

graph TB
    subgraph Traces["Chain Execution Traces"]
        T1[Step 1: classify_intent<br/>Haiku | 180ms | 50 tokens]
        T2[Step 2: vector_search<br/>OpenSearch | 250ms]
        T3[Step 3: format_response<br/>Haiku | 350ms | 120 tokens]
        T1 --> T2 --> T3
    end

    subgraph Metrics["Chain Metrics"]
        TL[Total Latency<br/>780ms]
        TK[Total Tokens<br/>170]
        CT[Total Cost<br/>$0.000043]
        SR[Success Rate<br/>99.2%]
    end

    subgraph Alerts["Chain Alerts"]
        LA[Latency Alert<br/>P95 > 2500ms]
        CA[Cost Alert<br/>> $0.02 per chain]
        EA[Error Alert<br/>Step failure rate > 5%]
    end

    T3 --> TL
    T3 --> TK
    T3 --> CT
    TL --> LA
    CT --> CA
    T3 --> SR
    SR --> EA

    style TL fill:#51cf66,color:#fff
    style LA fill:#ff6b6b,color:#fff

Chain Metrics Collector

"""
Observability and metrics collection for MangaAssist prompt chains.
Tracks latency, tokens, cost, and errors per chain and step.
"""

import statistics
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any


@dataclass
class ChainMetrics:
    """Aggregated metrics for a prompt chain."""
    chain_name: str
    execution_count: int = 0
    success_count: int = 0
    error_count: int = 0
    latencies_ms: list[float] = field(default_factory=list)
    token_counts: list[int] = field(default_factory=list)
    costs_usd: list[float] = field(default_factory=list)
    step_metrics: dict[str, dict[str, Any]] = field(default_factory=dict)


class ChainMetricsCollector:
    """Collects and aggregates metrics across chain executions."""

    def __init__(self):
        self._metrics: dict[str, ChainMetrics] = {}

    def record_execution(
        self,
        chain_name: str,
        state: ChainState,
        results: list[ChainStepResult],
    ) -> None:
        """Record metrics from a chain execution."""
        if chain_name not in self._metrics:
            self._metrics[chain_name] = ChainMetrics(chain_name=chain_name)

        metrics = self._metrics[chain_name]
        metrics.execution_count += 1

        has_error = any(r.error for r in results)
        if has_error:
            metrics.error_count += 1
        else:
            metrics.success_count += 1

        metrics.latencies_ms.append(state.elapsed_ms)
        metrics.token_counts.append(state.tokens_used)
        metrics.costs_usd.append(state.cost_usd)

        # Per-step metrics
        for result in results:
            if result.step_name not in metrics.step_metrics:
                metrics.step_metrics[result.step_name] = {
                    "latencies": [],
                    "errors": 0,
                    "tokens": [],
                }
            step_m = metrics.step_metrics[result.step_name]
            step_m["latencies"].append(result.latency_ms)
            step_m["tokens"].append(result.tokens_in + result.tokens_out)
            if result.error:
                step_m["errors"] += 1

    def get_summary(self, chain_name: str) -> dict[str, Any] | None:
        """Get aggregated metrics summary for a chain."""
        metrics = self._metrics.get(chain_name)
        if not metrics or metrics.execution_count == 0:
            return None

        sorted_lat = sorted(metrics.latencies_ms)
        p95_idx = min(int(len(sorted_lat) * 0.95), len(sorted_lat) - 1)

        return {
            "chain_name": chain_name,
            "executions": metrics.execution_count,
            "success_rate": round(
                metrics.success_count / metrics.execution_count, 3
            ),
            "latency": {
                "avg_ms": round(statistics.mean(metrics.latencies_ms), 2),
                "p50_ms": round(statistics.median(metrics.latencies_ms), 2),
                "p95_ms": round(sorted_lat[p95_idx], 2),
                "max_ms": round(max(metrics.latencies_ms), 2),
            },
            "tokens": {
                "avg_per_chain": round(statistics.mean(metrics.token_counts)),
                "total": sum(metrics.token_counts),
            },
            "cost": {
                "avg_per_chain_usd": round(
                    statistics.mean(metrics.costs_usd), 6
                ),
                "total_usd": round(sum(metrics.costs_usd), 4),
                "projected_daily_usd": round(
                    statistics.mean(metrics.costs_usd) * 1_000_000, 2
                ),
            },
            "step_breakdown": {
                name: {
                    "avg_latency_ms": round(statistics.mean(m["latencies"]), 2) if m["latencies"] else 0,
                    "error_rate": round(m["errors"] / max(len(m["latencies"]), 1), 3),
                    "avg_tokens": round(statistics.mean(m["tokens"])) if m["tokens"] else 0,
                }
                for name, m in metrics.step_metrics.items()
            },
        }

    def get_all_summaries(self) -> list[dict[str, Any]]:
        """Get summaries for all tracked chains."""
        return [
            s for name in self._metrics
            if (s := self.get_summary(name)) is not None
        ]

Key Takeaways

# Takeaway MangaAssist Application
1 The Chain Builder fluent API makes complex prompt orchestration readable and maintainable. Each step has explicit type, model, and timeout. MangaAssist defines 4 standard chains (search, recommend, complex, map-reduce) as reusable building blocks.
2 Sequential chains are the default for simple queries (classify → search → format), keeping latency predictable. The search chain completes in ~780ms total: 180ms classify + 250ms OpenSearch + 350ms format.
3 Parallel chains cut latency for independent sub-tasks — fetch preferences and history concurrently. The recommendation chain fetches user data in parallel (300ms total vs. 400ms sequential), then uses Sonnet for generation.
4 Conditional chains enable intent-based routing without wasting tokens on unnecessary steps. If intent is "greeting", skip OpenSearch and go directly to Haiku — saving 500ms and 50+ tokens.
5 Map-reduce chains handle catalog-wide operations by splitting across genres and reducing to a summary. Analyzing top manga across 5 genres runs 5 parallel Haiku calls then merges, staying within the 3-second budget.
6 Budget management (time, tokens, cost) must be enforced per-step and per-chain to prevent runaway costs. The chain executor stops execution if elapsed time exceeds 2800ms or cost exceeds $0.05 per chain.
7 Chain observability provides per-step latency and token breakdowns, enabling targeted optimization. Step-level metrics reveal that vector_search is the bottleneck at P95 — leading to OpenSearch index tuning.