LOCAL PREVIEW View on GitHub

MCP Server Implementation Patterns: Stateless, Stateful, and Connection Management

MangaAssist context: JP Manga store chatbot on AWS — Bedrock Claude 3 (Sonnet at $3/$15 per 1M tokens input/output, Haiku at $0.25/$1.25), OpenSearch Serverless (vector store), DynamoDB (sessions/products), ECS Fargate (orchestrator), API Gateway WebSocket, ElastiCache Redis. Target: useful answer in under 3 seconds, 1M messages/day scale.


Skill Mapping

Attribute Detail
Skill 2.1.7
Description Develop model extension frameworks to enhance FM capabilities
Sub-focus MCP client libraries for consistent access, connection management, capability advertisement
AWS Services Lambda, ECS Fargate, Bedrock, API Gateway, ElastiCache Redis, CloudWatch
MangaAssist Relevance Reliable tool connectivity under 1M messages/day, <3s end-to-end latency

Mind Map

mindmap
  root((MCP Implementation Patterns))
    Stateless MCP
      Request-Reply Model
      No Server-Side Session
      Lambda Natural Fit
      Idempotent Operations
      Cache-Assisted State
    Stateful MCP
      Persistent Connections
      Session Context
      In-Memory Caches
      Connection Affinity
      ECS Natural Fit
    Connection Management
      Connection Pooling
      Keepalive Strategies
      Circuit Breakers
      Retry Policies
      Backpressure
    Capability Advertisement
      Tool Registration
      Schema Validation
      Version Negotiation
      Dynamic Discovery
      Capability Refresh
    Health Monitoring
      Liveness Probes
      Readiness Probes
      Dependency Checks
      Degraded Mode
      Metric Emission

1. Stateless MCP Pattern

Stateless MCP servers process each request independently. The server holds no per-session data between invocations. This is the dominant pattern for Lambda MCP servers, where each invocation may land on a different (or cold) container.

1.1 Stateless Architecture

flowchart LR
    subgraph "Request 1"
        C1[MCP Client] -->|tools/call| L1[Lambda Container A]
        L1 -->|query| DB1[(DynamoDB)]
        DB1 -->|result| L1
        L1 -->|response| C1
    end

    subgraph "Request 2 (same tool, same user)"
        C2[MCP Client] -->|tools/call| L2[Lambda Container B]
        L2 -->|query| DB2[(DynamoDB)]
        DB2 -->|result| L2
        L2 -->|response| C2
    end

    style L1 fill:#d86613,color:#fff
    style L2 fill:#d86613,color:#fff
    style DB1 fill:#3b48cc,color:#fff
    style DB2 fill:#3b48cc,color:#fff

Each request is self-contained. Container A and B share nothing. The MCP client does not assume affinity.

1.2 Stateless Lambda MCP Server with External State

Even "stateless" servers often need context (user session, conversation history). The pattern is to externalize state to Redis or DynamoDB and fetch it per-request.

"""
Stateless Lambda MCP server for MangaAssist order tracking.
State (order data, session context) is fetched from DynamoDB
on every invocation. No server-side session is maintained.
"""

import json
import logging
import os
import time
from decimal import Decimal

import boto3
from boto3.dynamodb.conditions import Key

logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Module-level clients (reused across warm invocations -- this is
# CONNECTION reuse, not STATE reuse; the key distinction)
dynamodb = boto3.resource("dynamodb", region_name=os.environ.get("AWS_REGION", "ap-northeast-1"))
orders_table = dynamodb.Table(os.environ.get("ORDERS_TABLE", "mangaassist-orders"))
sessions_table = dynamodb.Table(os.environ.get("SESSIONS_TABLE", "mangaassist-sessions"))

PROTOCOL_VERSION = "2024-11-05"
SERVER_NAME = "mangaassist-order-tracking"
SERVER_VERSION = "1.2.0"

TOOL_DEFINITIONS = [
    {
        "name": "order_tracking",
        "description": (
            "Look up manga order status by order ID or customer email. "
            "Returns current status, tracking number, estimated delivery, "
            "and item list."
        ),
        "inputSchema": {
            "type": "object",
            "properties": {
                "order_id": {"type": "string", "description": "Order ID (MNG-XXXXXXXX)"},
                "customer_email": {"type": "string", "description": "Customer email"},
                "session_id": {
                    "type": "string",
                    "description": "Chat session ID for context continuity"
                }
            },
            "anyOf": [
                {"required": ["order_id"]},
                {"required": ["customer_email"]}
            ]
        }
    },
    {
        "name": "order_history",
        "description": (
            "Retrieve recent order history for a customer. Returns the "
            "last N orders with status summary."
        ),
        "inputSchema": {
            "type": "object",
            "properties": {
                "customer_email": {"type": "string"},
                "limit": {"type": "integer", "default": 5, "minimum": 1, "maximum": 20}
            },
            "required": ["customer_email"]
        }
    }
]


class DecimalEncoder(json.JSONEncoder):
    """Handle DynamoDB Decimal types in JSON serialization."""
    def default(self, obj):
        if isinstance(obj, Decimal):
            return float(obj) if obj % 1 else int(obj)
        return super().default(obj)


def _fetch_session_context(session_id: str) -> dict:
    """Fetch conversation context from DynamoDB sessions table.
    This is how a stateless server accesses 'state' -- by reading
    it from an external store every time."""
    if not session_id:
        return {}
    try:
        resp = sessions_table.get_item(Key={"session_id": session_id})
        item = resp.get("Item", {})
        return {
            "last_order_id": item.get("last_order_id"),
            "customer_email": item.get("customer_email"),
            "conversation_turns": item.get("turn_count", 0)
        }
    except Exception as e:
        logger.warning(f"Session fetch failed: {e}")
        return {}


def _update_session_context(session_id: str, updates: dict):
    """Write back any context updates for future requests."""
    if not session_id:
        return
    try:
        sessions_table.update_item(
            Key={"session_id": session_id},
            UpdateExpression="SET " + ", ".join(f"#{k} = :{k}" for k in updates),
            ExpressionAttributeNames={f"#{k}": k for k in updates},
            ExpressionAttributeValues={f":{k}": v for k, v in updates.items()},
        )
    except Exception as e:
        logger.warning(f"Session update failed: {e}")


def execute_order_tracking(arguments: dict) -> dict:
    """Look up a single order by ID or email."""
    order_id = arguments.get("order_id")
    email = arguments.get("customer_email")
    session_id = arguments.get("session_id", "")

    # Fetch session context for continuity
    ctx = _fetch_session_context(session_id)

    # If no order_id provided, check session context
    if not order_id and not email:
        order_id = ctx.get("last_order_id")
        email = ctx.get("customer_email")

    if not order_id and not email:
        return {
            "content": [{"type": "text", "text": json.dumps({
                "error": "Please provide an order ID or your email address.",
                "hint": "Order IDs look like MNG-12345678"
            })}],
            "isError": False  # Not a server error; user needs to provide info
        }

    start = time.time()

    try:
        if order_id:
            resp = orders_table.get_item(Key={"order_id": order_id})
            items = [resp["Item"]] if "Item" in resp else []
        else:
            resp = orders_table.query(
                IndexName="email-index",
                KeyConditionExpression=Key("customer_email").eq(email),
                ScanIndexForward=False,
                Limit=1
            )
            items = resp.get("Items", [])

        if not items:
            return {
                "content": [{"type": "text", "text": json.dumps({
                    "found": False,
                    "message": f"No order found for {'order ' + order_id if order_id else email}"
                })}],
                "isError": False
            }

        order = items[0]
        elapsed_ms = round((time.time() - start) * 1000, 1)

        result = {
            "order_id": order.get("order_id"),
            "status": order.get("status"),
            "tracking_number": order.get("tracking_number", ""),
            "carrier": order.get("carrier", ""),
            "estimated_delivery": order.get("estimated_delivery", ""),
            "items": order.get("items", []),
            "total_jpy": order.get("total_jpy", 0),
            "order_date": order.get("order_date", ""),
            "metadata": {"latency_ms": elapsed_ms}
        }

        # Update session context
        _update_session_context(session_id, {
            "last_order_id": order.get("order_id"),
            "customer_email": order.get("customer_email", email),
        })

        return {
            "content": [{"type": "text", "text": json.dumps(result, cls=DecimalEncoder, ensure_ascii=False)}],
            "isError": False
        }

    except Exception as e:
        logger.error(f"Order lookup failed: {e}", exc_info=True)
        return {
            "content": [{"type": "text", "text": f"Order lookup error: {str(e)}"}],
            "isError": True
        }


def execute_order_history(arguments: dict) -> dict:
    """Retrieve recent orders for a customer."""
    email = arguments["customer_email"]
    limit = min(arguments.get("limit", 5), 20)

    try:
        resp = orders_table.query(
            IndexName="email-index",
            KeyConditionExpression=Key("customer_email").eq(email),
            ScanIndexForward=False,
            Limit=limit
        )
        items = resp.get("Items", [])

        orders = []
        for item in items:
            orders.append({
                "order_id": item.get("order_id"),
                "status": item.get("status"),
                "total_jpy": item.get("total_jpy", 0),
                "item_count": len(item.get("items", [])),
                "order_date": item.get("order_date", ""),
            })

        return {
            "content": [{"type": "text", "text": json.dumps({
                "orders": orders,
                "total_orders": len(orders),
                "customer_email": email
            }, cls=DecimalEncoder, ensure_ascii=False)}],
            "isError": False
        }
    except Exception as e:
        logger.error(f"Order history failed: {e}", exc_info=True)
        return {
            "content": [{"type": "text", "text": f"Order history error: {str(e)}"}],
            "isError": True
        }


TOOL_HANDLERS = {
    "order_tracking": execute_order_tracking,
    "order_history": execute_order_history,
}

MCP_HANDLERS = {
    "initialize": lambda p: {
        "protocolVersion": PROTOCOL_VERSION,
        "capabilities": {"tools": {"listChanged": False}},
        "serverInfo": {"name": SERVER_NAME, "version": SERVER_VERSION}
    },
    "notifications/initialized": lambda p: None,
    "tools/list": lambda p: {"tools": TOOL_DEFINITIONS},
    "tools/call": lambda p: TOOL_HANDLERS.get(
        p.get("name", ""), lambda a: {
            "content": [{"type": "text", "text": "Unknown tool"}], "isError": True
        }
    )(p.get("arguments", {})),
}


def lambda_handler(event, context):
    """Lambda entry point for MCP protocol."""
    try:
        body = json.loads(event.get("body", "{}")) if isinstance(event.get("body"), str) else event
        method = body.get("method", "")
        msg_id = body.get("id")
        params = body.get("params", {})

        handler = MCP_HANDLERS.get(method)
        if not handler:
            return {"statusCode": 400, "body": json.dumps({
                "jsonrpc": "2.0", "id": msg_id,
                "error": {"code": -32601, "message": f"Unknown method: {method}"}
            })}

        result = handler(params)
        if msg_id is None:
            return {"statusCode": 204, "body": ""}

        return {"statusCode": 200, "body": json.dumps({
            "jsonrpc": "2.0", "id": msg_id, "result": result
        }, cls=DecimalEncoder, default=str)}

    except Exception as e:
        logger.error(f"Handler error: {e}", exc_info=True)
        return {"statusCode": 500, "body": json.dumps({
            "jsonrpc": "2.0", "id": None,
            "error": {"code": -32603, "message": "Internal error"}
        })}

1.3 Stateless Pattern Characteristics

Property Stateless MCP Behavior
Session affinity None required; any container can serve any request
Scaling Trivial horizontal scaling (more Lambda concurrency or ECS tasks)
Failure recovery Automatic; retry hits a different container, same result
State location DynamoDB, Redis, or S3 -- always external
Idempotency Natural; same input always produces same output
Cold start impact Moderate; connection setup adds 100-300ms
Best for Lookups, searches, CRUD, calculations

2. Stateful MCP Pattern

Stateful MCP servers maintain in-process data structures between requests. They are deployed on ECS Fargate where the process lifecycle is long-lived. For MangaAssist, the recommendation engine and session context service are stateful MCP servers.

2.1 Stateful Architecture

flowchart TB
    subgraph "ECS Task (long-lived process)"
        APP[Starlette App]
        POOL["Connection Pools<br/>- OpenSearch pool (10 conns)<br/>- Redis pool (20 conns)"]
        CACHE["In-Memory Caches<br/>- Genre embeddings (5 vectors)<br/>- Popular titles (top 100)<br/>- Schema cache"]
        SESSIONS["Active Sessions Map<br/>session_id -> context"]
        APP --> POOL
        APP --> CACHE
        APP --> SESSIONS
    end

    R1[Request 1: user-A] --> APP
    R2[Request 2: user-B] --> APP
    R3[Request 3: user-A] --> APP

    APP -->|"Request 3 sees user-A's cached preferences"| SESSIONS

    style APP fill:#d86613,color:#fff
    style POOL fill:#3b48cc,color:#fff
    style CACHE fill:#1a8c1a,color:#fff
    style SESSIONS fill:#8c4fff,color:#fff

2.2 Stateful ECS MCP Server with Connection Pooling

"""
Stateful ECS MCP server for MangaAssist session context management.
Maintains in-memory session state, connection pools, and preloaded caches.
Designed for ECS Fargate deployment with ALB health checks.
"""

import asyncio
import json
import logging
import os
import time
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import Any, Optional

import uvicorn
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Route
from contextlib import asynccontextmanager

import redis.asyncio as aioredis
import boto3

logger = logging.getLogger("mcp-session-context")
logging.basicConfig(level=logging.INFO)

# ---------- Configuration ----------

REDIS_ENDPOINT = os.environ.get("REDIS_ENDPOINT", "localhost")
MAX_SESSIONS = int(os.environ.get("MAX_SESSIONS", "10000"))
SESSION_TTL_SECONDS = int(os.environ.get("SESSION_TTL", "1800"))  # 30 min
PORT = int(os.environ.get("PORT", "8080"))

PROTOCOL_VERSION = "2024-11-05"
SERVER_NAME = "mangaassist-session-context"
SERVER_VERSION = "1.0.0"


# ---------- LRU Session Cache ----------

class LRUSessionCache:
    """In-memory LRU cache for active chat sessions.
    Evicts least-recently-used sessions when capacity is reached.
    This is the 'stateful' part -- data lives in this process's memory."""

    def __init__(self, max_size: int = 10000):
        self._cache: OrderedDict[str, dict] = OrderedDict()
        self._max_size = max_size
        self._hits = 0
        self._misses = 0

    def get(self, session_id: str) -> Optional[dict]:
        if session_id in self._cache:
            self._cache.move_to_end(session_id)
            self._hits += 1
            return self._cache[session_id]
        self._misses += 1
        return None

    def put(self, session_id: str, data: dict):
        if session_id in self._cache:
            self._cache.move_to_end(session_id)
            self._cache[session_id] = data
        else:
            if len(self._cache) >= self._max_size:
                evicted_key, _ = self._cache.popitem(last=False)
                logger.debug(f"Evicted session: {evicted_key}")
            self._cache[session_id] = data

    def remove(self, session_id: str):
        self._cache.pop(session_id, None)

    @property
    def size(self) -> int:
        return len(self._cache)

    @property
    def hit_rate(self) -> float:
        total = self._hits + self._misses
        return self._hits / total if total > 0 else 0.0

    def stats(self) -> dict:
        return {
            "size": self.size,
            "max_size": self._max_size,
            "hits": self._hits,
            "misses": self._misses,
            "hit_rate": round(self.hit_rate, 3)
        }


# ---------- Server State ----------

@dataclass
class ServerState:
    redis: Optional[aioredis.Redis] = None
    session_cache: LRUSessionCache = field(default_factory=lambda: LRUSessionCache(MAX_SESSIONS))
    startup_time: float = 0
    request_count: int = 0
    error_count: int = 0

state = ServerState()


# ---------- Tool Definitions ----------

TOOL_DEFINITIONS = [
    {
        "name": "session_context_get",
        "description": (
            "Retrieve the current chat session context including conversation "
            "history summary, user preferences, and recently viewed manga. "
            "Use this at the start of each turn to maintain continuity."
        ),
        "inputSchema": {
            "type": "object",
            "properties": {
                "session_id": {"type": "string", "description": "Chat session ID"},
                "include_history": {
                    "type": "boolean", "default": True,
                    "description": "Include conversation summary"
                }
            },
            "required": ["session_id"]
        }
    },
    {
        "name": "session_context_update",
        "description": (
            "Update session context after a conversation turn. Records "
            "user intent, viewed items, and interaction metadata."
        ),
        "inputSchema": {
            "type": "object",
            "properties": {
                "session_id": {"type": "string"},
                "user_intent": {"type": "string", "description": "Detected user intent"},
                "viewed_items": {
                    "type": "array", "items": {"type": "string"},
                    "description": "ISBNs or item IDs the user viewed"
                },
                "preferences_update": {
                    "type": "object",
                    "description": "Updated preference key-value pairs"
                }
            },
            "required": ["session_id"]
        }
    },
    {
        "name": "session_context_clear",
        "description": "Clear a chat session context (e.g., when user says 'start over').",
        "inputSchema": {
            "type": "object",
            "properties": {
                "session_id": {"type": "string"}
            },
            "required": ["session_id"]
        }
    }
]


# ---------- Tool Execution ----------

async def execute_session_get(arguments: dict) -> dict:
    """Get session context, checking in-memory cache first, then Redis."""
    session_id = arguments["session_id"]
    include_history = arguments.get("include_history", True)

    # Check in-memory LRU cache first (fast path)
    cached = state.session_cache.get(session_id)
    if cached:
        result = {**cached, "source": "memory_cache"}
        if not include_history:
            result.pop("conversation_summary", None)
        return {
            "content": [{"type": "text", "text": json.dumps(result, ensure_ascii=False)}],
            "isError": False
        }

    # Fall back to Redis (slower path)
    try:
        raw = await state.redis.get(f"session:{session_id}")
        if raw:
            data = json.loads(raw)
            # Populate in-memory cache for next access
            state.session_cache.put(session_id, data)
            result = {**data, "source": "redis_cache"}
            if not include_history:
                result.pop("conversation_summary", None)
            return {
                "content": [{"type": "text", "text": json.dumps(result, ensure_ascii=False)}],
                "isError": False
            }
    except Exception as e:
        logger.warning(f"Redis get failed: {e}")

    # No session found -- return empty context
    new_session = {
        "session_id": session_id,
        "created_at": time.time(),
        "turn_count": 0,
        "user_preferences": {},
        "viewed_items": [],
        "conversation_summary": "",
        "source": "new"
    }
    state.session_cache.put(session_id, new_session)

    return {
        "content": [{"type": "text", "text": json.dumps(new_session, ensure_ascii=False)}],
        "isError": False
    }


async def execute_session_update(arguments: dict) -> dict:
    """Update session context in both memory and Redis."""
    session_id = arguments["session_id"]
    user_intent = arguments.get("user_intent", "")
    viewed_items = arguments.get("viewed_items", [])
    prefs_update = arguments.get("preferences_update", {})

    # Get current session
    current = state.session_cache.get(session_id)
    if not current:
        try:
            raw = await state.redis.get(f"session:{session_id}")
            current = json.loads(raw) if raw else {}
        except Exception:
            current = {}

    if not current:
        current = {
            "session_id": session_id,
            "created_at": time.time(),
            "turn_count": 0,
            "user_preferences": {},
            "viewed_items": [],
            "conversation_summary": "",
        }

    # Apply updates
    current["turn_count"] = current.get("turn_count", 0) + 1
    current["last_updated"] = time.time()
    current["last_intent"] = user_intent

    if viewed_items:
        existing = current.get("viewed_items", [])
        existing.extend(viewed_items)
        current["viewed_items"] = existing[-50:]  # Keep last 50

    if prefs_update:
        prefs = current.get("user_preferences", {})
        prefs.update(prefs_update)
        current["user_preferences"] = prefs

    # Update summary with latest intent
    summary = current.get("conversation_summary", "")
    if user_intent:
        summary_parts = summary.split("; ") if summary else []
        summary_parts.append(f"Turn {current['turn_count']}: {user_intent}")
        current["conversation_summary"] = "; ".join(summary_parts[-10:])

    # Write to memory cache
    state.session_cache.put(session_id, current)

    # Write-through to Redis (async, non-blocking to the response)
    try:
        await state.redis.setex(
            f"session:{session_id}",
            SESSION_TTL_SECONDS,
            json.dumps(current, ensure_ascii=False, default=str)
        )
    except Exception as e:
        logger.warning(f"Redis write-through failed: {e}")

    return {
        "content": [{"type": "text", "text": json.dumps({
            "updated": True,
            "session_id": session_id,
            "turn_count": current["turn_count"]
        })}],
        "isError": False
    }


async def execute_session_clear(arguments: dict) -> dict:
    """Clear session from both memory cache and Redis."""
    session_id = arguments["session_id"]

    state.session_cache.remove(session_id)
    try:
        await state.redis.delete(f"session:{session_id}")
    except Exception as e:
        logger.warning(f"Redis delete failed: {e}")

    return {
        "content": [{"type": "text", "text": json.dumps({
            "cleared": True, "session_id": session_id
        })}],
        "isError": False
    }


TOOL_HANDLERS = {
    "session_context_get": execute_session_get,
    "session_context_update": execute_session_update,
    "session_context_clear": execute_session_clear,
}


# ---------- MCP Protocol Handler ----------

async def mcp_endpoint(request: Request) -> JSONResponse:
    """Handle MCP JSON-RPC messages."""
    state.request_count += 1
    try:
        body = await request.json()
    except Exception:
        return JSONResponse(
            {"jsonrpc": "2.0", "id": None, "error": {"code": -32700, "message": "Parse error"}},
            status_code=400
        )

    method = body.get("method", "")
    msg_id = body.get("id")
    params = body.get("params", {})

    if method == "initialize":
        return JSONResponse({"jsonrpc": "2.0", "id": msg_id, "result": {
            "protocolVersion": PROTOCOL_VERSION,
            "capabilities": {"tools": {"listChanged": False}},
            "serverInfo": {"name": SERVER_NAME, "version": SERVER_VERSION}
        }})

    elif method == "notifications/initialized":
        return JSONResponse(status_code=204, content=None)

    elif method == "tools/list":
        return JSONResponse({"jsonrpc": "2.0", "id": msg_id, "result": {"tools": TOOL_DEFINITIONS}})

    elif method == "tools/call":
        tool_name = params.get("name", "")
        handler = TOOL_HANDLERS.get(tool_name)
        if not handler:
            state.error_count += 1
            return JSONResponse({"jsonrpc": "2.0", "id": msg_id, "error": {
                "code": -32601, "message": f"Unknown tool: {tool_name}"
            }}, status_code=400)

        try:
            result = await handler(params.get("arguments", {}))
        except Exception as e:
            state.error_count += 1
            logger.error(f"Tool execution failed: {e}", exc_info=True)
            result = {
                "content": [{"type": "text", "text": f"Internal error: {e}"}],
                "isError": True
            }
        return JSONResponse({"jsonrpc": "2.0", "id": msg_id, "result": result})

    else:
        return JSONResponse({"jsonrpc": "2.0", "id": msg_id, "error": {
            "code": -32601, "message": f"Method not found: {method}"
        }}, status_code=400)


# ---------- Health Check Endpoint ----------

async def health_endpoint(request: Request) -> JSONResponse:
    """ALB health check -- reports server health and dependency status."""
    checks = {
        "server": "ok",
        "uptime_s": round(time.time() - state.startup_time),
        "request_count": state.request_count,
        "error_count": state.error_count,
        "session_cache": state.session_cache.stats(),
    }

    # Check Redis
    try:
        await state.redis.ping()
        checks["redis"] = "ok"
    except Exception:
        checks["redis"] = "degraded"

    all_ok = checks.get("redis") == "ok"
    checks["status"] = "healthy" if all_ok else "degraded"
    return JSONResponse(checks, status_code=200 if all_ok else 503)


# ---------- Readiness Endpoint ----------

async def ready_endpoint(request: Request) -> JSONResponse:
    """Readiness probe -- only returns 200 when server is fully initialized."""
    if state.redis is None:
        return JSONResponse({"ready": False, "reason": "initializing"}, status_code=503)
    return JSONResponse({"ready": True})


# ---------- Metrics Endpoint ----------

async def metrics_endpoint(request: Request) -> JSONResponse:
    """Prometheus-compatible metrics for CloudWatch Container Insights."""
    cache_stats = state.session_cache.stats()
    return JSONResponse({
        "mcp_requests_total": state.request_count,
        "mcp_errors_total": state.error_count,
        "session_cache_size": cache_stats["size"],
        "session_cache_hit_rate": cache_stats["hit_rate"],
        "uptime_seconds": round(time.time() - state.startup_time),
    })


# ---------- Application Lifecycle ----------

@asynccontextmanager
async def lifespan(app):
    """Startup: connect Redis, load caches. Shutdown: drain and close."""
    logger.info("Starting MCP session context server...")
    state.startup_time = time.time()

    # Initialize Redis connection pool
    state.redis = aioredis.Redis(
        host=REDIS_ENDPOINT, port=6379,
        decode_responses=True, socket_timeout=2,
        max_connections=20,  # Connection pool size
    )

    # Verify Redis connectivity
    try:
        await state.redis.ping()
        logger.info("Redis connection verified")
    except Exception as e:
        logger.error(f"Redis connection failed: {e}")

    logger.info(f"Server ready on port {PORT}")
    yield

    # Shutdown
    logger.info("Shutting down...")
    if state.redis:
        await state.redis.close()
    logger.info("Shutdown complete")


app = Starlette(
    routes=[
        Route("/mcp", mcp_endpoint, methods=["POST"]),
        Route("/mcp/health", health_endpoint, methods=["GET"]),
        Route("/mcp/ready", ready_endpoint, methods=["GET"]),
        Route("/mcp/metrics", metrics_endpoint, methods=["GET"]),
    ],
    lifespan=lifespan,
)

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=PORT, log_level="info")

3. Connection Management Patterns

3.1 Connection Management Architecture

graph TB
    subgraph "MCP Client (Orchestrator)"
        CLIENT[MCP Client Library]
        POOL_C[HTTP Connection Pool<br/>max_connections=50<br/>keepalive=30s]
        CB[Circuit Breaker<br/>per MCP server]
        RETRY[Retry Handler<br/>exponential backoff]
        CLIENT --> POOL_C
        CLIENT --> CB
        CLIENT --> RETRY
    end

    subgraph "Transport"
        ALB[ALB<br/>idle_timeout=60s]
        APIGW[API Gateway<br/>timeout=29s]
    end

    subgraph "Lambda MCP"
        L_POOL["Module-Level Clients<br/>(reused on warm start)<br/>- boto3 DynamoDB client<br/>- OpenSearch client<br/>- Redis client"]
    end

    subgraph "ECS MCP"
        E_POOL["Connection Pools<br/>- OpenSearch: 10 conns<br/>- Redis: 20 conns<br/>- DynamoDB: SDK default"]
    end

    POOL_C -->|"Lambda tools"| APIGW
    POOL_C -->|"ECS tools"| ALB
    APIGW --> L_POOL
    ALB --> E_POOL

    style CLIENT fill:#ff9900,color:#000
    style CB fill:#c7131b,color:#fff
    style L_POOL fill:#d86613,color:#fff
    style E_POOL fill:#1a8c1a,color:#fff

3.2 Circuit Breaker for MCP Servers

"""
Circuit breaker implementation for MCP server connections.
Prevents cascading failures when an MCP server is unhealthy.
"""

import asyncio
import time
import logging
from enum import Enum
from dataclasses import dataclass, field

logger = logging.getLogger("mcp-circuit-breaker")


class CircuitState(Enum):
    CLOSED = "closed"        # Normal operation, requests flow through
    OPEN = "open"            # Failures exceeded threshold, requests blocked
    HALF_OPEN = "half_open"  # Testing if service recovered


@dataclass
class CircuitBreaker:
    """Per-server circuit breaker to protect against cascading failures."""
    server_name: str
    failure_threshold: int = 5       # Failures before opening
    recovery_timeout_s: float = 30   # Seconds before trying half-open
    success_threshold: int = 3       # Successes in half-open before closing

    # Internal state
    _state: CircuitState = field(default=CircuitState.CLOSED, init=False)
    _failure_count: int = field(default=0, init=False)
    _success_count: int = field(default=0, init=False)
    _last_failure_time: float = field(default=0, init=False)
    _total_blocked: int = field(default=0, init=False)

    @property
    def state(self) -> CircuitState:
        if self._state == CircuitState.OPEN:
            elapsed = time.time() - self._last_failure_time
            if elapsed >= self.recovery_timeout_s:
                self._state = CircuitState.HALF_OPEN
                self._success_count = 0
                logger.info(f"Circuit {self.server_name}: OPEN -> HALF_OPEN")
        return self._state

    def allow_request(self) -> bool:
        """Check if a request should be allowed through."""
        current = self.state
        if current == CircuitState.CLOSED:
            return True
        if current == CircuitState.HALF_OPEN:
            return True  # Allow probe requests
        # OPEN
        self._total_blocked += 1
        return False

    def record_success(self):
        """Record a successful request."""
        if self._state == CircuitState.HALF_OPEN:
            self._success_count += 1
            if self._success_count >= self.success_threshold:
                self._state = CircuitState.CLOSED
                self._failure_count = 0
                logger.info(f"Circuit {self.server_name}: HALF_OPEN -> CLOSED")
        elif self._state == CircuitState.CLOSED:
            self._failure_count = 0  # Reset on success

    def record_failure(self):
        """Record a failed request."""
        self._failure_count += 1
        self._last_failure_time = time.time()

        if self._state == CircuitState.HALF_OPEN:
            self._state = CircuitState.OPEN
            logger.warning(f"Circuit {self.server_name}: HALF_OPEN -> OPEN")
        elif self._failure_count >= self.failure_threshold:
            self._state = CircuitState.OPEN
            logger.warning(
                f"Circuit {self.server_name}: CLOSED -> OPEN "
                f"(failures={self._failure_count})"
            )

    def stats(self) -> dict:
        return {
            "server": self.server_name,
            "state": self.state.value,
            "failures": self._failure_count,
            "total_blocked": self._total_blocked,
        }


class MCPConnectionManager:
    """Manages connections to multiple MCP servers with circuit breakers,
    retry logic, and connection pooling."""

    def __init__(self):
        self._circuits: dict[str, CircuitBreaker] = {}
        self._server_latencies: dict[str, list] = {}

    def register_server(self, server_name: str, **cb_kwargs):
        """Register a new MCP server with its circuit breaker."""
        self._circuits[server_name] = CircuitBreaker(server_name, **cb_kwargs)
        self._server_latencies[server_name] = []

    async def call_with_protection(
        self,
        server_name: str,
        call_fn,  # async callable
        max_retries: int = 2,
        base_delay_s: float = 0.1,
    ) -> dict:
        """Execute an MCP call with circuit breaker and retry protection."""
        circuit = self._circuits.get(server_name)
        if not circuit:
            raise ValueError(f"Unknown server: {server_name}")

        if not circuit.allow_request():
            return {
                "content": [{
                    "type": "text",
                    "text": f"Server {server_name} is temporarily unavailable (circuit open)"
                }],
                "isError": True,
                "_circuit_open": True
            }

        last_error = None
        for attempt in range(max_retries + 1):
            try:
                start = time.time()
                result = await call_fn()
                elapsed_ms = (time.time() - start) * 1000

                circuit.record_success()
                self._record_latency(server_name, elapsed_ms)
                return result

            except Exception as e:
                last_error = e
                logger.warning(
                    f"MCP call to {server_name} failed "
                    f"(attempt {attempt + 1}/{max_retries + 1}): {e}"
                )
                if attempt < max_retries:
                    delay = base_delay_s * (2 ** attempt)  # Exponential backoff
                    await asyncio.sleep(delay)

        # All retries exhausted
        circuit.record_failure()
        return {
            "content": [{"type": "text", "text": f"Tool call failed after {max_retries + 1} attempts: {last_error}"}],
            "isError": True
        }

    def _record_latency(self, server_name: str, latency_ms: float):
        """Track latency for monitoring."""
        latencies = self._server_latencies.get(server_name, [])
        latencies.append(latency_ms)
        if len(latencies) > 100:
            latencies = latencies[-100:]
        self._server_latencies[server_name] = latencies

    def get_stats(self) -> dict:
        """Get health stats for all managed servers."""
        stats = {}
        for name, circuit in self._circuits.items():
            latencies = self._server_latencies.get(name, [])
            avg_latency = sum(latencies) / len(latencies) if latencies else 0
            p99_latency = sorted(latencies)[int(len(latencies) * 0.99)] if latencies else 0
            stats[name] = {
                **circuit.stats(),
                "avg_latency_ms": round(avg_latency, 1),
                "p99_latency_ms": round(p99_latency, 1),
                "sample_count": len(latencies),
            }
        return stats

3.3 Connection Pool Configuration

Component Pool Size Keepalive Timeout Retry
MCP Client -> API GW 50 connections 30s 10s 2 retries, exponential
MCP Client -> ALB 50 connections 60s 15s 2 retries, exponential
Lambda -> OpenSearch 1 per container N/A (per-invoke) 5s 1 retry
Lambda -> DynamoDB SDK default (10) SDK managed 3s SDK default (3)
Lambda -> Redis 1 per container N/A 1s 0 retries
ECS -> OpenSearch 10 per task 120s 10s 2 retries
ECS -> Redis 20 per task 300s 2s 1 retry
ECS -> DynamoDB SDK default (10) SDK managed 3s SDK default (3)

4. Capability Advertisement Pattern

Capability advertisement is the mechanism by which MCP servers declare what tools they expose and what protocol features they support. This happens during the initialize handshake.

4.1 Capability Negotiation Flow

sequenceDiagram
    participant C as MCP Client
    participant S as MCP Server

    C->>S: initialize {protocolVersion, capabilities, clientInfo}
    Note over S: Validate protocol version<br/>Select compatible capabilities
    S->>C: {protocolVersion, capabilities, serverInfo}

    Note over C: Check server capabilities<br/>Store tool support flags

    C->>S: notifications/initialized {}
    Note over S: Server fully ready

    C->>S: tools/list {}
    S->>C: {tools: [{name, description, inputSchema}, ...]}

    Note over C: Cache tool definitions<br/>Validate input schemas<br/>Ready for tool calls

    loop Every 5 minutes (optional)
        C->>S: tools/list {}
        S->>C: Updated tool list
        Note over C: Detect capability drift
    end

4.2 Capability Advertisement Implementation

"""
Capability advertisement and version negotiation for MCP servers.
Handles protocol version compatibility and feature flagging.
"""

from dataclasses import dataclass, field
from typing import Optional
import logging

logger = logging.getLogger("mcp-capabilities")

# Supported protocol versions (newest first)
SUPPORTED_VERSIONS = ["2024-11-05", "2024-09-01"]


@dataclass
class ServerCapabilities:
    """MCP server capability declaration."""
    tools: bool = True
    tools_list_changed: bool = False  # Can tools change at runtime?
    resources: bool = False           # Does server expose resources?
    resources_subscribe: bool = False # Can client subscribe to resource changes?
    prompts: bool = False             # Does server expose prompt templates?
    logging_support: bool = True      # Does server emit log messages?


@dataclass
class NegotiatedSession:
    """Result of capability negotiation between client and server."""
    protocol_version: str
    client_name: str
    client_version: str
    server_capabilities: ServerCapabilities
    session_id: Optional[str] = None


def negotiate_capabilities(
    client_params: dict,
    server_caps: ServerCapabilities,
    server_name: str,
    server_version: str,
) -> dict:
    """
    Process an MCP initialize request and return the negotiate response.

    Handles:
    1. Protocol version negotiation (pick highest mutually supported)
    2. Capability advertisement (declare what server supports)
    3. Server identification

    Returns:
        dict: MCP initialize response result
    """
    client_version = client_params.get("protocolVersion", "")
    client_info = client_params.get("clientInfo", {})

    # Version negotiation: use the client's version if we support it,
    # otherwise fall back to our highest supported version
    if client_version in SUPPORTED_VERSIONS:
        negotiated_version = client_version
    else:
        negotiated_version = SUPPORTED_VERSIONS[0]
        logger.warning(
            f"Client requested unsupported version {client_version}, "
            f"using {negotiated_version}"
        )

    # Build capabilities response
    caps = {}
    if server_caps.tools:
        caps["tools"] = {"listChanged": server_caps.tools_list_changed}
    if server_caps.resources:
        res_cap = {}
        if server_caps.resources_subscribe:
            res_cap["subscribe"] = True
        caps["resources"] = res_cap
    if server_caps.prompts:
        caps["prompts"] = {"listChanged": False}
    if server_caps.logging_support:
        caps["logging"] = {}

    logger.info(
        f"Negotiated MCP session: version={negotiated_version}, "
        f"client={client_info.get('name', 'unknown')}"
    )

    return {
        "protocolVersion": negotiated_version,
        "capabilities": caps,
        "serverInfo": {
            "name": server_name,
            "version": server_version,
        }
    }


# ---------- Tool Schema Validation ----------

def validate_tool_call(
    tool_name: str,
    arguments: dict,
    tool_definitions: list[dict],
) -> tuple[bool, str]:
    """
    Validate a tool call against registered tool definitions.
    Returns (is_valid, error_message).
    """
    tool_def = None
    for t in tool_definitions:
        if t["name"] == tool_name:
            tool_def = t
            break

    if tool_def is None:
        return False, f"Tool '{tool_name}' is not registered on this server"

    schema = tool_def.get("inputSchema", {})
    required = schema.get("required", [])

    # Check required fields
    for req_field in required:
        if req_field not in arguments:
            return False, f"Missing required argument: '{req_field}'"

    # Check types of provided arguments
    properties = schema.get("properties", {})
    for key, value in arguments.items():
        if key in properties:
            expected_type = properties[key].get("type")
            if expected_type and not _type_matches(value, expected_type):
                return False, (
                    f"Argument '{key}' has wrong type: "
                    f"expected {expected_type}, got {type(value).__name__}"
                )

    return True, ""


def _type_matches(value, expected_type: str) -> bool:
    """Check if a Python value matches a JSON Schema type."""
    type_map = {
        "string": str,
        "integer": int,
        "number": (int, float),
        "boolean": bool,
        "array": list,
        "object": dict,
    }
    expected = type_map.get(expected_type)
    if expected is None:
        return True  # Unknown type, allow it
    return isinstance(value, expected)

5. Health Check Patterns

5.1 Health Check Architecture

flowchart TB
    subgraph "ECS MCP Server Health Checks"
        ALB_HC[ALB Health Check<br/>GET /mcp/health<br/>interval=15s, threshold=3]
        ECS_HC[ECS Health Check<br/>COMMAND healthcheck.py<br/>interval=30s, retries=3]
        CW[CloudWatch Alarm<br/>HealthyHostCount < 1]
    end

    subgraph "Health Check Logic"
        LIVE[Liveness: /mcp/health<br/>Is process alive?<br/>Can it accept requests?]
        READY[Readiness: /mcp/ready<br/>Are dependencies connected?<br/>Are caches warm?]
        DEEP[Deep: /mcp/health?deep=true<br/>Redis ping OK?<br/>OpenSearch ping OK?<br/>Session cache loaded?]
    end

    subgraph "Responses"
        H200[200 OK<br/>All checks pass]
        H503[503 Degraded<br/>Dependency down<br/>but can serve cached]
        H500[500 Unhealthy<br/>Cannot serve requests]
    end

    ALB_HC --> LIVE
    ECS_HC --> READY
    CW --> DEEP

    LIVE --> H200
    LIVE --> H503
    READY --> H200
    READY --> H500
    DEEP --> H200
    DEEP --> H503

    style H200 fill:#1a8c1a,color:#fff
    style H503 fill:#d86613,color:#fff
    style H500 fill:#c7131b,color:#fff

5.2 Comprehensive Health Check Implementation

"""
Health check endpoints for ECS MCP servers.
Three tiers: liveness (is the process running?), readiness (are
dependencies ready?), and deep (full dependency verification).
"""

import time
import logging
import asyncio
from starlette.requests import Request
from starlette.responses import JSONResponse

logger = logging.getLogger("mcp-health")


class HealthChecker:
    """Three-tier health check system for MCP servers."""

    def __init__(self, server_state):
        self.state = server_state
        self._startup_complete = False
        self._degraded_since: float = 0

    async def liveness(self, request: Request) -> JSONResponse:
        """Liveness probe: is the process alive and able to serve traffic?
        Used by ALB target group health checks."""
        uptime = round(time.time() - self.state.startup_time)
        return JSONResponse({
            "status": "alive",
            "uptime_s": uptime,
            "request_count": self.state.request_count,
        }, status_code=200)

    async def readiness(self, request: Request) -> JSONResponse:
        """Readiness probe: has the server fully initialized?
        Used by ECS to determine when to route traffic to a new task."""
        if not self._startup_complete:
            return JSONResponse(
                {"ready": False, "reason": "still initializing"},
                status_code=503
            )
        return JSONResponse({"ready": True}, status_code=200)

    async def deep_health(self, request: Request) -> JSONResponse:
        """Deep health check: verify all backend dependencies.
        Used by monitoring and alerting systems."""
        checks = {}
        overall_healthy = True

        # Check Redis
        try:
            start = time.time()
            await asyncio.wait_for(self.state.redis.ping(), timeout=2)
            checks["redis"] = {
                "status": "ok",
                "latency_ms": round((time.time() - start) * 1000, 1)
            }
        except Exception as e:
            checks["redis"] = {"status": "error", "message": str(e)}
            overall_healthy = False

        # Check session cache
        cache_stats = self.state.session_cache.stats()
        checks["session_cache"] = {
            "status": "ok",
            "size": cache_stats["size"],
            "hit_rate": cache_stats["hit_rate"],
        }
        if cache_stats["hit_rate"] < 0.5 and cache_stats["size"] > 100:
            checks["session_cache"]["status"] = "warning"

        # Error rate check
        if self.state.request_count > 0:
            error_rate = self.state.error_count / self.state.request_count
            checks["error_rate"] = {
                "rate": round(error_rate, 4),
                "status": "ok" if error_rate < 0.05 else "warning"
            }

        status_code = 200 if overall_healthy else 503
        status_text = "healthy" if overall_healthy else "degraded"

        if not overall_healthy and self._degraded_since == 0:
            self._degraded_since = time.time()
        elif overall_healthy:
            self._degraded_since = 0

        return JSONResponse({
            "status": status_text,
            "checks": checks,
            "uptime_s": round(time.time() - self.state.startup_time),
            "degraded_since": self._degraded_since or None,
        }, status_code=status_code)

    def mark_ready(self):
        """Call once all initialization is complete."""
        self._startup_complete = True
        logger.info("Server marked as ready")

6. Stateless vs Stateful MCP Comparison

Dimension Stateless MCP Stateful MCP
Server process Ephemeral (Lambda) or shared (ECS) Long-lived (ECS)
Session data External (DynamoDB/Redis) In-memory + write-through
Connection pools New per cold start Persistent, pre-warmed
Scaling Add concurrency/tasks Sticky sessions or shared cache
Failure impact Minimal (retry on new container) Session loss if task dies
Latency Higher (external state fetch) Lower (in-memory cache)
Complexity Lower Higher (cache coherence, drain)
MangaAssist fit Order tracking, catalog search Session context, recommendations
Cost Per-invocation Per-hour
Cold start 200-800ms 30-60s (but then always warm)

6.1 Hybrid Pattern Decision Tree

flowchart TD
    Q1{Does the tool need<br/>data from previous turns?}
    Q1 -->|Yes| Q2{Is latency critical?<br/>Must be <100ms overhead?}
    Q1 -->|No| STATELESS["Stateless MCP<br/>(Lambda recommended)"]

    Q2 -->|Yes| STATEFUL["Stateful MCP<br/>(ECS with in-memory cache)"]
    Q2 -->|No| HYBRID["Hybrid: Stateless MCP<br/>+ Redis for session state"]

    STATELESS --> S_TOOLS["MangaAssist tools:<br/>- manga_catalog_search<br/>- order_tracking<br/>- price_check<br/>- inventory_status"]

    STATEFUL --> E_TOOLS["MangaAssist tools:<br/>- session_context_get/update<br/>- manga_recommendations<br/>- reading_list_analysis"]

    HYBRID --> H_TOOLS["MangaAssist tools:<br/>- order_history (Lambda + Redis)<br/>- wishlist_manage (Lambda + DDB)"]

    style STATELESS fill:#d86613,color:#fff
    style STATEFUL fill:#1a8c1a,color:#fff
    style HYBRID fill:#8c4fff,color:#fff

7. MCP Server Registration and Discovery

In production, MangaAssist needs a registry so the orchestrator can discover MCP servers dynamically rather than hardcoding endpoints.

"""
MCP Server Registry using DynamoDB.
Allows dynamic registration, discovery, and health tracking
of MCP servers in the MangaAssist fleet.
"""

import json
import time
import logging
from typing import Optional
from decimal import Decimal

import boto3
from boto3.dynamodb.conditions import Key, Attr

logger = logging.getLogger("mcp-registry")


class MCPServerRegistry:
    """Centralized registry for MCP server discovery and health tracking."""

    def __init__(self, table_name: str = "mangaassist-mcp-registry", region: str = "ap-northeast-1"):
        dynamodb = boto3.resource("dynamodb", region_name=region)
        self.table = dynamodb.Table(table_name)

    def register_server(
        self,
        server_name: str,
        endpoint: str,
        server_type: str,  # "lambda" or "ecs"
        tools: list[dict],
        version: str,
        metadata: Optional[dict] = None,
    ):
        """Register or update an MCP server in the registry."""
        item = {
            "server_name": server_name,
            "endpoint": endpoint,
            "server_type": server_type,
            "tools": json.dumps(tools),
            "tool_names": [t["name"] for t in tools],
            "version": version,
            "registered_at": Decimal(str(time.time())),
            "last_heartbeat": Decimal(str(time.time())),
            "status": "active",
            "metadata": metadata or {},
        }
        self.table.put_item(Item=item)
        logger.info(f"Registered MCP server: {server_name} ({len(tools)} tools)")

    def discover_tool(self, tool_name: str) -> Optional[dict]:
        """Find which server hosts a specific tool."""
        response = self.table.scan(
            FilterExpression=Attr("tool_names").contains(tool_name)
            & Attr("status").eq("active")
        )
        items = response.get("Items", [])
        if not items:
            return None
        # Return the most recently heartbeat-ed server
        items.sort(key=lambda x: float(x.get("last_heartbeat", 0)), reverse=True)
        server = items[0]
        return {
            "server_name": server["server_name"],
            "endpoint": server["endpoint"],
            "server_type": server["server_type"],
            "version": server["version"],
        }

    def list_active_servers(self) -> list[dict]:
        """List all active MCP servers."""
        response = self.table.scan(
            FilterExpression=Attr("status").eq("active")
        )
        servers = []
        for item in response.get("Items", []):
            servers.append({
                "server_name": item["server_name"],
                "endpoint": item["endpoint"],
                "server_type": item["server_type"],
                "tools": json.loads(item.get("tools", "[]")),
                "version": item["version"],
                "last_heartbeat": float(item.get("last_heartbeat", 0)),
            })
        return servers

    def heartbeat(self, server_name: str):
        """Update the heartbeat timestamp for a server."""
        self.table.update_item(
            Key={"server_name": server_name},
            UpdateExpression="SET last_heartbeat = :ts",
            ExpressionAttributeValues={":ts": Decimal(str(time.time()))},
        )

    def deregister_server(self, server_name: str):
        """Mark a server as inactive."""
        self.table.update_item(
            Key={"server_name": server_name},
            UpdateExpression="SET #s = :status",
            ExpressionAttributeNames={"#s": "status"},
            ExpressionAttributeValues={":status": "inactive"},
        )
        logger.info(f"Deregistered MCP server: {server_name}")

Key Takeaways

# Takeaway
1 Stateless MCP servers externalize all state -- every request fetches context from DynamoDB or Redis, making them perfectly suited for Lambda where containers are ephemeral.
2 Stateful MCP servers use LRU caches and connection pools -- the session context server maintains an in-memory OrderedDict with 10K session capacity, achieving sub-millisecond reads for cache hits.
3 Connection pooling is configured per-backend -- OpenSearch gets 10 connections per ECS task, Redis gets 20, and the MCP client maintains 50 HTTP connections with 30s keepalive.
4 Circuit breakers prevent cascading failures -- each MCP server gets its own circuit breaker (5 failures to open, 30s recovery, 3 successes to close) to isolate failures.
5 Capability advertisement happens at initialize -- servers declare their protocol version, supported features (tools, resources, prompts), and whether tool lists can change at runtime.
6 Health checks are three-tiered -- liveness (process alive), readiness (dependencies connected), and deep (full dependency verification with latency measurement).
7 Write-through caching bridges stateless and stateful -- the session context server writes to both in-memory cache and Redis, so a task restart only costs a Redis read, not data loss.
8 Dynamic discovery via DynamoDB registry -- the orchestrator finds MCP servers by tool name, enabling zero-downtime server updates and blue-green deployments.