Streaming and WebSocket Implementation
MangaAssist context: JP Manga store chatbot on AWS — Bedrock Claude 3 (Sonnet at $3/$15 per 1M tokens input/output, Haiku at $0.25/$1.25), OpenSearch Serverless (vector store), DynamoDB (sessions/products), ECS Fargate (orchestrator), API Gateway WebSocket, ElastiCache Redis. Target: useful answer in under 3 seconds, 1M messages/day scale.
Skill Mapping
| Dimension | Value |
|---|---|
| Certification | AWS Certified AI Practitioner — Specialty (AIP-C01) |
| Task | 2.4 — Select and implement FM API integration patterns |
| Skill | 2.4.2 — Real-Time AI Interaction |
| This File | 02 — Streaming WebSocket Implementation (WebSocket API config, Lambda handlers, Bedrock stream relay, heartbeat) |
Skill Scope
This file provides the implementation details for the WebSocket API configuration, Lambda route handlers ($connect, $disconnect, $default), the Bedrock streaming relay through API Gateway WebSocket, and the heartbeat mechanism that keeps connections alive during long model inference. This is the "how to build it" companion to file 01's architectural overview.
Mind Map
mindmap
root((Streaming<br/>WebSocket<br/>Implementation))
WebSocket API Config
Route definitions
Stage settings
Throttling limits
Access logging
Custom domain
Lambda Route Handlers
$connect — auth + register
$disconnect — cleanup
$default — message routing
Error handling per route
Cold start mitigation
Bedrock Stream Relay
Stream initiation
Chunk parsing
Character boundary alignment
PostToConnection batching
Error propagation
Heartbeat System
Idle detection
Ping-pong protocol
Server-initiated heartbeat
Client-side reconnect
TTL management
Infrastructure as Code
CDK WebSocket API
Lambda permissions
DynamoDB connection table
CloudWatch alarms
Custom metrics
1. WebSocket API Configuration
1.1 API Gateway WebSocket API Structure
flowchart TD
subgraph "API Gateway WebSocket API"
direction TB
WS[wss://manga-ws.example.com]
WS --> R1[$connect Route<br/>Auth + Registration]
WS --> R2[$disconnect Route<br/>Cleanup]
WS --> R3[$default Route<br/>Message Handling]
R1 --> L1[manga-ws-connect<br/>Lambda]
R2 --> L2[manga-ws-disconnect<br/>Lambda]
R3 --> L3[manga-ws-default<br/>Lambda]
subgraph "Stage: prod"
S1[Throttle: 500 req/s]
S2[Burst: 1000 req/s]
S3[Idle timeout: 10 min]
S4[Access logging: enabled]
end
end
L1 --> DDB[(DynamoDB<br/>Connections)]
L3 --> ECS[ECS Fargate<br/>Orchestrator]
L2 --> DDB
1.2 CDK Infrastructure
"""
CDK stack for MangaAssist WebSocket API.
Defines the WebSocket API, Lambda handlers, DynamoDB connection table, and IAM roles.
"""
from aws_cdk import (
Stack,
Duration,
RemovalPolicy,
CfnOutput,
aws_apigatewayv2 as apigwv2,
aws_apigatewayv2_integrations as integrations,
aws_lambda as lambda_,
aws_dynamodb as dynamodb,
aws_iam as iam,
aws_logs as logs,
)
from constructs import Construct
class MangaWebSocketStack(Stack):
"""WebSocket API stack for MangaAssist real-time chat."""
def __init__(self, scope: Construct, construct_id: str, **kwargs):
super().__init__(scope, construct_id, **kwargs)
# --- DynamoDB Connection Table ---
connections_table = dynamodb.Table(
self, "ConnectionsTable",
table_name="manga-ws-connections",
partition_key=dynamodb.Attribute(
name="connectionId",
type=dynamodb.AttributeType.STRING,
),
billing_mode=dynamodb.BillingMode.PAY_PER_REQUEST,
removal_policy=RemovalPolicy.DESTROY,
time_to_live_attribute="ttl",
)
connections_table.add_global_secondary_index(
index_name="userId-index",
partition_key=dynamodb.Attribute(
name="userId",
type=dynamodb.AttributeType.STRING,
),
)
# --- Lambda Handlers ---
common_env = {
"CONNECTIONS_TABLE": connections_table.table_name,
"POWERTOOLS_SERVICE_NAME": "manga-ws",
"LOG_LEVEL": "INFO",
}
connect_handler = lambda_.Function(
self, "ConnectHandler",
function_name="manga-ws-connect",
runtime=lambda_.Runtime.PYTHON_3_12,
handler="handlers.connect.handler",
code=lambda_.Code.from_asset("lambda/websocket"),
timeout=Duration.seconds(10),
memory_size=256,
environment=common_env,
)
disconnect_handler = lambda_.Function(
self, "DisconnectHandler",
function_name="manga-ws-disconnect",
runtime=lambda_.Runtime.PYTHON_3_12,
handler="handlers.disconnect.handler",
code=lambda_.Code.from_asset("lambda/websocket"),
timeout=Duration.seconds(10),
memory_size=256,
environment=common_env,
)
default_handler = lambda_.Function(
self, "DefaultHandler",
function_name="manga-ws-default",
runtime=lambda_.Runtime.PYTHON_3_12,
handler="handlers.default_route.handler",
code=lambda_.Code.from_asset("lambda/websocket"),
timeout=Duration.seconds(60), # Long timeout for streaming
memory_size=512,
environment={
**common_env,
"ECS_ENDPOINT": "http://internal-manga-alb.us-east-1.elb.amazonaws.com",
},
)
# Grant DynamoDB access
connections_table.grant_read_write_data(connect_handler)
connections_table.grant_read_write_data(disconnect_handler)
connections_table.grant_read_write_data(default_handler)
# --- WebSocket API ---
web_socket_api = apigwv2.WebSocketApi(
self, "MangaWebSocketApi",
api_name="manga-assist-ws",
connect_route_options=apigwv2.WebSocketRouteOptions(
integration=integrations.WebSocketLambdaIntegration(
"ConnectIntegration", connect_handler,
),
),
disconnect_route_options=apigwv2.WebSocketRouteOptions(
integration=integrations.WebSocketLambdaIntegration(
"DisconnectIntegration", disconnect_handler,
),
),
default_route_options=apigwv2.WebSocketRouteOptions(
integration=integrations.WebSocketLambdaIntegration(
"DefaultIntegration", default_handler,
),
),
)
stage = apigwv2.WebSocketStage(
self, "ProdStage",
web_socket_api=web_socket_api,
stage_name="prod",
auto_deploy=True,
)
# Grant Lambda permission to post back to WebSocket connections
default_handler.add_to_role_policy(iam.PolicyStatement(
actions=["execute-api:ManageConnections"],
resources=[
f"arn:aws:execute-api:{self.region}:{self.account}:"
f"{web_socket_api.api_id}/{stage.stage_name}/POST/@connections/*"
],
))
CfnOutput(self, "WebSocketUrl", value=stage.url)
CfnOutput(self, "WebSocketCallbackUrl",
value=f"https://{web_socket_api.api_id}.execute-api."
f"{self.region}.amazonaws.com/{stage.stage_name}")
2. Lambda Route Handlers
2.1 $connect Handler
"""
$connect route handler for MangaAssist WebSocket API.
Authenticates the connection and registers it in DynamoDB.
"""
import os
import time
import logging
import json
from typing import Dict, Any
import boto3
logger = logging.getLogger(__name__)
table = boto3.resource("dynamodb").Table(os.environ["CONNECTIONS_TABLE"])
def handler(event: Dict[str, Any], context) -> Dict[str, Any]:
"""
Handle WebSocket $connect events.
Authentication options:
1. Query string token: wss://...?token=<jwt>
2. Header-based: Sec-WebSocket-Protocol carrying auth token
Note: API Gateway WebSocket $connect does NOT support
custom Authorization headers. Use query string or protocol header.
"""
connection_id = event["requestContext"]["connectionId"]
query_params = event.get("queryStringParameters") or {}
logger.info("WebSocket connect | connId=%s", connection_id)
# Extract authentication token
token = query_params.get("token", "")
if not token:
# Try Sec-WebSocket-Protocol header
headers = event.get("headers", {})
token = headers.get("Sec-WebSocket-Protocol", "")
# Validate token (simplified — use Cognito or custom authorizer in prod)
user_id = _validate_token(token)
if not user_id:
logger.warning("Auth failed | connId=%s", connection_id)
return {"statusCode": 403, "body": "Unauthorized"}
# Register connection
now = int(time.time() * 1000)
try:
table.put_item(Item={
"connectionId": connection_id,
"userId": user_id,
"sessionId": query_params.get("sessionId", ""),
"connectedAt": now,
"lastActiveAt": now,
"clientInfo": {
"userAgent": event.get("headers", {}).get("User-Agent", ""),
"sourceIp": event["requestContext"].get("identity", {}).get("sourceIp", ""),
},
"ttl": int(time.time()) + 7200, # 2-hour TTL
})
except Exception as exc:
logger.error("Failed to register connection: %s", exc)
return {"statusCode": 500, "body": "Registration failed"}
logger.info("Connection registered | connId=%s | userId=%s", connection_id, user_id)
return {"statusCode": 200, "body": "Connected"}
def _validate_token(token: str) -> str:
"""
Validate authentication token and return user ID.
In production, verify JWT against Cognito or custom auth.
"""
if not token:
return ""
# Simplified: decode JWT claim (use PyJWT with verification in prod)
try:
import base64
parts = token.split(".")
if len(parts) >= 2:
payload = json.loads(
base64.urlsafe_b64decode(parts[1] + "==").decode("utf-8")
)
return payload.get("sub", "")
except Exception:
pass
return ""
2.2 $disconnect Handler
"""
$disconnect route handler — cleans up connection state.
"""
import os
import time
import logging
from typing import Dict, Any
import boto3
logger = logging.getLogger(__name__)
table = boto3.resource("dynamodb").Table(os.environ["CONNECTIONS_TABLE"])
def handler(event: Dict[str, Any], context) -> Dict[str, Any]:
"""
Handle WebSocket $disconnect events.
Called when:
- Client closes the connection normally
- Client loses network connectivity (detected by API GW)
- Idle timeout (10 minutes) reached
- Server-side disconnect via DeleteConnection API
"""
connection_id = event["requestContext"]["connectionId"]
disconnect_reason = event["requestContext"].get("disconnectStatusCode", "unknown")
logger.info(
"WebSocket disconnect | connId=%s | reason=%s",
connection_id, disconnect_reason,
)
try:
# Get connection info before deleting (for logging)
response = table.get_item(Key={"connectionId": connection_id})
item = response.get("Item", {})
if item:
duration_ms = int(time.time() * 1000) - item.get("connectedAt", 0)
logger.info(
"Cleaning up | connId=%s | userId=%s | duration=%dms",
connection_id,
item.get("userId", "unknown"),
duration_ms,
)
table.delete_item(Key={"connectionId": connection_id})
except Exception as exc:
logger.error("Disconnect cleanup failed: %s", exc)
return {"statusCode": 200, "body": "Disconnected"}
2.3 $default Route Handler
"""
$default route handler — processes all WebSocket messages.
Routes to appropriate action handlers and manages streaming responses.
"""
import os
import json
import time
import logging
from typing import Dict, Any
import boto3
import urllib.request
logger = logging.getLogger(__name__)
table = boto3.resource("dynamodb").Table(os.environ["CONNECTIONS_TABLE"])
ECS_ENDPOINT = os.environ.get("ECS_ENDPOINT", "http://localhost:8080")
def handler(event: Dict[str, Any], context) -> Dict[str, Any]:
"""
Handle all WebSocket messages routed to $default.
Message format:
{
"action": "chat" | "ping" | "typing" | "history",
"data": { ... }
}
"""
connection_id = event["requestContext"]["connectionId"]
domain = event["requestContext"]["domainName"]
stage = event["requestContext"]["stage"]
callback_url = f"https://{domain}/{stage}"
# Parse message body
try:
body = json.loads(event.get("body", "{}"))
except json.JSONDecodeError:
return _send_error(callback_url, connection_id, "Invalid JSON")
action = body.get("action", "chat")
logger.info("Message received | connId=%s | action=%s", connection_id, action)
# Update activity timestamp
_update_activity(connection_id)
# Route to action handler
if action == "ping":
return _handle_ping(callback_url, connection_id)
elif action == "chat":
return _handle_chat(callback_url, connection_id, body)
elif action == "typing":
return {"statusCode": 200} # No-op acknowledgment
elif action == "history":
return _handle_history(callback_url, connection_id, body)
else:
return _send_error(callback_url, connection_id, f"Unknown action: {action}")
def _handle_ping(callback_url: str, connection_id: str) -> Dict:
"""Respond to client ping with pong."""
_post_to_connection(callback_url, connection_id, {
"type": "pong",
"timestamp": int(time.time() * 1000),
})
return {"statusCode": 200}
def _handle_chat(callback_url: str, connection_id: str, body: Dict) -> Dict:
"""
Forward chat message to ECS orchestrator for streaming response.
The ECS service handles:
1. Session context loading
2. Bedrock streaming invocation
3. Chunk-by-chunk relay back via PostToConnection
"""
data = body.get("data", {})
message = data.get("message", "")
session_id = data.get("sessionId", "")
if not message:
return _send_error(callback_url, connection_id, "Empty message")
# Send "processing" indicator to client
_post_to_connection(callback_url, connection_id, {
"type": "status",
"status": "processing",
"message": "考え中...", # "Thinking..." in Japanese
})
# Forward to ECS orchestrator
try:
request_payload = json.dumps({
"connectionId": connection_id,
"callbackUrl": callback_url,
"sessionId": session_id,
"message": message,
"language": data.get("preferredLanguage", "ja"),
"modelPreference": data.get("modelPreference", "auto"),
}, ensure_ascii=False).encode("utf-8")
req = urllib.request.Request(
f"{ECS_ENDPOINT}/stream-chat",
data=request_payload,
headers={"Content-Type": "application/json"},
method="POST",
)
with urllib.request.urlopen(req, timeout=55) as resp:
result = json.loads(resp.read().decode("utf-8"))
logger.info(
"Stream chat initiated | connId=%s | status=%s",
connection_id, result.get("status"),
)
except Exception as exc:
logger.error("ECS request failed: %s", exc)
_post_to_connection(callback_url, connection_id, {
"type": "error",
"error": "Service temporarily unavailable",
"errorJp": "サービスが一時的に利用できません",
})
return {"statusCode": 200}
def _handle_history(callback_url: str, connection_id: str, body: Dict) -> Dict:
"""Return conversation history for the session."""
session_id = body.get("data", {}).get("sessionId", "")
if not session_id:
return _send_error(callback_url, connection_id, "Missing sessionId")
sessions_table = boto3.resource("dynamodb").Table("manga-assist-sessions")
from boto3.dynamodb.conditions import Key
response = sessions_table.query(
KeyConditionExpression=Key("sessionId").eq(session_id),
ScanIndexForward=True,
Limit=50,
)
_post_to_connection(callback_url, connection_id, {
"type": "history",
"sessionId": session_id,
"messages": response.get("Items", []),
})
return {"statusCode": 200}
def _post_to_connection(callback_url: str, connection_id: str, payload: Dict) -> bool:
"""Send a message to a specific WebSocket connection."""
try:
apigw = boto3.client(
"apigatewaymanagementapi",
endpoint_url=callback_url,
)
apigw.post_to_connection(
ConnectionId=connection_id,
Data=json.dumps(payload, ensure_ascii=False).encode("utf-8"),
)
return True
except Exception as exc:
logger.error("PostToConnection failed | conn=%s | error=%s", connection_id, exc)
return False
def _send_error(callback_url: str, connection_id: str, error: str) -> Dict:
"""Send an error message to the client."""
_post_to_connection(callback_url, connection_id, {
"type": "error",
"error": error,
})
return {"statusCode": 400}
def _update_activity(connection_id: str) -> None:
"""Update last active timestamp to prevent premature TTL."""
try:
now = int(time.time() * 1000)
table.update_item(
Key={"connectionId": connection_id},
UpdateExpression="SET lastActiveAt = :ts, #ttl = :ttl",
ExpressionAttributeNames={"#ttl": "ttl"},
ExpressionAttributeValues={
":ts": now,
":ttl": int(time.time()) + 7200,
},
)
except Exception as exc:
logger.warning("Activity update failed: %s", exc)
3. ECS Streaming Orchestrator
3.1 Stream Chat Endpoint
"""
ECS Fargate streaming orchestrator endpoint.
Receives requests from Lambda, invokes Bedrock streaming,
and relays chunks back through WebSocket.
"""
import json
import time
import re
import logging
from typing import Dict, Any
import boto3
from flask import Flask, request, jsonify
logger = logging.getLogger(__name__)
app = Flask(__name__)
bedrock = boto3.client("bedrock-runtime", region_name="us-east-1")
SYSTEM_PROMPTS = {
"ja": (
"あなたはMangaAssistです。日本のマンガ書店のチャットボットです。"
"お客様のマンガ探し、シリーズや作者に関する質問への回答、おすすめの提案を行います。"
"丁寧な敬語を使用してください。"
),
"en": (
"You are MangaAssist, a helpful chatbot for a Japanese manga store. "
"Help customers find manga, answer questions, and provide recommendations."
),
}
@app.route("/stream-chat", methods=["POST"])
def stream_chat():
"""
Handle streaming chat requests from the WebSocket Lambda.
Flow:
1. Load session context (Redis then DynamoDB fallback)
2. Build prompt
3. Invoke Bedrock streaming
4. Relay each chunk to WebSocket via PostToConnection
5. Save session and return summary
"""
payload = request.get_json()
connection_id = payload["connectionId"]
callback_url = payload["callbackUrl"]
session_id = payload.get("sessionId", "")
message = payload["message"]
language = payload.get("language", "ja")
logger.info("Stream chat | conn=%s | session=%s", connection_id, session_id)
apigw = boto3.client(
"apigatewaymanagementapi",
endpoint_url=callback_url,
)
start_time = time.time()
try:
# 1. Load conversation context
history = _load_context(session_id)
# 2. Build messages array
messages = []
for turn in history[-6:]: # Last 3 turns for streaming
messages.append({"role": turn["role"], "content": turn["content"]})
messages.append({"role": "user", "content": message})
# 3. Select model
model_id = _select_model(message, language)
# 4. Invoke Bedrock streaming
bedrock_body = {
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 2048,
"temperature": 0.3,
"system": SYSTEM_PROMPTS.get(language, SYSTEM_PROMPTS["en"]),
"messages": messages,
}
response = bedrock.invoke_model_with_response_stream(
modelId=model_id,
contentType="application/json",
accept="application/json",
body=json.dumps(bedrock_body),
)
# 5. Relay stream chunks
full_text = ""
chunk_count = 0
text_buffer = ""
last_send = time.time()
first_token = True
ttft_ms = 0
for event in response["body"]:
chunk_data = event.get("chunk")
if not chunk_data:
continue
payload_bytes = chunk_data["bytes"]
event_payload = json.loads(payload_bytes.decode("utf-8"))
event_type = event_payload.get("type", "")
if event_type == "content_block_delta":
delta_text = event_payload.get("delta", {}).get("text", "")
if not delta_text:
continue
if first_token:
ttft_ms = (time.time() - start_time) * 1000
first_token = False
text_buffer += delta_text
full_text += delta_text
# Batch: send every 50ms or when buffer > 20 chars
now = time.time()
if len(text_buffer) >= 20 or (now - last_send) > 0.05:
_send_chunk(apigw, connection_id, text_buffer, chunk_count)
chunk_count += 1
text_buffer = ""
last_send = now
elif event_type == "message_delta":
# Final usage and stop reason
pass
# Flush remaining buffer
if text_buffer:
_send_chunk(apigw, connection_id, text_buffer, chunk_count)
chunk_count += 1
total_ms = (time.time() - start_time) * 1000
# Send completion signal
_send_to_ws(apigw, connection_id, {
"type": "done",
"totalLength": len(full_text),
"chunks": chunk_count,
"ttftMs": round(ttft_ms),
"totalMs": round(total_ms),
})
# Save to session
_save_turn(session_id, message, full_text)
logger.info(
"Stream complete | conn=%s | ttft=%.0fms | total=%.0fms | chunks=%d",
connection_id, ttft_ms, total_ms, chunk_count,
)
return jsonify({"status": "completed", "chunks": chunk_count})
except Exception as exc:
logger.error("Stream chat error: %s", exc, exc_info=True)
try:
_send_to_ws(apigw, connection_id, {
"type": "error",
"error": "Stream processing failed",
"errorJp": "ストリーム処理に失敗しました",
})
except Exception:
pass
return jsonify({"status": "error", "error": str(exc)}), 500
def _send_chunk(apigw, connection_id: str, text: str, index: int) -> None:
"""Send a text chunk to the WebSocket client."""
_send_to_ws(apigw, connection_id, {
"type": "chunk",
"text": text,
"index": index,
})
def _send_to_ws(apigw, connection_id: str, payload: dict) -> None:
"""Send any payload to a WebSocket connection."""
try:
apigw.post_to_connection(
ConnectionId=connection_id,
Data=json.dumps(payload, ensure_ascii=False).encode("utf-8"),
)
except apigw.exceptions.GoneException:
logger.info("Client disconnected | conn=%s", connection_id)
raise # Stop streaming
except Exception as exc:
logger.warning("PostToConnection error | conn=%s | %s", connection_id, exc)
def _select_model(message: str, language: str) -> str:
"""Select model based on message complexity."""
chars = len(message)
jp_chars = len(re.findall(r"[\u3000-\u9fff]", message))
if chars < 50 and jp_chars < 20:
return "anthropic.claude-3-haiku-20240307-v1:0"
return "anthropic.claude-3-sonnet-20240229-v1:0"
def _load_context(session_id: str) -> list:
"""Load session context from DynamoDB."""
if not session_id:
return []
try:
ddb = boto3.resource("dynamodb").Table("manga-assist-sessions")
from boto3.dynamodb.conditions import Key
resp = ddb.query(
KeyConditionExpression=Key("sessionId").eq(session_id),
ScanIndexForward=False,
Limit=6,
)
return list(reversed(resp.get("Items", [])))
except Exception:
return []
def _save_turn(session_id: str, user_msg: str, assistant_msg: str) -> None:
"""Save the conversation turn to DynamoDB."""
if not session_id:
return
try:
ddb = boto3.resource("dynamodb").Table("manga-assist-sessions")
ts = int(time.time() * 1000)
ddb.put_item(Item={"sessionId": session_id, "timestamp": ts, "role": "user", "content": user_msg})
ddb.put_item(Item={"sessionId": session_id, "timestamp": ts + 1, "role": "assistant", "content": assistant_msg})
except Exception as exc:
logger.error("Session save failed: %s", exc)
4. Heartbeat and Keep-Alive System
4.1 Server-Side Heartbeat Flow
sequenceDiagram
participant Client as Browser
participant APIGW as API Gateway<br/>WebSocket
participant Lambda as Lambda
Note over Client,Lambda: Connection idle for 9 minutes
Client->>APIGW: {"action":"ping"}
APIGW->>Lambda: $default route
Lambda->>APIGW: PostToConnection<br/>{"type":"pong"}
APIGW->>Client: pong frame
Note over Client,Lambda: Timer reset — 10 more minutes
Note over Client,Lambda: If no ping for 10 min:
APIGW->>Lambda: $disconnect (idle timeout)
Lambda->>Lambda: Cleanup connection record
4.2 Client-Side Heartbeat Reference
"""
Client-side heartbeat protocol reference implementation (Python).
In production, this runs in the browser as JavaScript.
Included here for testing and documentation.
"""
import json
import time
import logging
import threading
import websocket # websocket-client library
logger = logging.getLogger(__name__)
class MangaAssistWSClient:
"""
WebSocket client with heartbeat and auto-reconnect.
Reference implementation for the MangaAssist chat frontend.
"""
def __init__(
self,
url: str,
token: str,
session_id: str = "",
heartbeat_interval: int = 120, # seconds (< 10min idle timeout)
max_reconnect_attempts: int = 5,
):
self.url = f"{url}?token={token}&sessionId={session_id}"
self.heartbeat_interval = heartbeat_interval
self.max_reconnect_attempts = max_reconnect_attempts
self._ws = None
self._heartbeat_thread = None
self._heartbeat_stop = threading.Event()
self._reconnect_count = 0
def connect(self):
"""Establish WebSocket connection with heartbeat."""
self._ws = websocket.WebSocketApp(
self.url,
on_open=self._on_open,
on_message=self._on_message,
on_close=self._on_close,
on_error=self._on_error,
)
threading.Thread(target=self._ws.run_forever, daemon=True).start()
def send_chat(self, message: str, language: str = "ja") -> None:
"""Send a chat message."""
self._ws.send(json.dumps({
"action": "chat",
"data": {
"message": message,
"preferredLanguage": language,
},
}, ensure_ascii=False))
def _on_open(self, ws):
"""Connection established — start heartbeat."""
logger.info("WebSocket connected")
self._reconnect_count = 0
self._start_heartbeat()
def _on_message(self, ws, message):
"""Handle incoming messages by type."""
try:
data = json.loads(message)
msg_type = data.get("type", "")
if msg_type == "pong":
logger.debug("Heartbeat pong received")
elif msg_type == "chunk":
print(data.get("text", ""), end="", flush=True)
elif msg_type == "done":
print() # Newline after stream completes
logger.info(
"Stream complete | ttft=%dms | total=%dms",
data.get("ttftMs", 0), data.get("totalMs", 0),
)
elif msg_type == "error":
logger.error("Server error: %s", data.get("error"))
elif msg_type == "heartbeat":
logger.debug("Server heartbeat received")
elif msg_type == "status":
logger.info("Status: %s", data.get("message"))
except json.JSONDecodeError:
logger.warning("Invalid message: %s", message[:100])
def _on_close(self, ws, close_code, close_msg):
"""Handle connection close — attempt reconnect with backoff."""
logger.info("WebSocket closed | code=%s | msg=%s", close_code, close_msg)
self._heartbeat_stop.set()
if self._reconnect_count < self.max_reconnect_attempts:
self._reconnect_count += 1
delay = min(2 ** self._reconnect_count, 30)
logger.info("Reconnecting in %ds (attempt %d)", delay, self._reconnect_count)
time.sleep(delay)
self.connect()
def _on_error(self, ws, error):
"""Handle WebSocket errors."""
logger.error("WebSocket error: %s", error)
def _start_heartbeat(self):
"""Start the heartbeat thread."""
self._heartbeat_stop.clear()
self._heartbeat_thread = threading.Thread(
target=self._heartbeat_loop, daemon=True,
)
self._heartbeat_thread.start()
def _heartbeat_loop(self):
"""Send periodic pings to keep the connection alive."""
while not self._heartbeat_stop.is_set():
self._heartbeat_stop.wait(self.heartbeat_interval)
if self._heartbeat_stop.is_set():
break
try:
self._ws.send(json.dumps({"action": "ping"}))
logger.debug("Heartbeat ping sent")
except Exception as exc:
logger.warning("Heartbeat send failed: %s", exc)
break
5. Server-Sent Events (SSE) Alternative
5.1 SSE Architecture
For clients that cannot use WebSockets (REST API consumers, CLI tools), SSE provides a simpler one-way streaming alternative.
flowchart LR
subgraph "SSE Flow"
A[Client] -->|GET /chat/stream<br/>Accept: text/event-stream| B[API Gateway<br/>HTTP API]
B --> C[Lambda<br/>Response Streaming]
C --> D[Bedrock<br/>Streaming]
D -->|chunks| C
C -->|SSE events| B
B -->|SSE events| A
end
"""
SSE endpoint using Lambda response streaming.
Alternative to WebSocket for HTTP-based clients.
"""
import json
import logging
import boto3
logger = logging.getLogger(__name__)
bedrock = boto3.client("bedrock-runtime")
def handler(event, context):
"""
Lambda function URL with response streaming enabled.
Configure: FunctionUrlConfig.InvokeMode = RESPONSE_STREAM
The function streams SSE events directly to the client
without buffering the full response.
"""
body = json.loads(event.get("body", "{}"))
message = body.get("message", "")
bedrock_body = {
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 2048,
"temperature": 0.3,
"messages": [{"role": "user", "content": message}],
}
response = bedrock.invoke_model_with_response_stream(
modelId="anthropic.claude-3-haiku-20240307-v1:0",
contentType="application/json",
accept="application/json",
body=json.dumps(bedrock_body),
)
def generate():
yield _sse_event("stream_start", {"model": "haiku"})
for event_data in response["body"]:
chunk = event_data.get("chunk")
if not chunk:
continue
payload = json.loads(chunk["bytes"].decode("utf-8"))
if payload.get("type") == "content_block_delta":
text = payload.get("delta", {}).get("text", "")
if text:
yield _sse_event("content", {"text": text})
yield _sse_event("stream_end", {})
return {
"statusCode": 200,
"headers": {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
"body": generate(),
}
def _sse_event(event_type: str, data: dict) -> str:
"""Format a Server-Sent Event."""
return f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
6. Japanese Text Chunking Considerations
6.1 Character Boundary Awareness
"""
Japanese-aware text chunking for WebSocket streaming.
Ensures chunks don't split multi-byte characters or break
at awkward positions in Japanese text.
"""
import re
def align_chunk_boundary(buffer: str, min_size: int = 5) -> tuple:
"""
Split buffer at a natural Japanese boundary.
Returns (chunk_to_send, remaining_buffer).
Prefers splitting after: sentence endings, particles, punctuation.
"""
if len(buffer) < min_size:
return "", buffer
# Japanese sentence-ending patterns (in priority order)
break_patterns = [
r"[。!?\n]", # Sentence endings
r"[、,]", # Clause breaks
r"[)」』】]", # Closing brackets
r"[\s]", # Whitespace
]
for pattern in break_patterns:
matches = list(re.finditer(pattern, buffer))
if matches:
# Split after the last match that leaves at least min_size chars
for match in reversed(matches):
pos = match.end()
if pos >= min_size:
return buffer[:pos], buffer[pos:]
# No natural break found — send everything
return buffer, ""
Key Takeaways
| # | Takeaway |
|---|---|
| 1 | WebSocket $connect does not support Authorization headers — use query string tokens or Sec-WebSocket-Protocol header for auth. |
| 2 | API Gateway WebSocket idle timeout is 10 minutes — send heartbeat pings every 2 minutes from the client to keep the connection alive. |
| 3 | PostToConnection is how the backend sends data to WebSocket clients — the Lambda/ECS service calls this API for each stream chunk. |
| 4 | GoneException from PostToConnection signals client disconnect — catch it to stop Bedrock streaming and avoid wasted inference cost. |
| 5 | Chunk batching (20 chars or 50ms) reduces PostToConnection API calls while keeping perceived streaming smoothness. |
| 6 | Lambda timeout for $default must be 60s (not the default 3s) to accommodate full streaming sessions including Bedrock inference. |
| 7 | SSE via Lambda response streaming provides a simpler alternative for HTTP clients that cannot use WebSockets. |
| 8 | Japanese text chunking should align on sentence endings or particle boundaries to avoid awkward mid-word splits in the UI. |