LOCAL PREVIEW View on GitHub

07: POC Implementations

AIP-C01 Mapping

Task 5.2 → All Skills (5.2.1–5.2.5): Five proof-of-concept implementations demonstrating key troubleshooting capabilities. Each POC targets one skill area with deployable code.


POC 1: Context Window Overflow Detector (Skill 5.2.1)

Objective

Detect and alert when prompt assembly approaches or exceeds the token budget, before the FM silently truncates content.

Architecture

flowchart LR
    A[Chat Request] --> B[Orchestrator]
    B --> C[Token Budget<br>Calculator]
    C --> D{Budget<br>Status}
    D -->|Under 80%| E[Normal path]
    D -->|80-95%| F[Compress +<br>log warning]
    D -->|Over 95%| G[Emergency trim +<br>emit alarm]

    F --> H[CloudWatch Metric:<br>BudgetUtilization]
    G --> H
    H --> I[CloudWatch Alarm:<br>HighBudgetUtilization]

Full Working Code

"""
POC 1: Context Window Overflow Detector
Deployable as a Lambda or ECS sidecar.
Monitors token budget utilization and emits CloudWatch metrics.
"""

import json
import time
import logging
import boto3
from dataclasses import dataclass
from typing import Optional

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

cloudwatch = boto3.client("cloudwatch", region_name="ap-northeast-1")

MODEL_TOKEN_LIMITS = {
    "anthropic.claude-3-5-sonnet-20241022-v2:0": 200_000,
    "anthropic.claude-3-haiku-20240307-v1:0": 200_000,
}
PRACTICAL_LIMIT_RATIO = 0.025  # Use 2.5% of model limit as practical budget (5000 tokens)


@dataclass
class BudgetCheckResult:
    total_tokens: int
    budget_limit: int
    utilization_pct: float
    status: str  # "normal", "warning", "critical"
    largest_section: str
    largest_section_tokens: int
    recommendation: Optional[str] = None


def estimate_tokens(text: str) -> int:
    """Character-based estimation — 4 chars per token for English, 2 for Japanese."""
    if not text:
        return 0
    # Detect Japanese content
    jp_chars = sum(1 for c in text if '\u3040' <= c <= '\u9fff')
    jp_ratio = jp_chars / len(text) if text else 0
    chars_per_token = 2 if jp_ratio > 0.3 else 4
    return len(text) // chars_per_token


def check_budget(
    system_prompt: str,
    conversation_history: list,
    rag_context: str,
    user_message: str,
    model_id: str = "anthropic.claude-3-5-sonnet-20241022-v2:0",
) -> BudgetCheckResult:
    """Check token budget utilization for a prompt assembly."""

    model_limit = MODEL_TOKEN_LIMITS.get(model_id, 200_000)
    practical_limit = int(model_limit * PRACTICAL_LIMIT_RATIO)

    sections = {
        "system_prompt": estimate_tokens(system_prompt),
        "conversation_history": sum(estimate_tokens(msg.get("content", "")) for msg in conversation_history),
        "rag_context": estimate_tokens(rag_context),
        "user_message": estimate_tokens(user_message),
        "output_reserve": 1500,  # Reserved for response
    }

    total = sum(sections.values())
    utilization = total / practical_limit if practical_limit > 0 else 1.0

    # Find largest section
    content_sections = {k: v for k, v in sections.items() if k != "output_reserve"}
    largest = max(content_sections, key=content_sections.get)

    if utilization > 0.95:
        status = "critical"
        recommendation = (
            f"CRITICAL: {utilization:.0%} of budget used. "
            f"Largest section '{largest}' ({content_sections[largest]} tokens). "
            "Compress conversation history or reduce RAG context."
        )
    elif utilization > 0.80:
        status = "warning"
        recommendation = (
            f"WARNING: {utilization:.0%} of budget used. "
            f"Consider compressing '{largest}' section."
        )
    else:
        status = "normal"
        recommendation = None

    result = BudgetCheckResult(
        total_tokens=total,
        budget_limit=practical_limit,
        utilization_pct=round(utilization * 100, 1),
        status=status,
        largest_section=largest,
        largest_section_tokens=content_sections[largest],
        recommendation=recommendation,
    )

    # Emit CloudWatch metric
    cloudwatch.put_metric_data(
        Namespace="MangaAssist/ContentHandling",
        MetricData=[
            {
                "MetricName": "BudgetUtilization",
                "Value": result.utilization_pct,
                "Unit": "Percent",
                "Dimensions": [
                    {"Name": "Status", "Value": status},
                    {"Name": "LargestSection", "Value": largest},
                ],
            },
        ],
    )

    if status != "normal":
        logger.warning(json.dumps({
            "log_type": "budget_check",
            "status": status,
            "utilization_pct": result.utilization_pct,
            "total_tokens": total,
            "budget_limit": practical_limit,
            "sections": sections,
            "recommendation": recommendation,
        }))

    return result


# — Lambda handler for testing —
def lambda_handler(event, context):
    result = check_budget(
        system_prompt=event.get("system_prompt", ""),
        conversation_history=event.get("conversation_history", []),
        rag_context=event.get("rag_context", ""),
        user_message=event.get("user_message", ""),
        model_id=event.get("model_id", "anthropic.claude-3-5-sonnet-20241022-v2:0"),
    )
    return {
        "statusCode": 200,
        "body": json.dumps({
            "status": result.status,
            "utilization_pct": result.utilization_pct,
            "total_tokens": result.total_tokens,
            "budget_limit": result.budget_limit,
            "recommendation": result.recommendation,
        }),
    }

Deployment Steps

  1. Package as Lambda with boto3 (already in Lambda runtime)
  2. Add IAM permission: cloudwatch:PutMetricData
  3. Create CloudWatch alarm: BudgetUtilization > 95 for 1 datapoint → SNS topic → PagerDuty
  4. Integrate into orchestrator: call check_budget() before every FM invocation

Expected Output

{
  "status": "warning",
  "utilization_pct": 87.3,
  "total_tokens": 4365,
  "budget_limit": 5000,
  "recommendation": "WARNING: 87% of budget used. Consider compressing 'conversation_history' section."
}

POC 2: Bedrock API Health Dashboard (Skill 5.2.2)

Objective

Real-time dashboard showing Bedrock API health: error rates, latency percentiles, throttle rates, and circuit breaker state.

Architecture

flowchart TD
    A[Bedrock Calls] --> B[Structured Logger]
    B --> C[CloudWatch Logs]
    C --> D[Logs Insights<br>Queries]
    D --> E[CloudWatch Dashboard]

    B --> F[CloudWatch Metrics]
    F --> E

    F --> G[CloudWatch Alarms]
    G --> H[SNS → PagerDuty]

    subgraph Dashboard Widgets
        I[Error Rate<br>Line Chart]
        J[Latency P50/P95/P99<br>Line Chart]
        K[Throttle Count<br>Bar Chart]
        L[Circuit Breaker State<br>Status Widget]
        M[Model Tier Usage<br>Pie Chart]
    end

Full Working Code — Metric Emitter

"""
POC 2: Bedrock API Health Metrics Emitter
Wraps Bedrock calls and emits standardized CloudWatch metrics.
"""

import json
import time
import logging
import boto3
from dataclasses import dataclass, field
from typing import Optional

logger = logging.getLogger(__name__)

cloudwatch = boto3.client("cloudwatch", region_name="ap-northeast-1")
NAMESPACE = "MangaAssist/BedrockHealth"


@dataclass
class BedrockHealthMetrics:
    """Metrics collected per Bedrock call."""
    model_id: str
    intent: str
    latency_ms: float
    status: str  # success, error, throttled, timeout
    error_code: Optional[str] = None
    input_tokens: int = 0
    output_tokens: int = 0
    retry_count: int = 0
    circuit_state: str = "closed"


def emit_health_metrics(metrics: BedrockHealthMetrics):
    """Emit a batch of CloudWatch metrics for one Bedrock call."""

    dimensions = [
        {"Name": "ModelId", "Value": metrics.model_id.split("/")[-1]},
        {"Name": "Intent", "Value": metrics.intent},
    ]

    metric_data = [
        {
            "MetricName": "InvocationLatency",
            "Value": metrics.latency_ms,
            "Unit": "Milliseconds",
            "Dimensions": dimensions,
        },
        {
            "MetricName": "InvocationCount",
            "Value": 1,
            "Unit": "Count",
            "Dimensions": dimensions,
        },
        {
            "MetricName": "InputTokens",
            "Value": metrics.input_tokens,
            "Unit": "Count",
            "Dimensions": dimensions,
        },
        {
            "MetricName": "OutputTokens",
            "Value": metrics.output_tokens,
            "Unit": "Count",
            "Dimensions": dimensions,
        },
    ]

    # Status-specific metrics
    if metrics.status == "success":
        metric_data.append({
            "MetricName": "SuccessCount",
            "Value": 1, "Unit": "Count",
            "Dimensions": dimensions,
        })
    elif metrics.status == "throttled":
        metric_data.append({
            "MetricName": "ThrottleCount",
            "Value": 1, "Unit": "Count",
            "Dimensions": dimensions,
        })
    elif metrics.status in ("error", "timeout"):
        metric_data.append({
            "MetricName": "ErrorCount",
            "Value": 1, "Unit": "Count",
            "Dimensions": [*dimensions, {"Name": "ErrorCode", "Value": metrics.error_code or "Unknown"}],
        })

    if metrics.retry_count > 0:
        metric_data.append({
            "MetricName": "RetryCount",
            "Value": metrics.retry_count,
            "Unit": "Count",
            "Dimensions": dimensions,
        })

    # Circuit breaker state (as a number for graphing: 0=closed, 1=half_open, 2=open)
    state_values = {"closed": 0, "half_open": 1, "open": 2}
    metric_data.append({
        "MetricName": "CircuitBreakerState",
        "Value": state_values.get(metrics.circuit_state, 0),
        "Unit": "None",
        "Dimensions": [{"Name": "ModelId", "Value": metrics.model_id.split("/")[-1]}],
    })

    cloudwatch.put_metric_data(Namespace=NAMESPACE, MetricData=metric_data)


# — CloudWatch Dashboard Definition (JSON) —
DASHBOARD_BODY = {
    "widgets": [
        {
            "type": "metric",
            "properties": {
                "title": "Bedrock Error Rate (%)",
                "metrics": [
                    [NAMESPACE, "ErrorCount", {"stat": "Sum", "period": 60}],
                    [NAMESPACE, "InvocationCount", {"stat": "Sum", "period": 60}],
                ],
                "view": "timeSeries",
                "period": 60,
            },
        },
        {
            "type": "metric",
            "properties": {
                "title": "Latency (P50 / P95 / P99)",
                "metrics": [
                    [NAMESPACE, "InvocationLatency", {"stat": "p50"}],
                    [NAMESPACE, "InvocationLatency", {"stat": "p95"}],
                    [NAMESPACE, "InvocationLatency", {"stat": "p99"}],
                ],
                "view": "timeSeries",
                "period": 300,
            },
        },
        {
            "type": "metric",
            "properties": {
                "title": "Throttle Events",
                "metrics": [[NAMESPACE, "ThrottleCount", {"stat": "Sum"}]],
                "view": "bar",
                "period": 300,
            },
        },
        {
            "type": "metric",
            "properties": {
                "title": "Circuit Breaker State",
                "metrics": [[NAMESPACE, "CircuitBreakerState"]],
                "view": "timeSeries",
                "period": 60,
                "annotations": {
                    "horizontal": [
                        {"value": 2, "label": "OPEN", "color": "#d62728"},
                        {"value": 1, "label": "HALF_OPEN", "color": "#ff7f0e"},
                    ]
                },
            },
        },
    ],
}


def create_dashboard():
    """Create the Bedrock Health dashboard in CloudWatch."""
    cw = boto3.client("cloudwatch", region_name="ap-northeast-1")
    cw.put_dashboard(
        DashboardName="MangaAssist-BedrockHealth",
        DashboardBody=json.dumps(DASHBOARD_BODY),
    )
    logger.info("Dashboard 'MangaAssist-BedrockHealth' created")

Deployment Steps

  1. Integrate emit_health_metrics() into BedrockClientWrapper after each call
  2. Run create_dashboard() once to provision the CloudWatch dashboard
  3. Create alarms: ErrorCount/InvocationCount > 5% → PagerDuty, ThrottleCount > 10/min → Slack

POC 3: Prompt A/B Testing Framework (Skill 5.2.3)

Objective

Run controlled A/B tests comparing two prompt versions on live traffic with statistical significance testing.

Architecture

flowchart TD
    A[Chat Request] --> B{A/B Assignment<br>Hash session_id}
    B -->|Group A 50%| C[Prompt v1<br>Current]
    B -->|Group B 50%| D[Prompt v2<br>Candidate]

    C --> E[FM Response A]
    D --> F[FM Response B]

    E --> G[Score + Log]
    F --> G

    G --> H[DynamoDB:<br>Experiment Results]
    H --> I[Analysis Lambda<br>nightly]
    I --> J{Significant<br>Result?}
    J -->|p < 0.05| K[Promote winner /<br>reject loser]
    J -->|p ≥ 0.05| L[Continue experiment<br>or extend]

Full Working Code

"""
POC 3: Prompt A/B Testing Framework
Session-based deterministic assignment with statistical significance testing.
"""

import hashlib
import json
import time
import math
import logging
from dataclasses import dataclass, field
from typing import Optional

logger = logging.getLogger(__name__)


@dataclass
class Experiment:
    experiment_id: str
    prompt_a_version: str
    prompt_b_version: str
    traffic_split: float = 0.5  # Fraction going to B
    status: str = "running"     # running, concluded, cancelled
    created_at: float = field(default_factory=time.time)

    # Results — populated during analysis
    sample_a: int = 0
    sample_b: int = 0
    mean_score_a: float = 0.0
    mean_score_b: float = 0.0
    p_value: Optional[float] = None
    winner: Optional[str] = None


@dataclass
class ABObservation:
    session_id: str
    experiment_id: str
    group: str  # "A" or "B"
    prompt_version: str
    quality_score: float  # 0.0 to 1.0 composite score
    latency_ms: float
    timestamp: float = field(default_factory=time.time)


class PromptABTester:
    """Deterministic A/B testing for prompt versions.

    Assignment is hash-based on session_id for consistency:
    same user always sees the same version within one experiment.
    """

    def __init__(self):
        self.experiments: dict = {}
        self.observations: list = []

    def create_experiment(
        self,
        experiment_id: str,
        prompt_a_version: str,
        prompt_b_version: str,
        traffic_split: float = 0.5,
    ) -> Experiment:
        exp = Experiment(
            experiment_id=experiment_id,
            prompt_a_version=prompt_a_version,
            prompt_b_version=prompt_b_version,
            traffic_split=traffic_split,
        )
        self.experiments[experiment_id] = exp
        return exp

    def assign_group(self, experiment_id: str, session_id: str) -> str:
        """Deterministically assign a session to group A or B."""
        exp = self.experiments.get(experiment_id)
        if not exp or exp.status != "running":
            return "A"  # Default to control

        hash_input = f"{experiment_id}:{session_id}"
        hash_value = int(hashlib.sha256(hash_input.encode()).hexdigest()[:8], 16)
        normalized = hash_value / 0xFFFFFFFF

        return "B" if normalized < exp.traffic_split else "A"

    def get_prompt_version(self, experiment_id: str, session_id: str) -> str:
        """Get the prompt version for this session."""
        group = self.assign_group(experiment_id, session_id)
        exp = self.experiments[experiment_id]
        return exp.prompt_b_version if group == "B" else exp.prompt_a_version

    def record_observation(self, observation: ABObservation):
        """Record an A/B test observation."""
        self.observations.append(observation)

    def analyze(self, experiment_id: str) -> dict:
        """Analyze experiment results with Welch's t-test."""
        exp = self.experiments.get(experiment_id)
        if not exp:
            return {"error": "Experiment not found"}

        obs_a = [o for o in self.observations if o.experiment_id == experiment_id and o.group == "A"]
        obs_b = [o for o in self.observations if o.experiment_id == experiment_id and o.group == "B"]

        if len(obs_a) < 30 or len(obs_b) < 30:
            return {
                "status": "insufficient_data",
                "sample_a": len(obs_a),
                "sample_b": len(obs_b),
                "minimum_required": 30,
            }

        scores_a = [o.quality_score for o in obs_a]
        scores_b = [o.quality_score for o in obs_b]

        mean_a = sum(scores_a) / len(scores_a)
        mean_b = sum(scores_b) / len(scores_b)
        var_a = sum((x - mean_a) ** 2 for x in scores_a) / (len(scores_a) - 1)
        var_b = sum((x - mean_b) ** 2 for x in scores_b) / (len(scores_b) - 1)

        # Welch's t-test
        se = math.sqrt(var_a / len(scores_a) + var_b / len(scores_b))
        if se == 0:
            return {"status": "no_variance", "mean_a": mean_a, "mean_b": mean_b}

        t_stat = (mean_b - mean_a) / se

        # Welch-Satterthwaite degrees of freedom
        num = (var_a / len(scores_a) + var_b / len(scores_b)) ** 2
        denom = (
            (var_a / len(scores_a)) ** 2 / (len(scores_a) - 1) +
            (var_b / len(scores_b)) ** 2 / (len(scores_b) - 1)
        )
        df = num / denom if denom > 0 else 1

        # Approximate p-value using normal distribution for large samples
        p_value = 2 * (1 - self._normal_cdf(abs(t_stat)))

        # Update experiment
        exp.sample_a = len(obs_a)
        exp.sample_b = len(obs_b)
        exp.mean_score_a = round(mean_a, 4)
        exp.mean_score_b = round(mean_b, 4)
        exp.p_value = round(p_value, 6)

        if p_value < 0.05:
            exp.winner = "B" if mean_b > mean_a else "A"
        else:
            exp.winner = None

        return {
            "experiment_id": experiment_id,
            "sample_a": len(obs_a),
            "sample_b": len(obs_b),
            "mean_score_a": exp.mean_score_a,
            "mean_score_b": exp.mean_score_b,
            "score_delta": round(mean_b - mean_a, 4),
            "t_statistic": round(t_stat, 3),
            "degrees_of_freedom": round(df, 1),
            "p_value": exp.p_value,
            "significant": p_value < 0.05,
            "winner": exp.winner,
            "recommendation": self._recommendation(exp),
        }

    def _recommendation(self, exp: Experiment) -> str:
        if exp.p_value is None:
            return "Insufficient data — continue experiment"
        if exp.p_value >= 0.05:
            return "Not significant — continue experiment or declare tie"
        if exp.winner == "B":
            return f"PROMOTE version {exp.prompt_b_version} (p={exp.p_value})"
        return f"KEEP version {exp.prompt_a_version} (p={exp.p_value})"

    @staticmethod
    def _normal_cdf(x: float) -> float:
        """Approximate standard normal CDF using Abramowitz and Stegun formula."""
        t = 1.0 / (1.0 + 0.2316419 * abs(x))
        d = 0.3989422804014327  # 1/sqrt(2*pi)
        p = d * math.exp(-x * x / 2.0) * (
            t * (0.319381530 + t * (-0.356563782 + t * (1.781477937 + t * (-1.821255978 + t * 1.330274429))))
        )
        return 1.0 - p if x > 0 else p

Deployment Steps

  1. Deploy experiment configuration to DynamoDB (experiment table)
  2. Integrate assign_group() and get_prompt_version() into orchestrator
  3. Integrate record_observation() after response scoring
  4. Schedule nightly Lambda to run analyze() and emit results to CloudWatch + Slack

POC 4: Embedding Drift Monitor (Skill 5.2.4)

Objective

Detect when the vector index is drifting from the current embedding model, indicating documents need re-embedding.

Architecture

flowchart LR
    A[CloudWatch<br>EventBridge] -->|Daily 3 AM| B[Drift Monitor<br>Lambda]
    B --> C[OpenSearch:<br>Sample 200 docs]
    B --> D[Bedrock Embeddings:<br>Re-embed samples]
    B --> E[Compare:<br>cosine distance]
    E --> F[CloudWatch Metrics]
    F --> G[Alarm: drift > 0.15]
    G --> H[SNS → Ops team]

Full Working Code

"""
POC 4: Embedding Drift Monitor Lambda
Scheduled daily to detect embedding drift in OpenSearch index.
"""

import json
import math
import time
import logging
import boto3
from opensearchpy import OpenSearch, RequestsHttpConnection
from requests_aws4auth import AWS4Auth

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

# Configuration
REGION = "ap-northeast-1"
OS_ENDPOINT = "your-opensearch-endpoint.aoss.amazonaws.com"
INDEX_NAME = "mangaassist-products"
EMBEDDING_MODEL_ID = "amazon.titan-embed-text-v2:0"
SAMPLE_SIZE = 200
DRIFT_THRESHOLD = 0.15

# Clients
bedrock = boto3.client("bedrock-runtime", region_name=REGION)
cloudwatch = boto3.client("cloudwatch", region_name=REGION)

credentials = boto3.Session().get_credentials()
awsauth = AWS4Auth(
    credentials.access_key, credentials.secret_key,
    REGION, "aoss", session_token=credentials.token,
)

os_client = OpenSearch(
    hosts=[{"host": OS_ENDPOINT, "port": 443}],
    http_auth=awsauth,
    use_ssl=True,
    verify_certs=True,
    connection_class=RequestsHttpConnection,
)


def embed_text(text: str) -> list:
    """Get fresh embedding from Bedrock Titan Embed."""
    body = json.dumps({"inputText": text[:2000]})
    response = bedrock.invoke_model(
        modelId=EMBEDDING_MODEL_ID,
        contentType="application/json",
        accept="application/json",
        body=body,
    )
    result = json.loads(response["body"].read())
    return result.get("embedding", [])


def cosine_distance(a: list, b: list) -> float:
    """Cosine distance: 0 = identical, 2 = opposite."""
    if len(a) != len(b):
        return 1.0  # Dimension mismatch
    dot = sum(x * y for x, y in zip(a, b))
    norm_a = math.sqrt(sum(x * x for x in a))
    norm_b = math.sqrt(sum(x * x for x in b))
    if norm_a == 0 or norm_b == 0:
        return 1.0
    return 1.0 - dot / (norm_a * norm_b)


def sample_documents():
    """Random sample from OpenSearch."""
    body = {
        "size": SAMPLE_SIZE,
        "query": {
            "function_score": {
                "query": {"match_all": {}},
                "random_score": {"seed": int(time.time())},
            }
        },
        "_source": ["text", "embedding", "doc_id"],
    }
    response = os_client.search(index=INDEX_NAME, body=body)
    return [
        {
            "doc_id": hit["_source"].get("doc_id", hit["_id"]),
            "text": hit["_source"].get("text", ""),
            "stored_embedding": hit["_source"].get("embedding", []),
        }
        for hit in response["hits"]["hits"]
    ]


def lambda_handler(event, context):
    """Main drift check handler."""

    docs = sample_documents()
    if not docs:
        logger.warning("No documents sampled")
        return {"statusCode": 200, "body": "No documents to check"}

    distances = []
    drifted_ids = []

    for doc in docs:
        if not doc["text"] or not doc["stored_embedding"]:
            continue

        fresh = embed_text(doc["text"])
        if not fresh:
            continue

        dist = cosine_distance(doc["stored_embedding"], fresh)
        distances.append(dist)

        if dist > DRIFT_THRESHOLD:
            drifted_ids.append(doc["doc_id"])

    if not distances:
        return {"statusCode": 200, "body": "No valid comparisons"}

    distances.sort()
    mean_drift = sum(distances) / len(distances)
    p95_drift = distances[int(len(distances) * 0.95)]
    max_drift = distances[-1]

    # Emit metrics
    cloudwatch.put_metric_data(
        Namespace="MangaAssist/RetrievalHealth",
        MetricData=[
            {"MetricName": "EmbeddingDriftMean", "Value": mean_drift, "Unit": "None"},
            {"MetricName": "EmbeddingDriftP95", "Value": p95_drift, "Unit": "None"},
            {"MetricName": "EmbeddingDriftMax", "Value": max_drift, "Unit": "None"},
            {"MetricName": "DriftedDocumentCount", "Value": len(drifted_ids), "Unit": "Count"},
        ],
    )

    result = {
        "sample_size": len(distances),
        "mean_drift": round(mean_drift, 4),
        "p95_drift": round(p95_drift, 4),
        "max_drift": round(max_drift, 4),
        "drifted_count": len(drifted_ids),
        "alert": p95_drift > DRIFT_THRESHOLD,
    }

    logger.info(json.dumps({"log_type": "drift_report", **result}))

    if result["alert"]:
        logger.warning("DRIFT ALERT: p95_drift=%.4f > threshold=%.2f", p95_drift, DRIFT_THRESHOLD)

    return {"statusCode": 200, "body": json.dumps(result)}

Deployment Steps

  1. Package Lambda with opensearch-py, requests-aws4auth
  2. IAM: bedrock:InvokeModel, aoss:APIAccessAll, cloudwatch:PutMetricData
  3. EventBridge rule: rate(1 day) targeting this Lambda
  4. CloudWatch alarm: EmbeddingDriftP95 > 0.15 → SNS → ops team
  5. Lambda timeout: 5 minutes (200 embeddings × ~100ms each)

POC 5: Prompt Observability Pipeline (Skill 5.2.5)

Objective

End-to-end observability for prompt execution: from assembly through FM invocation to response validation, with X-Ray traces and CloudWatch metrics.

Architecture

flowchart TD
    A[Request] --> B[Orchestrator<br>with X-Ray SDK]

    subgraph Traced Execution
        B --> C["Span: intent_classification<br>(SageMaker)"]
        C --> D["Span: prompt_assembly<br>(template + context)"]
        D --> E["Span: bedrock_invocation<br>(FM call)"]
        E --> F["Span: response_validation<br>(schema check)"]
        F --> G["Span: response_delivery<br>(WebSocket/HTTP)"]
    end

    G --> H[X-Ray Trace<br>Complete]
    H --> I[X-Ray Console:<br>Service Map + Traces]

    B --> J[CloudWatch Metrics:<br>Per-span latency]
    F --> K[CloudWatch Metrics:<br>Schema compliance]

Full Working Code

"""
POC 5: Prompt Observability Pipeline
Lightweight X-Ray + CloudWatch instrumentation for prompt execution.
"""

import json
import time
import uuid
import logging
import boto3
from contextlib import contextmanager
from dataclasses import dataclass, field

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

cloudwatch = boto3.client("cloudwatch", region_name="ap-northeast-1")
NAMESPACE = "MangaAssist/PromptObservability"


@dataclass
class Span:
    name: str
    start: float = field(default_factory=time.time)
    end: float = 0.0
    duration_ms: float = 0.0
    success: bool = True
    error: str = ""
    metadata: dict = field(default_factory=dict)

    def close(self, success: bool = True, error: str = ""):
        self.end = time.time()
        self.duration_ms = (self.end - self.start) * 1000
        self.success = success
        self.error = error


@dataclass
class PromptTrace:
    trace_id: str = field(default_factory=lambda: str(uuid.uuid4()))
    session_id: str = ""
    intent: str = ""
    template_name: str = ""
    template_version: str = ""
    spans: list = field(default_factory=list)

    # Quality signals
    schema_valid: bool = True
    hallucination_risk: bool = False
    confusion_detected: bool = False


class PromptObservabilityPipeline:
    """Instruments prompt execution for observability."""

    def __init__(self):
        self._traces: dict = {}

    def start_trace(self, session_id: str, intent: str, template_name: str, template_version: str) -> PromptTrace:
        trace = PromptTrace(
            session_id=session_id,
            intent=intent,
            template_name=template_name,
            template_version=template_version,
        )
        self._traces[trace.trace_id] = trace
        return trace

    @contextmanager
    def span(self, trace: PromptTrace, name: str):
        """Context manager for timing a span."""
        s = Span(name=name)
        trace.spans.append(s)
        try:
            yield s
        except Exception as e:
            s.close(success=False, error=str(e))
            raise
        else:
            s.close(success=True)

    def finish_trace(self, trace: PromptTrace):
        """Finalize trace: emit metrics and structured log."""
        self._traces.pop(trace.trace_id, None)

        total_ms = sum(s.duration_ms for s in trace.spans)

        # Structured log
        log_entry = {
            "log_type": "prompt_trace",
            "trace_id": trace.trace_id,
            "session_id": trace.session_id,
            "intent": trace.intent,
            "template_name": trace.template_name,
            "template_version": trace.template_version,
            "total_duration_ms": round(total_ms, 2),
            "schema_valid": trace.schema_valid,
            "hallucination_risk": trace.hallucination_risk,
            "spans": [
                {
                    "name": s.name,
                    "duration_ms": round(s.duration_ms, 2),
                    "success": s.success,
                    "error": s.error,
                }
                for s in trace.spans
            ],
        }
        logger.info(json.dumps(log_entry))

        # CloudWatch metrics
        dimensions = [
            {"Name": "Intent", "Value": trace.intent},
            {"Name": "Template", "Value": trace.template_name},
        ]

        metric_data = [
            {
                "MetricName": "TotalExecutionDuration",
                "Value": total_ms,
                "Unit": "Milliseconds",
                "Dimensions": dimensions,
            },
        ]

        for s in trace.spans:
            metric_data.append({
                "MetricName": f"Span_{s.name}",
                "Value": s.duration_ms,
                "Unit": "Milliseconds",
                "Dimensions": dimensions,
            })

        if not trace.schema_valid:
            metric_data.append({
                "MetricName": "SchemaViolation", "Value": 1,
                "Unit": "Count", "Dimensions": dimensions,
            })

        cloudwatch.put_metric_data(Namespace=NAMESPACE, MetricData=metric_data)
        return log_entry


# — Example usage —
def handle_chat_request(session_id: str, user_message: str):
    """Example of instrumented chat request handling."""
    pipeline = PromptObservabilityPipeline()
    trace = pipeline.start_trace(
        session_id=session_id,
        intent="product_recommendation",
        template_name="recommend_v2",
        template_version="2.1.0",
    )

    with pipeline.span(trace, "intent_classification") as s:
        # intent = classify_intent(user_message)
        s.metadata["confidence"] = 0.92
        time.sleep(0.05)  # Simulated

    with pipeline.span(trace, "prompt_assembly") as s:
        # prompt = assemble_prompt(intent, user_message, context)
        s.metadata["token_count"] = 3200
        time.sleep(0.01)  # Simulated

    with pipeline.span(trace, "bedrock_invocation") as s:
        # response = bedrock_client.invoke(prompt)
        s.metadata["model_id"] = "claude-3-5-sonnet"
        time.sleep(1.5)  # Simulated

    with pipeline.span(trace, "response_validation") as s:
        # valid = schema_validator.validate(response)
        trace.schema_valid = True
        time.sleep(0.002)  # Simulated

    result = pipeline.finish_trace(trace)
    return result

Deployment Steps

  1. Integrate PromptObservabilityPipeline into the ECS orchestrator service
  2. For full X-Ray integration: add aws-xray-sdk and instrument each span as an X-Ray subsegment
  3. IAM: cloudwatch:PutMetricData, xray:PutTraceSegments
  4. Create CloudWatch dashboard combining span latency charts and schema violation counters
  5. Create Logs Insights saved queries for trace analysis

Expected Trace Output

{
  "log_type": "prompt_trace",
  "trace_id": "a1b2c3d4-...",
  "session_id": "sess-12345",
  "intent": "product_recommendation",
  "template_name": "recommend_v2",
  "template_version": "2.1.0",
  "total_duration_ms": 1562.3,
  "schema_valid": true,
  "spans": [
    {"name": "intent_classification", "duration_ms": 50.1, "success": true},
    {"name": "prompt_assembly", "duration_ms": 10.2, "success": true},
    {"name": "bedrock_invocation", "duration_ms": 1500.0, "success": true},
    {"name": "response_validation", "duration_ms": 2.0, "success": true}
  ]
}