PO-06: WebSocket Streaming Performance
User Story
As a frontend engineer, I want to minimize WebSocket delivery jitter to under 50ms and support 10K+ concurrent connections per node, So that users see a smooth, real-time streaming experience where tokens appear consistently without stuttering or delays.
Acceptance Criteria
- Token delivery jitter (variance between consecutive deltas) is under 50ms at p95.
- WebSocket connection setup completes in under 100ms after session init.
- Each WebSocket handler node supports 10,000+ concurrent connections.
- Heartbeat pings detect dead connections within 60 seconds.
- HTTPS fallback adds no more than 200ms additional latency per response compared to WebSocket.
- Backpressure mechanism prevents slow clients from blocking the streaming pipeline.
High-Level Design
Streaming Delivery Pipeline
graph LR
subgraph "Backend"
A[Bedrock<br>Token Stream] --> B[Orchestrator<br>Buffer & Forward]
B --> C[WebSocket Handler<br>Fan-out per session]
end
subgraph "Edge"
C --> D[ALB<br>Sticky Sessions]
D --> E[CloudFront<br>WebSocket Passthrough]
end
subgraph "Client"
E --> F[Browser<br>WebSocket API]
F --> G[Chat Widget<br>Progressive Render]
end
Optimization Strategy
graph TD
subgraph "Connection Management"
A1[Connection Pool<br>per handler node]
A2[Graceful Handoff<br>on scale-in]
A3[Dead Connection<br>Detection]
end
subgraph "Message Delivery"
B1[Token Batching<br>Micro-batches at 50ms]
B2[Backpressure<br>Async write buffers]
B3[Delta Compression<br>Minimize payload]
end
subgraph "Infrastructure"
C1[ALB Sticky Sessions<br>Route to same node]
C2[ECS Fargate<br>Long-lived tasks]
C3[HTTPS Fallback<br>Long-polling]
end
A1 --> D[Smooth Streaming<br>< 50ms jitter]
B1 --> D
C1 --> D
Low-Level Design
1. WebSocket Connection Manager
Each ECS Fargate task runs a WebSocket handler that manages thousands of concurrent connections.
graph TD
subgraph "WebSocket Handler (per ECS Task)"
A[Connection Registry<br>ConcurrentHashMap] --> B[Session: ws_001<br>Client WebSocket]
A --> C[Session: ws_002<br>Client WebSocket]
A --> D[Session: ws_003<br>Client WebSocket]
end
subgraph "Lifecycle"
E[New Connection] --> F[Register in Registry]
F --> G[Start Heartbeat Timer<br>30s interval]
G --> H{Ping/Pong?}
H -->|Pong received| G
H -->|No pong in 60s| I[Close + Deregister]
end
subgraph "Metrics"
J[Active Connections]
K[Messages/sec]
L[Connection Duration]
end
Code Example: WebSocket Connection Manager
import asyncio
import logging
import time
from dataclasses import dataclass, field
from typing import Optional
from aiohttp import web, WSMsgType
logger = logging.getLogger(__name__)
@dataclass
class WebSocketConnection:
session_id: str
ws: web.WebSocketResponse
connected_at: float
last_pong: float
messages_sent: int = 0
bytes_sent: int = 0
write_buffer_size: int = 0
class WebSocketConnectionManager:
"""Manages concurrent WebSocket connections with heartbeat and backpressure."""
HEARTBEAT_INTERVAL = 30 # seconds
PONG_TIMEOUT = 60 # seconds
MAX_WRITE_BUFFER = 64 * 1024 # 64 KB per connection
MAX_CONNECTIONS_PER_NODE = 12000
def __init__(self):
self._connections: dict[str, WebSocketConnection] = {}
self._lock = asyncio.Lock()
@property
def active_connections(self) -> int:
return len(self._connections)
async def register(
self, session_id: str, ws: web.WebSocketResponse
) -> bool:
"""Register a new WebSocket connection."""
if self.active_connections >= self.MAX_CONNECTIONS_PER_NODE:
logger.warning("Max connections reached, rejecting new connection")
return False
now = time.monotonic()
conn = WebSocketConnection(
session_id=session_id,
ws=ws,
connected_at=now,
last_pong=now,
)
async with self._lock:
# Close existing connection for same session (reconnect case)
existing = self._connections.get(session_id)
if existing and not existing.ws.closed:
await existing.ws.close()
self._connections[session_id] = conn
# Start heartbeat for this connection
asyncio.create_task(self._heartbeat_loop(session_id))
return True
async def deregister(self, session_id: str) -> None:
"""Remove a connection from the registry."""
async with self._lock:
conn = self._connections.pop(session_id, None)
if conn and not conn.ws.closed:
await conn.ws.close()
async def send_delta(
self, session_id: str, delta_payload: str
) -> bool:
"""Send a streaming delta to a specific session with backpressure."""
conn = self._connections.get(session_id)
if conn is None or conn.ws.closed:
return False
payload_size = len(delta_payload.encode("utf-8"))
# Backpressure: if write buffer is too large, drop delta
if conn.write_buffer_size > self.MAX_WRITE_BUFFER:
logger.warning(
f"Backpressure: dropping delta for {session_id}, "
f"buffer={conn.write_buffer_size}"
)
return False
try:
conn.write_buffer_size += payload_size
await conn.ws.send_str(delta_payload)
conn.write_buffer_size = max(0, conn.write_buffer_size - payload_size)
conn.messages_sent += 1
conn.bytes_sent += payload_size
return True
except Exception as e:
logger.warning(f"Failed to send to {session_id}: {e}")
await self.deregister(session_id)
return False
async def _heartbeat_loop(self, session_id: str) -> None:
"""Send pings and detect dead connections."""
while True:
await asyncio.sleep(self.HEARTBEAT_INTERVAL)
conn = self._connections.get(session_id)
if conn is None or conn.ws.closed:
return
# Check if we received a pong recently
if time.monotonic() - conn.last_pong > self.PONG_TIMEOUT:
logger.info(f"Pong timeout for {session_id}, closing")
await self.deregister(session_id)
return
try:
await conn.ws.ping()
except Exception:
await self.deregister(session_id)
return
async def handle_pong(self, session_id: str) -> None:
"""Update last pong timestamp when client responds."""
conn = self._connections.get(session_id)
if conn:
conn.last_pong = time.monotonic()
2. Token Micro-Batching
Instead of sending every individual token as a separate WebSocket frame, batch tokens into small groups (every 50ms) to reduce frame overhead and smooth delivery.
sequenceDiagram
participant Bedrock
participant Batcher as Token Batcher
participant WebSocket as WebSocket Handler
participant Client
loop Every ~30ms (token generation)
Bedrock-->>Batcher: Token "If"
Bedrock-->>Batcher: Token " you"
Bedrock-->>Batcher: Token " loved"
end
Note over Batcher: 50ms batch window
Batcher->>WebSocket: Batch: "If you loved"
WebSocket->>Client: Delta: "If you loved"
loop Next batch
Bedrock-->>Batcher: Token " Naruto"
Bedrock-->>Batcher: Token ","
end
Batcher->>WebSocket: Batch: " Naruto,"
WebSocket->>Client: Delta: " Naruto,"
Code Example: Token Batcher
import asyncio
import json
import time
from typing import Callable
class TokenBatcher:
"""Batches LLM output tokens into micro-batches for smoother delivery."""
def __init__(
self,
batch_interval_ms: int = 50,
max_batch_size: int = 10,
on_batch: Callable[[str, str, str], None] | None = None,
):
self.batch_interval_s = batch_interval_ms / 1000
self.max_batch_size = max_batch_size
self.on_batch = on_batch
self._buffer: list[str] = []
self._flush_task: asyncio.Task | None = None
self._session_id: str = ""
self._response_id: str = ""
self._delta_index: int = 0
async def start(self, session_id: str, response_id: str) -> None:
"""Start the batching loop for a response."""
self._session_id = session_id
self._response_id = response_id
self._buffer = []
self._delta_index = 0
self._flush_task = asyncio.create_task(self._flush_loop())
async def add_token(self, token: str) -> None:
"""Add a token to the current batch."""
self._buffer.append(token)
# Flush immediately if batch is full
if len(self._buffer) >= self.max_batch_size:
await self._flush()
async def finish(self) -> None:
"""Flush remaining tokens and stop the batcher."""
if self._flush_task:
self._flush_task.cancel()
try:
await self._flush_task
except asyncio.CancelledError:
pass
# Flush any remaining tokens
if self._buffer:
await self._flush()
async def _flush_loop(self) -> None:
"""Periodically flush the token buffer."""
while True:
await asyncio.sleep(self.batch_interval_s)
if self._buffer:
await self._flush()
async def _flush(self) -> None:
"""Send the current buffer as a single delta event."""
if not self._buffer:
return
text = "".join(self._buffer)
self._buffer.clear()
delta_event = json.dumps({
"type": "chat.response.delta",
"session_id": self._session_id,
"response_id": self._response_id,
"delta": text,
"index": self._delta_index,
})
self._delta_index += 1
if self.on_batch:
await self.on_batch(self._session_id, delta_event, text)
3. HTTPS Long-Polling Fallback
For clients that cannot maintain a WebSocket connection, provide an HTTPS fallback with server-sent events or long-polling.
sequenceDiagram
participant Client
participant API as API Gateway
participant Buffer as Response Buffer
participant Orchestrator
Client->>API: POST /chat/message
API-->>Client: 202 {response_id}
loop Poll until complete
Client->>API: GET /chat/message/{response_id}
API->>Buffer: Get buffered chunks
alt New chunks available
Buffer-->>API: Buffered text
API-->>Client: 200 {partial_response, status: "streaming"}
else No new chunks
Buffer-->>API: No new data
API-->>Client: 200 {status: "streaming"}
end
end
Client->>API: GET /chat/message/{response_id}
API->>Buffer: Get final response
Buffer-->>API: Complete response
API-->>Client: 200 {full_response, status: "completed"}
Code Example: Response Buffer for HTTPS Fallback
import asyncio
import time
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class BufferedResponse:
response_id: str
session_id: str
chunks: list[str] = field(default_factory=list)
full_text: str = ""
status: str = "pending" # pending, streaming, completed, error
metadata: dict = field(default_factory=dict)
created_at: float = field(default_factory=time.monotonic)
completed_at: float | None = None
last_read_index: int = 0
class ResponseBuffer:
"""In-memory buffer for HTTPS fallback streaming responses."""
BUFFER_TTL_SECONDS = 300 # 5 minutes
def __init__(self):
self._buffers: dict[str, BufferedResponse] = {}
self._events: dict[str, asyncio.Event] = {}
def create(self, response_id: str, session_id: str) -> None:
"""Create a buffer for a new response."""
self._buffers[response_id] = BufferedResponse(
response_id=response_id,
session_id=session_id,
)
self._events[response_id] = asyncio.Event()
async def append_chunk(self, response_id: str, chunk: str) -> None:
"""Append a streaming chunk to the buffer."""
buf = self._buffers.get(response_id)
if buf is None:
return
buf.chunks.append(chunk)
buf.full_text += chunk
buf.status = "streaming"
# Signal waiting long-poll clients
event = self._events.get(response_id)
if event:
event.set()
# Reset for next wait
self._events[response_id] = asyncio.Event()
async def complete(
self, response_id: str, metadata: dict
) -> None:
"""Mark a response as completed."""
buf = self._buffers.get(response_id)
if buf is None:
return
buf.status = "completed"
buf.metadata = metadata
buf.completed_at = time.monotonic()
event = self._events.get(response_id)
if event:
event.set()
async def read(
self,
response_id: str,
from_index: int = 0,
timeout_ms: int = 10000,
) -> Optional[dict]:
"""
Read new chunks from the buffer.
Long-polls for up to timeout_ms if no new data available.
"""
buf = self._buffers.get(response_id)
if buf is None:
return None
# If there are new chunks, return immediately
if from_index < len(buf.chunks) or buf.status == "completed":
new_chunks = buf.chunks[from_index:]
return {
"response_id": response_id,
"new_text": "".join(new_chunks),
"full_text": buf.full_text,
"chunk_index": len(buf.chunks),
"status": buf.status,
"metadata": buf.metadata if buf.status == "completed" else {},
}
# Long-poll: wait for new data
event = self._events.get(response_id)
if event:
try:
await asyncio.wait_for(
event.wait(), timeout=timeout_ms / 1000
)
except asyncio.TimeoutError:
pass
# Return whatever is available
new_chunks = buf.chunks[from_index:]
return {
"response_id": response_id,
"new_text": "".join(new_chunks),
"full_text": buf.full_text,
"chunk_index": len(buf.chunks),
"status": buf.status,
"metadata": buf.metadata if buf.status == "completed" else {},
}
def cleanup_expired(self) -> int:
"""Remove expired buffers to free memory."""
now = time.monotonic()
expired = [
rid for rid, buf in self._buffers.items()
if now - buf.created_at > self.BUFFER_TTL_SECONDS
]
for rid in expired:
del self._buffers[rid]
self._events.pop(rid, None)
return len(expired)
4. ALB Sticky Session Configuration
WebSocket connections must be routed to the same ECS task for the duration of the session.
graph TD
subgraph "ALB Configuration"
A[Target Group<br>Stickiness: enabled] --> B[Cookie Duration: 3600s<br>App cookie: WSID]
B --> C[Health Check:<br>/health every 10s]
end
subgraph "Routing"
D[Client A<br>Cookie: WSID=task-1] --> E[ALB]
F[Client B<br>Cookie: WSID=task-2] --> E
E --> G[Task 1]
E --> H[Task 2]
end
style A fill:#2d8,stroke:#333
Metrics and Monitoring
| Metric | Target | Alarm Threshold |
|---|---|---|
ws.delivery_jitter_ms |
p95 < 50ms | p95 > 100ms |
ws.connection_setup_ms |
p95 < 100ms | p95 > 200ms |
ws.active_connections |
< 10K/node | > 10K/node |
ws.dead_connections_rate |
< 1%/min | > 5%/min |
ws.messages_per_second |
Monitor trend | Sudden drop > 50% |
ws.backpressure_drops |
< 0.1% | > 1% |
fallback.poll_latency_ms |
p95 < 200ms | p95 > 500ms |
fallback.poll_count_per_response |
avg < 5 | avg > 10 |
graph LR
subgraph "WebSocket Health"
A[Active Connections<br>per node]
B[Delivery Jitter<br>p95]
C[Backpressure<br>Drop Rate]
D[Dead Connection<br>Rate]
end
A --> E{> 10K?}
E -->|Yes| F[Scale out<br>ECS tasks]
C --> G{> 1%?}
G -->|Yes| H[Investigate slow<br>clients or network]