vLLM Low-Level Implementation And Critical Decisions
A low-level implementation guide for the vLLM scenario set: memory efficiency, dynamic scheduling, multi-turn reuse, adapter routing, long-context stability, and migration-safe serving contracts. Every code pattern in this file runs in the MangaAssist production stack on Llama-3-8B-Instruct (AWQ INT4) served by vLLM 0.4.3 on ml.g5.xlarge (NVIDIA A10G, 24 GB VRAM).
1. Scope
This guide expands the scenario narratives in 01-vllm-game-changer-scenarios.md into production code patterns with full implementation detail.
Related documents: - Deployment lifecycle: 04-vllm-deployment-and-infrastructure.md — Docker build, SageMaker config, scaling policies, startup scripts - Monitoring and troubleshooting: 05-vllm-monitoring-and-troubleshooting.md — Metrics, alerting rules, SLOs, runbooks - Model preparation: 06-vllm-model-preparation-and-quantization.md — AWQ calibration, LoRA training, CI/CD pipeline
The target outcome is a production-grade self-hosted generation path for MangaAssist that is:
- fast enough for interactive chat (P50 TTFT < 200 ms, P99 total < 1,400 ms)
- efficient enough to justify self-hosting (85–90 concurrent sequences per GPU)
- stable under long multi-turn sessions (zero OOM restarts)
- observable enough to debug (per-request trace with adapter, cache hit, queue wait)
- modular enough to swap backends later (backend-neutral response contract)
2. Target Architecture
graph TD
A[Chat UI] --> B[API Gateway]
B --> C[Chat Orchestrator]
C --> D[Model Router]
D --> E[vLLM Gateway]
D --> F[Bedrock Adapter]
E --> G[Request Budgeter]
E --> H[Adapter Registry]
E --> I[Prompt Canonicalizer]
E --> J[vLLM Engine]
J --> K[Prefix Cache]
J --> L[PagedAttention KV Blocks]
J --> M[Continuous Scheduler]
C --> N[Trace and Metrics]
E --> N
J --> O[Streaming SSE or WebSocket]
3. Suggested Code Layout
app/
inference/
router.py
generation_policy.py
request_budgeter.py
serving/
vllm_gateway.py
vllm_engine.py
readiness.py
stream_adapter.py
prompts/
canonicalizer.py
prompt_versions.py
adapters/
registry.py
lora_selector.py
resilience/
oom_guard.py
fallback_policy.py
admission_control.py
observability/
tracing.py
inference_metrics.py
eval/
quantization_eval.py
adapter_regression_eval.py
shadow_compare.py
4. Shared Runtime Contracts
Request Contract
Every request to the self-hosted path should already be normalized by the orchestrator:
{
"request_id": "req_9f1b",
"session_id": "sess_77ad",
"route": "self_hosted_generation",
"intent": "manga_recommendation",
"adapter_id": "manga_domain_v3",
"prompt_version": "chat-manga-2.1",
"messages": [
{"role": "system", "content": "You are a manga shopping assistant."},
{"role": "user", "content": "I want something like Vinland Saga."}
],
"retrieval_chunks": [
{"chunk_id": "manga_441", "source_type": "catalog"}
],
"stream": true
}
Response Contract
Keep the response contract backend-neutral so the app does not care whether the answer came from vLLM or a managed provider:
{
"request_id": "req_9f1b",
"backend": "vllm",
"backend_version": "llama-8b-awq+manga_domain_v3",
"trace_id": "trace_4c1e",
"output_text": "If you liked Vinland Saga, start with...",
"finish_reason": "stop",
"usage": {
"input_tokens": 842,
"output_tokens": 137
}
}
5. Cross-Cutting Production Rules
- Create one top-level trace per customer-visible response.
- Keep one engine process per GPU worker. Do not run multiple unrelated Python workers against the same GPU allocator.
- Separate readiness from liveness. A process can be alive before weights, CUDA graphs, and warmup are ready.
- Apply admission control before the engine queue becomes the only backpressure mechanism.
- Keep the prompt contract deterministic anywhere you want prefix caching to help.
- Treat adapter choice and prompt version as first-class metadata in traces and logs.
6. Base Engine Construction
import logging
import os
from vllm import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs
logger = logging.getLogger(__name__)
def build_engine(model_path: str) -> AsyncLLMEngine:
"""
Build the vLLM async engine with production-tuned parameters.
Every parameter here is an operating decision, not a default:
- gpu_memory_utilization=0.92: Leaves 1.92 GB headroom (8%) for CUDA workspace,
driver allocations, and non-model tensors. 0.95 caused sporadic OOMs under
multi-adapter loads. 0.90 wasted ~480 MB of usable KV cache.
- max_num_seqs=128: Ceiling based on VRAM budget math. 14 GB KV cache at
block_size=16 with avg context ~1,100 tokens supports ~128 concurrent sequences.
Higher values cause preemptions; lower values waste throughput.
- max_num_batched_tokens=8192: Caps total tokens in a single scheduler step.
Prevents 1-2 large prompts (3,500+ tokens) from crowding out 50+ short requests.
- block_size=16: Tested 8, 16, and 32. Block size 8 had too much table overhead
for our avg response length (~140 tokens). Block size 32 wasted space on short
factual answers (~30 tokens). 16 was the sweet spot.
- enable_prefix_caching=True: Critical for chatbot workloads. The deterministic
system prompt + policy block prefix is shared across ~72% of requests.
- enforce_eager=False: CUDA graphs reduce kernel launch overhead. First request
after startup is slower (graph capture) but steady-state is 15-20% faster.
Environment variable overrides are supported for tuning without rebuilding:
see 04-vllm-deployment-and-infrastructure.md Section 10 for the full reference.
"""
args = AsyncEngineArgs(
model=model_path,
tensor_parallel_size=int(os.environ.get("TENSOR_PARALLEL_SIZE", "1")),
gpu_memory_utilization=float(os.environ.get("GPU_MEMORY_UTILIZATION", "0.92")),
max_model_len=int(os.environ.get("MAX_MODEL_LEN", "4096")),
max_num_seqs=int(os.environ.get("MAX_NUM_SEQS", "128")),
max_num_batched_tokens=int(os.environ.get("MAX_NUM_BATCHED_TOKENS", "8192")),
block_size=int(os.environ.get("BLOCK_SIZE", "16")),
enable_prefix_caching=os.environ.get("ENABLE_PREFIX_CACHING", "true").lower() == "true",
quantization=os.environ.get("QUANTIZATION", "awq"),
enforce_eager=os.environ.get("ENFORCE_EAGER", "false").lower() == "true",
)
logger.info(
"Building vLLM engine: model=%s, gpu_mem=%.2f, max_seqs=%d, "
"max_batched_tokens=%d, block_size=%d, prefix_caching=%s",
model_path,
args.gpu_memory_utilization,
args.max_num_seqs,
args.max_num_batched_tokens,
args.block_size,
args.enable_prefix_caching,
)
return AsyncLLMEngine.from_engine_args(args)
The values above are not magic defaults. They are operating decisions shaped by the chatbot workload. See the VRAM budget breakdown in 06-vllm-model-preparation-and-quantization.md Section 3 for the math behind each number.
7. Scenario 1 - PagedAttention And KV-Cache Efficiency
Implementation Flow
- Route only eligible self-hosted generation traffic to vLLM.
- Enforce a clear
max_model_lenand message-budget policy before the request reaches the engine. - Run one vLLM engine per GPU with explicit
gpu_memory_utilization. - Emit queue depth, active sequence count, and VRAM metrics continuously.
- Scale the endpoint on queueing and active sequence pressure, not only CPU or request count.
Key Components
request_budgeter.pytrims or summarizes low-value history before inference.vllm_engine.pyowns engine lifecycle and engine args.inference_metrics.pyemitsactive_sequences,queue_wait_ms, andgpu_memory_used_mb.
Critical Decisions
gpu_memory_utilization=0.92: Left headroom for CUDA workspace and non-model allocations instead of chasing dangerous full occupancy.block_size=16: Chosen as a good tradeoff between fragmentation reduction and block-table overhead for the observed response lengths.- One process per GPU: Avoided allocator contention and made memory behavior more predictable than a multi-worker process model.
- Explicit token budgeting before inference: Prevented the engine from becoming the first place where prompt excess was discovered.
Why These Decisions Were Correct
If we had only dropped vLLM into the old serving path without upstream token budgeting and explicit memory headroom, we would have improved throughput but still created intermittent stability failures. The low-level decision quality mattered as much as the library choice.
8. Scenario 2 - Continuous Batching And Admission Control
Implementation Flow
- Normalize every request into a token estimate before enqueue.
- Reject or degrade requests when the queue exceeds the latency budget.
- Let vLLM refill sequence slots continuously.
- Stream tokens immediately after decode starts.
- Emit both engine latency and queue wait latency so scheduling gains are measurable.
Admission-Control Pattern
The admission controller sits between the API gateway and the vLLM engine. It prevents GPU saturation from becoming the only backpressure mechanism.
"""
Admission control for the vLLM gateway.
Prevents queue saturation and bounds user-visible wait time.
"""
import logging
import time
from dataclasses import dataclass, field
from enum import Enum
logger = logging.getLogger(__name__)
class RejectionReason(Enum):
QUEUE_FULL = "queue_full"
WAIT_EXCEEDED = "wait_exceeded"
TOKEN_BUDGET_EXCEEDED = "token_budget_exceeded"
MEMORY_PRESSURE = "memory_pressure"
@dataclass
class AdmissionDecision:
allowed: bool
reason: RejectionReason | None = None
estimated_wait_ms: float = 0.0
priority: int = 0 # 0 = normal, 1 = high (repeat customer, etc.)
@dataclass
class AdmissionController:
"""
Gate requests before they reach the vLLM engine queue.
Design rationale:
- max_queue=100: At 128 max_num_seqs, a queue of 100 means worst-case
a new request waits for ~1 full batch cycle. Beyond 100, queue wait
exceeds our 200ms P95 SLO target.
- max_wait_ms=2000: If estimated wait > 2s, the user experience is already
degraded. Better to return a degraded response from cache or fallback.
- max_input_tokens=3600: Leaves 496 tokens for generation within
max_model_len=4096. Larger requests would crowd out others in the batch.
- gpu_cache_pressure_threshold=0.93: When KV cache is > 93% full, new
sequences will likely trigger preemptions. Reject early.
"""
max_queue: int = 100
max_wait_ms: int = 2000
max_input_tokens: int = 3600
gpu_cache_pressure_threshold: float = 0.93
_rejection_count: dict = field(default_factory=dict)
def evaluate(
self,
queue_depth: int,
estimated_wait_ms: float,
input_token_count: int,
gpu_cache_usage: float,
request_priority: int = 0,
) -> AdmissionDecision:
"""
Evaluate whether to admit a request.
Priority requests (repeat customers, paid tier) get a 20% higher
queue threshold and wait budget. This prevents VIP requests from
being rejected during moderate load while still protecting against
full saturation.
"""
effective_max_queue = int(self.max_queue * (1.2 if request_priority > 0 else 1.0))
effective_max_wait = int(self.max_wait_ms * (1.2 if request_priority > 0 else 1.0))
if gpu_cache_usage >= self.gpu_cache_pressure_threshold:
self._track_rejection(RejectionReason.MEMORY_PRESSURE)
return AdmissionDecision(
allowed=False,
reason=RejectionReason.MEMORY_PRESSURE,
estimated_wait_ms=estimated_wait_ms,
)
if input_token_count > self.max_input_tokens:
self._track_rejection(RejectionReason.TOKEN_BUDGET_EXCEEDED)
return AdmissionDecision(
allowed=False,
reason=RejectionReason.TOKEN_BUDGET_EXCEEDED,
estimated_wait_ms=estimated_wait_ms,
)
if queue_depth >= effective_max_queue:
self._track_rejection(RejectionReason.QUEUE_FULL)
return AdmissionDecision(
allowed=False,
reason=RejectionReason.QUEUE_FULL,
estimated_wait_ms=estimated_wait_ms,
)
if estimated_wait_ms >= effective_max_wait:
self._track_rejection(RejectionReason.WAIT_EXCEEDED)
return AdmissionDecision(
allowed=False,
reason=RejectionReason.WAIT_EXCEEDED,
estimated_wait_ms=estimated_wait_ms,
)
return AdmissionDecision(
allowed=True,
estimated_wait_ms=estimated_wait_ms,
priority=request_priority,
)
def _track_rejection(self, reason: RejectionReason) -> None:
self._rejection_count[reason] = self._rejection_count.get(reason, 0) + 1
logger.warning("Admission rejected: reason=%s, total=%d", reason.value, self._rejection_count[reason])
What happens when admission rejects a request: The gateway returns a structured fallback response. It does NOT return an HTTP error. The orchestrator layer handles the fallback decision:
async def handle_rejected_request(
decision: AdmissionDecision,
request: dict,
) -> dict:
"""
Graceful degradation for rejected requests.
Three fallback strategies, tried in order:
1. Cached response: If a semantically similar question was answered recently
2. Bedrock fallback: Route to the managed API (higher cost, but available)
3. Degraded response: Short, honest "I'm busy" message
"""
if decision.reason == RejectionReason.TOKEN_BUDGET_EXCEEDED:
# Trim and retry — this is recoverable
trimmed = trim_to_budget(request, max_tokens=3600)
return await retry_with_trimmed(trimmed)
if decision.reason in (RejectionReason.QUEUE_FULL, RejectionReason.WAIT_EXCEEDED):
# Try managed fallback
return await route_to_bedrock_fallback(request)
# Memory pressure — do not send more work to this GPU
return {
"fallback": "degraded_response",
"message": "I'm handling a lot of questions right now. Let me give you a shorter answer.",
"reason": decision.reason.value,
}
Critical Decisions
scheduler_delay_factor=0.0: We did not want artificial wait windows. The chatbot needed immediate admission when decode slots freed up.max_num_batched_tokenscap: Protected the system from a few very large prompts crowding out many medium ones.- Queue timeout with graceful fallback: Better to downgrade or reroute than to let a user wait behind an overloaded GPU silently.
- Separate metrics for
queue_wait_msandgeneration_ms: Prevented us from blaming the model for scheduler problems.
Why These Decisions Were Correct
The mistake in many inference migrations is to celebrate higher throughput while hiding queue pain inside the tail. We optimized for user-visible latency first, then used continuous batching to recover throughput without fixed-window penalties.
9. Scenario 3 - Prefix Caching And Streaming
Implementation Flow
- Canonicalize the stable prefix of the prompt.
- Keep volatile fields after the shared prefix boundary.
- Enable prefix caching in the engine.
- Stream tokens through SSE or WebSocket to the chat UI.
- Log
prefix_cache_hit, TTFT, and full completion latency on every request.
Canonicalization Rule
Only deterministic content belongs in the cacheable prefix:
def build_prompt_prefix(system_prompt: str, policy_block: str) -> str:
return "\n".join([
system_prompt.strip(),
policy_block.strip(),
"Follow catalog grounding rules.",
])
Do not put timestamps, random request IDs, or per-user personalization in the cacheable prefix if you want reuse to be real.
Streaming Adapter
"""
SSE streaming adapter for vLLM token generation.
Converts vLLM's async generator to Server-Sent Events for the chat UI.
"""
import json
import logging
import time
logger = logging.getLogger(__name__)
async def stream_tokens(
request_id: str,
generator,
send_event,
metrics_emitter,
) -> dict:
"""
Stream tokens from the vLLM engine to the client.
Design decisions:
- First token timing: We record TTFT from the moment generation begins,
not from request receipt. Queue wait is tracked separately to isolate
scheduler delay from model compute.
- Metadata at tail: Usage stats and trace metadata are sent in the final
event so they don't delay the first visible token.
- Error containment: If the generator raises mid-stream, we send a clean
termination event rather than dropping the connection. The client can
distinguish between complete and truncated responses.
"""
token_count = 0
first_token_time = None
start_time = time.monotonic()
try:
async for output in generator:
if first_token_time is None:
first_token_time = time.monotonic()
token_count += 1
await send_event({
"type": "token",
"delta": output.text,
})
total_ms = (time.monotonic() - start_time) * 1000
ttft_ms = (first_token_time - start_time) * 1000 if first_token_time else total_ms
# Send completion metadata as the final event
await send_event({
"type": "done",
"usage": {
"output_tokens": token_count,
"ttft_ms": round(ttft_ms, 1),
"total_ms": round(total_ms, 1),
},
})
return {"ttft_ms": ttft_ms, "total_ms": total_ms, "output_tokens": token_count}
except Exception as exc:
logger.error("Stream error for request %s: %s", request_id, exc)
await send_event({
"type": "error",
"message": "Generation was interrupted. Please try again.",
})
raise
Critical Decisions
- Deterministic prefix boundary: Prefix caching only works if the leading tokens are actually shared.
- Stream-first architecture: We optimized for TTFT, not just final completion time.
- Metadata at tail, not head: Final usage numbers and trace metadata were sent after generation so they did not block initial token delivery.
- Prefix cache hit tracking: Without direct measurement, it is easy to assume caching is helping when prompt drift has made it ineffective.
Why These Decisions Were Correct
The chatbot experience is judged early. Users notice whether the assistant starts responding. They do not care whether the total completion was 1.1 seconds or 1.3 seconds if the first token arrived quickly and steadily.
10. Scenario 4 - Multi-LoRA Adapter Routing
Implementation Flow
- The orchestrator selects an adapter based on route, locale, or specialization.
- The vLLM gateway resolves the adapter from a registry.
- The request is sent with one explicit adapter ID.
- Observability tags the response with base model version and adapter version.
- Offline evaluation runs adapter regressions before promotion.
Adapter Registry Pattern
from vllm.lora.request import LoRARequest
ADAPTERS = {
"manga_domain_v3": LoRARequest("manga", 1, "/models/lora/manga_domain_v3"),
"general_support_v2": LoRARequest("support", 2, "/models/lora/general_support_v2"),
"jp_style_v1": LoRARequest("jp_style", 3, "/models/lora/jp_style_v1"),
}
Critical Decisions
- One compatible base-model family: Avoided operational complexity from trying to mix incompatible adapter ecosystems.
- Single adapter per request: Kept routing explicit and evaluation simpler.
- Preload hottest adapters: Reduced adapter-switch overhead for the common paths.
- Version adapters independently: Allowed domain behavior updates without pretending the base model changed.
Why These Decisions Were Correct
Multi-LoRA only helps if adapter routing stays explicit and observable. Otherwise you reduce infrastructure cost but increase debugging ambiguity, which is a bad trade in a production chatbot.
11. Scenario 5 - AWQ, Context Budgeting, And OOM Containment
Implementation Flow
- Quantize the base model offline using a manga-domain calibration set.
- Evaluate the quantized artifact on language quality, grounding, and safety-sensitive prompts.
- At request time, reserve token budget for retrieved evidence and the latest turns.
- Summarize or trim older turns before crossing the model limit.
- Catch OOM failures and return a controlled degradation path instead of crashing the worker.
Context Budgeting Pattern
def allocate_token_budget(max_tokens: int) -> dict[str, int]:
return {
"system": 400,
"retrieval": 1400,
"recent_turns": 1200,
"older_summary": 400,
"generation_reserve": max_tokens - 3400,
}
OOM Guard Pattern
The OOM guard is a three-layer defense: prevention (token budgeting), containment (runtime catch), and recovery (graceful degradation). This is the containment layer.
"""
OOM containment for the vLLM gateway.
Catches CUDA OOM at the gateway boundary to prevent worker crash loops.
"""
import functools
import logging
import traceback
logger = logging.getLogger(__name__)
def with_oom_guard(metrics_emitter=None):
"""
Decorator that catches CUDA OOM errors and returns a degraded response.
Why this exists:
- vLLM's internal memory management prevents most OOMs, but edge cases
remain: adapter loading during high concurrency, CUDA graph capture
on unusual sequence lengths, and driver-level allocation failures.
- Without this guard, a single OOM kills the worker process. SageMaker
restarts it, but that takes 60-90 seconds — an eternity for all
concurrent users on that GPU.
- With this guard, the single request that triggered OOM gets a degraded
response, and all other in-flight requests complete normally.
What this does NOT do:
- It does not catch OOMs inside the vLLM C++/CUDA kernels. Those crash
the process regardless. This catches Python-level RuntimeErrors from
PyTorch tensor allocations.
"""
def decorator(fn):
@functools.wraps(fn)
async def wrapper(*args, **kwargs):
try:
return await fn(*args, **kwargs)
except RuntimeError as exc:
error_msg = str(exc).lower()
if "out of memory" not in error_msg and "cuda" not in error_msg:
raise
logger.error(
"OOM caught at gateway boundary: %s\n%s",
exc,
traceback.format_exc(),
)
if metrics_emitter:
adapter_id = kwargs.get("adapter_id", "unknown")
metrics_emitter.emit_oom_caught(adapter_id)
return {
"fallback": "degraded_response",
"message": "I can answer this, but I need to shorten the context first.",
"error_type": "oom_contained",
"retry_hint": "Retry with fewer conversation turns.",
}
return wrapper
return decorator
Critical Decisions
- In-domain AWQ calibration: Generic calibration text would have preserved the wrong activations for Japanese manga vocabulary.
- Preserve evidence before chit-chat: Retrieval chunks and recent user intent mattered more than old conversational filler.
- OOM containment at the gateway boundary: Protected the worker process from a full crash loop.
- Quality gate before rollout: Quantization was not accepted on throughput alone.
Why These Decisions Were Correct
This scenario was about reliability, not only efficiency. Long sessions are common in a recommendation assistant, so the system needed to survive success cases rather than only benchmark cases.
12. Scenario 6 - OpenAI-Compatible Contract And Safe Migration
Implementation Flow
- The orchestrator calls a backend-neutral generation client.
- The generation client maps the request to either Bedrock or the vLLM gateway.
- Shadow mode can send the same normalized request to both backends.
- Comparison jobs score output shape, latency, and guardrail pass rates.
- Promotion changes routing config, not application business logic.
Backend-Neutral Client
"""
Backend-neutral generation client.
Abstracts the difference between vLLM (self-hosted) and Bedrock (managed).
The application layer never calls an inference backend directly.
"""
import logging
import time
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@dataclass
class GenerationResult:
request_id: str
backend: str
backend_version: str
trace_id: str
output_text: str
finish_reason: str
input_tokens: int
output_tokens: int
ttft_ms: float
total_ms: float
class GenerationClient:
"""
Unified generation interface.
Design decisions:
- Backend selection is driven by the routing config, not by the caller.
The caller sends a normalized request; the client resolves the backend.
- The response contract is identical regardless of backend. The application
never inspects backend-specific fields.
- Shadow mode sends the same request to both backends for comparison.
Results are logged to MLflow for offline analysis.
- Fallback: If vLLM is unavailable (OOM, scaling, health check failure),
the client automatically routes to Bedrock. The caller does not need
to handle this.
"""
def __init__(self, vllm_gateway, bedrock_adapter, shadow_mode: bool = False):
self._vllm = vllm_gateway
self._bedrock = bedrock_adapter
self._shadow_mode = shadow_mode
async def generate(self, request: dict) -> GenerationResult:
backend = request.get("backend", "vllm")
start = time.monotonic()
if backend == "vllm":
result = await self._call_vllm(request)
elif backend == "bedrock":
result = await self._call_bedrock(request)
else:
raise ValueError(f"Unsupported backend: {backend}")
if self._shadow_mode and backend == "vllm":
# Fire-and-forget shadow request to Bedrock for comparison
# Results are logged asynchronously; errors are swallowed
self._schedule_shadow(request)
return result
async def _call_vllm(self, request: dict) -> GenerationResult:
"""Route to self-hosted vLLM endpoint."""
raise NotImplementedError("See vllm_gateway.py for full implementation")
async def _call_bedrock(self, request: dict) -> GenerationResult:
"""Route to managed Bedrock endpoint."""
raise NotImplementedError("See bedrock_adapter.py for full implementation")
def _schedule_shadow(self, request: dict) -> None:
"""Schedule async shadow comparison. See eval/shadow_compare.py."""
pass
Critical Decisions
- No direct business-logic calls to the engine: All traffic flowed through one gateway so metadata, retries, and tracing stayed consistent.
- Stable response contract: Kept the application insulated from backend-specific fields.
- Shadow first, cut over later: Reduced migration risk and made rollback a routing decision.
- Always tag backend version: Allowed us to answer "which engine produced this response?" without guesswork.
Why These Decisions Were Correct
The best migration is the one that can be reversed quickly. The OpenAI-style contract and centralized gateway made vLLM an implementation choice instead of a deep application dependency.
13. Shared Rollout Checklist
- Benchmark on real chat workloads, not synthetic single-prompt demos.
- Validate TTFT, completion latency, queue wait, and GPU utilization together.
- Run soak tests with long conversations and adapter-switch traffic.
- Confirm prefix-cache hit rate after prompt changes.
- Gate AWQ or adapter changes on offline evaluation plus low-percent canary.
- Keep rollback at the routing layer wherever possible.
- Verify health checks pass within 90 seconds of container start (see 04-deployment Section 6).
- Validate alerting rules fire correctly in staging before production (see 05-monitoring Section 5).
14. Health And Readiness Separation
Liveness and readiness are separate concerns. Getting this wrong causes cascading failures.
"""
Health check implementation for the vLLM serving container.
Separates liveness (is the process alive?) from readiness (can it serve traffic?).
"""
import logging
import time
import torch
logger = logging.getLogger(__name__)
class HealthProbe:
"""
Three-tier health check for SageMaker and the load balancer.
Liveness (/ping):
- "Is this process alive and the GPU accessible?"
- If this fails, SageMaker should kill and replace the container.
- Must be fast (< 100ms). Do NOT load models or run inference here.
Readiness (/ready):
- "Can this instance serve inference requests?"
- Fails during: model loading, warmup, adapter loading, CUDA graph capture.
- SageMaker uses this to decide whether to route traffic here.
Deep health (/health/deep):
- "Is the GPU healthy and the model producing sensible output?"
- Runs a tiny inference request. Used for periodic validation, not per-request.
- Takes ~500ms. Only called by monitoring, not by the load balancer.
"""
def __init__(self, engine=None) -> None:
self._engine = engine
self._ready = False
self._startup_time = time.monotonic()
def mark_ready(self) -> None:
elapsed = time.monotonic() - self._startup_time
logger.info("Instance marked ready after %.1f seconds", elapsed)
self._ready = True
def liveness(self) -> dict:
"""Fast check: process alive, GPU accessible."""
gpu_ok = torch.cuda.is_available() and torch.cuda.device_count() > 0
return {
"status": "healthy" if gpu_ok else "unhealthy",
"gpu_available": gpu_ok,
"uptime_seconds": round(time.monotonic() - self._startup_time, 1),
}
def readiness(self) -> dict:
"""Full check: model loaded, warmup complete, adapters registered."""
if not self._ready:
return {
"status": "not_ready",
"reason": "Model loading or warmup in progress",
}
return {
"status": "ready",
"uptime_seconds": round(time.monotonic() - self._startup_time, 1),
}
async def deep_health(self) -> dict:
"""
Run a minimal inference to validate the model is producing output.
Only used by monitoring probes, not by the load balancer.
"""
if not self._ready or self._engine is None:
return {"status": "not_ready"}
try:
start = time.monotonic()
# Minimal inference: 5 tokens with deterministic sampling
result = await self._engine.generate(
prompt="test",
sampling_params={"max_tokens": 5, "temperature": 0.0},
)
latency_ms = (time.monotonic() - start) * 1000
return {
"status": "healthy",
"inference_latency_ms": round(latency_ms, 1),
"output_length": len(result),
}
except Exception as exc:
logger.error("Deep health check failed: %s", exc)
return {
"status": "unhealthy",
"error": str(exc),
}
Why separation matters: In a previous incident, we combined liveness and readiness into one /ping endpoint. During model loading (~45s), SageMaker saw /ping failures and killed the container before it could finish loading. The container restarted, started loading again, got killed again — an infinite restart loop. Separating the checks fixed this immediately.
15. The Most Important Low-Level Decisions
If I had to compress the whole implementation into eight decisions, they would be:
- Budget tokens before the engine sees the request. The engine should never be the first place where prompt excess is discovered.
- Keep one engine process per GPU and leave explicit VRAM headroom (8%). Predictable memory behavior beats maximum theoretical utilization.
- Measure queue wait separately from generation latency. Without this, you blame the model for scheduler problems.
- Make prompt prefixes deterministic if you want caching to matter. A single timestamp in the cacheable prefix destroys hit rate.
- Treat adapter IDs and backend versions as observability primitives. You must be able to answer "which model produced this response?" for any request.
- Design the backend contract so migration and rollback are routing changes, not app rewrites. The best migration is the one that can be reversed in under a minute.
- Separate liveness from readiness. Combining them causes container restart loops during model loading.
- Contain OOMs at the gateway boundary. One user's oversized request should not crash the GPU for 127 other concurrent users.
16. Cross-References
| Document | What it covers | When to read it |
|---|---|---|
| 01-vllm-game-changer-scenarios.md | Business context and architecture story | Understanding why each decision was made |
| 03-vllm-interview-prep-deep-dive.md | Interview questions with deep-dive answers | Preparing to explain these decisions |
| 04-vllm-deployment-and-infrastructure.md | Docker, SageMaker, scaling, startup | Deploying and operating the stack |
| 05-vllm-monitoring-and-troubleshooting.md | Metrics, alerts, SLOs, diagnostics | Monitoring and debugging production issues |
| 06-vllm-model-preparation-and-quantization.md | AWQ calibration, LoRA training, CI/CD | Preparing and promoting model artifacts |