LOCAL PREVIEW View on GitHub

PO-06: WebSocket Streaming Performance

User Story

As a frontend engineer, I want to minimize WebSocket delivery jitter to under 50ms and support 10K+ concurrent connections per node, So that users see a smooth, real-time streaming experience where tokens appear consistently without stuttering or delays.

Acceptance Criteria

  • Token delivery jitter (variance between consecutive deltas) is under 50ms at p95.
  • WebSocket connection setup completes in under 100ms after session init.
  • Each WebSocket handler node supports 10,000+ concurrent connections.
  • Heartbeat pings detect dead connections within 60 seconds.
  • HTTPS fallback adds no more than 200ms additional latency per response compared to WebSocket.
  • Backpressure mechanism prevents slow clients from blocking the streaming pipeline.

High-Level Design

Streaming Delivery Pipeline

graph LR
    subgraph "Backend"
        A[Bedrock<br>Token Stream] --> B[Orchestrator<br>Buffer & Forward]
        B --> C[WebSocket Handler<br>Fan-out per session]
    end

    subgraph "Edge"
        C --> D[ALB<br>Sticky Sessions]
        D --> E[CloudFront<br>WebSocket Passthrough]
    end

    subgraph "Client"
        E --> F[Browser<br>WebSocket API]
        F --> G[Chat Widget<br>Progressive Render]
    end

Optimization Strategy

graph TD
    subgraph "Connection Management"
        A1[Connection Pool<br>per handler node]
        A2[Graceful Handoff<br>on scale-in]
        A3[Dead Connection<br>Detection]
    end

    subgraph "Message Delivery"
        B1[Token Batching<br>Micro-batches at 50ms]
        B2[Backpressure<br>Async write buffers]
        B3[Delta Compression<br>Minimize payload]
    end

    subgraph "Infrastructure"
        C1[ALB Sticky Sessions<br>Route to same node]
        C2[ECS Fargate<br>Long-lived tasks]
        C3[HTTPS Fallback<br>Long-polling]
    end

    A1 --> D[Smooth Streaming<br>< 50ms jitter]
    B1 --> D
    C1 --> D

Low-Level Design

1. WebSocket Connection Manager

Each ECS Fargate task runs a WebSocket handler that manages thousands of concurrent connections.

graph TD
    subgraph "WebSocket Handler (per ECS Task)"
        A[Connection Registry<br>ConcurrentHashMap] --> B[Session: ws_001<br>Client WebSocket]
        A --> C[Session: ws_002<br>Client WebSocket]
        A --> D[Session: ws_003<br>Client WebSocket]
    end

    subgraph "Lifecycle"
        E[New Connection] --> F[Register in Registry]
        F --> G[Start Heartbeat Timer<br>30s interval]
        G --> H{Ping/Pong?}
        H -->|Pong received| G
        H -->|No pong in 60s| I[Close + Deregister]
    end

    subgraph "Metrics"
        J[Active Connections]
        K[Messages/sec]
        L[Connection Duration]
    end

Code Example: WebSocket Connection Manager

import asyncio
import logging
import time
from dataclasses import dataclass, field
from typing import Optional

from aiohttp import web, WSMsgType

logger = logging.getLogger(__name__)


@dataclass
class WebSocketConnection:
    session_id: str
    ws: web.WebSocketResponse
    connected_at: float
    last_pong: float
    messages_sent: int = 0
    bytes_sent: int = 0
    write_buffer_size: int = 0


class WebSocketConnectionManager:
    """Manages concurrent WebSocket connections with heartbeat and backpressure."""

    HEARTBEAT_INTERVAL = 30  # seconds
    PONG_TIMEOUT = 60  # seconds
    MAX_WRITE_BUFFER = 64 * 1024  # 64 KB per connection
    MAX_CONNECTIONS_PER_NODE = 12000

    def __init__(self):
        self._connections: dict[str, WebSocketConnection] = {}
        self._lock = asyncio.Lock()

    @property
    def active_connections(self) -> int:
        return len(self._connections)

    async def register(
        self, session_id: str, ws: web.WebSocketResponse
    ) -> bool:
        """Register a new WebSocket connection."""
        if self.active_connections >= self.MAX_CONNECTIONS_PER_NODE:
            logger.warning("Max connections reached, rejecting new connection")
            return False

        now = time.monotonic()
        conn = WebSocketConnection(
            session_id=session_id,
            ws=ws,
            connected_at=now,
            last_pong=now,
        )

        async with self._lock:
            # Close existing connection for same session (reconnect case)
            existing = self._connections.get(session_id)
            if existing and not existing.ws.closed:
                await existing.ws.close()
            self._connections[session_id] = conn

        # Start heartbeat for this connection
        asyncio.create_task(self._heartbeat_loop(session_id))
        return True

    async def deregister(self, session_id: str) -> None:
        """Remove a connection from the registry."""
        async with self._lock:
            conn = self._connections.pop(session_id, None)
        if conn and not conn.ws.closed:
            await conn.ws.close()

    async def send_delta(
        self, session_id: str, delta_payload: str
    ) -> bool:
        """Send a streaming delta to a specific session with backpressure."""
        conn = self._connections.get(session_id)
        if conn is None or conn.ws.closed:
            return False

        payload_size = len(delta_payload.encode("utf-8"))

        # Backpressure: if write buffer is too large, drop delta
        if conn.write_buffer_size > self.MAX_WRITE_BUFFER:
            logger.warning(
                f"Backpressure: dropping delta for {session_id}, "
                f"buffer={conn.write_buffer_size}"
            )
            return False

        try:
            conn.write_buffer_size += payload_size
            await conn.ws.send_str(delta_payload)
            conn.write_buffer_size = max(0, conn.write_buffer_size - payload_size)
            conn.messages_sent += 1
            conn.bytes_sent += payload_size
            return True
        except Exception as e:
            logger.warning(f"Failed to send to {session_id}: {e}")
            await self.deregister(session_id)
            return False

    async def _heartbeat_loop(self, session_id: str) -> None:
        """Send pings and detect dead connections."""
        while True:
            await asyncio.sleep(self.HEARTBEAT_INTERVAL)

            conn = self._connections.get(session_id)
            if conn is None or conn.ws.closed:
                return

            # Check if we received a pong recently
            if time.monotonic() - conn.last_pong > self.PONG_TIMEOUT:
                logger.info(f"Pong timeout for {session_id}, closing")
                await self.deregister(session_id)
                return

            try:
                await conn.ws.ping()
            except Exception:
                await self.deregister(session_id)
                return

    async def handle_pong(self, session_id: str) -> None:
        """Update last pong timestamp when client responds."""
        conn = self._connections.get(session_id)
        if conn:
            conn.last_pong = time.monotonic()

2. Token Micro-Batching

Instead of sending every individual token as a separate WebSocket frame, batch tokens into small groups (every 50ms) to reduce frame overhead and smooth delivery.

sequenceDiagram
    participant Bedrock
    participant Batcher as Token Batcher
    participant WebSocket as WebSocket Handler
    participant Client

    loop Every ~30ms (token generation)
        Bedrock-->>Batcher: Token "If"
        Bedrock-->>Batcher: Token " you"
        Bedrock-->>Batcher: Token " loved"
    end

    Note over Batcher: 50ms batch window
    Batcher->>WebSocket: Batch: "If you loved"
    WebSocket->>Client: Delta: "If you loved"

    loop Next batch
        Bedrock-->>Batcher: Token " Naruto"
        Bedrock-->>Batcher: Token ","
    end

    Batcher->>WebSocket: Batch: " Naruto,"
    WebSocket->>Client: Delta: " Naruto,"

Code Example: Token Batcher

import asyncio
import json
import time
from typing import Callable


class TokenBatcher:
    """Batches LLM output tokens into micro-batches for smoother delivery."""

    def __init__(
        self,
        batch_interval_ms: int = 50,
        max_batch_size: int = 10,
        on_batch: Callable[[str, str, str], None] | None = None,
    ):
        self.batch_interval_s = batch_interval_ms / 1000
        self.max_batch_size = max_batch_size
        self.on_batch = on_batch
        self._buffer: list[str] = []
        self._flush_task: asyncio.Task | None = None
        self._session_id: str = ""
        self._response_id: str = ""
        self._delta_index: int = 0

    async def start(self, session_id: str, response_id: str) -> None:
        """Start the batching loop for a response."""
        self._session_id = session_id
        self._response_id = response_id
        self._buffer = []
        self._delta_index = 0
        self._flush_task = asyncio.create_task(self._flush_loop())

    async def add_token(self, token: str) -> None:
        """Add a token to the current batch."""
        self._buffer.append(token)

        # Flush immediately if batch is full
        if len(self._buffer) >= self.max_batch_size:
            await self._flush()

    async def finish(self) -> None:
        """Flush remaining tokens and stop the batcher."""
        if self._flush_task:
            self._flush_task.cancel()
            try:
                await self._flush_task
            except asyncio.CancelledError:
                pass

        # Flush any remaining tokens
        if self._buffer:
            await self._flush()

    async def _flush_loop(self) -> None:
        """Periodically flush the token buffer."""
        while True:
            await asyncio.sleep(self.batch_interval_s)
            if self._buffer:
                await self._flush()

    async def _flush(self) -> None:
        """Send the current buffer as a single delta event."""
        if not self._buffer:
            return

        text = "".join(self._buffer)
        self._buffer.clear()

        delta_event = json.dumps({
            "type": "chat.response.delta",
            "session_id": self._session_id,
            "response_id": self._response_id,
            "delta": text,
            "index": self._delta_index,
        })
        self._delta_index += 1

        if self.on_batch:
            await self.on_batch(self._session_id, delta_event, text)

3. HTTPS Long-Polling Fallback

For clients that cannot maintain a WebSocket connection, provide an HTTPS fallback with server-sent events or long-polling.

sequenceDiagram
    participant Client
    participant API as API Gateway
    participant Buffer as Response Buffer
    participant Orchestrator

    Client->>API: POST /chat/message
    API-->>Client: 202 {response_id}

    loop Poll until complete
        Client->>API: GET /chat/message/{response_id}
        API->>Buffer: Get buffered chunks
        alt New chunks available
            Buffer-->>API: Buffered text
            API-->>Client: 200 {partial_response, status: "streaming"}
        else No new chunks
            Buffer-->>API: No new data
            API-->>Client: 200 {status: "streaming"}
        end
    end

    Client->>API: GET /chat/message/{response_id}
    API->>Buffer: Get final response
    Buffer-->>API: Complete response
    API-->>Client: 200 {full_response, status: "completed"}

Code Example: Response Buffer for HTTPS Fallback

import asyncio
import time
from dataclasses import dataclass, field
from typing import Optional


@dataclass
class BufferedResponse:
    response_id: str
    session_id: str
    chunks: list[str] = field(default_factory=list)
    full_text: str = ""
    status: str = "pending"  # pending, streaming, completed, error
    metadata: dict = field(default_factory=dict)
    created_at: float = field(default_factory=time.monotonic)
    completed_at: float | None = None
    last_read_index: int = 0


class ResponseBuffer:
    """In-memory buffer for HTTPS fallback streaming responses."""

    BUFFER_TTL_SECONDS = 300  # 5 minutes

    def __init__(self):
        self._buffers: dict[str, BufferedResponse] = {}
        self._events: dict[str, asyncio.Event] = {}

    def create(self, response_id: str, session_id: str) -> None:
        """Create a buffer for a new response."""
        self._buffers[response_id] = BufferedResponse(
            response_id=response_id,
            session_id=session_id,
        )
        self._events[response_id] = asyncio.Event()

    async def append_chunk(self, response_id: str, chunk: str) -> None:
        """Append a streaming chunk to the buffer."""
        buf = self._buffers.get(response_id)
        if buf is None:
            return

        buf.chunks.append(chunk)
        buf.full_text += chunk
        buf.status = "streaming"

        # Signal waiting long-poll clients
        event = self._events.get(response_id)
        if event:
            event.set()
            # Reset for next wait
            self._events[response_id] = asyncio.Event()

    async def complete(
        self, response_id: str, metadata: dict
    ) -> None:
        """Mark a response as completed."""
        buf = self._buffers.get(response_id)
        if buf is None:
            return

        buf.status = "completed"
        buf.metadata = metadata
        buf.completed_at = time.monotonic()

        event = self._events.get(response_id)
        if event:
            event.set()

    async def read(
        self,
        response_id: str,
        from_index: int = 0,
        timeout_ms: int = 10000,
    ) -> Optional[dict]:
        """
        Read new chunks from the buffer.
        Long-polls for up to timeout_ms if no new data available.
        """
        buf = self._buffers.get(response_id)
        if buf is None:
            return None

        # If there are new chunks, return immediately
        if from_index < len(buf.chunks) or buf.status == "completed":
            new_chunks = buf.chunks[from_index:]
            return {
                "response_id": response_id,
                "new_text": "".join(new_chunks),
                "full_text": buf.full_text,
                "chunk_index": len(buf.chunks),
                "status": buf.status,
                "metadata": buf.metadata if buf.status == "completed" else {},
            }

        # Long-poll: wait for new data
        event = self._events.get(response_id)
        if event:
            try:
                await asyncio.wait_for(
                    event.wait(), timeout=timeout_ms / 1000
                )
            except asyncio.TimeoutError:
                pass

        # Return whatever is available
        new_chunks = buf.chunks[from_index:]
        return {
            "response_id": response_id,
            "new_text": "".join(new_chunks),
            "full_text": buf.full_text,
            "chunk_index": len(buf.chunks),
            "status": buf.status,
            "metadata": buf.metadata if buf.status == "completed" else {},
        }

    def cleanup_expired(self) -> int:
        """Remove expired buffers to free memory."""
        now = time.monotonic()
        expired = [
            rid for rid, buf in self._buffers.items()
            if now - buf.created_at > self.BUFFER_TTL_SECONDS
        ]
        for rid in expired:
            del self._buffers[rid]
            self._events.pop(rid, None)
        return len(expired)

4. ALB Sticky Session Configuration

WebSocket connections must be routed to the same ECS task for the duration of the session.

graph TD
    subgraph "ALB Configuration"
        A[Target Group<br>Stickiness: enabled] --> B[Cookie Duration: 3600s<br>App cookie: WSID]
        B --> C[Health Check:<br>/health every 10s]
    end

    subgraph "Routing"
        D[Client A<br>Cookie: WSID=task-1] --> E[ALB]
        F[Client B<br>Cookie: WSID=task-2] --> E
        E --> G[Task 1]
        E --> H[Task 2]
    end

    style A fill:#2d8,stroke:#333

Metrics and Monitoring

Metric Target Alarm Threshold
ws.delivery_jitter_ms p95 < 50ms p95 > 100ms
ws.connection_setup_ms p95 < 100ms p95 > 200ms
ws.active_connections < 10K/node > 10K/node
ws.dead_connections_rate < 1%/min > 5%/min
ws.messages_per_second Monitor trend Sudden drop > 50%
ws.backpressure_drops < 0.1% > 1%
fallback.poll_latency_ms p95 < 200ms p95 > 500ms
fallback.poll_count_per_response avg < 5 avg > 10
graph LR
    subgraph "WebSocket Health"
        A[Active Connections<br>per node]
        B[Delivery Jitter<br>p95]
        C[Backpressure<br>Drop Rate]
        D[Dead Connection<br>Rate]
    end

    A --> E{> 10K?}
    E -->|Yes| F[Scale out<br>ECS tasks]
    C --> G{> 1%?}
    G -->|Yes| H[Investigate slow<br>clients or network]