diff --git a/alembic/versions/e4f5a6b7c8d9_add_model_selection_decision_promotion.py b/alembic/versions/e4f5a6b7c8d9_add_model_selection_decision_promotion.py new file mode 100644 index 00000000..73f2b648 --- /dev/null +++ b/alembic/versions/e4f5a6b7c8d9_add_model_selection_decision_promotion.py @@ -0,0 +1,90 @@ +"""add model_selection decision + promotion columns + +Revision ID: e4f5a6b7c8d9 +Revises: d3e4f5a6b7c8 +Create Date: 2026-06-01 11:00:00.000000 + +Slice C of the Forecast Champion Selector (issue #362). Adds the decision + +operationalization columns to ``model_selection_run`` — all ADDITIVE: + +- ``trained_model_type`` / ``is_override`` / ``override_reason`` — which model + the final bundle holds and whether it was a non-recommended override; +- ``champion_run_id`` / ``promoted_alias`` / ``promotion_decision`` — the + approval-gated registry handoff (registry ``model_run.run_id``, alias name, + and the audited decision record); +- ``feature_frame_version`` — M1, the request's V (1 or 2) persisted at + run-creation so train/promote carry the REAL version end-to-end. The + server_default ``'1'`` backfills legacy rows ONLY (not a code hardcode). + +No CheckConstraint change. ``downgrade`` drops all seven columns. +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "e4f5a6b7c8d9" +down_revision: str | None = "d3e4f5a6b7c8" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Apply migration — seven additive columns on model_selection_run.""" + op.add_column( + "model_selection_run", + sa.Column("trained_model_type", sa.String(length=40), nullable=True), + ) + op.add_column( + "model_selection_run", + sa.Column( + "is_override", + sa.Boolean(), + nullable=False, + server_default=sa.text("false"), + ), + ) + op.add_column( + "model_selection_run", + sa.Column("override_reason", sa.String(length=2000), nullable=True), + ) + op.add_column( + "model_selection_run", + sa.Column("champion_run_id", sa.String(length=32), nullable=True), + ) + op.add_column( + "model_selection_run", + sa.Column("promoted_alias", sa.String(length=100), nullable=True), + ) + op.add_column( + "model_selection_run", + sa.Column( + "promotion_decision", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + ), + ) + op.add_column( + "model_selection_run", + sa.Column( + "feature_frame_version", + sa.Integer(), + nullable=False, + server_default=sa.text("1"), + ), + ) + + +def downgrade() -> None: + """Revert migration — drop the seven Slice C columns.""" + op.drop_column("model_selection_run", "feature_frame_version") + op.drop_column("model_selection_run", "promotion_decision") + op.drop_column("model_selection_run", "promoted_alias") + op.drop_column("model_selection_run", "champion_run_id") + op.drop_column("model_selection_run", "override_reason") + op.drop_column("model_selection_run", "is_override") + op.drop_column("model_selection_run", "trained_model_type") diff --git a/app/features/model_selection/decision.py b/app/features/model_selection/decision.py new file mode 100644 index 00000000..335a3ee8 --- /dev/null +++ b/app/features/model_selection/decision.py @@ -0,0 +1,137 @@ +"""Deterministic forecast-decision layer for the champion selector (Slice C). + +Pure functions — NO LLM, NO DB, NO I/O (mirror ``explanations.py``). Translate a +horizon forecast into an inventory-decision heuristic a planner can act on: +peak/low demand day, a CLEARLY-LABELED safety-stock heuristic, and bias-risk +wording. + +The safety-stock formula is the demand-variability-only form (King 2011, +constant lead time): + + safety_stock = z(service_level) * sigma_daily * sqrt(lead_time_days) + expected_demand_over_lead_time = average_demand * lead_time_days + reorder_point = expected_demand_over_lead_time + safety_stock + +``z`` comes from a fixed one-sided service-level lookup (NO scipy); an +in-between service level falls back to the nearest table key. Every field is +labeled ``method="heuristic"`` and carries a caveat — this output NEVER feeds +ranking (LOCKED #3). +""" + +from __future__ import annotations + +import statistics +from datetime import date +from typing import Any + +from app.features.model_selection.schemas import ForecastDecision + +# LOCKED #4 — the canonical bias sentence, kept byte-identical to the frontend +# ``BIAS_EXPLANATION`` constant (``components/champion-selector/copy.ts``) so the +# wording never drifts between the two surfaces. +BIAS_EXPLANATION = ( + "Positive bias means the model under-forecasts (risk of stockouts); " + "negative bias means it over-forecasts (risk of overstock)." +) + +# One-sided service-level z values (NO scipy dependency). Source: King 2011 +# safety-stock z-from-service-level table. +_Z_TABLE: dict[float, float] = {0.90: 1.2816, 0.95: 1.6449, 0.975: 1.9600, 0.99: 2.3263} + +_CAVEATS = [ + "Safety stock is a deterministic heuristic (demand variability only; constant lead time).", + "Not a substitute for a full inventory-optimisation model.", +] + + +def z_for_service_level(service_level: float) -> float: + """Return the one-sided z for a service level (exact key, else nearest). + + An exact table key returns its z directly; any other level snaps to the + nearest table key (documented heuristic — the table is coarse on purpose). + """ + if service_level in _Z_TABLE: + return _Z_TABLE[service_level] + nearest = min(_Z_TABLE, key=lambda key: abs(key - service_level)) + return _Z_TABLE[nearest] + + +def _coerce_date(value: object) -> date | None: + """Coerce a point's ``date`` (ISO string in JSON-mode dumps, or a date).""" + if isinstance(value, date): + return value + if isinstance(value, str): + try: + return date.fromisoformat(value) + except ValueError: + return None + return None + + +def forecast_peak_low( + points: list[dict[str, Any]], +) -> tuple[date | None, float | None, date | None, float | None]: + """Return ``(peak_date, peak_demand, low_date, low_demand)`` over points. + + Picks the max/min ``forecast`` value; ``(None, None, None, None)`` on an + empty forecast. Ties resolve to the first occurrence (deterministic). + """ + if not points: + return (None, None, None, None) + peak = max(points, key=lambda p: float(p["forecast"])) + low = min(points, key=lambda p: float(p["forecast"])) + return ( + _coerce_date(peak.get("date")), + float(peak["forecast"]), + _coerce_date(low.get("date")), + float(low["forecast"]), + ) + + +def _bias_direction(winner_bias: float | None) -> str: + """Plain-English direction phrase for a winner's bias sign.""" + if winner_bias is None: + return "has no recorded bias measurement" + if winner_bias > 0: + return "under-forecasts (risk of stockouts)" + if winner_bias < 0: + return "over-forecasts (risk of overstock)" + return "is roughly unbiased" + + +def compute_forecast_decision( + points: list[dict[str, Any]], + average_demand: float, + lead_time_days: int, + service_level: float, + winner_bias: float | None, +) -> ForecastDecision: + """Build the deterministic, labeled inventory-decision heuristic. + + ``sigma_daily`` is the POPULATION stdev of the forecast values; a flat or + single-point forecast yields ``sigma=0`` → ``safety_stock=0`` (honest, not + an error). + """ + values = [float(p["forecast"]) for p in points] + sigma_daily = statistics.pstdev(values) if len(values) > 1 else 0.0 + z = z_for_service_level(service_level) + safety_stock = z * sigma_daily * (lead_time_days**0.5) + expected_lt = average_demand * lead_time_days + bias_dir = _bias_direction(winner_bias) + if winner_bias is None: + bias_text = f"{BIAS_EXPLANATION} For this winner, bias {bias_dir}." + else: + bias_text = ( + f"{BIAS_EXPLANATION} For this winner, bias {winner_bias:.2f} indicates it {bias_dir}." + ) + return ForecastDecision( + lead_time_days=lead_time_days, + service_level=service_level, + z_value=z, + sigma_daily_demand=sigma_daily, + expected_demand_over_lead_time=expected_lt, + safety_stock=safety_stock, + reorder_point=expected_lt + safety_stock, + bias_risk_text=bias_text, + caveats=list(_CAVEATS), + ) diff --git a/app/features/model_selection/models.py b/app/features/model_selection/models.py index a39d5763..8d987a58 100644 --- a/app/features/model_selection/models.py +++ b/app/features/model_selection/models.py @@ -15,7 +15,16 @@ from enum import Enum from typing import Any -from sqlalchemy import CheckConstraint, Date, DateTime, ForeignKey, Index, Integer, String +from sqlalchemy import ( + Boolean, + CheckConstraint, + Date, + DateTime, + ForeignKey, + Index, + Integer, + String, +) from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped, mapped_column @@ -110,6 +119,24 @@ class ModelSelectionRun(TimestampMixin, Base): completed_at: Mapped[_dt.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True ) + # Slice C (forecast decision + operationalization) — all additive. + # ``trained_model_type`` records which model the final bundle holds (the + # ranked winner, or a user override); ``is_override`` flags a non-recommended + # choice; the promotion columns capture the approval-gated registry handoff. + trained_model_type: Mapped[str | None] = mapped_column(String(40), nullable=True) + is_override: Mapped[bool] = mapped_column( + Boolean, default=False, server_default="false", nullable=False + ) + override_reason: Mapped[str | None] = mapped_column(String(2000), nullable=True) + champion_run_id: Mapped[str | None] = mapped_column(String(32), nullable=True) + promoted_alias: Mapped[str | None] = mapped_column(String(100), nullable=True) + promotion_decision: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True) + # M1 — V2 promotion support: the request's feature_frame_version persisted at + # run-creation so train/promote carry the REAL version end-to-end. The + # server_default '1' backfills legacy rows ONLY (it is not a code hardcode). + feature_frame_version: Mapped[int] = mapped_column( + Integer, default=1, server_default="1", nullable=False + ) __table_args__ = ( CheckConstraint( diff --git a/app/features/model_selection/routes.py b/app/features/model_selection/routes.py index 7597464e..61a0cb8e 100644 --- a/app/features/model_selection/routes.py +++ b/app/features/model_selection/routes.py @@ -6,7 +6,9 @@ - GET /{selection_id} — fetch a persisted selection run - GET /{selection_id}/ranking — fetch just the ranking block - POST /{selection_id}/train-winner — train the winning model -- POST /{selection_id}/predict — forecast with the trained winner +- POST /{selection_id}/train-selected — train a user-chosen candidate (override) +- POST /{selection_id}/predict — forecast with the trained winner + decision +- POST /{selection_id}/promote — promote the trained champion to a registry alias Error mapping mirrors ``app/features/backtesting/routes.py``: ``ValueError`` → ``BadRequestError`` (RFC 7807 400), ``SQLAlchemyError`` → ``DatabaseError`` (500). @@ -16,7 +18,7 @@ from __future__ import annotations -from fastapi import APIRouter, Depends, Query, Response, status +from fastapi import APIRouter, Body, Depends, Query, Response, status from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession @@ -24,13 +26,17 @@ from app.core.exceptions import BadRequestError, DatabaseError from app.core.logging import get_logger from app.features.model_selection.schemas import ( + ForecastDecisionParams, ModelCatalogResponse, ModelSelectionRunRequest, ModelSelectionRunResponse, PairAvailabilityResponse, PredictWinnerResponse, + PromoteRequest, + PromoteResponse, RankingResult, SubmitRunResponse, + TrainSelectedRequest, TrainWinnerResponse, ) from app.features.model_selection.service import ModelSelectionService @@ -241,24 +247,92 @@ async def train_winner( ) from exc +@router.post( + "/{selection_id}/train-selected", + response_model=TrainWinnerResponse, + status_code=status.HTTP_200_OK, + summary="Train a user-chosen candidate (override)", +) +async def train_selected( + selection_id: str, + request: TrainSelectedRequest, + db: AsyncSession = Depends(get_db), +) -> TrainWinnerResponse: + """Train a chosen candidate (override). A non-candidate ``model_type`` → 400. + + Overriding the recommended winner returns ``is_override=true`` plus an + ``override_warning`` and records the override reason on the run. + """ + service = ModelSelectionService() + try: + return await service.train_selected( + db, selection_id, request.model_type, request.override_reason + ) + except ValueError as exc: + raise BadRequestError(message=str(exc)) from exc + except SQLAlchemyError as exc: + raise DatabaseError( + message="Failed to train selected model", details={"error": str(exc)} + ) from exc + + @router.post( "/{selection_id}/predict", response_model=PredictWinnerResponse, status_code=status.HTTP_200_OK, - summary="Forecast with the trained winning model", + summary="Forecast with the trained model + inventory decision", ) async def predict_winner( selection_id: str, + request: ForecastDecisionParams | None = Body(default=None), db: AsyncSession = Depends(get_db), ) -> PredictWinnerResponse: - """Generate a horizon forecast from the trained champion bundle.""" + """Generate a horizon forecast + a labeled safety-stock decision heuristic. + + The body is OPTIONAL — an empty body uses ``ForecastDecisionParams`` + defaults (lead_time_days=7, service_level=0.95). A feature-aware model 400s + (use the What-If Planner instead). + """ + params = request or ForecastDecisionParams() service = ModelSelectionService() try: - forecast = await service.predict_winner(db, selection_id) - return PredictWinnerResponse(selection_id=selection_id, forecast=forecast) + forecast, decision = await service.predict_winner( + db, selection_id, params.lead_time_days, params.service_level + ) + return PredictWinnerResponse( + selection_id=selection_id, forecast=forecast, decision=decision + ) except ValueError as exc: raise BadRequestError(message=str(exc)) from exc except SQLAlchemyError as exc: raise DatabaseError( message="Failed to forecast with winning model", details={"error": str(exc)} ) from exc + + +@router.post( + "/{selection_id}/promote", + response_model=PromoteResponse, + status_code=status.HTTP_200_OK, + summary="Promote the trained champion to a registry alias (approval-gated)", +) +async def promote( + selection_id: str, + request: PromoteRequest, + db: AsyncSession = Depends(get_db), +) -> PromoteResponse: + """Register a SUCCESS model run + alias for the trained champion. + + Approval-gated + audited: requires ``approved_by``; a non-recommended model + requires ``acknowledge_non_recommended=true`` (else 422); promoting before + training → 422; a bad ``alias_name`` → 422 at the schema boundary. + """ + service = ModelSelectionService() + try: + return await service.promote(db, selection_id, request) + except ValueError as exc: + raise BadRequestError(message=str(exc)) from exc + except SQLAlchemyError as exc: + raise DatabaseError( + message="Failed to promote champion", details={"error": str(exc)} + ) from exc diff --git a/app/features/model_selection/schemas.py b/app/features/model_selection/schemas.py index 050d3ead..f494882d 100644 --- a/app/features/model_selection/schemas.py +++ b/app/features/model_selection/schemas.py @@ -164,6 +164,49 @@ class AvailabilityQuery(BaseModel): forecast_horizon: int = Field(default=14, ge=1, le=90) +class TrainSelectedRequest(BaseModel): + """``POST /model-selection/{id}/train-selected`` body (Slice C). + + Trains a USER-CHOSEN candidate (override). Only ``str``/``None`` fields so + ``ConfigDict(strict=True)`` needs no ``Field(strict=False)`` override (no + date/uuid/decimal field) — keeps ``test_strict_mode_policy`` green. + """ + + model_config = ConfigDict(strict=True) + + model_type: ModelType + override_reason: str | None = Field(default=None, max_length=2000) + + +class ForecastDecisionParams(BaseModel): + """Optional ``POST /model-selection/{id}/predict`` body (Slice C). + + Drives the deterministic safety-stock heuristic in ``decision.py``. All + fields are JSON-native (``int``/``float``) → no ``Field(strict=False)``. + """ + + model_config = ConfigDict(strict=True) + + lead_time_days: int = Field(default=7, ge=1, le=365) + service_level: float = Field(default=0.95, ge=0.5, lt=1.0) + + +class PromoteRequest(BaseModel): + """``POST /model-selection/{id}/promote`` body (Slice C). + + Approval-gated promotion of a trained champion to a registry alias. + ``alias_name`` mirrors the registry regex so a bad name 422s at the schema + boundary. ``approved_by`` is required — promotion is never automatic. + """ + + model_config = ConfigDict(strict=True) + + alias_name: str = Field(..., min_length=1, max_length=100, pattern=r"^[a-z0-9][a-z0-9\-_]*$") + approved_by: str = Field(..., min_length=1, max_length=100) + acknowledge_non_recommended: bool = False + description: str | None = Field(default=None, max_length=500) + + # ============================================================================= # Intermediate models (service-internal; embedded in JSONB snapshots) # ============================================================================= @@ -259,12 +302,41 @@ class PairAvailabilityResponse(BaseModel): class ForecastSummary(BaseModel): - """Forecast output rolled up for the response.""" + """Forecast output rolled up for the response. + + Slice C adds ``peak_*`` / ``low_*`` as ADDITIVE Optional fields (default + ``None``) so legacy JSONB snapshots written by the Slice A/B auto-predict + path still validate on reload. + """ points: list[dict[str, Any]] total_demand: float average_demand: float horizon: int + peak_date: date | None = None + peak_demand: float | None = None + low_date: date | None = None + low_demand: float | None = None + + +class ForecastDecision(BaseModel): + """Deterministic, CLEARLY-LABELED inventory-decision heuristic (Slice C). + + Computed by ``decision.compute_forecast_decision`` from the forecast points + + lead time + service level. ``method`` is fixed ``"heuristic"`` and every + use carries a caveat — this NEVER feeds ranking (LOCKED #3). + """ + + method: Literal["heuristic"] = "heuristic" + lead_time_days: int + service_level: float + z_value: float + sigma_daily_demand: float + expected_demand_over_lead_time: float + safety_stock: float + reorder_point: float + bias_risk_text: str + caveats: list[str] class CandidateProgress(BaseModel): @@ -364,15 +436,40 @@ class ModelCatalogResponse(BaseModel): class TrainWinnerResponse(BaseModel): - """``POST /model-selection/{id}/train-winner`` response.""" + """``POST /model-selection/{id}/train-winner`` and ``/train-selected``. + + Slice C adds ``is_override`` / ``override_warning`` as ADDITIVE fields with + back-compatible defaults — ``train-winner`` still returns + ``is_override=False`` + ``override_warning=None`` (its shape is unchanged for + existing callers/tests). + """ selection_id: str model_type: str model_path: str + is_override: bool = False + override_warning: str | None = None class PredictWinnerResponse(BaseModel): - """``POST /model-selection/{id}/predict`` response.""" + """``POST /model-selection/{id}/predict`` response. + + Slice C adds ``decision`` (the labeled safety-stock heuristic) as an + ADDITIVE Optional field; ``forecast`` now also carries peak/low. + """ selection_id: str forecast: ForecastSummary + decision: ForecastDecision | None = None + + +class PromoteResponse(BaseModel): + """``POST /model-selection/{id}/promote`` response (Slice C).""" + + selection_id: str + alias_name: str + run_id: str + run_status: str + model_type: str + is_override: bool + promoted_at: datetime diff --git a/app/features/model_selection/service.py b/app/features/model_selection/service.py index 743e647e..baef2875 100644 --- a/app/features/model_selection/service.py +++ b/app/features/model_selection/service.py @@ -19,7 +19,8 @@ import uuid from collections.abc import Sequence from datetime import UTC, datetime -from typing import TYPE_CHECKING +from pathlib import Path +from typing import TYPE_CHECKING, Any from sqlalchemy import and_, func, or_, select from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker @@ -31,12 +32,17 @@ ConflictError, GatewayTimeoutError, NotFoundError, + UnprocessableEntityError, ) from app.core.logging import get_logger from app.features.backtesting.schemas import SplitConfig from app.features.data_platform.models import Product, Promotion, SalesDaily, Store from app.features.model_selection import runner from app.features.model_selection.capabilities import build_model_catalog +from app.features.model_selection.decision import ( + compute_forecast_decision, + forecast_peak_low, +) from app.features.model_selection.explanations import explain_winner from app.features.model_selection.models import ( TERMINAL_SELECTION_STATES, @@ -53,11 +59,14 @@ CandidateResult, ChartData, FoldChart, + ForecastDecision, ForecastSummary, ModelCatalogResponse, ModelSelectionRunRequest, ModelSelectionRunResponse, PairAvailabilityResponse, + PromoteRequest, + PromoteResponse, RankingResult, SelectionProgress, SelectionWindow, @@ -269,6 +278,7 @@ async def run_selection( ranking_metric=request.ranking_metric, candidate_models=[c.model_dump() for c in request.candidate_models], policy_snapshot=request.ranking_policy.model_dump(mode="json"), + feature_frame_version=request.feature_frame_version, ) db.add(row) await db.flush() @@ -463,6 +473,7 @@ async def submit_run( availability_snapshot=availability.model_dump(mode="json"), started_at=now, total_candidates=len(request.candidate_models), + feature_frame_version=request.feature_frame_version, ) db.add(row) # Flush the parent INSERT before the children — there is no ORM @@ -920,8 +931,15 @@ async def train_winner(self, db: AsyncSession, selection_id: str) -> TrainWinner row.start_date, row.end_date, cfg, # type: ignore[arg-type] + feature_frame_version=row.feature_frame_version, # M1 — train as configured (V1/V2) ) row.final_model_path = train.model_path + # Slice C — additive: the winner is the trained model and is NOT an + # override. The response shape is unchanged (is_override/override_warning + # default to False/None). + row.trained_model_type = ranking.winner.model_type + row.is_override = False + row.override_reason = None await db.flush() logger.info( "model_selection.winner_trained", @@ -934,8 +952,87 @@ async def train_winner(self, db: AsyncSession, selection_id: str) -> TrainWinner model_path=train.model_path, ) - async def predict_winner(self, db: AsyncSession, selection_id: str) -> ForecastSummary: - """Forecast with the trained winning model (requires train-winner first).""" + async def train_selected( + self, + db: AsyncSession, + selection_id: str, + model_type: str, + override_reason: str | None, + ) -> TrainWinnerResponse: + """Train a USER-CHOSEN candidate (override) — Slice C. + + ``model_type`` must be one of the run's CONFIGURED candidates + (``candidate_models``), NOT only the ranked/included entries: a candidate + that FAILED its backtest is still override-trainable (training is + independent of backtesting). A model never offered as a candidate → 400. + """ + from pydantic import TypeAdapter # lazy + + from app.features.forecasting.schemas import ModelConfig # lazy + from app.features.forecasting.service import ForecastingService # lazy + + row = await self._load(db, selection_id) + ranking = self._load_ranking(row) + + configured = { + str(c.get("model_type")) for c in (row.candidate_models or []) if c.get("model_type") + } + if model_type not in configured: + raise BadRequestError( + message=( + f"Model '{model_type}' was not a candidate in this selection. " + f"Candidates: {sorted(configured)}." + ) + ) + + params = self._params_for_trained_type(row, model_type) + adapter: TypeAdapter[object] = TypeAdapter(ModelConfig) + cfg = adapter.validate_python({"model_type": model_type, **params}) + train = await ForecastingService().train_model( + db, + row.store_id, + row.product_id, + row.start_date, + row.end_date, + cfg, # type: ignore[arg-type] + feature_frame_version=row.feature_frame_version, # M1 — V1/V2 as configured + ) + row.final_model_path = train.model_path + row.trained_model_type = model_type + winner_type = ranking.winner.model_type if ranking.winner else None + row.is_override = (model_type != winner_type) if winner_type is not None else True + row.override_reason = override_reason + await db.flush() + + warning = self._override_warning(model_type, ranking) if row.is_override else None + logger.info( + "model_selection.winner_selected_override", + selection_id=row.selection_id, + model_type=model_type, + is_override=row.is_override, + ) + return TrainWinnerResponse( + selection_id=row.selection_id, + model_type=model_type, + model_path=train.model_path, + is_override=row.is_override, + override_warning=warning, + ) + + async def predict_winner( + self, + db: AsyncSession, + selection_id: str, + lead_time_days: int, + service_level: float, + ) -> tuple[ForecastSummary, ForecastDecision | None]: + """Forecast with the trained model + compute the decision heuristic. + + Returns a ``(forecast, decision)`` tuple — the ROUTE assembles the + ``PredictWinnerResponse``. ``decision`` (safety stock etc.) NEVER feeds + ranking. A feature-aware model 400s inside ``ForecastingService.predict`` + (bubbles as ``ValueError`` → 400). + """ from app.features.forecasting.service import ForecastingService # lazy row = await self._load(db, selection_id) @@ -947,14 +1044,131 @@ async def predict_winner(self, db: AsyncSession, selection_id: str) -> ForecastS row.store_id, row.product_id, row.forecast_horizon, row.final_model_path ) summary = self._forecast_summary(prediction, row.forecast_horizon) + winner_bias: float | None = None + if row.winner_metrics is not None and row.winner_metrics.get("bias") is not None: + winner_bias = float(row.winner_metrics["bias"]) + decision = compute_forecast_decision( + summary.points, + summary.average_demand, + lead_time_days, + service_level, + winner_bias, + ) row.forecast_result = summary.model_dump(mode="json") await db.flush() logger.info( "model_selection.winner_predicted", selection_id=row.selection_id, horizon=row.forecast_horizon, + lead_time_days=lead_time_days, + ) + return summary, decision + + async def promote( + self, db: AsyncSession, selection_id: str, req: PromoteRequest + ) -> PromoteResponse: + """Approval-gated, audited promotion of a trained champion (Slice C). + + Orchestrates the registry in ONE request transaction (create_run → + RUNNING → register artifact → SUCCESS → create_alias), then persists the + audit record on ``model_selection_run``. Promotion is NEVER automatic and + performs NO comparison. + """ + from app.features.registry.schemas import ( # lazy + AliasCreate, + RunCreate, + RunStatus, + RunUpdate, + ) + from app.features.registry.service import RegistryService # lazy + + row = await self._load(db, selection_id) + if not row.final_model_path or not row.trained_model_type: + raise UnprocessableEntityError(message="Train the model before promoting.") + if row.is_override and not req.acknowledge_non_recommended: + raise UnprocessableEntityError( + message=( + "Promoting a non-recommended model requires acknowledge_non_recommended=true." + ) + ) + + registry = RegistryService() + params = self._params_for_trained(row) + # ``RunCreate``/``RunUpdate`` use Pydantic ``Field(None, ...)`` defaults + + # the ``model_config`` alias; mypy's pydantic plugin resolves these but + # pyright (no plugin) cannot — mirror the established + # ``registry_tools.py`` suppression. ``model_config_data=`` is the field + # name (populate_by_name=True), NOT the ``model_config`` ConfigDict alias. + run = await registry.create_run( + db, + RunCreate( # pyright: ignore[reportCallIssue] + model_type=row.trained_model_type, + model_config_data={ # pyright: ignore[reportCallIssue] + "model_type": row.trained_model_type, + **params, + }, + data_window_start=row.start_date, + data_window_end=row.end_date, + store_id=row.store_id, + product_id=row.product_id, + runtime_info_extras={"feature_frame_version": row.feature_frame_version}, + ), + ) + await registry.update_run( + db, + run.run_id, + RunUpdate(status=RunStatus.RUNNING), # pyright: ignore[reportCallIssue] + ) + artifact_uri, artifact_hash, artifact_size = self._register_artifact( + row.final_model_path, run.run_id + ) + await registry.update_run( + db, + run.run_id, + RunUpdate( # pyright: ignore[reportCallIssue] + status=RunStatus.SUCCESS, + metrics=row.winner_metrics, + artifact_uri=artifact_uri, + artifact_hash=artifact_hash, + artifact_size_bytes=artifact_size, + ), + ) + alias = await registry.create_alias( + db, + AliasCreate(alias_name=req.alias_name, run_id=run.run_id, description=req.description), + ) + + promoted_at = datetime.now(UTC) + row.champion_run_id = run.run_id + row.promoted_alias = alias.alias_name + row.promotion_decision = { + "decision_id": uuid.uuid4().hex, + "alias": alias.alias_name, + "champion_run_id": run.run_id, + "approved_by": req.approved_by, + "approved_at": promoted_at.isoformat(), + "decision": "promoted", + "reason": req.description, + "trained_model_type": row.trained_model_type, + "is_override": row.is_override, + } + await db.flush() + logger.info( + "model_selection.champion_promoted", + selection_id=row.selection_id, + alias=alias.alias_name, + run_id=run.run_id, + approved_by=req.approved_by, + ) + return PromoteResponse( + selection_id=row.selection_id, + alias_name=alias.alias_name, + run_id=run.run_id, + run_status=alias.run_status.value, + model_type=row.trained_model_type, + is_override=row.is_override, + promoted_at=promoted_at, ) - return summary # ------------------------------------------------------------------------- # Pure mappers @@ -1001,10 +1215,78 @@ def _forecast_summary(self, prediction: PredictResponse, horizon: int) -> Foreca points = [point.model_dump(mode="json") for point in prediction.forecasts] total = float(sum(point.forecast for point in prediction.forecasts)) average = total / len(prediction.forecasts) if prediction.forecasts else 0.0 + peak_date, peak_demand, low_date, low_demand = forecast_peak_low(points) return ForecastSummary( - points=points, total_demand=total, average_demand=average, horizon=horizon + points=points, + total_demand=total, + average_demand=average, + horizon=horizon, + peak_date=peak_date, + peak_demand=peak_demand, + low_date=low_date, + low_demand=low_demand, + ) + + @staticmethod + def _params_for_trained_type(row: ModelSelectionRun, model_type: str) -> dict[str, Any]: + """Return the configured params for a candidate ``model_type`` (or {}).""" + for candidate in row.candidate_models or []: + if candidate.get("model_type") == model_type: + params = candidate.get("params") or {} + return dict(params) + return {} + + def _params_for_trained(self, row: ModelSelectionRun) -> dict[str, Any]: + """Return the params of the model actually trained on this run.""" + if row.trained_model_type is None: + return {} + return self._params_for_trained_type(row, row.trained_model_type) + + @staticmethod + def _override_warning(chosen_type: str, ranking: RankingResult) -> str: + """Deterministic warning copy when a non-recommended model is trained.""" + winner = ranking.winner + if winner is None: + return f"You trained '{chosen_type}', but no model was recommended for this selection." + chosen_entry = next( + (e for e in ranking.entries if e.model_type == chosen_type and e.included), + None, + ) + winner_wape = (winner.metrics or {}).get("wape") + if chosen_entry and chosen_entry.metrics and winner_wape is not None: + chosen_wape = chosen_entry.metrics.get("wape") + if chosen_wape is not None: + gap = chosen_wape - winner_wape + return ( + f"You trained '{chosen_type}' instead of the recommended " + f"'{winner.model_type}'. Its backtest WAPE is {chosen_wape:.1f}% " + f"vs the recommended {winner_wape:.1f}% " + f"(a {gap:+.1f} percentage-point gap)." + ) + return ( + f"You trained '{chosen_type}' instead of the recommended " + f"'{winner.model_type}'. '{chosen_type}' was not successfully evaluated " + "in the backtest, so no WAPE comparison is available." ) + @staticmethod + def _register_artifact(final_model_path: str, run_id: str) -> tuple[str, str, int]: + """Copy the trained bundle into registry storage and return (uri, hash, size). + + Mirrors the demo pipeline's register step (``demo/pipeline.py``): the + forecasting bundle lives under ``forecast_model_artifacts_dir``; copying + it into ``registry_artifact_root`` makes the promoted run's artifact + verifiable via ``GET /registry/runs/{id}/verify``. + """ + from app.features.registry.storage import LocalFSProvider # lazy + + source = Path(final_model_path) + if not source.exists(): + raise BadRequestError(message=f"Trained artifact missing at {final_model_path}") + artifact_uri = f"champion-selector/{run_id}-{source.name}" + file_hash, file_size = LocalFSProvider().save(source, artifact_uri) + return artifact_uri, file_hash, file_size + async def _load(self, db: AsyncSession, selection_id: str) -> ModelSelectionRun: row = await db.scalar( select(ModelSelectionRun).where(ModelSelectionRun.selection_id == selection_id) diff --git a/app/features/model_selection/tests/conftest.py b/app/features/model_selection/tests/conftest.py index 3d0c0ae7..0a417631 100644 --- a/app/features/model_selection/tests/conftest.py +++ b/app/features/model_selection/tests/conftest.py @@ -19,7 +19,7 @@ import pytest from httpx import ASGITransport, AsyncClient -from sqlalchemy import delete +from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from app.core.config import get_settings @@ -32,6 +32,7 @@ FoldChart, PairAvailabilityResponse, ) +from app.features.registry.models import DeploymentAlias, ModelRun from app.main import app # Integration test window. @@ -204,6 +205,18 @@ async def db_session() -> AsyncGenerator[AsyncSession, None]: finally: store_ids = _registered_store_ids() if store_ids: + # Slice C — clean up registry runs/aliases a promote() created + # (cross-slice teardown). Delete aliases first (FK to + # model_run.id), then the runs, scoped to the seeded store ids. + run_id_rows = await session.execute( + select(ModelRun.id).where(ModelRun.store_id.in_(store_ids)) + ) + run_ids = [r[0] for r in run_id_rows] + if run_ids: + await session.execute( + delete(DeploymentAlias).where(DeploymentAlias.run_id.in_(run_ids)) + ) + await session.execute(delete(ModelRun).where(ModelRun.id.in_(run_ids))) await session.execute( delete(ModelSelectionRun).where(ModelSelectionRun.store_id.in_(store_ids)) ) diff --git a/app/features/model_selection/tests/test_decision.py b/app/features/model_selection/tests/test_decision.py new file mode 100644 index 00000000..f1ab66b2 --- /dev/null +++ b/app/features/model_selection/tests/test_decision.py @@ -0,0 +1,155 @@ +"""Unit tests for the pure forecast-decision module (Slice C). + +``decision.py`` has NO DB/IO — every function is deterministic and tested here +directly (z-table, safety-stock formula, peak/low, bias wording). +""" + +from __future__ import annotations + +import statistics +from datetime import date + +import pytest + +from app.features.model_selection.decision import ( + BIAS_EXPLANATION, + compute_forecast_decision, + forecast_peak_low, + z_for_service_level, +) + + +def _points(values: list[float], start: date = date(2026, 1, 1)) -> list[dict[str, object]]: + """Build forecast points (JSON-mode shape: ISO date string + forecast).""" + return [ + { + "date": (start.fromordinal(start.toordinal() + i)).isoformat(), + "forecast": v, + "lower_bound": None, + "upper_bound": None, + } + for i, v in enumerate(values) + ] + + +# ----------------------------------------------------------------------------- +# z-table +# ----------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + ("service_level", "expected_z"), + [(0.90, 1.2816), (0.95, 1.6449), (0.975, 1.9600), (0.99, 2.3263)], +) +def test_decision_z_table_exact(service_level: float, expected_z: float) -> None: + assert z_for_service_level(service_level) == expected_z + + +@pytest.mark.parametrize( + ("service_level", "expected_z"), + [(0.92, 1.2816), (0.93, 1.6449), (0.96, 1.6449), (0.98, 1.9600)], +) +def test_decision_z_table_nearest(service_level: float, expected_z: float) -> None: + """In-between service levels snap to the nearest table key.""" + assert z_for_service_level(service_level) == expected_z + + +# ----------------------------------------------------------------------------- +# Safety stock +# ----------------------------------------------------------------------------- + + +def test_safety_stock_formula_matches_z_sigma_sqrt_l() -> None: + values = [10.0, 12.0, 8.0, 11.0, 9.0] + decision = compute_forecast_decision( + _points(values), average_demand=10.0, lead_time_days=7, service_level=0.95, winner_bias=0.5 + ) + sigma = statistics.pstdev(values) + expected_ss = 1.6449 * sigma * (7**0.5) + assert decision.method == "heuristic" + assert decision.z_value == 1.6449 + assert decision.sigma_daily_demand == pytest.approx(sigma) + assert decision.safety_stock == pytest.approx(expected_ss) + assert decision.expected_demand_over_lead_time == pytest.approx(70.0) + assert decision.reorder_point == pytest.approx(70.0 + expected_ss) + assert decision.caveats # always carries a caveat + + +def test_flat_forecast_safety_stock_zero() -> None: + """A flat (zero-variance) forecast → sigma 0 → safety stock 0 (honest).""" + decision = compute_forecast_decision( + _points([10.0, 10.0, 10.0]), + average_demand=10.0, + lead_time_days=7, + service_level=0.95, + winner_bias=0.0, + ) + assert decision.sigma_daily_demand == 0.0 + assert decision.safety_stock == 0.0 + + +def test_single_point_forecast_safety_stock_zero() -> None: + decision = compute_forecast_decision( + _points([42.0]), + average_demand=42.0, + lead_time_days=3, + service_level=0.95, + winner_bias=None, + ) + assert decision.sigma_daily_demand == 0.0 + assert decision.safety_stock == 0.0 + + +# ----------------------------------------------------------------------------- +# Peak / low +# ----------------------------------------------------------------------------- + + +def test_forecast_peak_low_picks_max_and_min() -> None: + points = _points([10.0, 25.0, 5.0, 18.0]) + peak_date, peak_demand, low_date, low_demand = forecast_peak_low(points) + assert peak_demand == 25.0 + assert low_demand == 5.0 + assert peak_date == date(2026, 1, 2) + assert low_date == date(2026, 1, 3) + + +def test_forecast_peak_low_empty_returns_none() -> None: + assert forecast_peak_low([]) == (None, None, None, None) + + +# ----------------------------------------------------------------------------- +# Bias wording (LOCKED #4 — reuses BIAS_EXPLANATION) +# ----------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + ("bias", "fragment"), + [ + (1.5, "under-forecasts (risk of stockouts)"), + (-1.5, "over-forecasts (risk of overstock)"), + (0.0, "roughly unbiased"), + ], +) +def test_bias_risk_text_under_over_neutral(bias: float, fragment: str) -> None: + decision = compute_forecast_decision( + _points([10.0, 12.0]), + average_demand=11.0, + lead_time_days=7, + service_level=0.95, + winner_bias=bias, + ) + assert BIAS_EXPLANATION in decision.bias_risk_text + assert fragment in decision.bias_risk_text + + +def test_bias_risk_text_handles_missing_bias() -> None: + decision = compute_forecast_decision( + _points([10.0, 12.0]), + average_demand=11.0, + lead_time_days=7, + service_level=0.95, + winner_bias=None, + ) + assert BIAS_EXPLANATION in decision.bias_risk_text + assert "no recorded bias" in decision.bias_risk_text diff --git a/app/features/model_selection/tests/test_models.py b/app/features/model_selection/tests/test_models.py index 7264aec9..589d7630 100644 --- a/app/features/model_selection/tests/test_models.py +++ b/app/features/model_selection/tests/test_models.py @@ -78,3 +78,38 @@ def test_model_selection_candidate_construction() -> None: assert cand.status == "pending" assert cand.result is None assert cand.error_message is None + + +# ============================================================================= +# Slice C — decision + promotion columns (in-Python construction) +# ============================================================================= + + +def test_model_selection_run_slice_c_columns_construct() -> None: + run = ModelSelectionRun( + selection_id="selC", + status=ModelSelectionStatus.COMPLETED.value, + store_id=3, + product_id=8, + start_date=date(2026, 1, 1), + end_date=date(2026, 5, 31), + forecast_horizon=14, + ranking_metric="wape", + candidate_models=[{"model_type": "naive", "params": {}}], + policy_snapshot={}, + trained_model_type="naive", + is_override=True, + override_reason="domain seasonality", + champion_run_id="run_abc123", + promoted_alias="champion-test", + promotion_decision={"decision": "promoted", "approved_by": "gabor"}, + feature_frame_version=2, + ) + assert run.trained_model_type == "naive" + assert run.is_override is True + assert run.override_reason == "domain seasonality" + assert run.champion_run_id == "run_abc123" + assert run.promoted_alias == "champion-test" + assert run.promotion_decision is not None + assert run.promotion_decision["approved_by"] == "gabor" + assert run.feature_frame_version == 2 diff --git a/app/features/model_selection/tests/test_routes.py b/app/features/model_selection/tests/test_routes.py index 2effbc62..6193012d 100644 --- a/app/features/model_selection/tests/test_routes.py +++ b/app/features/model_selection/tests/test_routes.py @@ -228,3 +228,180 @@ async def test_models_route_not_captured_by_selection_id( response = await ac.get("/model-selection/models") assert response.status_code == 200 assert "models" in response.json() + + +# ============================================================================= +# Slice C — train-selected / predict-with-decision / promote routes +# ============================================================================= + +from app.core.exceptions import UnprocessableEntityError # noqa: E402 +from app.features.model_selection.schemas import ( # noqa: E402 + ForecastDecision, + ForecastSummary, + PromoteResponse, + TrainWinnerResponse, +) + + +def _forecast_summary() -> ForecastSummary: + return ForecastSummary( + points=[{"date": "2026-06-01", "forecast": 10.0}], + total_demand=10.0, + average_demand=10.0, + horizon=14, + peak_date=date(2026, 6, 1), + peak_demand=10.0, + low_date=date(2026, 6, 1), + low_demand=10.0, + ) + + +def _forecast_decision() -> ForecastDecision: + return ForecastDecision( + lead_time_days=7, + service_level=0.95, + z_value=1.6449, + sigma_daily_demand=0.0, + expected_demand_over_lead_time=70.0, + safety_stock=0.0, + reorder_point=70.0, + bias_risk_text="bias text", + caveats=["heuristic"], + ) + + +async def test_train_selected_returns_200(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + ModelSelectionService, + "train_selected", + AsyncMock( + return_value=TrainWinnerResponse( + selection_id="sel123", + model_type="seasonal_naive", + model_path="artifacts/models/model_x.joblib", + is_override=True, + override_warning="you overrode the recommendation", + ) + ), + ) + async with _client() as ac: + response = await ac.post( + "/model-selection/sel123/train-selected", + json={"model_type": "seasonal_naive", "override_reason": "domain"}, + ) + assert response.status_code == 200 + body = response.json() + assert body["is_override"] is True + assert body["override_warning"] + + +async def test_train_selected_bad_model_type_returns_400(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + ModelSelectionService, + "train_selected", + AsyncMock(side_effect=BadRequestError(message="not a candidate")), + ) + async with _client() as ac: + response = await ac.post( + "/model-selection/sel123/train-selected", + json={"model_type": "naive"}, + ) + assert response.status_code == 400 + _assert_problem_detail(response.json(), 400) + + +async def test_predict_no_body_uses_defaults_returns_200(monkeypatch: pytest.MonkeyPatch) -> None: + """Empty body → ForecastDecisionParams defaults → 200 with a decision.""" + predict_mock = AsyncMock(return_value=(_forecast_summary(), _forecast_decision())) + monkeypatch.setattr(ModelSelectionService, "predict_winner", predict_mock) + async with _client() as ac: + response = await ac.post("/model-selection/sel123/predict") + assert response.status_code == 200 + body = response.json() + assert body["decision"]["lead_time_days"] == 7 + assert body["forecast"]["peak_demand"] == 10.0 + # service called with the default lead time + service level + assert predict_mock.await_args is not None + assert predict_mock.await_args.args[2] == 7 + assert predict_mock.await_args.args[3] == 0.95 + + +async def test_predict_with_body_returns_200(monkeypatch: pytest.MonkeyPatch) -> None: + predict_mock = AsyncMock(return_value=(_forecast_summary(), _forecast_decision())) + monkeypatch.setattr(ModelSelectionService, "predict_winner", predict_mock) + async with _client() as ac: + response = await ac.post( + "/model-selection/sel123/predict", + json={"lead_time_days": 14, "service_level": 0.99}, + ) + assert response.status_code == 200 + assert predict_mock.await_args is not None + assert predict_mock.await_args.args[2] == 14 + assert predict_mock.await_args.args[3] == 0.99 + + +async def test_predict_feature_aware_returns_400(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + ModelSelectionService, + "predict_winner", + AsyncMock(side_effect=ValueError("Feature-aware models forecast through /scenarios")), + ) + async with _client() as ac: + response = await ac.post("/model-selection/sel123/predict") + assert response.status_code == 400 + _assert_problem_detail(response.json(), 400) + + +async def test_promote_returns_200(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + ModelSelectionService, + "promote", + AsyncMock( + return_value=PromoteResponse( + selection_id="sel123", + alias_name="champion-test", + run_id="run_abc", + run_status="success", + model_type="naive", + is_override=False, + promoted_at=datetime(2026, 6, 1, tzinfo=UTC), + ) + ), + ) + async with _client() as ac: + response = await ac.post( + "/model-selection/sel123/promote", + json={"alias_name": "champion-test", "approved_by": "gabor"}, + ) + assert response.status_code == 200 + body = response.json() + assert body["alias_name"] == "champion-test" + assert body["run_status"] == "success" + + +async def test_promote_bad_alias_name_returns_422() -> None: + """A bad alias_name is rejected by the schema regex (422) before the service.""" + async with _client() as ac: + response = await ac.post( + "/model-selection/sel123/promote", + json={"alias_name": "Bad Alias!", "approved_by": "gabor"}, + ) + assert response.status_code == 422 + _assert_problem_detail(response.json(), 422) + + +async def test_promote_before_train_returns_422(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + ModelSelectionService, + "promote", + AsyncMock( + side_effect=UnprocessableEntityError(message="Train the model before promoting.") + ), + ) + async with _client() as ac: + response = await ac.post( + "/model-selection/sel123/promote", + json={"alias_name": "champion-test", "approved_by": "gabor"}, + ) + assert response.status_code == 422 + _assert_problem_detail(response.json(), 422) diff --git a/app/features/model_selection/tests/test_routes_integration.py b/app/features/model_selection/tests/test_routes_integration.py index b74b98c2..bc03d122 100644 --- a/app/features/model_selection/tests/test_routes_integration.py +++ b/app/features/model_selection/tests/test_routes_integration.py @@ -270,3 +270,134 @@ async def test_legacy_sync_run_has_no_progress_children( body = fetched.json() assert body["progress"] is None assert body["candidate_progress"] == [] + + +# --------------------------------------------------------------------- Slice C + +from uuid import uuid4 # noqa: E402 + + +async def test_decision_journey_override_predict( + client: AsyncClient, ready_pair: dict[str, Any] +) -> None: + """Sync run → train-selected (override) → predict (decision + peak/low).""" + run = await client.post("/model-selection/run", json=_run_body(ready_pair)) + assert run.status_code == 200 + body = run.json() + selection_id = body["selection_id"] + winner_type = body["winner"]["model_type"] + + # Pick a candidate that is NOT the winner to exercise the override path. + candidate_types = [c["model_type"] for c in _run_body(ready_pair)["candidate_models"]] + override_type = next(t for t in candidate_types if t != winner_type) + + trained = await client.post( + f"/model-selection/{selection_id}/train-selected", + json={"model_type": override_type, "override_reason": "domain seasonality"}, + ) + assert trained.status_code == 200 + tbody = trained.json() + assert tbody["is_override"] is True + assert tbody["override_warning"] + + predicted = await client.post( + f"/model-selection/{selection_id}/predict", + json={"lead_time_days": 7, "service_level": 0.95}, + ) + assert predicted.status_code == 200 + pbody = predicted.json() + assert pbody["forecast"]["peak_demand"] is not None + assert pbody["forecast"]["low_demand"] is not None + assert pbody["decision"]["method"] == "heuristic" + assert pbody["decision"]["lead_time_days"] == 7 + assert "safety_stock" in pbody["decision"] + assert "reorder_point" in pbody["decision"] + + +async def test_promote_creates_registry_run_and_alias_with_real_v( + client: AsyncClient, ready_pair: dict[str, Any], db_session: AsyncSession +) -> None: + """V2-configured run → train-winner → promote → SUCCESS registry run + alias. + + Asserts the registry run's runtime_info carries the REAL feature_frame_version + (2), and the selection persisted champion_run_id/promoted_alias/decision. + """ + body = _run_body(ready_pair) + body["feature_frame_version"] = 2 # baseline winner ignores V at train, column persists 2 + run = await client.post("/model-selection/run", json=body) + assert run.status_code == 200 + selection_id = run.json()["selection_id"] + + trained = await client.post(f"/model-selection/{selection_id}/train-winner") + assert trained.status_code == 200 + + alias_name = f"champ-{uuid4().hex[:8]}" + promoted = await client.post( + f"/model-selection/{selection_id}/promote", + json={"alias_name": alias_name, "approved_by": "integration", "description": "Q3"}, + ) + assert promoted.status_code == 200 + pbody = promoted.json() + assert pbody["alias_name"] == alias_name + assert pbody["run_status"] == "success" + run_id = pbody["run_id"] + + # The alias resolves to the SUCCESS run via the registry endpoint. + alias_resp = await client.get(f"/registry/aliases/{alias_name}") + assert alias_resp.status_code == 200 + assert alias_resp.json()["run_status"] == "success" + + # The registry run carries the REAL feature_frame_version (2). + run_detail = await client.get(f"/registry/runs/{run_id}") + assert run_detail.status_code == 200 + assert run_detail.json()["runtime_info"]["feature_frame_version"] == 2 + + # The selection persisted the promotion audit. + selection = await client.get(f"/model-selection/{selection_id}") + sbody = selection.json() + rows = await db_session.execute( + text( + "SELECT champion_run_id, promoted_alias, promotion_decision " + "FROM model_selection_run WHERE selection_id = :sid" + ), + {"sid": selection_id}, + ) + champion_run_id, promoted_alias, promotion_decision = rows.one() + assert champion_run_id == run_id + assert promoted_alias == alias_name + assert promotion_decision["approved_by"] == "integration" + assert promotion_decision["decision"] == "promoted" + assert sbody["status"] in {"completed", "partial"} + + +async def test_promote_before_train_returns_422( + client: AsyncClient, ready_pair: dict[str, Any] +) -> None: + run = await client.post("/model-selection/run", json=_run_body(ready_pair)) + selection_id = run.json()["selection_id"] + promoted = await client.post( + f"/model-selection/{selection_id}/promote", + json={"alias_name": f"champ-{uuid4().hex[:8]}", "approved_by": "x"}, + ) + assert promoted.status_code == 422 + assert promoted.json()["status"] == 422 + + +async def test_seven_decision_columns_exist(db_session: AsyncSession) -> None: + rows = await db_session.execute( + text( + "SELECT column_name FROM information_schema.columns " + "WHERE table_name = 'model_selection_run'" + ) + ) + cols = {row[0] for row in rows} + for col in ( + "trained_model_type", + "is_override", + "override_reason", + "champion_run_id", + "promoted_alias", + "promotion_decision", + "feature_frame_version", + ): + assert col in cols, f"missing Slice C column: {col}" diff --git a/app/features/model_selection/tests/test_schemas.py b/app/features/model_selection/tests/test_schemas.py index 87fb093d..0ad530de 100644 --- a/app/features/model_selection/tests/test_schemas.py +++ b/app/features/model_selection/tests/test_schemas.py @@ -161,3 +161,94 @@ def test_submit_run_response_carries_monitor_and_cancel_urls() -> None: assert submit.progress is not None assert submit.progress.pending == 1 assert submit.candidate_progress[0].model_type == "naive" + + +# ============================================================================= +# Slice C — decision + promotion schemas +# ============================================================================= + +from app.features.model_selection.schemas import ( # noqa: E402 + ForecastDecisionParams, + ForecastSummary, + PredictWinnerResponse, + PromoteRequest, + TrainSelectedRequest, + TrainWinnerResponse, +) + + +def test_train_selected_request_accepts_model_type() -> None: + req = TrainSelectedRequest.model_validate( + {"model_type": "seasonal_naive", "override_reason": "domain"} + ) + assert req.model_type == "seasonal_naive" + assert req.override_reason == "domain" + + +def test_train_selected_request_rejects_unknown_model_type() -> None: + with pytest.raises(ValidationError): + TrainSelectedRequest.model_validate({"model_type": "not_a_model"}) + + +def test_forecast_decision_params_defaults() -> None: + params = ForecastDecisionParams() + assert params.lead_time_days == 7 + assert params.service_level == 0.95 + + +@pytest.mark.parametrize("service_level", [0.49, 1.0, 1.5]) +def test_forecast_decision_params_rejects_out_of_bound_service_level(service_level: float) -> None: + with pytest.raises(ValidationError): + ForecastDecisionParams.model_validate({"service_level": service_level}) + + +def test_forecast_decision_params_validate_python_path() -> None: + """Exercise the validate_python path (matches FastAPI's body coercion).""" + params = ForecastDecisionParams.model_validate({"lead_time_days": 14, "service_level": 0.99}) + assert params.lead_time_days == 14 + + +@pytest.mark.parametrize("alias", ["Bad Alias", "UPPER", "-leading", "has space"]) +def test_promote_request_rejects_bad_alias_name(alias: str) -> None: + with pytest.raises(ValidationError): + PromoteRequest.model_validate({"alias_name": alias, "approved_by": "gabor"}) + + +def test_promote_request_accepts_valid_alias_and_defaults() -> None: + req = PromoteRequest.model_validate({"alias_name": "champion-store5", "approved_by": "gabor"}) + assert req.alias_name == "champion-store5" + assert req.acknowledge_non_recommended is False + assert req.description is None + + +def test_promote_request_requires_approved_by() -> None: + with pytest.raises(ValidationError): + PromoteRequest.model_validate({"alias_name": "champion-x"}) + + +def test_train_winner_response_back_compat_defaults() -> None: + """train-winner callers that omit the Slice C fields still validate.""" + resp = TrainWinnerResponse.model_validate( + {"selection_id": "s", "model_type": "naive", "model_path": "p"} + ) + assert resp.is_override is False + assert resp.override_warning is None + + +def test_forecast_summary_peak_low_optional() -> None: + """Legacy ForecastSummary (no peak/low) still validates.""" + summary = ForecastSummary.model_validate( + {"points": [], "total_demand": 0.0, "average_demand": 0.0, "horizon": 14} + ) + assert summary.peak_date is None + assert summary.peak_demand is None + + +def test_predict_winner_response_decision_optional() -> None: + resp = PredictWinnerResponse.model_validate( + { + "selection_id": "s", + "forecast": {"points": [], "total_demand": 0.0, "average_demand": 0.0, "horizon": 14}, + } + ) + assert resp.decision is None diff --git a/app/features/model_selection/tests/test_service.py b/app/features/model_selection/tests/test_service.py index 2fd3002e..c60080b1 100644 --- a/app/features/model_selection/tests/test_service.py +++ b/app/features/model_selection/tests/test_service.py @@ -365,3 +365,409 @@ async def test_cancel_run_409_when_settle_races_cancel( monkeypatch.setattr(_runner, "cancel_selection", lambda _sid: False) with pytest.raises(ConflictError): await ModelSelectionService().cancel_run(db, "sel_race") + + +# ============================================================================= +# Slice C — train-selected (override) / predict-decision / promote +# ============================================================================= + +from app.core.exceptions import UnprocessableEntityError # noqa: E402 +from app.features.forecasting.schemas import ForecastPoint # noqa: E402 +from app.features.model_selection.schemas import PromoteRequest # noqa: E402 + + +def _ranking_dict( + *, + winner_type: str = "naive", + extra_included: list[dict[str, Any]] | None = None, +) -> dict[str, Any]: + """A persisted ``ranking_result`` JSONB with a winner + ranked entries.""" + winner_metrics = {"wape": 10.0, "smape": 8.0, "mae": 4.0, "bias": 0.5} + entries: list[dict[str, Any]] = [ + { + "rank": 1, + "model_type": winner_type, + "params": {}, + "included": True, + "exclusion_reason": None, + "metrics": winner_metrics, + } + ] + if extra_included: + entries.extend(extra_included) + return { + "winner": { + "rank": 1, + "model_type": winner_type, + "params": {}, + "included": True, + "exclusion_reason": None, + "metrics": winner_metrics, + }, + "entries": entries, + "confidence": "high", + "reasons": [], + } + + +def _decision_row( + *, + candidate_models: list[dict[str, Any]] | None = None, + ranking_result: dict[str, Any] | None = None, + feature_frame_version: int = 1, + final_model_path: str | None = None, + trained_model_type: str | None = None, + is_override: bool = False, + winner_metrics: dict[str, Any] | None = None, +) -> ModelSelectionRun: + """Build an in-memory ModelSelectionRun for decision-layer unit tests.""" + return ModelSelectionRun( + selection_id="sel_decision", + status=ModelSelectionStatus.COMPLETED.value, + store_id=3, + product_id=8, + start_date=date(2026, 1, 1), + end_date=date(2026, 5, 31), + forecast_horizon=14, + ranking_metric="wape", + candidate_models=candidate_models or [{"model_type": "naive", "params": {}}], + policy_snapshot={}, + ranking_result=ranking_result, + feature_frame_version=feature_frame_version, + final_model_path=final_model_path, + trained_model_type=trained_model_type, + is_override=is_override, + winner_metrics=winner_metrics, + ) + + +def _row_db(row: ModelSelectionRun) -> AsyncMock: + db = AsyncMock() + db.scalar = AsyncMock(return_value=row) + db.flush = AsyncMock() + return db + + +def _patch_train(monkeypatch: pytest.MonkeyPatch) -> AsyncMock: + train_mock = AsyncMock( + return_value=SimpleNamespace(model_path="artifacts/models/model_sel.joblib") + ) + monkeypatch.setattr( + "app.features.forecasting.service.ForecastingService", + lambda: SimpleNamespace(train_model=train_mock), + ) + return train_mock + + +async def test_train_selected_trains_chosen_candidate(monkeypatch: pytest.MonkeyPatch) -> None: + train_mock = _patch_train(monkeypatch) + row = _decision_row( + candidate_models=[ + {"model_type": "naive", "params": {}}, + {"model_type": "seasonal_naive", "params": {"season_length": 7}}, + ], + ranking_result=_ranking_dict(winner_type="naive"), + ) + resp = await ModelSelectionService().train_selected(_row_db(row), "sel_decision", "naive", None) + assert resp.model_type == "naive" + assert resp.is_override is False + assert resp.override_warning is None + assert row.trained_model_type == "naive" + assert row.is_override is False + train_mock.assert_awaited_once() + + +async def test_train_selected_rejects_non_candidate_model_type_400( + monkeypatch: pytest.MonkeyPatch, +) -> None: + train_mock = _patch_train(monkeypatch) + row = _decision_row( + candidate_models=[{"model_type": "naive", "params": {}}], + ranking_result=_ranking_dict(winner_type="naive"), + ) + with pytest.raises(BadRequestError): + await ModelSelectionService().train_selected(_row_db(row), "sel_decision", "lightgbm", None) + train_mock.assert_not_awaited() + + +async def test_train_selected_sets_is_override_and_warning_for_non_winner( + monkeypatch: pytest.MonkeyPatch, +) -> None: + _patch_train(monkeypatch) + row = _decision_row( + candidate_models=[ + {"model_type": "naive", "params": {}}, + {"model_type": "seasonal_naive", "params": {"season_length": 7}}, + ], + ranking_result=_ranking_dict( + winner_type="naive", + extra_included=[ + { + "rank": 2, + "model_type": "seasonal_naive", + "params": {"season_length": 7}, + "included": True, + "exclusion_reason": None, + "metrics": {"wape": 15.0, "smape": 9.0, "mae": 5.0, "bias": 0.2}, + } + ], + ), + ) + resp = await ModelSelectionService().train_selected( + _row_db(row), "sel_decision", "seasonal_naive", "domain seasonality" + ) + assert resp.is_override is True + assert resp.override_warning is not None + assert "seasonal_naive" in resp.override_warning + assert "naive" in resp.override_warning + assert row.override_reason == "domain seasonality" + + +async def test_train_selected_failed_candidate_still_trainable( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A candidate that FAILED its backtest (no ranked metrics) stays trainable.""" + _patch_train(monkeypatch) + row = _decision_row( + candidate_models=[ + {"model_type": "naive", "params": {}}, + {"model_type": "moving_average", "params": {}}, + ], + # moving_average failed its backtest → not in ranking entries. + ranking_result=_ranking_dict(winner_type="naive"), + ) + resp = await ModelSelectionService().train_selected( + _row_db(row), "sel_decision", "moving_average", None + ) + assert resp.is_override is True + assert resp.override_warning is not None + assert "not successfully evaluated" in resp.override_warning + + +async def test_train_selected_threads_feature_frame_version_into_train_model( + monkeypatch: pytest.MonkeyPatch, +) -> None: + train_mock = _patch_train(monkeypatch) + row = _decision_row( + candidate_models=[{"model_type": "prophet_like", "params": {}}], + ranking_result=_ranking_dict(winner_type="prophet_like"), + feature_frame_version=2, + ) + await ModelSelectionService().train_selected(_row_db(row), "sel_decision", "prophet_like", None) + assert train_mock.await_args is not None + assert train_mock.await_args.kwargs["feature_frame_version"] == 2 + + +async def test_train_winner_now_persists_trained_model_type_not_override( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Regression — train-winner persists trained_model_type, is_override=False.""" + train_mock = _patch_train(monkeypatch) + row = _decision_row( + candidate_models=[{"model_type": "naive", "params": {}}], + ranking_result=_ranking_dict(winner_type="naive"), + feature_frame_version=2, + ) + resp = await ModelSelectionService().train_winner(_row_db(row), "sel_decision") + assert resp.model_type == "naive" + assert resp.is_override is False + assert resp.override_warning is None + assert row.trained_model_type == "naive" + assert row.is_override is False + assert train_mock.await_args is not None + assert train_mock.await_args.kwargs["feature_frame_version"] == 2 + + +def _predict_points() -> list[ForecastPoint]: + base = date(2026, 6, 1) + values = [10.0, 25.0, 8.0, 12.0] + return [ + ForecastPoint(date=base.fromordinal(base.toordinal() + i), forecast=v) + for i, v in enumerate(values) + ] + + +async def test_predict_attaches_decision_and_peak_low(monkeypatch: pytest.MonkeyPatch) -> None: + predict_mock = AsyncMock(return_value=SimpleNamespace(forecasts=_predict_points())) + monkeypatch.setattr( + "app.features.forecasting.service.ForecastingService", + lambda: SimpleNamespace(predict=predict_mock), + ) + row = _decision_row( + final_model_path="artifacts/models/model_sel.joblib", + trained_model_type="naive", + winner_metrics={"wape": 10.0, "bias": 0.5}, + ) + forecast, decision = await ModelSelectionService().predict_winner( + _row_db(row), "sel_decision", 7, 0.95 + ) + assert decision is not None + assert decision.lead_time_days == 7 + assert decision.method == "heuristic" + assert forecast.peak_demand == 25.0 + assert forecast.low_demand == 8.0 + assert forecast.peak_date == date(2026, 6, 2) + + +async def test_predict_winner_untrained_raises_400() -> None: + row = _decision_row(final_model_path=None) + with pytest.raises(BadRequestError): + await ModelSelectionService().predict_winner(_row_db(row), "sel_decision", 7, 0.95) + + +def _patch_registry(monkeypatch: pytest.MonkeyPatch) -> dict[str, AsyncMock]: + from app.features.registry.schemas import RunStatus + + run_resp = SimpleNamespace(run_id="run_abc123def456") + alias_resp = SimpleNamespace(alias_name="champion-test", run_status=RunStatus.SUCCESS) + create_run = AsyncMock(return_value=run_resp) + update_run = AsyncMock(return_value=run_resp) + create_alias = AsyncMock(return_value=alias_resp) + monkeypatch.setattr( + "app.features.registry.service.RegistryService", + lambda: SimpleNamespace( + create_run=create_run, update_run=update_run, create_alias=create_alias + ), + ) + monkeypatch.setattr( + ModelSelectionService, + "_register_artifact", + staticmethod(lambda final_model_path, run_id: ("champion-selector/x.joblib", "h", 100)), + ) + return {"create_run": create_run, "update_run": update_run, "create_alias": create_alias} + + +async def test_promote_orchestrates_create_run_success_and_alias( + monkeypatch: pytest.MonkeyPatch, +) -> None: + mocks = _patch_registry(monkeypatch) + row = _decision_row( + final_model_path="artifacts/models/model_sel.joblib", + trained_model_type="naive", + is_override=False, + winner_metrics={"wape": 10.0}, + feature_frame_version=1, + ) + req = PromoteRequest(alias_name="champion-test", approved_by="gabor") + resp = await ModelSelectionService().promote(_row_db(row), "sel_decision", req) + assert resp.run_id == "run_abc123def456" + assert resp.run_status == "success" + assert resp.alias_name == "champion-test" + mocks["create_run"].assert_awaited_once() + mocks["create_alias"].assert_awaited_once() + # two update_run calls: RUNNING then SUCCESS + assert mocks["update_run"].await_count == 2 + + +async def test_promote_persists_promotion_decision_audit( + monkeypatch: pytest.MonkeyPatch, +) -> None: + _patch_registry(monkeypatch) + row = _decision_row( + final_model_path="artifacts/models/model_sel.joblib", + trained_model_type="naive", + winner_metrics={"wape": 10.0}, + ) + req = PromoteRequest(alias_name="champion-test", approved_by="gabor", description="Q3") + await ModelSelectionService().promote(_row_db(row), "sel_decision", req) + assert row.champion_run_id == "run_abc123def456" + assert row.promoted_alias == "champion-test" + assert row.promotion_decision is not None + assert row.promotion_decision["approved_by"] == "gabor" + assert row.promotion_decision["decision"] == "promoted" + assert row.promotion_decision["reason"] == "Q3" + + +async def test_promote_carries_real_feature_frame_version_v2( + monkeypatch: pytest.MonkeyPatch, +) -> None: + mocks = _patch_registry(monkeypatch) + row = _decision_row( + final_model_path="artifacts/models/model_sel.joblib", + trained_model_type="prophet_like", + winner_metrics={"wape": 10.0}, + feature_frame_version=2, + ) + req = PromoteRequest(alias_name="champion-v2", approved_by="gabor") + await ModelSelectionService().promote(_row_db(row), "sel_decision", req) + assert mocks["create_run"].await_args is not None + run_create = mocks["create_run"].await_args.args[1] + assert run_create.runtime_info_extras["feature_frame_version"] == 2 + + +async def test_promote_defaults_feature_frame_version_1_for_legacy_run( + monkeypatch: pytest.MonkeyPatch, +) -> None: + mocks = _patch_registry(monkeypatch) + row = _decision_row( + final_model_path="artifacts/models/model_sel.joblib", + trained_model_type="naive", + winner_metrics={"wape": 10.0}, + feature_frame_version=1, # legacy / server_default + ) + req = PromoteRequest(alias_name="champion-legacy", approved_by="gabor") + await ModelSelectionService().promote(_row_db(row), "sel_decision", req) + assert mocks["create_run"].await_args is not None + run_create = mocks["create_run"].await_args.args[1] + assert run_create.runtime_info_extras["feature_frame_version"] == 1 + + +async def test_promote_requires_trained_model_422(monkeypatch: pytest.MonkeyPatch) -> None: + _patch_registry(monkeypatch) + row = _decision_row(final_model_path=None, trained_model_type=None) + req = PromoteRequest(alias_name="champion-test", approved_by="gabor") + with pytest.raises(UnprocessableEntityError): + await ModelSelectionService().promote(_row_db(row), "sel_decision", req) + + +async def test_promote_non_recommended_requires_ack_422(monkeypatch: pytest.MonkeyPatch) -> None: + _patch_registry(monkeypatch) + row = _decision_row( + final_model_path="artifacts/models/model_sel.joblib", + trained_model_type="seasonal_naive", + is_override=True, + winner_metrics={"wape": 10.0}, + ) + req = PromoteRequest( + alias_name="champion-test", approved_by="gabor", acknowledge_non_recommended=False + ) + with pytest.raises(UnprocessableEntityError): + await ModelSelectionService().promote(_row_db(row), "sel_decision", req) + + +def _capturing_run_db() -> AsyncMock: + db = AsyncMock() + rows: list[Any] = [] + db.add = MagicMock(side_effect=lambda o: rows.append(o)) + + async def _flush() -> None: + for obj in rows: + if isinstance(obj, ModelSelectionRun) and obj.created_at is None: + obj.created_at = datetime.now(UTC) + + db.flush = AsyncMock(side_effect=_flush) + db.refresh = AsyncMock() + db._rows = rows + return db + + +async def test_run_creation_persists_request_feature_frame_version_sync( + monkeypatch: pytest.MonkeyPatch, +) -> None: + _patch_availability(monkeypatch, "ready") + _patch_backtester(monkeypatch, side_effect=[make_backtest_response(wape=10.0)]) + db = _capturing_run_db() + await ModelSelectionService().run_selection(db, _request(feature_frame_version=2)) + runs = [r for r in db._rows if isinstance(r, ModelSelectionRun)] + assert runs[0].feature_frame_version == 2 + + +async def test_run_creation_persists_request_feature_frame_version_async( + monkeypatch: pytest.MonkeyPatch, +) -> None: + _patch_availability(monkeypatch, "ready") + monkeypatch.setattr(ModelSelectionService, "_run_in_background", AsyncMock()) + db = _submit_mock_db() + await ModelSelectionService().submit_run(db, _request(feature_frame_version=2)) + runs = [o for o in db._added if isinstance(o, ModelSelectionRun)] + assert runs[0].feature_frame_version == 2 diff --git a/docs/user-guide/champion-selector-guide.md b/docs/user-guide/champion-selector-guide.md new file mode 100644 index 00000000..1bef1afb --- /dev/null +++ b/docs/user-guide/champion-selector-guide.md @@ -0,0 +1,126 @@ +# Champion Selector Guide + +The **Champion Selector** turns "which forecasting model is best for this +store + product?" into a guided, end-to-end workflow: compare candidate models +on a leakage-safe backtest, read a recommendation, **decide** (accept it or +override), train the chosen model, generate and interpret its forecast, and — +only with explicit approval — **promote** it to a registry alias. + +It lives at **`/visualize/champion`** in the dashboard and is served by the +`/model-selection/*` REST API (Swagger at **/docs** is the authoritative +contract). + +> **The golden rule of promotion:** the app *recommends* a champion, but a +> human *approves* it, and that decision is **recorded**. Promotion is never +> automatic. + +## The journey at a glance + +``` +Select → Run comparison → Results → Decide / override → Train → Forecast → Interpret → Promote +``` + +### 1 · Select & check availability + +Pick a store, a product, a time period, a forecast horizon (1–90 days), and the +candidate models to compare. The page checks **data availability** for the pair +and recommends a cross-validation split. A pair with too little history is +flagged *unusable* and the comparison is refused (`400`). + +### 2 · Run the comparison + +`POST /model-selection/runs` submits an asynchronous run (returns `202` with a +monitor URL); the page polls it to a terminal state. Each candidate is +backtested with time-series cross-validation; results are ranked deterministically. + +**Ranking** is by **WAPE** by default, with a fixed tie-break chain: +*WAPE, then sMAPE, then |bias|, then MAE.* The winner, runners-up, and any +failed candidates are all shown. + +### 3 · Decide — accept or override + +The recommended winner is pre-selected. You can: + +- **Accept the recommendation** → trains the ranked winner. +- **Override to another candidate** → you must confirm an explicit warning (the + recommended model and the WAPE gap are named) and may record a reason. The + override is flagged (`is_override=true`) and audited. A candidate that *failed* + its backtest is still override-trainable (training is independent of backtesting). + +`POST /model-selection/{id}/train-selected` trains the chosen model; +`train-winner` trains the recommendation. + +### 4 · Forecast + +`POST /model-selection/{id}/predict` generates the horizon forecast for the +trained model. The response carries the **peak** and **low** demand days plus a +**decision** block (see below). + +> **Capability limit.** A *feature-aware* model (`regression`, `prophet_like`, +> `lightgbm`, `xgboost`, `random_forest`) cannot auto-forecast here — it needs a +> future feature frame. The page shows a blocked state and routes you to the +> **What-If Planner** (Scenarios) instead of faking a forecast. + +### 5 · Interpret + +The **business interpretation** panel restates *why the model won*, the +**expected demand over the lead time**, and the **bias risk**: + +> Positive bias means the model under-forecasts (risk of stockouts); negative +> bias means it over-forecasts (risk of overstock). + +The **safety stock** panel shows a clearly-labeled, deterministic heuristic: + +``` +safety_stock = z(service_level) · σ_daily · √(lead_time_days) +expected_demand = average_demand · lead_time_days +reorder_point = expected_demand + safety_stock +``` + +`σ_daily` is the standard deviation of the daily forecast; `z` comes from a fixed +service-level table (90% → 1.2816, 95% → 1.6449, 97.5% → 1.9600, 99% → 2.3263), +snapping to the nearest level in between. Adjust the lead time / service level and +recompute. + +> **This is a heuristic** (demand variability only, constant lead time) — not a +> full inventory-optimisation model, and it **never** influences the model +> ranking. + +### 6 · Promote (approval-gated, audited) + +`POST /model-selection/{id}/promote` registers the trained model as a registry +`model_run` (transitioned to **SUCCESS** with a verified artifact) and points a +**registry alias** at it. It records a `promotion_decision` audit +(`approved_by`, the alias, the run id, the decision, the reason, and whether it +was an override). + +Promotion requires: + +- a valid **alias name** (`^[a-z0-9][a-z0-9\-_]*$`) — a bad name is rejected `422`; +- an **approver** (`approved_by`) — promotion is never anonymous; +- for a **non-recommended** (override) model, an explicit + `acknowledge_non_recommended=true` — else `422`; +- a **trained** model first — promoting before training is `422`. + +Re-promoting the same alias name repoints the existing alias (registry upsert +semantics). **Compare and promote stay separate** — promote performs no +ranking or comparison; it only registers and aliases the already-trained champion. + +## Endpoint reference + +| Method | Path | Purpose | +|--------|------|---------| +| POST | `/model-selection/runs` | Submit an async comparison (202) | +| GET | `/model-selection/{id}` | Poll progress / fetch terminal results | +| POST | `/model-selection/{id}/train-winner` | Train the ranked winner | +| POST | `/model-selection/{id}/train-selected` | Train a chosen candidate (override) | +| POST | `/model-selection/{id}/predict` | Forecast + inventory decision | +| POST | `/model-selection/{id}/promote` | Promote to a registry alias (audited) | + +## Notes & caveats + +- Backtest accuracy reflects historical fit, not a guarantee of future + performance; metrics measure correlation with past demand, not causation. +- The decision layer is **deterministic** — no LLM is involved. +- V2 (richer feature frame) runs promote as V2: the registry run records the + real `feature_frame_version`. diff --git a/docs/user-guide/feature-reference.md b/docs/user-guide/feature-reference.md index 71e2d21b..521f795b 100644 --- a/docs/user-guide/feature-reference.md +++ b/docs/user-guide/feature-reference.md @@ -133,6 +133,23 @@ Tracks every trained model so runs are reproducible and comparable. A run moves through `pending → running → success` (or `failed`), and an alias is a human-friendly pointer (like `production` or `champion`) to a chosen successful run. +## Champion Selector + +An end-to-end "which model is best, and now what?" workflow over one (store, +product) pair: compare candidate models, accept or override the recommendation, +train, forecast, interpret, and promote to a registry alias. + +- `POST /model-selection/runs` — submit an async candidate comparison (`202`). +- `GET /model-selection/{id}` — poll progress / fetch the ranked results + winner. +- `POST /model-selection/{id}/train-winner` — train the ranked winner. +- `POST /model-selection/{id}/train-selected` — train a chosen candidate (override + audit). +- `POST /model-selection/{id}/predict` — forecast the trained model + a labeled + safety-stock decision heuristic (feature-aware models are blocked → use Scenarios). +- `POST /model-selection/{id}/promote` — approval-gated, audited promotion to a + registry alias (requires an approver; a non-recommended model needs an explicit ack). + +See the full walkthrough in **[champion-selector-guide.md](./champion-selector-guide.md)**. + ## Jobs Long-running work — training, prediction, backtesting — submitted as jobs. diff --git a/frontend/src/components/champion-selector/decision/business-interpretation-panel.test.tsx b/frontend/src/components/champion-selector/decision/business-interpretation-panel.test.tsx new file mode 100644 index 00000000..27fd9c50 --- /dev/null +++ b/frontend/src/components/champion-selector/decision/business-interpretation-panel.test.tsx @@ -0,0 +1,47 @@ +import { afterEach, describe, expect, it } from 'vitest' +import { cleanup, render, screen } from '@testing-library/react' +import { BusinessInterpretationPanel } from './business-interpretation-panel' +import type { ForecastDecision } from '@/types/api' + +afterEach(cleanup) + +const decision: ForecastDecision = { + method: 'heuristic', + lead_time_days: 7, + service_level: 0.95, + z_value: 1.6449, + sigma_daily_demand: 1.4, + expected_demand_over_lead_time: 70, + safety_stock: 6.1, + reorder_point: 76.1, + bias_risk_text: 'Positive bias means the model under-forecasts (risk of stockouts).', + caveats: ['Safety stock is a deterministic heuristic.'], +} + +const businessSummary = { + headline: 'Recommended model: naive (high confidence).', + winner: { model_type: 'naive', summary: 'WAPE 10.0%' }, + comparison: { lead_text: '15% lower WAPE than the runner-up' }, + data_notes: ['Observed 120 of 120 calendar days.'], +} + +describe('BusinessInterpretationPanel', () => { + it('renders the headline, expected demand, and bias risk', () => { + render( + , + ) + const text = screen.getByTestId('business-interpretation-panel').textContent ?? '' + expect(text).toContain('Recommended model: naive') + expect(screen.getByTestId('business-expected-demand').textContent).toContain('70.0') + expect(screen.getByTestId('business-bias-risk').textContent).toContain( + 'under-forecasts', + ) + }) + + it('falls back to the bias explanation when no decision is present', () => { + render() + expect( + screen.getByText(/Positive bias means the model under-forecasts/), + ).toBeTruthy() + }) +}) diff --git a/frontend/src/components/champion-selector/decision/business-interpretation-panel.tsx b/frontend/src/components/champion-selector/decision/business-interpretation-panel.tsx new file mode 100644 index 00000000..5cb9d1d0 --- /dev/null +++ b/frontend/src/components/champion-selector/decision/business-interpretation-panel.tsx @@ -0,0 +1,84 @@ +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card' +import { BIAS_EXPLANATION } from '@/components/champion-selector/copy' +import type { ForecastDecision } from '@/types/api' + +interface BusinessInterpretationPanelProps { + /** The deterministic backend `business_summary` (read-only). */ + businessSummary: Record | null + /** The decision heuristic (carries bias-risk text + expected demand). */ + decision: ForecastDecision | null +} + +function str(value: unknown): string | null { + return typeof value === 'string' ? value : null +} + +/** + * Slice C — business interpretation. Renders the SAME `business_summary` the + * backend computed (read-only — Slice B's winner card owns the headline) and + * ADDS the decision-layer fields (expected demand + bias risk + caveats). + */ +export function BusinessInterpretationPanel({ + businessSummary, + decision, +}: BusinessInterpretationPanelProps) { + const headline = str(businessSummary?.['headline']) + const winner = businessSummary?.['winner'] as Record | null | undefined + const winnerSummary = str(winner?.['summary']) + const comparison = businessSummary?.['comparison'] as Record | null | undefined + const leadText = str(comparison?.['lead_text']) + const dataNotes = Array.isArray(businessSummary?.['data_notes']) + ? (businessSummary?.['data_notes'] as unknown[]).filter((x): x is string => typeof x === 'string') + : [] + + return ( + + + Business interpretation + {headline && {headline}} + + + {winnerSummary && ( +

+ Why it won: + {winnerSummary} + {leadText ? ` — ${leadText}.` : '.'} +

+ )} + + {decision && ( +
+

+ Expected demand over lead time: + {decision.expected_demand_over_lead_time.toFixed(1)} units ( + {decision.lead_time_days} days). +

+

+ {decision.bias_risk_text} +

+
+ )} + + {!decision && ( +

{BIAS_EXPLANATION}

+ )} + + {dataNotes.length > 0 && ( +
    + {dataNotes.map((note, i) => ( +
  • {note}
  • + ))} +
+ )} + + {decision?.caveats?.length ? ( +
    + {decision.caveats.map((caveat, i) => ( +
  • {caveat}
  • + ))} +
+ ) : null} +
+
+ ) +} diff --git a/frontend/src/components/champion-selector/decision/constants.ts b/frontend/src/components/champion-selector/decision/constants.ts new file mode 100644 index 00000000..1005d481 --- /dev/null +++ b/frontend/src/components/champion-selector/decision/constants.ts @@ -0,0 +1,19 @@ +/** + * Non-component constants for the Slice C decision panels. Kept in a `.ts` + * module so `react-refresh/only-export-components` never trips on them. + */ + +/** Service levels the safety-stock z-table supports exactly (others snap nearest). */ +export const SERVICE_LEVEL_OPTIONS = [0.9, 0.95, 0.975, 0.99] as const + +/** Capability-limited blocked state for a feature-aware winner (LOCKED #5). */ +export const FEATURE_AWARE_BLOCKED_COPY = + 'Forecast not available for feature-aware models — use the What-If Planner ' + + '(Scenarios) to forecast through explicit assumptions.' + +/** The promotion-is-audited note shown in the promote dialog. */ +export const PROMOTE_AUDIT_NOTE = + 'Promotion is explicit and recorded — the approver and decision are saved as ' + + 'an audit record on this run. It is never automatic.' + +export const SAFETY_STOCK_HEADER = 'Safety stock (heuristic)' diff --git a/frontend/src/components/champion-selector/decision/daily-forecast-table.test.tsx b/frontend/src/components/champion-selector/decision/daily-forecast-table.test.tsx new file mode 100644 index 00000000..c4d2a73b --- /dev/null +++ b/frontend/src/components/champion-selector/decision/daily-forecast-table.test.tsx @@ -0,0 +1,26 @@ +import { afterEach, describe, expect, it } from 'vitest' +import { cleanup, render, screen } from '@testing-library/react' +import { DailyForecastTable } from './daily-forecast-table' +import type { ModelSelectionForecastSummary } from '@/types/api' + +afterEach(cleanup) + +const forecast: ModelSelectionForecastSummary = { + points: [ + { date: '2026-06-01', forecast: 10.5, lower_bound: 8, upper_bound: 12 }, + { date: '2026-06-02', forecast: 14.2, lower_bound: null, upper_bound: null }, + ], + total_demand: 24.7, + average_demand: 12.35, + horizon: 2, +} + +describe('DailyForecastTable', () => { + it('renders one row per forecast point with the forecast value', () => { + render() + const text = screen.getByTestId('daily-forecast-table').textContent ?? '' + expect(text).toContain('2026-06-01') + expect(text).toContain('10.50') + expect(text).toContain('14.20') + }) +}) diff --git a/frontend/src/components/champion-selector/decision/daily-forecast-table.tsx b/frontend/src/components/champion-selector/decision/daily-forecast-table.tsx new file mode 100644 index 00000000..96404c8f --- /dev/null +++ b/frontend/src/components/champion-selector/decision/daily-forecast-table.tsx @@ -0,0 +1,57 @@ +import { + Table, + TableBody, + TableCell, + TableHead, + TableHeader, + TableRow, +} from '@/components/ui/table' +import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card' +import type { ModelSelectionForecastSummary } from '@/types/api' + +interface DailyForecastTableProps { + forecast: ModelSelectionForecastSummary +} + +function cell(value: unknown): string { + return typeof value === 'number' && Number.isFinite(value) ? value.toFixed(2) : '—' +} + +/** Slice C — the per-day forecast table (date, forecast, lower, upper). */ +export function DailyForecastTable({ forecast }: DailyForecastTableProps) { + return ( + + + Daily forecast + + + + + + Date + Forecast + Lower + Upper + + + + {forecast.points.map((point, index) => ( + + {String(point['date'] ?? '—')} + + {cell(point['forecast'])} + + + {cell(point['lower_bound'])} + + + {cell(point['upper_bound'])} + + + ))} + +
+
+
+ ) +} diff --git a/frontend/src/components/champion-selector/decision/decision-section.tsx b/frontend/src/components/champion-selector/decision/decision-section.tsx new file mode 100644 index 00000000..c1dbd95c --- /dev/null +++ b/frontend/src/components/champion-selector/decision/decision-section.tsx @@ -0,0 +1,180 @@ +import { useMemo, useState } from 'react' +import { Button } from '@/components/ui/button' +import { Card, CardContent } from '@/components/ui/card' +import { getErrorMessage } from '@/lib/api' +import { + usePredictWinner, + usePromoteChampion, + useTrainSelected, + useTrainWinner, +} from '@/hooks/use-model-selection' +import type { + ModelCatalogResponse, + ModelSelectionRunResponse, + PredictWinnerResponse, + TrainWinnerResponse, +} from '@/types/api' +import { WinnerDecisionPanel } from './winner-decision-panel' +import { TrainForecastActions } from './train-forecast-actions' +import { ForecastSummaryCard } from './forecast-summary-card' +import { ForecastChart } from './forecast-chart' +import { DailyForecastTable } from './daily-forecast-table' +import { BusinessInterpretationPanel } from './business-interpretation-panel' +import { SafetyStockPanel } from './safety-stock-panel' +import { PromoteChampionDialog } from './promote-champion-dialog' + +interface DecisionSectionProps { + selectionId: string + run: ModelSelectionRunResponse + catalog: ModelCatalogResponse | undefined +} + +/** + * Slice C — the decision section rendered below a terminal winning run. + * + * Owns the train / predict / promote mutations (so the page keeps its hooks + * unconditional). Mount it with `key={selectionId}` so a fresh run resets the + * train/forecast/promote state. + */ +export function DecisionSection({ selectionId, run, catalog }: DecisionSectionProps) { + const winnerModelType = run.winner?.model_type ?? null + + const [trainResult, setTrainResult] = useState(null) + const [predictResult, setPredictResult] = useState(null) + const [leadTimeDays, setLeadTimeDays] = useState(7) + const [serviceLevel, setServiceLevel] = useState(0.95) + const [promoteOpen, setPromoteOpen] = useState(false) + const [promoteError, setPromoteError] = useState(null) + const [promotedAlias, setPromotedAlias] = useState(null) + + const trainWinner = useTrainWinner(selectionId) + const trainSelected = useTrainSelected(selectionId) + const predict = usePredictWinner(selectionId) + const promote = usePromoteChampion(selectionId) + + // Every candidate the run offered (winner + runners-up + failed), de-duped. + const candidateModelTypes = useMemo(() => { + const seen = new Set() + for (const entry of run.ranking) seen.add(entry.model_type) + if (winnerModelType) seen.add(winnerModelType) + return [...seen] + }, [run.ranking, winnerModelType]) + + // Capability of the model that WILL be (or was) trained — drives the blocked + // forecast state for a feature-aware winner (LOCKED #5). + const activeModelType = trainResult?.model_type ?? winnerModelType + const supportsAutoPredict = useMemo(() => { + const info = catalog?.models.find((m) => m.model_type === activeModelType) + return info?.supports_auto_predict ?? true + }, [catalog, activeModelType]) + + const trained = trainResult !== null || run.final_model !== null + + if (winnerModelType === null) return null + + function handleTrain(modelType: string, overrideReason: string | null) { + setPredictResult(null) + setPromotedAlias(null) + const onSuccess = (data: TrainWinnerResponse) => setTrainResult(data) + if (modelType === winnerModelType) { + trainWinner.mutate(undefined, { onSuccess }) + } else { + trainSelected.mutate({ model_type: modelType, override_reason: overrideReason }, { onSuccess }) + } + } + + function handleForecast() { + predict.mutate( + { lead_time_days: leadTimeDays, service_level: serviceLevel }, + { onSuccess: (data) => setPredictResult(data) }, + ) + } + + function handlePromote(body: Parameters[0]) { + setPromoteError(null) + promote.mutate(body, { + onSuccess: (data) => setPromotedAlias(data.alias_name), + onError: (err) => setPromoteError(getErrorMessage(err)), + }) + } + + const forecast = predictResult?.forecast ?? null + const decision = predictResult?.decision ?? null + const isOverride = trainResult?.is_override ?? false + + return ( +
+ + + + + + {predict.isError && ( +

+ {getErrorMessage(predict.error)} +

+ )} +
+
+ + {forecast && ( + <> + + + + + + + )} + + {trained && ( + + +

+ Promote the trained champion to a registry alias (approval-gated). +

+ +
+
+ )} + + +
+ ) +} diff --git a/frontend/src/components/champion-selector/decision/forecast-chart.test.tsx b/frontend/src/components/champion-selector/decision/forecast-chart.test.tsx new file mode 100644 index 00000000..c28c6726 --- /dev/null +++ b/frontend/src/components/champion-selector/decision/forecast-chart.test.tsx @@ -0,0 +1,33 @@ +import { afterEach, beforeAll, describe, expect, it, vi } from 'vitest' +import { cleanup, render, screen } from '@testing-library/react' +import { ForecastChart } from './forecast-chart' +import type { ModelSelectionForecastSummary } from '@/types/api' + +// Recharts' ResponsiveContainer needs ResizeObserver in jsdom. +beforeAll(() => { + class ResizeObserverStub { + observe() {} + unobserve() {} + disconnect() {} + } + vi.stubGlobal('ResizeObserver', ResizeObserverStub) +}) + +afterEach(cleanup) + +const forecast: ModelSelectionForecastSummary = { + points: [ + { date: '2026-06-01', forecast: 10, lower_bound: 8, upper_bound: 12 }, + { date: '2026-06-02', forecast: 14, lower_bound: 11, upper_bound: 17 }, + ], + total_demand: 24, + average_demand: 12, + horizon: 2, +} + +describe('ForecastChart', () => { + it('renders the chart container from forecast points', () => { + render() + expect(screen.getByTestId('forecast-chart')).toBeTruthy() + }) +}) diff --git a/frontend/src/components/champion-selector/decision/forecast-chart.tsx b/frontend/src/components/champion-selector/decision/forecast-chart.tsx new file mode 100644 index 00000000..fccd54b8 --- /dev/null +++ b/frontend/src/components/champion-selector/decision/forecast-chart.tsx @@ -0,0 +1,43 @@ +import { TimeSeriesChart } from '@/components/charts/time-series-chart' +import type { ModelSelectionForecastSummary } from '@/types/api' + +interface ForecastChartProps { + forecast: ModelSelectionForecastSummary +} + +interface ChartRow { + date: string + forecast: number + lower?: number + upper?: number +} + +/** Slice C — the horizon forecast curve (optional interval band). */ +export function ForecastChart({ forecast }: ForecastChartProps) { + const rows: ChartRow[] = forecast.points.map((point) => { + const lower = point['lower_bound'] + const upper = point['upper_bound'] + return { + date: String(point['date'] ?? ''), + forecast: Number(point['forecast'] ?? 0), + lower: typeof lower === 'number' ? lower : undefined, + upper: typeof upper === 'number' ? upper : undefined, + } + }) + const hasInterval = rows.some((row) => row.lower !== undefined && row.upper !== undefined) + + return ( +
+ +
+ ) +} diff --git a/frontend/src/components/champion-selector/decision/forecast-summary-card.test.tsx b/frontend/src/components/champion-selector/decision/forecast-summary-card.test.tsx new file mode 100644 index 00000000..d9e57324 --- /dev/null +++ b/frontend/src/components/champion-selector/decision/forecast-summary-card.test.tsx @@ -0,0 +1,37 @@ +import { afterEach, describe, expect, it } from 'vitest' +import { cleanup, render, screen } from '@testing-library/react' +import { ForecastSummaryCard } from './forecast-summary-card' +import type { ModelSelectionForecastSummary } from '@/types/api' + +afterEach(cleanup) + +const forecast: ModelSelectionForecastSummary = { + points: [], + total_demand: 140, + average_demand: 10, + horizon: 14, + peak_date: '2026-06-02', + peak_demand: 25, + low_date: '2026-06-03', + low_demand: 5, +} + +describe('ForecastSummaryCard', () => { + it('renders total, peak, and low tiles', () => { + render() + const text = screen.getByTestId('forecast-summary-card').textContent ?? '' + expect(text).toContain('140.0') + expect(text).toContain('25.0') + expect(text).toContain('2026-06-02') + expect(text).toContain('14d') + }) + + it('renders an em-dash for null peak/low', () => { + render( + , + ) + expect(screen.getByTestId('forecast-summary-card').textContent).toContain('—') + }) +}) diff --git a/frontend/src/components/champion-selector/decision/forecast-summary-card.tsx b/frontend/src/components/champion-selector/decision/forecast-summary-card.tsx new file mode 100644 index 00000000..dddab510 --- /dev/null +++ b/frontend/src/components/champion-selector/decision/forecast-summary-card.tsx @@ -0,0 +1,48 @@ +import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card' +import type { ModelSelectionForecastSummary } from '@/types/api' + +interface ForecastSummaryCardProps { + forecast: ModelSelectionForecastSummary +} + +function Tile({ label, value, sub }: { label: string; value: string; sub?: string }) { + return ( +
+

{label}

+

{value}

+ {sub &&

{sub}

} +
+ ) +} + +function num(value: number | null | undefined): string { + return typeof value === 'number' && Number.isFinite(value) ? value.toFixed(1) : '—' +} + +/** Slice C — total / average / peak / low / horizon KPI tiles (null-safe). */ +export function ForecastSummaryCard({ forecast }: ForecastSummaryCardProps) { + return ( + + + Forecast summary + + +
+ + + + + +
+
+
+ ) +} diff --git a/frontend/src/components/champion-selector/decision/promote-champion-dialog.test.tsx b/frontend/src/components/champion-selector/decision/promote-champion-dialog.test.tsx new file mode 100644 index 00000000..c7dbe718 --- /dev/null +++ b/frontend/src/components/champion-selector/decision/promote-champion-dialog.test.tsx @@ -0,0 +1,72 @@ +import { afterEach, describe, expect, it, vi } from 'vitest' +import { cleanup, fireEvent, render, screen } from '@testing-library/react' +import { PromoteChampionDialog } from './promote-champion-dialog' + +afterEach(cleanup) + +function renderDialog(overrides: Partial[0]> = {}) { + const props = { + open: true, + onOpenChange: vi.fn(), + isOverride: false, + isPromoting: false, + promoteError: null, + promotedAlias: null, + onConfirm: vi.fn(), + ...overrides, + } + render() + return props +} + +describe('PromoteChampionDialog', () => { + it('keeps confirm disabled until alias + approver are valid', () => { + renderDialog() + expect(screen.getByTestId('promote-confirm-action').hasAttribute('disabled')).toBe(true) + fireEvent.change(screen.getByTestId('promote-alias-input'), { + target: { value: 'champion-x' }, + }) + fireEvent.change(screen.getByTestId('promote-approver-input'), { + target: { value: 'gabor' }, + }) + expect(screen.getByTestId('promote-confirm-action').hasAttribute('disabled')).toBe(false) + }) + + it('flags an invalid alias name', () => { + renderDialog() + fireEvent.change(screen.getByTestId('promote-alias-input'), { + target: { value: 'Bad Alias' }, + }) + expect(screen.getByTestId('promote-alias-error')).toBeTruthy() + }) + + it('requires the ack checkbox for a non-recommended (override) model', () => { + renderDialog({ isOverride: true }) + fireEvent.change(screen.getByTestId('promote-alias-input'), { + target: { value: 'champion-x' }, + }) + fireEvent.change(screen.getByTestId('promote-approver-input'), { + target: { value: 'gabor' }, + }) + // still disabled until the ack is checked + expect(screen.getByTestId('promote-confirm-action').hasAttribute('disabled')).toBe(true) + fireEvent.click(screen.getByTestId('promote-ack-checkbox')) + expect(screen.getByTestId('promote-confirm-action').hasAttribute('disabled')).toBe(false) + }) + + it('calls onConfirm with the promote body', () => { + const props = renderDialog() + fireEvent.change(screen.getByTestId('promote-alias-input'), { + target: { value: 'champion-x' }, + }) + fireEvent.change(screen.getByTestId('promote-approver-input'), { + target: { value: 'gabor' }, + }) + fireEvent.click(screen.getByTestId('promote-confirm-action')) + expect(props.onConfirm).toHaveBeenCalledWith({ + alias_name: 'champion-x', + approved_by: 'gabor', + acknowledge_non_recommended: false, + }) + }) +}) diff --git a/frontend/src/components/champion-selector/decision/promote-champion-dialog.tsx b/frontend/src/components/champion-selector/decision/promote-champion-dialog.tsx new file mode 100644 index 00000000..79e7e486 --- /dev/null +++ b/frontend/src/components/champion-selector/decision/promote-champion-dialog.tsx @@ -0,0 +1,163 @@ +import { useState } from 'react' +import { CheckCircle2, ShieldAlert } from 'lucide-react' +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, +} from '@/components/ui/alert-dialog' +import { Checkbox } from '@/components/ui/checkbox' +import { Input } from '@/components/ui/input' +import type { PromoteRequest } from '@/types/api' +import { PROMOTE_AUDIT_NOTE } from './constants' + +const ALIAS_RE = /^[a-z0-9][a-z0-9\-_]*$/ + +interface PromoteChampionDialogProps { + open: boolean + onOpenChange: (open: boolean) => void + /** True when a non-recommended model was trained (requires explicit ack). */ + isOverride: boolean + defaultAliasName?: string + isPromoting: boolean + /** Error message from the last promote attempt (null on success/idle). */ + promoteError: string | null + /** The alias name on a successful promotion (null until promoted). */ + promotedAlias: string | null + onConfirm: (body: PromoteRequest) => void +} + +/** + * Slice C — the approval-gated promote dialog. Requires an approver and a valid + * alias name; a non-recommended model additionally requires the ack checkbox. + * Mirrors `forecast-intelligence/promote-confirmation-dialog.tsx`, but calls the + * model_selection `promote` flow (compare and promote stay separate). + */ +export function PromoteChampionDialog({ + open, + onOpenChange, + isOverride, + defaultAliasName = '', + isPromoting, + promoteError, + promotedAlias, + onConfirm, +}: PromoteChampionDialogProps) { + const [aliasName, setAliasName] = useState(defaultAliasName) + const [approvedBy, setApprovedBy] = useState('') + const [ack, setAck] = useState(false) + + const aliasValid = ALIAS_RE.test(aliasName.trim()) + const canConfirm = + aliasValid && + approvedBy.trim().length > 0 && + (!isOverride || ack) && + !isPromoting + + function handleConfirm() { + if (!canConfirm) return + onConfirm({ + alias_name: aliasName.trim(), + approved_by: approvedBy.trim(), + acknowledge_non_recommended: isOverride ? ack : false, + }) + } + + return ( + { + if (!next) setAck(false) + onOpenChange(next) + }} + > + + + Promote champion to a registry alias + {PROMOTE_AUDIT_NOTE} + + +
+
+ + setAliasName(event.target.value)} + placeholder="e.g. champion-store5-prod8" + autoComplete="off" + data-testid="promote-alias-input" + /> + {aliasName.length > 0 && !aliasValid && ( +

+ Lowercase letters, digits, hyphens and underscores only (must start + with a letter or digit). +

+ )} +
+ +
+ + setApprovedBy(event.target.value)} + placeholder="your name" + autoComplete="off" + data-testid="promote-approver-input" + /> +
+ + {isOverride && ( + + )} + + {promoteError && ( +

+ {promoteError} +

+ )} + + {promotedAlias && ( +
+ + Promoted to alias {promotedAlias}. +
+ )} +
+ + + Close + + {isPromoting ? 'Promoting…' : 'Promote'} + + +
+
+ ) +} diff --git a/frontend/src/components/champion-selector/decision/safety-stock-panel.test.tsx b/frontend/src/components/champion-selector/decision/safety-stock-panel.test.tsx new file mode 100644 index 00000000..a5a27e9f --- /dev/null +++ b/frontend/src/components/champion-selector/decision/safety-stock-panel.test.tsx @@ -0,0 +1,54 @@ +import { afterEach, describe, expect, it, vi } from 'vitest' +import { cleanup, fireEvent, render, screen } from '@testing-library/react' +import { SafetyStockPanel } from './safety-stock-panel' +import type { ForecastDecision } from '@/types/api' + +afterEach(cleanup) + +const decision: ForecastDecision = { + method: 'heuristic', + lead_time_days: 7, + service_level: 0.95, + z_value: 1.6449, + sigma_daily_demand: 1.41, + expected_demand_over_lead_time: 70, + safety_stock: 6.13, + reorder_point: 76.13, + bias_risk_text: 'bias text', + caveats: ['heuristic'], +} + +function renderPanel(overrides: Partial[0]> = {}) { + const props = { + decision, + leadTimeDays: 7, + serviceLevel: 0.95, + isRecomputing: false, + onLeadTimeChange: vi.fn(), + onServiceLevelChange: vi.fn(), + onRecompute: vi.fn(), + ...overrides, + } + render() + return props +} + +describe('SafetyStockPanel', () => { + it('renders the labeled heuristic header and stats', () => { + renderPanel() + const text = screen.getByTestId('safety-stock-panel').textContent ?? '' + expect(text).toContain('Safety stock (heuristic)') + expect(text).toContain('1.6449') + expect(text).toContain('6.1') + }) + + it('fires onLeadTimeChange and onRecompute', () => { + const props = renderPanel() + fireEvent.change(screen.getByTestId('safety-stock-lead-time'), { + target: { value: '14' }, + }) + expect(props.onLeadTimeChange).toHaveBeenCalledWith(14) + fireEvent.click(screen.getByTestId('safety-stock-recompute')) + expect(props.onRecompute).toHaveBeenCalledOnce() + }) +}) diff --git a/frontend/src/components/champion-selector/decision/safety-stock-panel.tsx b/frontend/src/components/champion-selector/decision/safety-stock-panel.tsx new file mode 100644 index 00000000..11f1b43e --- /dev/null +++ b/frontend/src/components/champion-selector/decision/safety-stock-panel.tsx @@ -0,0 +1,115 @@ +import { Loader2, RefreshCw } from 'lucide-react' +import { Button } from '@/components/ui/button' +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card' +import { Input } from '@/components/ui/input' +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from '@/components/ui/select' +import type { ForecastDecision } from '@/types/api' +import { SAFETY_STOCK_HEADER, SERVICE_LEVEL_OPTIONS } from './constants' + +interface SafetyStockPanelProps { + decision: ForecastDecision | null + leadTimeDays: number + serviceLevel: number + isRecomputing: boolean + onLeadTimeChange: (value: number) => void + onServiceLevelChange: (value: number) => void + onRecompute: () => void +} + +function Stat({ label, value }: { label: string; value: string }) { + return ( +
+

{label}

+

{value}

+
+ ) +} + +/** + * Slice C — the CLEARLY-LABELED safety-stock heuristic. Lead time + service + * level inputs recompute the forecast decision. Never influences ranking. + */ +export function SafetyStockPanel({ + decision, + leadTimeDays, + serviceLevel, + isRecomputing, + onLeadTimeChange, + onServiceLevelChange, + onRecompute, +}: SafetyStockPanelProps) { + return ( + + + {SAFETY_STOCK_HEADER} + + A deterministic reorder heuristic (demand variability only, constant lead + time). Adjust the inputs and recompute. + + + +
+
+ Lead time (days) + onLeadTimeChange(Number(event.target.value) || 0)} + className="w-32" + data-testid="safety-stock-lead-time" + /> +
+
+ Service level + +
+ +
+ + {decision && ( +
+ + + + +
+ )} +
+
+ ) +} diff --git a/frontend/src/components/champion-selector/decision/train-forecast-actions.test.tsx b/frontend/src/components/champion-selector/decision/train-forecast-actions.test.tsx new file mode 100644 index 00000000..3ff27f99 --- /dev/null +++ b/frontend/src/components/champion-selector/decision/train-forecast-actions.test.tsx @@ -0,0 +1,48 @@ +import { afterEach, describe, expect, it, vi } from 'vitest' +import { cleanup, fireEvent, render, screen } from '@testing-library/react' +import { TrainForecastActions } from './train-forecast-actions' + +afterEach(cleanup) + +describe('TrainForecastActions', () => { + it('shows the blocked state for a feature-aware winner', () => { + render( + {}} + />, + ) + expect(screen.getByTestId('forecast-blocked-state').textContent).toContain( + 'What-If Planner', + ) + expect(screen.queryByTestId('forecast-button')).toBeNull() + }) + + it('fires onForecast when the trained forecast button is clicked', () => { + const onForecast = vi.fn() + render( + , + ) + fireEvent.click(screen.getByTestId('forecast-button')) + expect(onForecast).toHaveBeenCalledOnce() + }) + + it('disables the forecast button until a model is trained', () => { + render( + {}} + />, + ) + expect(screen.getByTestId('forecast-button').hasAttribute('disabled')).toBe(true) + }) +}) diff --git a/frontend/src/components/champion-selector/decision/train-forecast-actions.tsx b/frontend/src/components/champion-selector/decision/train-forecast-actions.tsx new file mode 100644 index 00000000..0ba12605 --- /dev/null +++ b/frontend/src/components/champion-selector/decision/train-forecast-actions.tsx @@ -0,0 +1,54 @@ +import { Loader2, LineChart, Ban } from 'lucide-react' +import { Button } from '@/components/ui/button' +import { FEATURE_AWARE_BLOCKED_COPY } from './constants' + +interface TrainForecastActionsProps { + /** From the Slice A catalog (`supports_auto_predict = not feature_aware`). */ + supportsAutoPredict: boolean + /** True once a model bundle has been trained for the selection. */ + trained: boolean + isPredicting: boolean + onForecast: () => void +} + +/** + * Slice C — the Forecast action + the capability-limited blocked state. + * + * A feature-aware winner cannot auto-predict (LOCKED #5): instead of faking a + * forecast we surface the limitation and route the user to the What-If Planner. + */ +export function TrainForecastActions({ + supportsAutoPredict, + trained, + isPredicting, + onForecast, +}: TrainForecastActionsProps) { + if (!supportsAutoPredict) { + return ( +
+ + {FEATURE_AWARE_BLOCKED_COPY} +
+ ) + } + + return ( + + ) +} diff --git a/frontend/src/components/champion-selector/decision/winner-decision-panel.test.tsx b/frontend/src/components/champion-selector/decision/winner-decision-panel.test.tsx new file mode 100644 index 00000000..5aca0b38 --- /dev/null +++ b/frontend/src/components/champion-selector/decision/winner-decision-panel.test.tsx @@ -0,0 +1,48 @@ +import { afterEach, describe, expect, it, vi } from 'vitest' +import { cleanup, fireEvent, render, screen } from '@testing-library/react' +import { WinnerDecisionPanel } from './winner-decision-panel' +import type { TrainWinnerResponse } from '@/types/api' + +afterEach(cleanup) + +describe('WinnerDecisionPanel', () => { + it('trains the recommended winner without a confirm dialog', () => { + const onTrain = vi.fn() + render( + , + ) + expect(screen.getByTestId('decision-train-button').textContent).toContain( + 'Train recommended', + ) + fireEvent.click(screen.getByTestId('decision-train-button')) + expect(onTrain).toHaveBeenCalledWith('naive', null) + }) + + it('renders the override warning from a train result', () => { + const trainResult: TrainWinnerResponse = { + selection_id: 's', + model_type: 'seasonal_naive', + model_path: 'p', + is_override: true, + override_warning: 'You trained seasonal_naive instead of naive.', + } + render( + {}} + />, + ) + expect(screen.getByTestId('decision-override-warning').textContent).toContain( + 'seasonal_naive', + ) + }) +}) diff --git a/frontend/src/components/champion-selector/decision/winner-decision-panel.tsx b/frontend/src/components/champion-selector/decision/winner-decision-panel.tsx new file mode 100644 index 00000000..5b0d58d5 --- /dev/null +++ b/frontend/src/components/champion-selector/decision/winner-decision-panel.tsx @@ -0,0 +1,158 @@ +import { useState } from 'react' +import { Loader2, Trophy, TriangleAlert } from 'lucide-react' +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, +} from '@/components/ui/alert-dialog' +import { Button } from '@/components/ui/button' +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card' +import { Input } from '@/components/ui/input' +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from '@/components/ui/select' +import type { TrainWinnerResponse } from '@/types/api' + +interface WinnerDecisionPanelProps { + winnerModelType: string + /** Every candidate offered in the run (winner + runners-up + failed). */ + candidateModelTypes: string[] + isTraining: boolean + trainResult: TrainWinnerResponse | null + /** Train the chosen model — the page routes winner vs. override. */ + onTrain: (modelType: string, overrideReason: string | null) => void +} + +/** + * Slice C — accept the recommended winner OR override to another candidate. + * + * Picking a non-winner opens a confirm dialog (explicit warning + an optional + * reason) before training. Presentational — the page owns the train mutations. + */ +export function WinnerDecisionPanel({ + winnerModelType, + candidateModelTypes, + isTraining, + trainResult, + onTrain, +}: WinnerDecisionPanelProps) { + const [selected, setSelected] = useState(winnerModelType) + const [overrideReason, setOverrideReason] = useState('') + const [confirmOpen, setConfirmOpen] = useState(false) + + const isOverride = selected !== winnerModelType + + function handleTrainClick() { + if (isOverride) { + setConfirmOpen(true) + return + } + onTrain(selected, null) + } + + function handleConfirmOverride() { + onTrain(selected, overrideReason.trim() || null) + setConfirmOpen(false) + } + + return ( + + + 5 · Decide & train + + Train the recommended champion, or override to another candidate. The + recommended model is {winnerModelType}. + + + +
+
+ Model to train + +
+ +
+ + {trainResult?.override_warning && ( +
+ + {trainResult.override_warning} +
+ )} + + {trainResult && !trainResult.override_warning && ( +

+ Trained {trainResult.model_type}. +

+ )} +
+ + + + + Train a non-recommended model? + + You picked {selected} instead of the + recommended {winnerModelType}. This is an + override and is recorded on the run. + + +
+ Reason (optional) + setOverrideReason(event.target.value)} + placeholder="e.g. domain seasonality outweighs the WAPE lead" + data-testid="override-reason-input" + /> +
+ + Cancel + + Train override + + +
+
+
+ ) +} diff --git a/frontend/src/hooks/use-model-selection.test.ts b/frontend/src/hooks/use-model-selection.test.ts index 4209a072..5074351b 100644 --- a/frontend/src/hooks/use-model-selection.test.ts +++ b/frontend/src/hooks/use-model-selection.test.ts @@ -13,8 +13,12 @@ import { useCancelSelectionRun, useModelCatalog, usePairAvailability, + usePredictWinner, + usePromoteChampion, useSelectionRun, useSubmitSelectionRun, + useTrainSelected, + useTrainWinner, } from './use-model-selection' import type { ModelCatalogResponse, @@ -271,3 +275,90 @@ describe('useCancelSelectionRun', () => { expect((call[1] as RequestInit).method).toBe('DELETE') }) }) + +// --------------------------------------------------------------- Slice C hooks + +function jsonResponse(body: unknown) { + return new Response(JSON.stringify(body), { + status: 200, + headers: { 'content-type': 'application/json' }, + }) +} + +describe('useTrainWinner', () => { + it('POSTs /train-winner (no body) and invalidates the run query', async () => { + const fetchMock = vi.fn().mockResolvedValue( + jsonResponse({ selection_id: 'sel_c', model_type: 'naive', model_path: 'p', is_override: false, override_warning: null }), + ) + vi.stubGlobal('fetch', fetchMock) + const { result } = renderHook(() => useTrainWinner('sel_c'), { + wrapper: makeWrapper(makeClient()), + }) + await act(async () => { + result.current.mutate() + }) + await waitFor(() => expect(result.current.isSuccess).toBe(true)) + const call = fetchMock.mock.calls[0]! + expect(String(call[0])).toContain('/model-selection/sel_c/train-winner') + expect((call[1] as RequestInit).method).toBe('POST') + }) +}) + +describe('useTrainSelected', () => { + it('POSTs /train-selected with the override body', async () => { + const fetchMock = vi.fn().mockResolvedValue( + jsonResponse({ selection_id: 'sel_c', model_type: 'seasonal_naive', model_path: 'p', is_override: true, override_warning: 'w' }), + ) + vi.stubGlobal('fetch', fetchMock) + const { result } = renderHook(() => useTrainSelected('sel_c'), { + wrapper: makeWrapper(makeClient()), + }) + await act(async () => { + result.current.mutate({ model_type: 'seasonal_naive', override_reason: 'domain' }) + }) + await waitFor(() => expect(result.current.isSuccess).toBe(true)) + const call = fetchMock.mock.calls[0]! + expect(String(call[0])).toContain('/model-selection/sel_c/train-selected') + expect((call[1] as RequestInit).method).toBe('POST') + expect(String((call[1] as RequestInit).body)).toContain('seasonal_naive') + }) +}) + +describe('usePredictWinner', () => { + it('POSTs /predict with the decision params body', async () => { + const fetchMock = vi.fn().mockResolvedValue( + jsonResponse({ selection_id: 'sel_c', forecast: { points: [], total_demand: 0, average_demand: 0, horizon: 14 }, decision: null }), + ) + vi.stubGlobal('fetch', fetchMock) + const { result } = renderHook(() => usePredictWinner('sel_c'), { + wrapper: makeWrapper(makeClient()), + }) + await act(async () => { + result.current.mutate({ lead_time_days: 7, service_level: 0.95 }) + }) + await waitFor(() => expect(result.current.isSuccess).toBe(true)) + const call = fetchMock.mock.calls[0]! + expect(String(call[0])).toContain('/model-selection/sel_c/predict') + expect((call[1] as RequestInit).method).toBe('POST') + }) +}) + +describe('usePromoteChampion', () => { + it('POSTs /promote with the promote body', async () => { + const fetchMock = vi.fn().mockResolvedValue( + jsonResponse({ selection_id: 'sel_c', alias_name: 'champion-x', run_id: 'r', run_status: 'success', model_type: 'naive', is_override: false, promoted_at: '2026-06-01T00:00:00Z' }), + ) + vi.stubGlobal('fetch', fetchMock) + const { result } = renderHook(() => usePromoteChampion('sel_c'), { + wrapper: makeWrapper(makeClient()), + }) + await act(async () => { + result.current.mutate({ alias_name: 'champion-x', approved_by: 'gabor' }) + }) + await waitFor(() => expect(result.current.isSuccess).toBe(true)) + const call = fetchMock.mock.calls[0]! + expect(String(call[0])).toContain('/model-selection/sel_c/promote') + expect((call[1] as RequestInit).method).toBe('POST') + expect(String((call[1] as RequestInit).body)).toContain('champion-x') + }) +}) diff --git a/frontend/src/hooks/use-model-selection.ts b/frontend/src/hooks/use-model-selection.ts index 2cf7286f..bc861a1b 100644 --- a/frontend/src/hooks/use-model-selection.ts +++ b/frontend/src/hooks/use-model-selection.ts @@ -2,18 +2,24 @@ import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query' import { api } from '@/lib/api' import { isTerminalSelectionStatus } from '@/components/champion-selector/results/constants' import type { + ForecastDecisionParams, ModelCatalogResponse, ModelSelectionRunRequest, ModelSelectionRunResponse, PairAvailability, + PredictWinnerResponse, + PromoteRequest, + PromoteResponse, SubmitRunResponse, + TrainSelectedRequest, + TrainWinnerResponse, } from '@/types/api' /** * Model-selection query hooks (Champion Selector). * * Slice A: catalog + availability GETs. Slice B: async submit / poll / cancel. - * Train/predict/promotion are owned by Slice C. + * Slice C: train (winner / override) / predict (decision) / promote. */ /** @@ -118,3 +124,67 @@ export function useCancelSelectionRun() { }, }) } + +/** + * Invalidate the polled run query so a terminal run re-fetches the new + * `final_model_path` / `forecast` / promotion after a Slice C mutation. + */ +function invalidateRun( + queryClient: ReturnType, + selectionId: string, +) { + void queryClient.invalidateQueries({ + queryKey: ['model-selection', 'run', selectionId], + }) +} + +/** Train the ranked winner (`POST /{id}/train-winner`, no body). */ +export function useTrainWinner(selectionId: string) { + const queryClient = useQueryClient() + return useMutation({ + mutationFn: () => + api(`/model-selection/${selectionId}/train-winner`, { + method: 'POST', + }), + onSuccess: () => invalidateRun(queryClient, selectionId), + }) +} + +/** Train a user-chosen candidate (`POST /{id}/train-selected`, override). */ +export function useTrainSelected(selectionId: string) { + const queryClient = useQueryClient() + return useMutation({ + mutationFn: (body: TrainSelectedRequest) => + api(`/model-selection/${selectionId}/train-selected`, { + method: 'POST', + body, + }), + onSuccess: () => invalidateRun(queryClient, selectionId), + }) +} + +/** Forecast with the trained model + decision (`POST /{id}/predict`). */ +export function usePredictWinner(selectionId: string) { + const queryClient = useQueryClient() + return useMutation({ + mutationFn: (body: ForecastDecisionParams) => + api(`/model-selection/${selectionId}/predict`, { + method: 'POST', + body, + }), + onSuccess: () => invalidateRun(queryClient, selectionId), + }) +} + +/** Promote the trained champion to a registry alias (`POST /{id}/promote`). */ +export function usePromoteChampion(selectionId: string) { + const queryClient = useQueryClient() + return useMutation({ + mutationFn: (body: PromoteRequest) => + api(`/model-selection/${selectionId}/promote`, { + method: 'POST', + body, + }), + onSuccess: () => invalidateRun(queryClient, selectionId), + }) +} diff --git a/frontend/src/pages/visualize/champion.tsx b/frontend/src/pages/visualize/champion.tsx index 6157148e..30b624c8 100644 --- a/frontend/src/pages/visualize/champion.tsx +++ b/frontend/src/pages/visualize/champion.tsx @@ -25,6 +25,7 @@ import { WinnerCard } from '@/components/champion-selector/results/winner-card' import { ComparisonCharts } from '@/components/champion-selector/results/comparison-charts' import { ModelDetailDrawer } from '@/components/champion-selector/results/model-detail-drawer' import { CancelRunDialog } from '@/components/champion-selector/results/cancel-run-dialog' +import { DecisionSection } from '@/components/champion-selector/decision/decision-section' import { isTerminalSelectionStatus } from '@/components/champion-selector/results/constants' import { Button } from '@/components/ui/button' import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card' @@ -369,6 +370,16 @@ export default function ChampionSelectorPage() { open={drawerOpen} onOpenChange={setDrawerOpen} /> + {/* Slice C — decide → train → forecast → interpret → promote. Keyed by + selectionId so a fresh run resets the decision state. */} + {selectionId && run.winner && ( + + )} )} diff --git a/frontend/src/types/api.ts b/frontend/src/types/api.ts index 63ebe3f4..88a204e1 100644 --- a/frontend/src/types/api.ts +++ b/frontend/src/types/api.ts @@ -1330,6 +1330,11 @@ export interface ModelSelectionForecastSummary { total_demand: number average_demand: number horizon: number + // Slice C — additive peak/low day (null on legacy snapshots). + peak_date?: string | null + peak_demand?: number | null + low_date?: string | null + low_demand?: number | null } // Slice B — live async progress on a selection run. @@ -1384,3 +1389,66 @@ export interface SubmitRunResponse extends ModelSelectionRunResponse { monitor_url: string cancel_url: string } + +// Slice C — forecast decision, override, and promotion contracts. + +/** `POST /model-selection/{id}/train-selected` body (override). */ +export interface TrainSelectedRequest { + model_type: string + override_reason?: string | null +} + +/** Optional `POST /model-selection/{id}/predict` body. */ +export interface ForecastDecisionParams { + lead_time_days: number + service_level: number +} + +/** Deterministic, labeled inventory-decision heuristic (never feeds ranking). */ +export interface ForecastDecision { + method: 'heuristic' + lead_time_days: number + service_level: number + z_value: number + sigma_daily_demand: number + expected_demand_over_lead_time: number + safety_stock: number + reorder_point: number + bias_risk_text: string + caveats: string[] +} + +/** `POST /model-selection/{id}/train-winner` and `/train-selected` response. */ +export interface TrainWinnerResponse { + selection_id: string + model_type: string + model_path: string + is_override: boolean + override_warning: string | null +} + +/** `POST /model-selection/{id}/predict` response (forecast + decision). */ +export interface PredictWinnerResponse { + selection_id: string + forecast: ModelSelectionForecastSummary + decision: ForecastDecision | null +} + +/** `POST /model-selection/{id}/promote` body (approval-gated). */ +export interface PromoteRequest { + alias_name: string + approved_by: string + acknowledge_non_recommended?: boolean + description?: string | null +} + +/** `POST /model-selection/{id}/promote` response. */ +export interface PromoteResponse { + selection_id: string + alias_name: string + run_id: string + run_status: string + model_type: string + is_override: boolean + promoted_at: string // ISO datetime +}