LOCAL PREVIEW View on GitHub

US-MLE-01: Intent Classifier Retraining Pipeline

User Story

As an ML Engineer at Amazon-scale on the MangaAssist team, I want to own a weekly retraining pipeline for the multilingual intent classifier (DistilBERT) that ingests new labels, validates them, retrains, evaluates across slices, promotes through shadow/canary, and detects drift, So that the chatbot's downstream routing (which MCP to call, which model tier to use) keeps pace with evolving query patterns, new manga categories (manhwa, manhua, light-novel adjuncts), and bilingual JP/EN traffic without quality regression.

Acceptance Criteria

  • Weekly scheduled retrain runs on Friday 03:00 JST off-peak window; finishes by 11:00 JST.
  • Retrain ingests labels at label_version > previous_run.max_label_version from the label platform with EN κ ≥ 0.75 and JP κ ≥ 0.75 verified per IAA.
  • Validation accuracy ≥ 0.90 (macro across 14 intent classes) on the frozen golden set.
  • Per-slice macro-F1 ≥ 0.85 on each of: EN/JP/mixed × {new-user, returning, power-user} × {peak, off-peak}.
  • No regression > 2σ on safety-critical intents (return_request, escalation, complaint) vs the production model.
  • Adversarial test suite (typos, JP/EN code-switching, Unicode tricks) regression ≤ 1% absolute vs baseline.
  • Online p95 latency on ml.g5.xlarge real-time endpoint ≤ 15 ms.
  • Online p99 latency ≤ 35 ms.
  • Promoted model is rollback-able to previous version via SageMaker traffic shift in ≤ 60 seconds.
  • Drift detection emits CloudWatch alarms within 24 hours of input drift (PSI > 0.2) and within 7 days of concept drift (Δ-F1 > 0.03).
  • Per-request cost ≤ $0.000018 ($1.80 per 100K requests) at the contracted traffic mix.

Architecture (HLD)

The Production Surface

The intent classifier sits on the chatbot's hot path immediately after the rule-based pre-filter (per US-02 of cost optimization). Every message that the rule-based pre-filter cannot handle with confidence ≥ 0.85 falls through to this model. At the contracted traffic of ~10M chat messages/day, with the rule-based pre-filter handling ~65%, this model serves ~3.5M inferences/day with diurnal peaks of ~80 RPS and JP-region peaks of ~140 RPS during 21:00–24:00 JST (the manga-reading peak).

The model is DistilBERT-multilingual (distilbert-base-multilingual-cased, 134M parameters) fine-tuned on a 14-class intent taxonomy:

catalog_search           recommendation_request   order_lookup
review_query             support_policy           pricing_inquiry
trending_request         personalization_setup    chitchat
return_request           escalation               complaint
account_management       multilingual_unknown

The output is a softmax distribution over the 14 classes; the consumer (template router, MCP dispatcher) reads the top-1 and the entropy-derived confidence. Three intents (return_request, escalation, complaint) are safety-critical — misrouting them is a CX incident, so they have a higher quality contract than the rest.

End-to-End ML Lifecycle Diagram

flowchart TB
    subgraph DATA[Data Layer]
        L1[Vendor Labels<br/>Appen + Sama<br/>~5K/week]
        L2[Programmatic Labels<br/>Rule-based pre-filter<br/>~50K/week]
        L3[LLM-Distilled Labels<br/>Sonnet weak labels<br/>~10K/week]
        L4[Label Platform<br/>Iceberg on S3<br/>PIT-correct]
        L1 --> L4
        L2 --> L4
        L3 --> L4
    end

    subgraph FEAT[Feature Layer]
        F1[Feature Store<br/>SageMaker Feature Store<br/>+ Iceberg offline]
        F2[Feature Catalog<br/>schema_v3.4]
        F1 -.pinned to.-> F2
    end

    subgraph TRAIN[Training Pipeline - SageMaker Pipelines]
        T1[Step 1<br/>Data Validation]
        T2[Step 2<br/>Feature Materialization<br/>PIT join]
        T3[Step 3<br/>Train/Val/Test Split<br/>stratified by language x intent]
        T4[Step 4<br/>Training<br/>ml.g5.xlarge spot]
        T5[Step 5<br/>Offline Eval<br/>5 modes]
        T6[Step 6<br/>Slice Analysis]
        T7[Step 7<br/>Bias Audit]
        T8[Step 8<br/>Model Registration]
        T1 --> T2 --> T3 --> T4 --> T5 --> T6 --> T7 --> T8
    end

    subgraph SERVE[Serving Layer]
        S1[Model Registry<br/>v47 prod, v48 candidate]
        S2[Shadow Endpoint<br/>v48 parallel]
        S3[Canary 1% -> 5% -> 25%]
        S4[Production Endpoint<br/>ml.g5.xlarge AS]
        S5[Auto-Abort<br/>Daemon]
        S1 --> S2 --> S3 --> S4
        S5 -.monitors.-> S3
    end

    subgraph DRIFT[Drift Detection]
        D1[Drift Hub<br/>PSI/KS/Chi-sq]
        D2[CloudWatch Alarms]
        D1 --> D2
        D2 -.triggers.-> T1
    end

    L4 --> T1
    F1 --> T2
    T8 --> S1
    S4 -.predictions.-> D1

    style L4 fill:#9cf,stroke:#333
    style F1 fill:#9cf,stroke:#333
    style S1 fill:#fd2,stroke:#333
    style S5 fill:#f66,stroke:#333
    style D1 fill:#fd2,stroke:#333

Data Contracts and Volume

Asset Schema Version Snapshot Cadence Volume Owner
intent_labels Iceberg table label_v3 Continuous (label platform) ~3.5M cumulative; ~65K added/week Label Platform PM
intent_features feature group schema_v3.4 1 h batch + online realtime ~3.5M serving rows/day ML Eng (this story)
intent_classifier model package model_v47 in prod Weekly 134M params, ~530MB artifact ML Eng (this story)
intent_predictions_log Iceberg log_v2 Continuous ~3.5M rows/day, retained 90d hot, 12mo glacier Data Platform
intent_drift_metrics CloudWatch n/a 5 min (input/pred); daily (label/concept) ~70K data points/day Drift Hub

Model Registry + Promotion Gates

flowchart LR
    R47[Registry v47<br/>prod, label_v=2840]:::prod
    R48[Registry v48<br/>candidate, label_v=2905]:::cand

    R48 --> G1{Stage 1<br/>Offline Gate}
    G1 -->|pass| G2{Stage 2<br/>Shadow 24h}
    G1 -.fail.-> RB1[Rollback v47]
    G2 -->|pass| G3{Stage 3<br/>Canary 7d/3d/7d}
    G2 -.fail.-> RB2[Rollback v47]
    G3 -->|pass| G4[Stage 4<br/>Full Promote v48]
    G3 -.fail.-> RB3[Rollback v47]

    classDef prod fill:#2d8,stroke:#333
    classDef cand fill:#fd2,stroke:#333
    style RB1 fill:#f66,stroke:#333
    style RB2 fill:#f66,stroke:#333
    style RB3 fill:#f66,stroke:#333
    style G4 fill:#2d8,stroke:#333

The four-stage promotion gate is the standard from deep-dives/00-foundations-and-primitives-for-ml-engineering.md §5.1. Story-specific thresholds:

  • Stage 1 (offline): val accuracy ≥ 0.90, slice gate passes, adversarial regression ≤ 1%.
  • Stage 2 (shadow): per-request prediction agreement with v47 ≥ 75% (low because v48 is allowed to change predictions; we are checking that it does so reasonably). Shadow latency ≤ 1.3× v47.
  • Stage 3 (canary): online routing-correctness (measured by downstream MCP success rate) ≥ v47 baseline; safety-critical intent recall ≥ 0.92.
  • Stage 4 (full): traffic shift to 100% with old v47 retained for 14 days as rollback target.

Low-Level Design

1. Feature / Data Pipeline

The intent classifier reads three feature groups from the SageMaker Feature Store (offline mirror in Iceberg on S3, online cache in DynamoDB-backed feature group):

  • Message-level features (computed online): tokenized message, language detector output, character set distribution, message length, has-emoji flag, presence-of-product-tokens flag.
  • Session-level features (computed online with 5-minute Redis aggregation): turn count, last 3 intents in session, session-start timestamp delta.
  • User-level features (computed via 1-hour batch): registration recency, lifetime intent histogram (privacy-redacted, only top-3 intent counts), preferred language flag, returning-vs-new bucket.

The feature catalog version schema_v3.4 is pinned; a feature schema upgrade is a coordinated change across this story, US-MLE-02, and US-MLE-06 (which all read overlapping features).

# feature_pipeline.py
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Optional
import boto3
import sagemaker
from sagemaker.feature_store.feature_group import FeatureGroup


@dataclass
class IntentTrainingExample:
    example_id: str               # sha256(canonical_message + user_id + captured_at)
    user_id: str                  # hashed
    session_id: str
    captured_at: datetime
    label_value: str              # one of 14 intent classes
    label_version: int
    label_source: str
    message_text: str             # canonical, lowercased, NFKC-normalized
    language_detected: str        # "en" / "ja" / "mixed"


class IntentFeaturePipeline:
    """PIT-correct feature read for intent-classifier training.

    Implements primitive §2.1 from the foundations doc. The serving lag for each
    feature group is read from the feature catalog, never hardcoded.
    """

    def __init__(self, region: str = "ap-northeast-1"):
        self.session = sagemaker.Session(boto3.Session(region_name=region))
        self.fs = sagemaker.feature_store.FeatureStoreRuntime(self.session)
        self.catalog = FeatureCatalog.load("schema_v3.4")
        self.serving_lag = self.catalog.serving_lag_per_group()

    def materialize(
        self,
        examples: list[IntentTrainingExample],
        feature_groups: list[str],
    ) -> "pd.DataFrame":
        """Read PIT-correct features for every example."""
        rows = []
        for ex in examples:
            row = {"example_id": ex.example_id, "label": ex.label_value}
            for fg_name in feature_groups:
                lag = self.serving_lag[fg_name]
                as_of = ex.captured_at - lag
                features = self.fs.get_record(
                    feature_group_name=fg_name,
                    record_identifier_value_as_string=ex.user_id,
                    as_of_timestamp=as_of,
                )
                row.update(features)
            rows.append(row)
        return pd.DataFrame(rows)

    def leak_detector(
        self,
        examples: list[IntentTrainingExample],
        sample_pct: float = 0.01,
    ) -> "LeakReport":
        """Re-read features at as_of - 1ns vs as_of + serving_lag and compare.

        A feature whose values across this 1ns window correlate suspiciously with
        the label has a leak. Returns a per-feature leak score; pipeline aborts
        if any feature scores above 0.85.
        """
        sampled = random.sample(examples, int(len(examples) * sample_pct))
        before = self.materialize_at_time_offset(sampled, offset=timedelta(microseconds=-1))
        after = self.materialize_at_time_offset(sampled, offset=timedelta(hours=1))
        return LeakReport.compute(before, after, labels=[ex.label_value for ex in sampled])

The leak detector has caught two real production issues on the platform: a session-level last_intent_in_session feature that was being populated with the current turn's intent (a labeled row leaks its own label into its own features) and a user-level returning_vs_new feature that was being computed against the training-snapshot-time user table rather than the captured_at-time user table.

2. Training Pipeline (SageMaker Pipelines)

The training step DAG is implemented in training_pipeline.py:

# training_pipeline.py
from sagemaker.workflow.pipeline import Pipeline
from sagemaker.workflow.steps import ProcessingStep, TrainingStep
from sagemaker.workflow.condition_step import ConditionStep
from sagemaker.workflow.conditions import ConditionGreaterThanOrEqualTo
from sagemaker.workflow.execution_variables import ExecutionVariables
from sagemaker.workflow.parameters import ParameterString, ParameterInteger
from sagemaker.huggingface import HuggingFaceProcessor
from sagemaker.huggingface.estimator import HuggingFace


def build_pipeline(role: str, region: str = "ap-northeast-1") -> Pipeline:
    label_version = ParameterInteger(name="MaxLabelVersion", default_value=0)
    feature_schema = ParameterString(name="FeatureSchemaVersion", default_value="schema_v3.4")

    # Step 1: Data validation
    data_val = ProcessingStep(
        name="IntentDataValidation",
        processor=HuggingFaceProcessor(
            transformers_version="4.36",
            pytorch_version="2.1",
            role=role,
            instance_type="ml.m5.4xlarge",
            instance_count=1,
        ),
        code="src/data_validation.py",
        job_arguments=[
            "--label-version", label_version.to_string(),
            "--min-iaa-en", "0.75",
            "--min-iaa-jp", "0.75",
            "--required-class-coverage", "14",
            "--min-examples-per-class-jp", "200",
            "--min-examples-per-class-en", "200",
        ],
    )

    # Step 2: Feature materialization (PIT-correct)
    feature_mat = ProcessingStep(
        name="IntentFeatureMaterialization",
        processor=HuggingFaceProcessor(
            transformers_version="4.36",
            pytorch_version="2.1",
            role=role,
            instance_type="ml.m5.4xlarge",
            instance_count=1,
        ),
        code="src/feature_materialize.py",
        depends_on=[data_val],
        job_arguments=[
            "--feature-schema", feature_schema.to_string(),
            "--leak-detector-pct", "0.01",
            "--leak-score-abort-threshold", "0.85",
        ],
    )

    # Step 3: Train/Val/Test split, stratified by language x intent
    split = ProcessingStep(
        name="IntentTrainValTestSplit",
        processor=HuggingFaceProcessor(
            transformers_version="4.36",
            pytorch_version="2.1",
            role=role,
            instance_type="ml.m5.4xlarge",
            instance_count=1,
        ),
        code="src/split.py",
        depends_on=[feature_mat],
        job_arguments=[
            "--strategy", "stratified",
            "--strata", "language,intent",
            "--val-frac", "0.15",
            "--test-frac", "0.15",
            "--seed", "42",
        ],
    )

    # Step 4: Training
    estimator = HuggingFace(
        entry_point="train.py",
        source_dir="src/",
        instance_type="ml.g5.xlarge",
        instance_count=1,
        role=role,
        transformers_version="4.36",
        pytorch_version="2.1",
        py_version="py310",
        use_spot_instances=True,
        max_wait=14400,
        max_run=10800,
        checkpoint_s3_uri=f"s3://manga-ml-checkpoints-apne1/intent/",
        checkpoint_local_path="/opt/ml/checkpoints",
        hyperparameters={
            "model_name_or_path": "distilbert-base-multilingual-cased",
            "num_train_epochs": 4,
            "per_device_train_batch_size": 64,
            "per_device_eval_batch_size": 128,
            "learning_rate": 2e-5,
            "warmup_ratio": 0.1,
            "weight_decay": 0.01,
            "fp16": True,
            "save_strategy": "steps",
            "save_steps": 500,
            "evaluation_strategy": "steps",
            "eval_steps": 500,
            "metric_for_best_model": "macro_f1",
            "load_best_model_at_end": True,
            "report_to": "sagemaker",
            "seed": 42,
        },
    )
    train = TrainingStep(name="IntentTrain", estimator=estimator, depends_on=[split])

    # Step 5: Offline evaluation
    offline_eval = ProcessingStep(
        name="IntentOfflineEval",
        processor=HuggingFaceProcessor(
            transformers_version="4.36",
            pytorch_version="2.1",
            role=role,
            instance_type="ml.g5.xlarge",
            instance_count=1,
        ),
        code="src/offline_eval.py",
        depends_on=[train],
        job_arguments=[
            "--golden-set-uri", "s3://manga-ml-eval-apne1/intent/golden/v3.parquet",
            "--adversarial-set-uri", "s3://manga-ml-eval-apne1/intent/adversarial/v2.parquet",
            "--counterfactual-set-uri", "s3://manga-ml-eval-apne1/intent/replay/last7d.parquet",
            "--baseline-model-version", "47",
        ],
    )

    # Step 6: Slice analysis
    slice_analysis = ProcessingStep(
        name="IntentSliceAnalysis",
        processor=HuggingFaceProcessor(
            transformers_version="4.36",
            pytorch_version="2.1",
            role=role,
            instance_type="ml.m5.4xlarge",
            instance_count=1,
        ),
        code="src/slice_analysis.py",
        depends_on=[offline_eval],
        job_arguments=[
            "--slices", "language,intent,cohort,device,catalog_tier,time_of_day",
            "--safety-critical-intents", "return_request,escalation,complaint",
            "--regression-sigma-floor", "2.0",
        ],
    )

    # Step 7: Bias audit
    bias = ProcessingStep(
        name="IntentBiasAudit",
        processor=HuggingFaceProcessor(
            transformers_version="4.36",
            pytorch_version="2.1",
            role=role,
            instance_type="ml.m5.4xlarge",
            instance_count=1,
        ),
        code="src/bias_audit.py",
        depends_on=[slice_analysis],
    )

    # Stage 1 offline gate
    promote_gate = ConditionStep(
        name="IntentOfflineGate",
        conditions=[
            ConditionGreaterThanOrEqualTo(
                left=offline_eval.properties.ProcessingOutputConfig.Outputs[
                    "metrics"
                ].S3Output.S3Uri,
                right="0.90",
            ),
        ],
        if_steps=[register_model_step(estimator)],
        else_steps=[abort_step()],
        depends_on=[bias],
    )

    return Pipeline(
        name="IntentClassifierWeeklyRetrain",
        parameters=[label_version, feature_schema],
        steps=[data_val, feature_mat, split, train, offline_eval, slice_analysis, bias, promote_gate],
    )

3. Offline Evaluation (5 modes)

The offline-evaluation step exercises all five modes from primitive §4.1:

# offline_eval.py
class IntentOfflineEvaluator:
    def __init__(self, model, baseline_model, label_encoder):
        self.model = model
        self.baseline = baseline_model
        self.le = label_encoder

    def evaluate(self) -> EvalReport:
        return EvalReport(
            golden=self.eval_golden(),
            slice=self.eval_slice(),
            adversarial=self.eval_adversarial(),
            counterfactual=self.eval_counterfactual_replay(),
            offline_online_corr=self.eval_offline_online_corr(),
        )

    def eval_golden(self) -> GoldenReport:
        """Frozen 1500-example golden set, refreshed quarterly. The hard contract."""
        golden = pd.read_parquet("s3://manga-ml-eval-apne1/intent/golden/v3.parquet")
        preds = self.model.predict(golden["message"].tolist())
        macro_f1 = f1_score(golden["label"], preds, average="macro", labels=self.le.classes_)
        per_class_f1 = f1_score(golden["label"], preds, average=None, labels=self.le.classes_)
        return GoldenReport(
            macro_f1=macro_f1,
            per_class_f1=dict(zip(self.le.classes_, per_class_f1)),
            min_class_f1=per_class_f1.min(),
            min_class=self.le.classes_[per_class_f1.argmin()],
        )

    def eval_slice(self) -> SliceReport:
        """Cartesian product over language x intent x cohort x device x time."""
        slices_df = pd.read_parquet("s3://manga-ml-eval-apne1/intent/slice/v3.parquet")
        results = {}
        for slice_key in slice_combinations(["language", "intent", "cohort", "device", "tod"]):
            subset = slices_df.query(slice_key.query_str)
            if len(subset) < 30:
                continue  # skip too-small slices
            preds = self.model.predict(subset["message"].tolist())
            f1 = f1_score(subset["label"], preds, average="macro")
            baseline_preds = self.baseline.predict(subset["message"].tolist())
            baseline_f1 = f1_score(subset["label"], baseline_preds, average="macro")
            results[slice_key.name] = SliceResult(
                f1=f1,
                baseline_f1=baseline_f1,
                delta=f1 - baseline_f1,
                n=len(subset),
                is_safety_critical=slice_key.has_intent(SAFETY_CRITICAL_INTENTS),
            )
        return SliceReport(slices=results)

    def eval_adversarial(self) -> AdversarialReport:
        """Curated 800-example set: typos, JP/EN code-switching, Unicode tricks,
        adversarial paraphrases, prompt-injection attempts redirected to intent."""
        adv = pd.read_parquet("s3://manga-ml-eval-apne1/intent/adversarial/v2.parquet")
        preds = self.model.predict(adv["message"].tolist())
        baseline_preds = self.baseline.predict(adv["message"].tolist())
        f1 = f1_score(adv["label"], preds, average="macro")
        baseline_f1 = f1_score(adv["label"], baseline_preds, average="macro")
        return AdversarialReport(
            f1=f1,
            baseline_f1=baseline_f1,
            regression=baseline_f1 - f1,
            per_attack_f1=self._per_attack_breakdown(adv, preds),
        )

    def eval_counterfactual_replay(self) -> CounterfactualReport:
        """5K production replay sessions from last 7d, stratified by intent."""
        replay = pd.read_parquet("s3://manga-ml-eval-apne1/intent/replay/last7d.parquet")
        preds = self.model.predict(replay["message"].tolist())
        baseline_preds = self.baseline.predict(replay["message"].tolist())
        # Replay measures (cost, quality) shift, since labels are not available
        agreement = (preds == baseline_preds).mean()
        latency_delta = self._measure_latency_delta(replay)
        return CounterfactualReport(
            agreement_rate=agreement,
            latency_delta_p95_ms=latency_delta,
            disagreement_distribution=self._disagreement_breakdown(replay, preds, baseline_preds),
        )

    def eval_offline_online_corr(self) -> OfflineOnlineCorrReport:
        """Reads the last 6 retrains' (offline_macro_f1, online_routing_correctness)
        pairs from the model registry and computes Pearson correlation."""
        history = ModelRegistry.get_history(model_target="intent_classifier", last_n=6)
        offline = [h.metrics["macro_f1"] for h in history]
        online = [h.online_metrics["routing_correctness_28d"] for h in history]
        corr = scipy.stats.pearsonr(offline, online).statistic
        return OfflineOnlineCorrReport(
            corr=corr,
            trustworthy=corr >= 0.6,
            history_pairs=list(zip(offline, online)),
        )

4. Online Serving

The model serves on a SageMaker real-time endpoint, ml.g5.xlarge instance with TorchServe handler. Endpoint configuration:

# endpoint_config.py
from sagemaker.model import Model
from sagemaker.huggingface import HuggingFaceModel

def deploy_intent_endpoint(model_package_arn: str, version: str) -> str:
    model = HuggingFaceModel(
        model_data=model_package_arn,
        role=role,
        transformers_version="4.36",
        pytorch_version="2.1",
        py_version="py310",
        env={
            "HF_MODEL_ID": "intent_classifier",
            "HF_TASK": "text-classification",
            "MAX_BATCH_SIZE": "32",
            "MAX_BATCH_DELAY_MS": "10",
            "MODEL_VERSION": version,
        },
    )

    predictor = model.deploy(
        initial_instance_count=2,
        instance_type="ml.g5.xlarge",
        endpoint_name=f"intent-classifier-{version}",
        async_inference_config=None,
        data_capture_config=DataCaptureConfig(
            enable_capture=True,
            sampling_percentage=10,
            destination_s3_uri="s3://manga-ml-capture-apne1/intent/",
            capture_options=["REQUEST", "RESPONSE"],
        ),
        deployment_config=DeploymentConfig(
            blue_green_update_policy=BlueGreenUpdatePolicy(
                traffic_routing_config=TrafficRoutingConfig(
                    type="LINEAR",
                    canary_size=10,  # 10% canary first
                    linear_step_size=10,
                    wait_in_seconds_per_step=300,
                ),
                termination_wait_in_seconds=600,
            ),
            auto_rollback_configuration=AutoRollbackConfiguration(
                alarms=[
                    "intent-endpoint-error-rate-canary",
                    "intent-endpoint-p99-latency-canary",
                ],
            ),
        ),
    )
    return predictor.endpoint_name

The endpoint is fronted by a per-region application autoscaler:

# autoscaling.py
client = boto3.client("application-autoscaling")

client.register_scalable_target(
    ServiceNamespace="sagemaker",
    ResourceId=f"endpoint/intent-classifier-v48/variant/AllTraffic",
    ScalableDimension="sagemaker:variant:DesiredInstanceCount",
    MinCapacity=2,           # min for HA
    MaxCapacity=12,          # peak ~140 RPS / 12 RPS-per-instance ~ 12
)

client.put_scaling_policy(
    PolicyName="IntentEndpointTargetTracking",
    ServiceNamespace="sagemaker",
    ResourceId=f"endpoint/intent-classifier-v48/variant/AllTraffic",
    ScalableDimension="sagemaker:variant:DesiredInstanceCount",
    PolicyType="TargetTrackingScaling",
    TargetTrackingScalingPolicyConfiguration={
        "TargetValue": 70.0,
        "PredefinedMetricSpecification": {
            "PredefinedMetricType": "SageMakerVariantInvocationsPerInstance",
        },
        "ScaleInCooldown": 600,
        "ScaleOutCooldown": 60,
    },
)

5. Shadow + Canary Promotion

Shadow mode is implemented at the application gateway layer (not the SageMaker endpoint). The gateway calls v47 and v48 in parallel; v47's response goes to the user, v48's response goes to a comparison log. The comparison log writes to Kinesis Firehose → Iceberg intent_shadow_compare table, partitioned by hour:

# shadow_gateway.py
async def classify_intent(message: str, request_id: str) -> IntentResult:
    prod_task = asyncio.create_task(call_endpoint("intent-classifier-v47", message))
    shadow_task = asyncio.create_task(call_endpoint("intent-classifier-v48", message))

    prod_result = await prod_task
    try:
        shadow_result = await asyncio.wait_for(shadow_task, timeout=0.05)  # 50ms cap
        await emit_shadow_compare(
            request_id=request_id,
            prod=prod_result,
            shadow=shadow_result,
            shadow_latency_ms=(shadow_result.completed_at - shadow_task.created_at).total_seconds() * 1000,
        )
    except asyncio.TimeoutError:
        await emit_shadow_compare_timeout(request_id=request_id)

    return prod_result

Canary promotion is implemented via the SageMaker endpoint deployment config above. The auto-rollback alarms intent-endpoint-error-rate-canary and intent-endpoint-p99-latency-canary are CloudWatch composite alarms that AND the canary instance metric with the production instance metric (canary error rate > 2× production error rate AND > 0.5%). The auto-rollback daemon described in primitive §5.3 also runs against the canary, using the online routing-correctness metric (downstream MCP success rate per request) instead of inference-side metrics.

6. Drift Detection

Drift detection consumes the intent_predictions_log Iceberg table and the captured features from data-capture. Per §6.1 of the foundations doc:

Drift Kind Detector Cadence Threshold
Input drift PSI on each top-30 feature; KL on token-distribution 5 min PSI > 0.2 sustained 24h
Label drift χ² on intent class proportions vs reference Daily p < 0.01 sustained 7d
Prediction drift KS on confidence distribution; per-class call rate 5 min KS > 0.15 sustained 24h
Concept drift Rolling-holdout F1 on most-recent-2-week labeled subset Daily Δ-F1 > 0.03 sustained 7d
# drift_check.py
def compute_input_drift(
    current_window: pd.DataFrame,
    reference_window: pd.DataFrame,
    feature_cols: list[str],
) -> dict[str, float]:
    """PSI per feature against the v47 training holdout reference."""
    psi_scores = {}
    for col in feature_cols:
        psi_scores[col] = population_stability_index(
            current=current_window[col],
            reference=reference_window[col],
            n_bins=10,
        )
    return psi_scores


def population_stability_index(current, reference, n_bins=10):
    breakpoints = np.percentile(reference, np.linspace(0, 100, n_bins + 1))
    breakpoints[0] = -np.inf
    breakpoints[-1] = np.inf
    cur_counts, _ = np.histogram(current, bins=breakpoints)
    ref_counts, _ = np.histogram(reference, bins=breakpoints)
    cur_pct = (cur_counts + 1e-6) / cur_counts.sum()
    ref_pct = (ref_counts + 1e-6) / ref_counts.sum()
    return float(np.sum((cur_pct - ref_pct) * np.log(cur_pct / ref_pct)))

Drift signals route to CloudWatch and to the on-call slack channel #manga-ml-oncall. Triage runbook is Runbooks/intent-drift-triage.md.

7. Retraining Trigger Logic

Two retrain triggers:

  • Scheduled: EventBridge rule fires every Friday 03:00 JST (18:00 UTC Thursday). Triggers the pipeline with MaxLabelVersion = max(label_version) of the label-platform.
  • Drift-triggered: drift hub publishes to an SNS topic when concept_drift > 0.03 sustained 7 days OR safety_critical_slice_f1 < 0.85 for any of return_request, escalation, complaint. SNS triggers the pipeline ad-hoc.

Both triggers go through a promotion-eligibility gate that prevents two retrains running concurrently and blocks retrains when global_ml_freeze=true:

# trigger_lambda.py
def lambda_handler(event, context):
    eligibility = check_promotion_eligibility(model_target="intent_classifier")
    if not eligibility.eligible:
        log.warning(f"Skipping retrain: {eligibility.reason}")
        return {"statusCode": 200, "body": "skipped"}

    pipeline_arn = boto3.client("sagemaker").start_pipeline_execution(
        PipelineName="IntentClassifierWeeklyRetrain",
        PipelineParameters=[
            {"Name": "MaxLabelVersion", "Value": str(get_max_label_version())},
            {"Name": "FeatureSchemaVersion", "Value": "schema_v3.4"},
        ],
    )["PipelineExecutionArn"]
    log.info(f"Started retrain: {pipeline_arn}")
    return {"statusCode": 200, "body": pipeline_arn}


def check_promotion_eligibility(model_target: str) -> Eligibility:
    if get_ssm_flag("global_ml_freeze", default=False):
        return Eligibility(False, "global_ml_freeze=true")
    if get_ssm_flag(f"{model_target}_promotion_enabled", default=False) is False:
        return Eligibility(False, f"{model_target}_promotion_enabled=false")
    if has_running_pipeline(model_target):
        return Eligibility(False, "concurrent retrain in progress")
    return Eligibility(True, "eligible")

8. Multilingual Handling (JP/EN-Specific)

Three concerns are JP-specific and have project-wide significance:

Tokenization. DistilBERT-multilingual's WordPiece tokenizer over-segments Japanese (kanji are mostly single-character tokens; this destroys word-level signal). Mitigation: a pre-tokenizer pass with MeCab + Sudachi runs on JP-detected messages before WordPiece. Empirically gains +0.04 macro-F1 on JP-only slices. The pre-tokenizer is part of the model artifact (a serialized vocabulary + Sudachi rules dict), not an external service — the endpoint loads them at startup so there is no network hop on the hot path.

Code-switching. ~14% of MangaAssist's JP-region messages are JP/EN mixed (e.g., "鬼滅の刃 のpre-orderはいつ?"). The intent classifier's tokenizer handles this natively because it is multilingual-trained, but the rule-based pre-filter (US-02) does not — so code-switched messages disproportionately fall to this model. The slice gate explicitly tracks language=mixed to ensure quality stays above 0.85 macro-F1 on this slice.

Honorifics and politeness markers. JP messages with elevated politeness ("〜していただけますでしょうか" vs "〜してくれない?") express the same intent but the formal register is much rarer in the training set. The feature pipeline normalizes politeness markers (collapses 〜ます/〜だ/〜である to a single politeness-feature flag, leaves the verb stem). The data-augmentation script in src/augment.py generates politeness-shifted versions of training examples to ensure class balance across registers.


Monitoring & Metrics

Category Metric Target Alarm Threshold
Online — Latency p50 inference ≤ 7 ms > 12 ms 5min
p95 inference ≤ 15 ms > 22 ms 5min
p99 inference ≤ 35 ms > 50 ms 5min
Online — Throughput RPS match traffic scale-out lag > 60s
Endpoint instance count 2–12 stuck at max 30min
Online — Errors 5xx error rate < 0.05% > 0.5% 5min
Invocation timeout rate < 0.1% > 1% 5min
Quality — Per-class Routing correctness (downstream MCP success) per intent ≥ 0.92 safety-critical, ≥ 0.85 others < 0.90 safety-critical 1h
Quality — Per-slice Macro-F1 EN/JP/mixed ≥ 0.85 < 0.80 24h
Drift PSI per top-30 feature < 0.2 > 0.2 24h
KS on confidence distribution < 0.15 > 0.15 24h
Δ-F1 vs reference < 0.03 > 0.03 7d
Cost $/1k inferences ≤ $0.018 > $0.025 24h
Training $/run ≤ $25 > $40
Pipeline Weekly retrain success rate ≥ 95% < 90% 30d rolling
Pipeline wall-clock ≤ 8h > 10h

Risks & Mitigations

Risk Impact Mitigation
WordPiece over-segmentation on JP regresses JP slice -0.05 macro-F1 on JP, ~6% of traffic mis-routed MeCab+Sudachi pre-tokenizer; JP slice gate enforces ≥0.85; bilingual IAA tracking
Programmatic-label rule changes correlate with label leak Inflates training accuracy ~+0.04, regresses production -0.02 Leak detector at 1ns/serving-lag boundary; quarterly audit of programmatic rule change log against training-data deltas
Vendor IAA drops below 0.75 mid-batch undetected Subtle quality regression in next training cycle Per-annotator running κ dashboards; reject-and-resend at vendor expense; per-batch κ gate in data-validation step
Spot reclaim during 4-epoch training run Wasted compute, missed Friday window Checkpoints every 500 steps to S3; max_wait=14400s; on-demand fallback if >3 reclaims
Concept drift on safety-critical intent unnoticed Misroute return_request / escalation / complaint, CX SEV-2 Concept-drift detection runs daily on these intents specifically; SEV-3 page within 24h of detection
New manga genre (e.g., manhwa) creates intent class confusion catalog_search accuracy regresses on genre-bearing queries Coordinated retrain triggered by US-MLE-05 (embedding adapter on category expansion)
LLM-distilled labels concentrate model bias Training set inherits LLM-classifier failure modes LLM-distilled labels capped at 25% of any class's training examples; quarterly vendor-relabel audit
Endpoint deployment rolls back during peak Customer-visible 5xx spike ~3-5 minutes Blue/green deploy with linear traffic shift; auto-rollback on canary alarms; deployments scheduled outside 21:00–24:00 JST
Feature-schema migration without coordinated US-MLE-02 update Training-serving skew across two stories Feature schema upgrade is a Coordinated Change Request requiring sign-off from all consuming stories
Adversarial prompt-injection bypasses intent classifier into LLM Cost / safety risk in downstream system Adversarial set v2 includes prompt-injection redirected to intent; quarterly red-team

Deep Dive — Why This Works at Amazon-Scale on the Manga Workload

The MangaAssist intent classifier is small (134M params) and serves enormous traffic (~3.5M inferences/day after rule-based filtering). Three workload properties make the design above the right shape rather than the obvious "just train BERT-large weekly":

Workload property 1: bilingual traffic with 10× imbalance. ~85% of traffic is JP, ~14% EN, ~1% mixed; but the catalog is 60% English-titled. The intent class distribution in JP traffic is meaningfully different from EN: recommendation_request is overrepresented in JP (52% vs 38%), support_policy underrepresented (3% vs 11%). A model that optimizes macro-F1 globally will under-serve JP power users on rare classes (escalation, complaint). The slice gate is therefore not optional — it is the only thing preventing macro-F1 improvements from being achieved by sacrificing the JP minority cohort. The bilingual IAA track ensures the input-side label distribution does not drift either; vendor κ is reported per-language because EN and JP annotation teams have separate calibration histories.

Workload property 2: heavy-tailed traffic with peak 3× off-peak. Off-peak JP traffic (07:00–18:00 JST) is dominated by catalog_search and recommendation_request. Peak traffic (21:00–24:00 JST) shifts toward pricing_inquiry, order_lookup, and return_request. The slice-by-time-of-day reveals that a weekly model retrained on time-of-day-balanced data over-fits to the peak distribution; mitigation is the stratified split that includes time-of-day as a stratum. This is the same lesson the Cost-Optimization stories learned about Redis traffic: mean traffic patterns hide peak-time regressions.

Workload property 3: label velocity matches retrain cadence. Vendor produces ~5K labels/week; programmatic rule-based produces ~50K/week; LLM-distilled produces ~10K/week. The combined ~65K/week new examples on top of a ~3.5M cumulative training set is a 1.8% data-set turnover per week. This is the right velocity for weekly retrain — the model sees enough new signal to adapt, but not so much that one bad batch overwhelms the prior. A daily retrain would over-react to noise; a monthly retrain would under-react to drift.

Workload property 4: failure cost is asymmetric across intents. Mis-routing chitchat to recommendation_request is irritating (cost: customer asks again). Mis-routing escalation to chitchat is a CX incident (cost: complaint, possible regulator letter, brand damage). The safety-critical-intent gate puts a higher quality contract on three classes specifically, and the SEV-3 page on those classes is set lower than the global SEV. This is a workload-specific guardrail; a generic intent classifier would treat all 14 classes equally.

These four properties together explain why the per-slice + bilingual-IAA + safety-critical-gate machinery exists. A naïve "retrain weekly when drift fires" without slice gating would silently regress JP power users on escalation and not be caught until a customer complaint reached the CX dashboard.


Real-World Validation

Industry analogues. Amazon's A9 ranking team uses an analogous three-stage promotion (offline → shadow → online holdout) for ranking model promotion, with specific safety carve-outs for safety-critical query classes. Google's intent system for Assistant uses per-language slice gates, with separate IAA tracks for each supported language; their published documentation describes a similar concept-drift detector running on a rolling 2-week labeled subset. Meta's interaction classifier for ranking uses a multi-source label aggregation (vendor + implicit + programmatic) similar to this story's three sources; their published mid-2024 paper on label engineering reports IAA floors at 0.72 (slightly below this story's 0.75; the difference is justified by their use of ordinal labels vs this story's nominal labels).

Math validation — cost. On ml.g5.xlarge at $1.408/hr in ap-northeast-1, a 12-instance peak fleet costs $1.408 × 12 × 24 × 30 = $12,165/month; with autoscaling between 2 and 12 averaging 4 instances, monthly cost is $4,055. Per-inference cost: $4055 / (3.5M × 30) = $0.0000386 per inference, or $3.86 per 100K inferences. Below the $0.018 / 1K target listed in the monitoring table. The training cost: ~$25/run × 4 runs/month = $100/month, immaterial against the serving cost.

Math validation — latency. DistilBERT-multilingual on g5.xlarge (single A10 GPU, 24GB VRAM) with batch size 32 and FP16 produces ~50ms batch latency, dividing to ~1.6ms per request. Add ~5ms for tokenization (including MeCab pre-tokenize on JP messages), ~2ms for SageMaker request overhead, ~3ms for application gateway. Total p50: ~12ms, p95: ~15ms. The target is achievable; verified against a load-test in pre-production.

Math validation — label volume. 14 classes × 200 examples/class JP minimum × 14 classes × 200 examples/class EN minimum = 5,600 minimum examples per retrain to satisfy data-validation. Vendor produces ~5K labels/week, of which ~15% are JP-only escalations and ~85% are split EN/JP. Effective fresh JP examples: ~5K × 0.85 × 0.50 = 2,125/week. Effective fresh EN examples: ~5K × 0.85 × 0.50 = 2,125/week. The 200/class minimum requires 14 × 200 = 2,800 per language, which one vendor batch alone does not satisfy — but cumulative data over 4 weeks does. Hence the 4-week sliding window with weekly retrain.


Cross-Story Interactions

Edge Direction Contract Conflict mode
US-MLE-01 → US-MLE-02 (reranker) provides intent label intent label as feature; pinned to US-MLE-01 model_version If US-MLE-01 mis-routes a query, US-MLE-02 reranker reads wrong intent and ranks accordingly. Mitigation: US-MLE-02 fall-back behavior under intent_confidence < 0.6 documented there.
US-MLE-01 → US-MLE-06 (recommendation) provides intent label intent label as feature in two-tower context tower Same as above. US-MLE-06 reads intent_confidence and, below 0.6, falls back to user-only tower (no context tower).
US-MLE-01 → Cost-Optimization US-08 (traffic-based) provides intent label intent label drives degradation tier assignment US-08 reads from intent label cache; US-MLE-01 model_version pinned in cache key.
US-MLE-01 ← US-MLE-05 (embedding adapter) catalog category expansion triggers retrain when US-MLE-05 promotes a new embedding model after catalog expansion, US-MLE-01 retrains with augmented training data covering new genres If the two retrains are not coordinated, US-MLE-01 continues mis-classifying manhwa queries as catalog_search for the entire week between scheduled retrains. The drift hub catches this and triggers a coordinated retrain.
US-MLE-01 ↔ US-MLE-07 (spam) shared adversarial test set adversarial set v2 includes prompt-injection examples When US-MLE-07 augments its adversarial set, those examples are reviewed for intent-redirect attacks and added to US-MLE-01's adversarial set.
US-MLE-01 → Cost-Optimization US-02 (intent classifier cost) this story owns the model US-02 optimizes the cost of model artifact reused; US-02 owns autoscaling config, this story owns model quality A US-02 cost-driven autoscale-floor change must not violate this story's p95 latency contract. Coordinated change request.

Rollback & Experimentation

Shadow Mode Plan

  • Duration: 24 hours minimum, 72 hours if Friday traffic is below 70% of recent peak (lighter-traffic days produce lower-power statistics).
  • Sample size: ~3.5M predictions × 1 day = 3.5M comparisons; statistical power ≫ what is needed to detect 1% prediction-disagreement shift.
  • Pass criteria: per-request prediction agreement with v47 between 60% and 90% (too low = model is doing something genuinely different; too high = model has not learned anything new). Latency p99 ≤ 1.3× v47.
  • Slice criteria: agreement on each language slice within ±5 percentage points of the global agreement rate. A model that agrees 80% globally but only 50% on JP fails shadow.

Canary Thresholds

  • Phase A (1% traffic, 7 days): routing-correctness on safety-critical intents ≥ baseline; full-population csat ≥ baseline - 0.05; auto-abort on the four canary-daemon conditions.
  • Phase B (5% traffic, 3 days): same as Phase A plus per-slice csat regression ≤ 0.10.
  • Phase C (25% traffic, 7 days): same as Phase B plus cost metric within ±10% of baseline.

Kill-Switch Flags

  • intent_classifier_promotion_enabled (default: false; SSM Parameter Store /manga-ml/intent/promotion_enabled) — when false, weekly trigger logs and exits without starting the pipeline. Used during code freezes.
  • intent_classifier_canary_pause (default: false) — when true, traffic shift halts at current canary stage and waits for manual unblock.
  • global_ml_freeze (default: false) — overrides the above; applies to all 8 stories. See README's kill-switch precedence section.

Quality Regression Criteria (Hard Rollback)

A canary that satisfies any one of these conditions is automatically rolled back:

  • Macro-F1 on any of return_request, escalation, complaint regresses by > 0.02 absolute for any 1-hour window.
  • Routing-correctness (downstream MCP success rate) regresses by > 0.03 absolute for any 1-hour window.
  • p99 latency exceeds 50ms for any 5-minute window during peak.
  • 5xx error rate exceeds 0.5% for any 5-minute window.
  • Customer-reported incident on the production model that appears traceable to a prediction (manual override by SRE on-call).

The rollback is via SageMaker traffic shift to v47 (SLA: 60 seconds). The pipeline registers v48 as status: failed_promotion in the registry and does not retry until manual review.


Multi-Reviewer Validation Findings & Resolutions

S1 — Must Fix Before Production

ML Scientist lens: The training-set's bilingual stratification needs explicit weighting in the loss function, not just in the split. A class-balanced cross-entropy with per-language re-weighting prevents the model from optimizing global macro-F1 by sacrificing JP minority classes. Resolution: train.py uses class_weight="balanced" per-language; the data-validation step computes per-language class weights and passes them as a hyperparameter.

SRE lens: Spot-instance training can fail Friday's window if there are >3 reclaims; the on-demand fallback isn't documented. Resolution: training_pipeline.py checks elapsed time at each checkpoint; if cumulative reclaim time exceeds 60 minutes, the next retry uses on-demand. Documented in Runbooks/intent-spot-fallback.md.

Application Security / Privacy lens: Adversarial set v2 includes prompt-injection examples that contain plausibly real PII (e.g., synthetic order numbers). These must be tagged and excluded from any shared corpora. Resolution: Adversarial set entries carry a synthetic_pii=true flag; data-loading pipeline filters them out of any non-evaluation use; quarterly audit reads the audit log for unexpected access.

S2 — Address Before Scale

Data Engineering lens: intent_predictions_log retention of 90 days hot is generous; consider 60 days hot + 12 months glacier with restore runbook. Resolution: scheduled S3 lifecycle policy applied; runbook Runbooks/intent-log-glacier-restore.md written.

FinOps lens: At 12 instances peak, the endpoint cost is bounded but the rule-based pre-filter has not been re-tuned in 6 months. A higher pre-filter coverage could reduce traffic to this model by another 5–10 percentage points and save another ~$200/month. Resolution: Cost-Optimization US-02's quarterly pre-filter audit is calendar'd; ML Engineer participates in pre-filter coverage review.

ML Scientist lens (S2): Offline-online correlation has been 0.71 over the last 6 retrains, comfortably above the 0.6 trustworthy floor. However, as the catalog and intent distribution evolve, this number may drift. Quarterly recalibration runs to refresh the offline gate's calibration to current online metric distribution. Resolution: Quarterly calibration job calendared; results land in registry metadata.

S3 — Future Work

Principal Architect lens: Model is a single multilingual DistilBERT. As the catalog expands further (Korean manhwa, Chinese manhua), per-language specialized models may outperform a single multilingual model. Tracked as a v2 backlog item; revisit if JP-only or KO-only macro-F1 falls below 0.83 for >2 retrain cycles.

ML Scientist lens (S3): Experimenting with constrained-output decoding (limit predictions to subset of intents based on user-segment or product-page context). Could reduce safety-critical mis-routings but adds complexity. Defer until current model is stable for 90 days post-launch.

SRE lens (S3): Multi-region active-active. Currently this model serves only ap-northeast-1; an EU region extension for KO/EN traffic would require cross-region label and feature replication. Tracked separately under data-residency expansion roadmap.