diff --git a/PRPs/ai_docs/forecast-champion-selector-backend-research.md b/PRPs/ai_docs/forecast-champion-selector-backend-research.md new file mode 100644 index 00000000..2d37603b --- /dev/null +++ b/PRPs/ai_docs/forecast-champion-selector-backend-research.md @@ -0,0 +1,222 @@ +# Forecast Champion Selector Backend Research + +Date: 2026-06-01 + +This note captures external-library and runtime facts used by +`PRPs/forecast-champion-selector-backend.md`. It is intentionally narrow: +only claims that affect backend implementation are recorded here. + +## Official Documentation References + +- FastAPI APIRouter / multi-file apps: + https://fastapi.tiangolo.com/tutorial/bigger-applications/ + - Reason: the new `app/features/model_selection/routes.py` must follow the + existing `APIRouter(prefix=..., tags=...)` slice pattern and be wired in + `app/main.py`. + +- Pydantic v2 strict mode and field-level overrides: + https://pydantic.dev/docs/validation/latest/concepts/strict_mode/ + - Reason: ForecastLabAI request schemas use `ConfigDict(strict=True)`, but + JSON request bodies still need date/datetime/UUID/Decimal fields to accept + JSON-native strings via `Field(strict=False, ...)`. + +- SQLAlchemy 2.0 PostgreSQL JSONB: + https://docs.sqlalchemy.org/en/20/dialects/postgresql.html#json-types + - Reason: `model_selection_run` should store immutable request/response + snapshots (`candidate_models`, `ranking_result`, `winner_metrics`, + `forecast_result`, `business_summary`) as PostgreSQL JSONB. + +- Alembic `Operations.create_index`: + https://alembic.sqlalchemy.org/en/latest/ops.html#alembic.operations.Operations.create_index + - Reason: the migration should use explicit named indexes; any partial or + JSONB index must use Alembic operations rather than raw SQL. + +- scikit-learn `TimeSeriesSplit`: + https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.TimeSeriesSplit.html + - Reason: the selector's recommended split defaults mirror the project's + own `SplitConfig` semantics and should not assume unsupported parameters. + +## Runtime Verification Commands + +Run from repository root on 2026-06-01. + +```bash +uv run python -c "import inspect; from sqlalchemy import select, table, column; import sqlalchemy; stmt=select(column('id')).select_from(table('t')).with_for_update(skip_locked=True); print('sqlalchemy', sqlalchemy.__version__); print('with_for_update_has_skip_locked', 'skip_locked' in str(inspect.signature(select(column('id')).with_for_update))); print(stmt)" +``` + +Observed: + +```text +sqlalchemy 2.0.46 +with_for_update_has_skip_locked True +SELECT id +FROM t FOR UPDATE +``` + +Note: generic SQL compilation does not render PostgreSQL-specific +`SKIP LOCKED`; use PostgreSQL dialect compilation in tests when asserting +that string. + +```bash +uv run python -c "from datetime import date; import pydantic; from pydantic import BaseModel, ConfigDict, Field; M=type('M',(BaseModel,),{'__annotations__':{'d':date},'model_config':ConfigDict(strict=True),'d':Field(strict=False)}); print('pydantic', pydantic.__version__); print(M.model_validate({'d':'2026-06-01'}).d.isoformat())" +``` + +Observed: + +```text +pydantic 2.12.5 +2026-06-01 +``` + +```bash +uv run python -c "import inspect, sklearn; from sklearn.model_selection import TimeSeriesSplit; print('sklearn', sklearn.__version__); print(inspect.signature(TimeSeriesSplit)); t=TimeSeriesSplit(n_splits=3, test_size=2, gap=1); print(t)" +``` + +Observed: + +```text +sklearn 1.8.0 +(n_splits=5, *, max_train_size=None, test_size=None, gap=0) +TimeSeriesSplit(gap=1, max_train_size=None, n_splits=3, test_size=2) +``` + +```bash +uv run python -c "import inspect, fastapi; from fastapi import APIRouter, BackgroundTasks; print('fastapi', fastapi.__version__); print('APIRouter_prefix_param', 'prefix' in inspect.signature(APIRouter).parameters); print('BackgroundTasks_add_task', inspect.signature(BackgroundTasks.add_task))" +``` + +Observed: + +```text +fastapi 0.128.0 +APIRouter_prefix_param True +BackgroundTasks_add_task (self, func: ..., *args: P.args, **kwargs: P.kwargs) -> None +``` + +```bash +uv run python -c "import inspect, alembic; from alembic.operations import Operations; print('alembic', alembic.__version__); print(inspect.signature(Operations.create_index))" +``` + +Observed: + +```text +alembic 1.18.4 +(self, index_name, table_name, columns, *, schema=None, unique=False, if_not_exists=None, **kw) -> None +``` + +## Implementation Consequences + +- Use `Literal[...]` request fields for JSON string enums under + `ConfigDict(strict=True)`; convert to ORM enums at service boundaries. +- Use `Field(strict=False, ...)` on every request-body date/datetime/UUID/ + Decimal field, or `app/core/tests/test_strict_mode_policy.py` can fail. +- Persist selector decisions in JSONB snapshots because registry metrics are + free-form JSONB and metric key names differ across layers. +- Do not assume a batch backtest item contains fold-level chart data. Batch + metrics are intentionally pinned to `{wape, smape, mae, bias, sample_size}`. +- If an implementation compiles SQL for PostgreSQL-specific clauses, compile + with the PostgreSQL dialect rather than relying on generic SQL strings. + +## Verified Internal Service Contracts (read from source 2026-06-01) + +These are the in-repo signatures the selector orchestrates. They were the prior +draft's #1 residual risk; recorded here so they survive and can be re-verified on +refactor. Re-verify with `grep -n "async def run_backtest\|async def train_model\|async def predict" app/features/backtesting/service.py app/features/forecasting/service.py`. + +### BacktestingService — `app/features/backtesting/service.py:213` + +```python +# __init__(self) -> None — takes NO db; instantiate as BacktestingService() +async def run_backtest( + self, db: AsyncSession, store_id: int, product_id: int, + start_date: date, end_date: date, config: BacktestConfig, +) -> BacktestResponse +``` + +`BacktestConfig` (`backtesting/schemas.py:81`, `frozen=True, extra="forbid"`): +`split_config: SplitConfig`, `model_config_main: Annotated[ModelConfig, Field(discriminator="model_type")]`, +`include_baselines: bool = True`, `store_fold_details: bool = True`. + +`SplitConfig` (`:24`): `strategy: Literal["expanding","sliding"]="expanding"`, +`n_splits: int=5 (ge=2,le=20)`, `min_train_size: int=30 (ge=7)`, `gap: int=0 (ge=0,le=30)`, +`horizon: int=14 (ge=1,le=90)`; validator `horizon > gap`. + +### BacktestResponse — `backtesting/schemas.py:257` + +`main_model_results: ModelBacktestResult`, `baseline_results: list[ModelBacktestResult] | None`, +plus `backtest_id, store_id, product_id, config_hash, split_config, comparison_summary, +duration_ms, leakage_check_passed`. + +`ModelBacktestResult` (`:180`): `model_type, config_hash, fold_results: list[FoldResult], +aggregated_metrics: dict[str,float], metric_std: dict[str,float], +bucketed_aggregated_metrics: dict|None, feature_aware: bool, exogenous_policy`. + +`FoldResult` (`:147`): `fold_index, split, dates: list[date], actuals: list[float], +predictions: list[float], metrics: dict[str,float], horizon_bucket_metrics`. + +**Metric keys (CORRECTION to the prior draft):** `aggregated_metrics` has **five** keys — +`{"mae", "rmse", "smape", "wape", "bias"}` (`backtesting/metrics.py:347`; PRP-36 added `rmse`). +`metric_std` keys are suffixed `"{name}_stability"` (a coefficient of variation, not a raw std). +`sample_size` is NOT in `aggregated_metrics` — derive from fold actuals length or n_folds. +Fold chart data path: `main_model_results.fold_results[i].{dates,actuals,predictions}` — populated +only when `config.store_fold_details=True`. + +### ForecastingService — `app/features/forecasting/service.py` + +```python +# __init__(self) -> None +async def train_model( # :247 + self, db: AsyncSession, store_id: int, product_id: int, + train_start_date: date, train_end_date: date, config: ModelConfig, + *, feature_frame_version: int = 1, feature_groups: list[str] | None = None, +) -> TrainResponse # TrainResponse.model_path is the artifact path + +async def predict( # :402 — NO db arg + self, store_id: int, product_id: int, horizon: int, model_path: str, +) -> PredictResponse # PredictResponse.forecasts: list[ForecastPoint] +``` + +`predict()` rejects feature-aware models (`service.py:491`) — feature-aware winners must route +through `/scenarios/simulate`; catch and warn rather than 500. + +### ModelConfig union — `forecasting/schemas.py:417` + +Plain PEP 604 union (`NaiveModelConfig | SeasonalNaiveModelConfig | … | ProphetLikeModelConfig`), +discriminated by each member's `model_type` Literal. Members are **flat** (`SeasonalNaiveModelConfig` +has `model_type` + `season_length`, NOT a nested `params`). No module-level `TypeAdapter`/helper. +Build from `{"model_type": ..., "params": {...}}` by FLATTENING: + +```python +from pydantic import TypeAdapter +from app.features.forecasting.schemas import ModelConfig +TypeAdapter(ModelConfig).validate_python({"model_type": c.model_type, **c.params}) +``` + +Members are `frozen=True, extra="forbid"` → bad params raise `ValidationError` (treat as a failed +candidate). `model_type` values: `naive, seasonal_naive, moving_average, weighted_moving_average, +seasonal_average, trend_regression_baseline, random_forest, lightgbm, xgboost, regression, +prophet_like` (`lightgbm`/`xgboost` are opt-in extras → may `ImportError`). + +### Data-platform ORM column names — `data_platform/models.py` + +`Store` (`:40`): `id` (int PK), `code` (business key — NOT `store_code`). `Product` (`:68`): `id`, +`sku`, `launch_date: date|None`. `SalesDaily` (`:172`): `date` (FK calendar.date), `store_id`, +`product_id`, `quantity` (Integer, CHECK ≥0), `unit_price`, `total_amount`; grain unique +`(date, store_id, product_id)`. `Promotion` (`:274`): `product_id` NOT NULL, `store_id` NULLABLE +(NULL = chain-wide, applies to all stores), date RANGE `[start_date, end_date]`, +`kind ∈ {pct_off,bogo,bundle,markdown}`. + +### Cross-cutting patterns + +- Exceptions (`app/core/exceptions.py`): `BadRequestError`(400), `NotFoundError`(404), + `DatabaseError`(500), `ConflictError`(409), `UnprocessableEntityError`(422); each + `(message=..., details=None)`. Routes map `ValueError→BadRequestError`, + `SQLAlchemyError→DatabaseError` (mirror `backtesting/routes.py:60`). +- `validate_date_range` is slice-local in `analytics/routes.py:36` (raises `BadRequestError`, + inverted-range + 730-day-max) — NOT importable cross-slice; reimplement locally. +- `TimestampMixin` (`app/shared/models.py`): `created_at`/`updated_at`, `server_default func.now()`, + `updated_at onupdate func.now()`. Mix in first: `class X(TimestampMixin, Base)`. +- JSONB import differs: migration `from sqlalchemy.dialects import postgresql` → + `postgresql.JSONB(astext_type=sa.Text())`; ORM `from sqlalchemy.dialects.postgresql import JSONB`. +- `app/main.py` wires routers as `from app.features..routes import router as _router` + + `app.include_router(_router)` (NO prefix at include; the router carries it). +- Current alembic head observed: `c1d2e3f40512` (`create_batch_tables`). diff --git a/PRPs/forecast-champion-selector-backend.md b/PRPs/forecast-champion-selector-backend.md new file mode 100644 index 00000000..651fc009 --- /dev/null +++ b/PRPs/forecast-champion-selector-backend.md @@ -0,0 +1,970 @@ +name: "Forecast Champion Selector Backend" +description: | + Backend foundation for an interactive Forecast Champion Selector. Adds a + first-class `model_selection` vertical slice that validates a store/product + pair, recommends/selects backtest settings, runs candidate model comparison, + ranks results by WAPE/sMAPE/bias/MAE, persists an auditable selection record, + and optionally trains/predicts with the winning model. This PRP deliberately + scopes UI work out; it creates the stable backend contract the UI can consume. + +**Created:** 2026-06-01 · **Refined:** 2026-06-01 (signatures verified against live code) +**Current repo base observed:** `dev` at `1b4c3f3` (`Merge pull request #352 ...fix/agents-finalizer-fallback`) +**Current alembic head observed:** `c1d2e3f40512` (`create_batch_tables`) — verify with `uv run alembic heads` at implementation time and chain to whatever head exists THEN. +**Working-tree caveat observed:** `docker-compose.lan.yml` is an untracked local dogfood override; do not commit it. +**Tracking issue:** create before implementation, suggested title `feat(api): add forecast champion selector backend`. +**Suggested branch:** `feat/forecast-champion-selector-backend` (off `dev`, per `.claude/rules/branch-naming.md`). +**Commit scope:** `api` (cross-feature backend wiring + new slice + `app/main.py`) and `db` (migration). Every commit references the tracking issue. + +--- + +## VALIDATE — Missing Backend Surface Check + +The lower-level primitives exist; the business workflow does not. + +### Reusable backend primitives already present (verified) + +- `POST /backtesting/run` → single store/product/model backtest with fold metrics, + aggregated metrics, optional baselines, bucketed horizon metrics, leakage status. + `app/features/backtesting/routes.py:24` (router), `:60` (handler). + **Service entry point is `BacktestingService().run_backtest(db, store_id, product_id, start_date, end_date, config)`** — see verified signature below. +- `POST /forecasting/train` → trains one model; supports `feature_frame_version` (1|2) and + `feature_groups`. `app/features/forecasting/routes.py:25`. Service: + `ForecastingService().train_model(db, store_id, product_id, train_start_date, train_end_date, config, *, feature_frame_version=1, feature_groups=None) -> TrainResponse`. +- `POST /forecasting/predict` → predicts from a saved bundle. Service: + `ForecastingService().predict(store_id, product_id, horizon, model_path) -> PredictResponse` + (**no db arg** — loads bundle from disk; rejects feature-aware models, `service.py:491`). +- `POST /batch/forecasting` fan-out exists but pins metrics to five keys and does **not** + expose fold-level chart data — NOT suitable for this slice's chart payload. +- `GET /dimensions/stores`, `GET /dimensions/products` provide dimension metadata. +- `app/features/ops/service.py` is the canonical read-only cross-slice ORM aggregation precedent. + +### Backend pieces missing for the full feature + +- No `app/features/model_selection/` slice; no `POST /model-selection/run`; no persisted + `model_selection_run` table; no orchestration of pair-validation → candidate backtests → + ranking → optional final train → optional predict; no pair-availability endpoint; no + backend ranking/confidence policy; no deterministic business explanation layer; no + chart-ready comparison payload. +- Batch/Job model allow-lists are narrower than forecasting's full `ModelConfig` union, and + job/batch training does not pass `feature_frame_version`/`feature_groups`. **Therefore this + slice calls the direct backtesting/forecasting services**, not batch/jobs. + +--- + +## BRAINSTORM / RERANK — Chosen Scope + +Chosen: **Option A — Backend foundation only** (new `model_selection` slice: pair +availability, candidate comparison, ranking/confidence, persisted audit, optional +train/predict, chart-ready payload). It covers every backend gap the eventual UI needs, +reuses mature primitives, creates a stable testable contract, and avoids frontend coupling. + +Non-goals (out of scope for this PRP): + +- No React page / shadcn UI / frontend routing. +- No agent tool, no `agent_require_approval` entry, no agent mutation surface. +- No alias auto-promotion (the selector may *recommend* a winner; alias mutation is a future + approval-gated PRP). +- No batch model-zoo retrofit. Use direct services for the single selected pair. + +--- + +## Goal + +**Feature Goal:** A backend-only Forecast Champion Selector vertical slice that, given one +store/product pair + window + horizon + candidate models, validates data availability, runs +comparable backtests for every candidate, deterministically ranks completed candidates, +computes a recommendation confidence with reasons, persists an auditable selection run, and +returns chart-ready comparison data plus optional final-model training and forecast output. + +**Deliverable:** `app/features/model_selection/` slice (`models.py`, `schemas.py`, +`ranking.py`, `explanations.py`, `service.py`, `routes.py`, `tests/`) + one Alembic migration +creating `model_selection_run`, wired in `app/main.py`. + +**Success Definition:** `POST /model-selection/run` with the default five candidates against +a seeded pair returns HTTP 200 with a persisted `selection_id`, a non-empty deterministic +`ranking`, a `winner`, a `recommendation_confidence`, and a `chart_data` payload; the row is +retrievable by `GET /model-selection/{selection_id}`; all validation gates pass. + +## Why + +- Business users want to ask "which model should I use for this store/product?" without + manually coordinating `/backtesting/run`, `/forecasting/train`, `/forecasting/predict`. +- The UI needs **one stable backend contract** rather than re-implementing ranking in TypeScript. +- A persisted selection run makes the model choice auditable: which models competed, which + window, which policy, and why the winner won. +- Keeps the single-host architecture intact — no queue, no cloud SDK, no new service. + +## What + +### New endpoints (all under `APIRouter(prefix="/model-selection", tags=["model-selection"])`) + +```http +GET /model-selection/availability?store_id=...&product_id=...&forecast_horizon=14 +POST /model-selection/run +GET /model-selection/{selection_id} +GET /model-selection/{selection_id}/ranking +POST /model-selection/{selection_id}/train-winner +POST /model-selection/{selection_id}/predict +``` + +### Core request shape (`POST /model-selection/run`) + +```json +{ + "store_id": 1, + "product_id": 1, + "selection_window": { "start_date": "2026-01-01", "end_date": "2026-05-31" }, + "forecast_horizon": 14, + "ranking_metric": "wape", + "split_config": { "strategy": "expanding", "n_splits": 5, "min_train_size": 30, "gap": 0, "horizon": 14 }, + "candidate_models": [ + {"model_type": "naive", "params": {}}, + {"model_type": "seasonal_naive", "params": {"season_length": 7}}, + {"model_type": "moving_average", "params": {"window_size": 7}}, + {"model_type": "regression", "params": {}}, + {"model_type": "prophet_like", "params": {}} + ], + "feature_frame_version": 1, + "feature_groups": null, + "auto_train_winner": false, + "auto_predict": false +} +``` + +### LOCKED decisions (these remove every "choose one and test" ambiguity in the prior draft) + +1. **HTTP status codes:** `POST /model-selection/run` → **200** (synchronous, returns the + full result, mirrors `/backtesting/run` which is `status.HTTP_200_OK`). All GETs → 200. + `train-winner` / `predict` → 200. (201 is *not* used; the row is an audit side-effect, the + response is the computed result. Tests lock 200.) +2. **Availability gate:** if `availability.status == "unusable"`, **fail fast** — persist the + row as `status="failed"` with `error_message`, then raise `BadRequestError` (RFC 7807 **400**). + Nothing is ranked. +3. **All-candidates-fail (availability OK but every backtest errored):** **do NOT raise.** + Persist `status="failed"`, `ranking_result` with the failed entries, `winner=null`, and + return **200** with the failed-status response. Rationale: the run was validly attempted and + is an auditable outcome, not a client error. (Distinguish from #2: #2 is "we never started".) +4. **Per-candidate backtest config:** `BacktestConfig(split_config=req.split_config, + model_config_main=, include_baselines=False, store_fold_details=True)`. + `include_baselines=False` because each candidate is itself a `model_config_main` run — we do + not want N redundant baseline runs. `store_fold_details=True` so fold chart data is populated. +5. **`split_config.horizon` MUST equal `forecast_horizon`** (model-validator on the request). + The window dates from `selection_window` become `run_backtest`'s `start_date`/`end_date`. +6. **Ranking determinism:** primary = `ranking_metric` (default `"wape"`), then the fixed + tie-break chain `wape → smape → abs(bias) → mae → model_type`. With the default, the sort key + is exactly `(wape, smape, abs(bias), mae, model_type)` (success-criteria order). A non-default + `ranking_metric` puts that metric first, remaining chain follows excluding the duplicate. +7. **`auto_predict=True` requires `auto_train_winner=True`** (request model-validator) — predict + needs a freshly trained `final_model.model_path` from this run. + +### Success Criteria + +- [ ] `app/features/model_selection/` slice exists and is wired in `app/main.py`. +- [ ] `POST /model-selection/run` with the default five candidates returns a persisted + `status="completed"` (or `"partial"`) selection with `winner`, `ranking`, confidence, and `chart_data`. +- [ ] `GET /model-selection/availability` returns: `first_sales_date`, `last_sales_date`, + `observed_days`, `expected_calendar_days`, `coverage_ratio`, `missing_days`, + `zero_sale_days`, `promotion_days` (or `null` + warning), `average_daily_demand`, + `status` ∈ `{ready, limited, unusable}`, and `recommended_split_config`. +- [ ] Ranking is deterministic per LOCKED decision #6. +- [ ] Partial success supported (LOCKED #3): failed candidates appear in `ranking` with error + detail and are excluded from winner selection; a valid candidate still wins. +- [ ] `auto_train_winner=True` stores `final_model.model_path` via the **direct** + `ForecastingService.train_model`, preserving `feature_frame_version` + `feature_groups`. +- [ ] `auto_predict=True` (with train) returns forecast points + total/average demand summary. +- [ ] New migration creates `model_selection_run` with JSONB snapshots and named indexes; + `downgrade` drops indexes then table cleanly. +- [ ] `app/core/tests/test_strict_mode_policy.py` stays green for all new strict request schemas. +- [ ] No agent tools / `agent_require_approval` entries; no frontend files; no cloud SDK. + +## All Needed Context + +### Documentation & References + +```yaml +# PRP conventions +- file: PRPs/templates/prp_base.md + why: Base template (Goal/Context/Blueprint/Validation). NOTE — the user referenced a + "PRPs/prp-readme.md.md"; it does NOT exist (`find PRPs -iname '*readme*'` empty on 2026-06-01). +- file: PRPs/PRP-33-batch-runner-mvp.md + why: Strongest backend vertical-slice precedent — migration assertions, strict-mode gotchas, + route/test detail. Mirror its structure. +- file: PRPs/PRP-28-forecast-explainability-driver-attribution.md + why: Read/composition-slice precedent consuming existing contracts; deterministic explanation layer. +- docfile: PRPs/ai_docs/forecast-champion-selector-backend-research.md + why: External-lib + runtime verification (FastAPI APIRouter, Pydantic strict, JSONB, Alembic + create_index, sklearn TimeSeriesSplit). Versions: pydantic 2.12.5, sqlalchemy 2.0.46, + sklearn 1.8.0, fastapi 0.128.0, alembic 1.18.4. + +# Verified service contracts to reuse (DO NOT re-derive — exact signatures below in Gotchas) +- file: app/features/backtesting/service.py + why: BacktestingService().run_backtest(db, store_id, product_id, start_date, end_date, config). :213 +- file: app/features/backtesting/schemas.py + why: SplitConfig :24, BacktestConfig :81, BacktestResponse :257, ModelBacktestResult :180, + FoldResult :147. aggregated_metrics keys = {mae,rmse,smape,wape,bias}. +- file: app/features/backtesting/routes.py + why: EXACT route error-mapping pattern to mirror (try/except ValueError->BadRequestError, + SQLAlchemyError->DatabaseError; service instantiated as BacktestingService()). :60-140 +- file: app/features/forecasting/service.py + why: ForecastingService().train_model :247 (db first; feature_frame_version/feature_groups + keyword-only after *), predict :402 (NO db). Lazy cross-slice import precedent :55-61, :967. +- file: app/features/forecasting/schemas.py + why: ModelConfig union :417-429 (flat members, model_type discriminator, NO module-level helper); + TrainResponse.model_path :540; PredictResponse.forecasts :605; ForecastPoint :574. +- file: app/features/data_platform/models.py + why: Store :40 (business key `code`, not store_code), Product :68 (`sku`, `launch_date`), + SalesDaily :172 (date/store_id/product_id/quantity/unit_price/total_amount), Promotion :274. +- file: app/features/ops/service.py + why: Read-only cross-slice ORM aggregation precedent — module-scope ORM-model imports, stateless + service, db: AsyncSession per method, func.min/max/count/sum + group_by style. :225, :456. +- file: app/features/analytics/routes.py + why: validate_date_range :36 (raises BadRequestError, inverted-range + 730-day-max). CANNOT be + cross-slice imported — reimplement the two checks locally raising BadRequestError. +- file: app/core/exceptions.py + why: BadRequestError(400) :152, NotFoundError(404) :64, DatabaseError(500) :108, + ConflictError(409) :130, UnprocessableEntityError(422) :174. Each: (message=..., details=None). +- file: app/core/problem_details.py + why: RFC 7807 envelope; never raise bare HTTPException with raw strings. +- file: app/core/config.py + why: get_settings() cached singleton :225; Settings(BaseSettings) :62; add a plain typed attr + with literal default; env var = UPPER_SNAKE of the field name. +- file: app/core/database.py + why: Base (ORM declarative base) + get_db dependency used by routes/tests. +- file: app/shared/models.py + why: TimestampMixin (created_at/updated_at, server_default func.now(), updated_at onupdate). Mix in first. +- file: app/main.py + why: Router wiring — `from app.features..routes import router as _router` (:18-26), + `app.include_router(_router)` with NO prefix at include (:137-155), inside create_app(). +- file: app/core/tests/test_strict_mode_policy.py + why: AST policy — scans app/features/*/schemas.py; any ConfigDict(strict=True) model field typed + date/datetime/time/UUID/Decimal (anywhere in the annotation) MUST carry Field(strict=False, ...). + +# Migration / test patterns +- file: alembic/versions/c1d2e3f40512_create_batch_tables.py + why: JSONB via `from sqlalchemy.dialects import postgresql` -> postgresql.JSONB(astext_type=sa.Text()); + named CheckConstraint; op.create_index (op.f for single-col, explicit name for composite); + sa.DateTime(timezone=True) server_default sa.text("now()"); downgrade drops indexes THEN table. +- file: app/features/batch/models.py + why: ORM JSONB via `from sqlalchemy.dialects.postgresql import JSONB` (bare); Mapped[]+mapped_column; + status as String + default=Enum.PENDING.value + CheckConstraint in __table_args__; TimestampMixin. +- file: app/features/batch/schemas.py + why: Strict request pattern — ConfigDict(strict=True), Literal[...] for JSON enums, Field(strict=False) + on date fields (:132-133), @model_validator cross-field checks. +- file: app/features/explainability/tests/test_routes.py + why: ASGITransport + AsyncClient + app.dependency_overrides[get_db]; RFC 7807 4-key body assert; async tests. +- file: app/features/explainability/tests/conftest.py + why: Integration fixture — real engine from get_settings().database_url, prefix-scoped teardown in finally. + +# External official docs (verified in research doc) +- url: https://fastapi.tiangolo.com/tutorial/bigger-applications/ + why: APIRouter prefix/tags multi-file pattern. +- url: https://pydantic.dev/docs/validation/latest/concepts/strict_mode/ + why: strict mode + field-level Field(strict=False) override (runtime-verified, pydantic 2.12.5). +- url: https://docs.sqlalchemy.org/en/20/dialects/postgresql.html#json-types + why: JSONB column type for audit snapshots. +- url: https://alembic.sqlalchemy.org/en/latest/ops.html#alembic.operations.Operations.create_index + why: create_index signature (alembic 1.18.4: index_name, table_name, columns, *, unique, **kw). +- url: https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.TimeSeriesSplit.html + why: split semantics (sklearn 1.8.0 signature: n_splits, *, max_train_size, test_size, gap). +``` + +### Current Codebase Tree (relevant slices) + +```bash +app/features/ +├── analytics/ # KPI/drilldown/timeseries; validate_date_range lives in routes.py (slice-local) +├── backtesting/ # single-pair single-model backtesting; fold/chart data via store_fold_details +├── batch/ # batch fan-out; pinned 5-key metrics; NO fold chart data +├── data_platform/ # shared ORM: Store, Product, SalesDaily, Promotion, InventorySnapshotDaily, ... +├── dimensions/ # store/product discovery +├── forecasting/ # direct train/predict; full ModelConfig union +├── jobs/ # train/predict/backtest job orchestration +├── ops/ # read-only cross-slice ORM aggregation precedent (OpsService) +└── registry/ # model runs, aliases, compare, artifact verify +alembic/versions/ # current head: c1d2e3f40512 (create_batch_tables) +``` + +### Desired Codebase Tree + +```bash +app/features/model_selection/ +├── __init__.py +├── models.py # ModelSelectionRun ORM + ModelSelectionStatus enum +├── schemas.py # strict request models + response models +├── ranking.py # PURE: normalize metrics, filter, rank, confidence +├── explanations.py # PURE: deterministic business summary + confidence_reasons +├── service.py # ModelSelectionService: availability + orchestration (lazy cross-slice imports) +├── routes.py # APIRouter(prefix="/model-selection") +└── tests/ + ├── __init__.py + ├── conftest.py + ├── test_models.py + ├── test_schemas.py + ├── test_ranking.py + ├── test_explanations.py + ├── test_service.py + ├── test_routes.py + └── test_routes_integration.py +alembic/versions/_create_model_selection_run.py +``` + +### Known Gotchas & VERIFIED Library/Internal Contracts + +```python +# ── VERIFIED INTERNAL SIGNATURES (exact, read 2026-06-01) ───────────────────── +# BacktestingService.__init__(self) -> None # takes NO db; instantiate as BacktestingService() +# await BacktestingService().run_backtest( +# db, store_id, product_id, start_date, end_date, config: BacktestConfig +# ) -> BacktestResponse # service.py:213 ; db is FIRST arg +# +# ForecastingService.__init__(self) -> None +# await ForecastingService().train_model( +# db, store_id, product_id, train_start_date, train_end_date, config: ModelConfig, +# *, feature_frame_version: int = 1, feature_groups: list[str] | None = None +# ) -> TrainResponse # service.py:247 ; .model_path is the artifact path +# await ForecastingService().predict( +# store_id, product_id, horizon, model_path # NO db arg — loads bundle from disk +# ) -> PredictResponse # service.py:402 ; .forecasts: list[ForecastPoint] +# # ForecastPoint: {date, forecast, lower_bound?, upper_bound?} +# GOTCHA: predict() REJECTS feature-aware models (service.py:491). For a feature-aware winner, +# auto_predict may raise; catch and surface a warning rather than failing the whole run. + +# ── METRIC KEYS — CORRECTED (draft was incomplete) ──────────────────────────── +# BacktestResponse.main_model_results.aggregated_metrics has FIVE keys: +# {"mae", "rmse", "smape", "wape", "bias"} # metrics.py:347 — draft MISSED "rmse" +# metric_std keys are SUFFIXED "{name}_stability" (a coefficient of variation, NOT raw std). +# sample_size is NOT in aggregated_metrics — derive it from fold actuals length +# (sum of len(fold.actuals) across fold_results) or n_folds; normalize in ranking.py. +# Fold chart data path: BacktestResponse.main_model_results.fold_results[i].{dates, actuals, predictions} +# populated ONLY when config.store_fold_details=True (LOCKED #4 sets it True). +# bucketed_aggregated_metrics lives on each ModelBacktestResult (optional, may be None). + +# ── ModelConfig CONSTRUCTION — members are FLAT, no nested "params" ──────────── +# The request uses {"model_type": "seasonal_naive", "params": {"season_length": 7}} but the +# ModelConfig members are FLAT (SeasonalNaiveModelConfig has model_type + season_length at top +# level). There is NO module-level TypeAdapter/helper. Build at the service boundary by FLATTENING: +# from pydantic import TypeAdapter +# from app.features.forecasting.schemas import ModelConfig +# _MODEL_CONFIG_ADAPTER = TypeAdapter(ModelConfig) +# cfg = _MODEL_CONFIG_ADAPTER.validate_python({"model_type": c.model_type, **c.params}) +# Members are frozen + extra="forbid", so unknown params raise a ValidationError (good — surfaces +# bad candidate params as a failed candidate with a reason). Do this import LAZILY in-method. +# Valid model_type values (full union, forecasting/schemas.py:417): naive, seasonal_naive, +# moving_average, weighted_moving_average, seasonal_average, trend_regression_baseline, +# random_forest, lightgbm, xgboost, regression, prophet_like. +# (lightgbm/xgboost are opt-in extras — may ImportError at runtime; treat as a failed candidate.) + +# ── CROSS-SLICE IMPORT RULE ─────────────────────────────────────────────────── +# Vertical-slice rule: app/features/X must not import app/features/Y at MODULE scope when it +# would close an alembic cold-boot cycle. model_selection is a NEW leaf (nothing imports it), but +# to match the BatchService/forecasting precedent and stay safe, import the SERVICE CLASSES +# (BacktestingService, ForecastingService) and the ModelConfig TypeAdapter LAZILY inside the +# methods that use them. Read ORM models (Store/Product/SalesDaily/Promotion) at module scope — +# that mirrors OpsService and is the sanctioned read-only ORM surface. + +# ── validate_date_range IS NOT IMPORTABLE ───────────────────────────────────── +# It lives in app/features/analytics/routes.py (slice-local). Reimplement the two checks locally +# (inverted range; max-span) raising app.core.exceptions.BadRequestError, OR rely on schema +# validators. Do NOT import across the slice boundary. +# NOTE: analytics' max-span is settings.analytics_max_date_range_days (configurable, ~730), not a +# hardcoded constant — pick your own local bound (or reuse the setting) when reimplementing. + +# ── STRICT-MODE POLICY (app/core/tests/test_strict_mode_policy.py) ──────────── +# Every request model with model_config = ConfigDict(strict=True) MUST add Field(strict=False, ...) +# to EVERY field typed date|datetime|time|UUID|Decimal (incl. inside Optional/Annotated/list/dict). +# Use Literal[...] for JSON enum strings (NOT a str-Enum — strict won't coerce). The AST walker does +# NOT follow inheritance, so set ConfigDict(strict=True) on each concrete request model directly. + +# ── ORM / MIGRATION QUIRKS ──────────────────────────────────────────────────── +# JSONB import DIFFERS by layer: +# migration: from sqlalchemy.dialects import postgresql -> postgresql.JSONB(astext_type=sa.Text()) +# ORM: from sqlalchemy.dialects.postgresql import JSONB -> mapped_column(JSONB) +# Status enum enforced via CheckConstraint("status IN (...)", name="ck_...") in BOTH migration and +# ORM __table_args__; ORM column is String(N) with default=ModelSelectionStatus.PENDING.value. +# created_at/updated_at come from TimestampMixin (app/shared/models.py) — declare class as +# `class ModelSelectionRun(TimestampMixin, Base)` (mixin FIRST). Declare completed_at explicitly. +# Migration down_revision: chain to the CURRENT head at implementation time (observed c1d2e3f40512); +# run `uv run alembic heads` to confirm — do NOT hardcode this PRP's observed value blindly. + +# ── DATA-PLATFORM COLUMN NAMES (availability aggregation) ───────────────────── +# Store.id (int PK), Store.code (business key). Product.id, Product.sku, Product.launch_date (date|None). +# SalesDaily: .date (Date FK calendar.date), .store_id, .product_id, .quantity (Integer, CHECK >=0), +# .unit_price (Numeric), .total_amount (Numeric). Grain unique (date, store_id, product_id). +# => For ONE pair: count(distinct date) == count(*); zero_sale_days = count where quantity == 0. +# Promotion: per-product (product_id NOT NULL), store_id NULLABLE (NULL = CHAIN-WIDE, applies to all +# stores), date RANGE [start_date, end_date], kind in {pct_off,bogo,bundle,markdown}. To count +# promotion_days for (store, product) within the window, JOIN promotion to the pair's sales dates +# ON sd.date BETWEEN p.start_date AND p.end_date AND p.product_id=? AND (p.store_id=? OR p.store_id IS NULL), +# then COUNT(DISTINCT sd.date). If this proves complex/edge-casey, return promotion_days=None with a +# warning string (acceptable per Success Criteria) — do NOT sum (end-start) per row (double-counts overlaps). + +# ── RUNTIME-VERIFIED LIBRARY FACTS (research doc) ───────────────────────────── +# Pydantic 2.12.5 accepts Field(strict=False) date string under a strict model. sklearn 1.8.0 +# TimeSeriesSplit(n_splits, *, max_train_size, test_size, gap). FastAPI 0.128.0 APIRouter(prefix=...). +# Alembic 1.18.4 Operations.create_index(index_name, table_name, columns, *, unique, **kw). +``` + +## Implementation Blueprint + +### Data Models and Schemas + +`app/features/model_selection/models.py`: + +```python +from datetime import date, datetime +from enum import Enum +from typing import Any + +from sqlalchemy import CheckConstraint, Date, DateTime, Index, Integer, String +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import Mapped, mapped_column + +from app.core.database import Base +from app.shared.models import TimestampMixin + + +class ModelSelectionStatus(str, Enum): + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + PARTIAL = "partial" + FAILED = "failed" + + +class ModelSelectionRun(TimestampMixin, Base): # TimestampMixin FIRST → created_at/updated_at + __tablename__ = "model_selection_run" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + selection_id: Mapped[str] = mapped_column(String(32), unique=True, index=True) + store_id: Mapped[int] = mapped_column(Integer, index=True) + product_id: Mapped[int] = mapped_column(Integer, index=True) + start_date: Mapped[date] = mapped_column(Date) + end_date: Mapped[date] = mapped_column(Date) + forecast_horizon: Mapped[int] = mapped_column(Integer) + ranking_metric: Mapped[str] = mapped_column(String(20)) + status: Mapped[str] = mapped_column(String(20), default=ModelSelectionStatus.PENDING.value, index=True) + candidate_models: Mapped[list[dict[str, Any]]] = mapped_column(JSONB) + policy_snapshot: Mapped[dict[str, Any]] = mapped_column(JSONB) + availability_snapshot: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True) + ranking_result: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True) + winner_model_type: Mapped[str | None] = mapped_column(String(40), nullable=True) + winner_metrics: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True) + final_model_path: Mapped[str | None] = mapped_column(String(512), nullable=True) + forecast_result: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True) + business_summary: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True) + error_message: Mapped[str | None] = mapped_column(String(2000), nullable=True) + completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + + __table_args__ = ( + CheckConstraint( + "status IN ('pending','running','completed','partial','failed')", + name="ck_model_selection_run_valid_status", + ), + Index("ix_model_selection_run_store_product_created", "store_id", "product_id", "created_at"), + Index("ix_model_selection_run_status_created", "status", "created_at"), + ) +``` + +`app/features/model_selection/schemas.py` — strict request models + response models: + +- `SelectionWindow(start_date, end_date)` — `ConfigDict(strict=True)`, both dates `Field(strict=False, ...)`. +- `CandidateModelConfig(model_type: Literal[<11 model_types>], params: dict[str, Any] = {})`. +- `RankingPolicy(minimum_sample_size: int = 0, high_confidence_rel_improvement: float = 0.10, + max_acceptable_abs_bias: float = ...)` — defaults; snapshotted into `policy_snapshot`. +- `ModelSelectionRunRequest` — `ConfigDict(strict=True)`; fields: `store_id`, `product_id`, + `selection_window`, `forecast_horizon` (int, ge=1, le=90), `ranking_metric: Literal["wape","smape","mae","bias"]="wape"`, + `split_config: SplitConfig` (reuse backtesting's? — see NOTE), `candidate_models: list` (min_length=1, max_length=10), + `feature_frame_version: int = 1` (ge=1, le=2), `feature_groups: list[str] | None = None`, + `ranking_policy: RankingPolicy = Field(default_factory=RankingPolicy)`, + `auto_train_winner: bool = False`, `auto_predict: bool = False`. + - `@model_validator(mode="after")`: `split_config.horizon == forecast_horizon` (LOCKED #5); + `auto_predict implies auto_train_winner` (LOCKED #7). + - NOTE on `split_config`: `backtesting.schemas.SplitConfig` is `frozen=True, extra="forbid"` + (NOT strict). Either (a) reuse it directly (import lazily is unnecessary for a schema type — + it's safe at module scope since backtesting.schemas has no cycle back to model_selection), or + (b) define a local `SplitSettings` mirror. **Prefer reusing `SplitConfig`** to avoid drift; it + already validates n_splits/min_train_size/gap/horizon. Since it is not strict-mode, its `date`-free + fields don't trip the strict-mode linter. +**Response + intermediate models (plain `BaseModel` — outputs don't need `strict=True`). These +fields ARE the stable contract the UI consumes; specify them exactly, do not improvise.** + +```python +# ── intermediate (service-internal, also embedded in JSONB) ─────────────────── +class CandidateResult(BaseModel): # what shape_candidate()/shape_failed_candidate() return + model_type: str + params: dict[str, Any] # ORIGINAL candidate params — REQUIRED so the winner can be rebuilt (pseudocode L667) + failed: bool + error: str | None = None # reason when failed=True + aggregated_metrics: dict[str, float] | None = None # raw 5-key dict from backtest (mae,rmse,smape,wape,bias) or None + sample_size: int = 0 # RULE: sum(len(fold.actuals)) across main_model_results.fold_results + config_hash: str | None = None + folds: list[FoldChart] = [] # per-fold chart points (empty when failed) + +class FoldChart(BaseModel): + fold_index: int + dates: list[date] + actuals: list[float] + predictions: list[float] + +class ModelRankEntry(BaseModel): # one row in the ranking table (valid OR excluded) + rank: int | None # 1-based; None when excluded/failed + model_type: str + params: dict[str, Any] # carried through (see CandidateResult.params) + included: bool # False = failed or filtered out + exclusion_reason: str | None = None + metrics: dict[str, float] | None = None # normalized {wape,smape,mae,rmse,bias,sample_size} + +class RankingResult(BaseModel): # Pydantic (model_dump'd into ranking_result JSONB, L663) + winner: ModelRankEntry | None + entries: list[ModelRankEntry] # ALL candidates, ranked-then-failed, never hidden + confidence: Literal["high", "medium", "low"] + reasons: list[str] + +class WinnerSummary(BaseModel): + model_type: str + params: dict[str, Any] + metrics: dict[str, float] # normalized winner metrics + rank: int # always 1 + +class ChartData(BaseModel): # chart-ready comparison payload (Success Criteria deliverable) + wape_by_model: dict[str, float] # {model_type: wape} → WAPE bar chart + bias_by_model: dict[str, float] # {model_type: bias} → bias chart + fold_stability: dict[str, list[float]] # {model_type: per-fold wape} → stability lines + winner_actual_vs_predicted: list[FoldChart] # the WINNER's folds only → actual-vs-predicted overlay + +class PairAvailabilityResponse(BaseModel): + store_id: int + product_id: int + first_sales_date: date | None + last_sales_date: date | None + observed_days: int + expected_calendar_days: int + coverage_ratio: float + missing_days: int + zero_sale_days: int + promotion_days: int | None # None + a warning when not safely derivable + average_daily_demand: float # CAST float(...) — func.avg over Integer quantity returns Decimal + status: Literal["ready", "limited", "unusable"] + recommended_split_config: SplitConfig # reuse backtesting.schemas.SplitConfig + warnings: list[str] = [] + +class ForecastSummary(BaseModel): + points: list[dict[str, Any]] # ForecastPoint.model_dump(mode="json") list + total_demand: float + average_demand: float + horizon: int + +class ModelSelectionRunResponse(BaseModel): # THE /run + /{id} contract + selection_id: str + store_id: int + product_id: int + status: Literal["pending", "running", "completed", "partial", "failed"] + selection_window: SelectionWindow + forecast_horizon: int + ranking_metric: str + availability: PairAvailabilityResponse | None + ranking: list[ModelRankEntry] # == RankingResult.entries + winner: WinnerSummary | None + recommendation_confidence: Literal["high", "medium", "low"] | None # CANONICAL KEY (maps from RankingResult.confidence) + confidence_reasons: list[str] # == RankingResult.reasons + chart_data: ChartData | None + final_model: dict[str, Any] | None # {"model_path": ...} when auto_train_winner + forecast: ForecastSummary | None # when auto_predict + business_summary: dict[str, Any] | None + error_message: str | None + created_at: datetime + completed_at: datetime | None + +class TrainWinnerResponse(BaseModel): + selection_id: str + model_type: str + model_path: str + +class PredictWinnerResponse(BaseModel): + selection_id: str + forecast: ForecastSummary +``` + +> **NAMING (resolves the only internal-consistency nit):** the response key is +> **`recommendation_confidence`** (Success Criteria + manual probe + Goal all use it). +> `RankingResult.confidence` is the service-internal field; `_response()` maps +> `RankingResult.confidence → ModelSelectionRunResponse.recommendation_confidence` and +> `RankingResult.reasons → confidence_reasons`. Tests assert the response key +> `recommendation_confidence`. + +> **`self._response(row, ranking)` helper:** pure mapping `ModelSelectionRun` ORM row + +> `RankingResult` → `ModelSelectionRunResponse` (rehydrate `availability_snapshot`/`ranking_result`/ +> `business_summary`/`forecast_result` JSONB back into the response models; build `chart_data` from +> the per-candidate `CandidateResult.folds` + normalized metrics; map the confidence keys per above). + +### Implementation Tasks (dependency-ordered) + +```yaml +Task 1 — Migration + ORM: + RUN: uv run alembic heads # confirm current head (observed c1d2e3f40512) + CREATE alembic/versions/_create_model_selection_run.py: + - down_revision = "" + - MIRROR alembic/versions/c1d2e3f40512_create_batch_tables.py exactly: + - from sqlalchemy.dialects import postgresql -> postgresql.JSONB(astext_type=sa.Text()) + - sa.DateTime(timezone=True), server_default=sa.text("now()") for created_at/updated_at + - CheckConstraint name="ck_model_selection_run_valid_status" + - op.create_index(op.f("ix_model_selection_run_selection_id"), ..., unique=True) + - op.create_index("ix_model_selection_run_store_product_created", ..., ["store_id","product_id","created_at"]) + - op.create_index("ix_model_selection_run_status_created", ..., ["status","created_at"]) + - downgrade(): drop indexes (reverse order) THEN op.drop_table("model_selection_run") + CREATE app/features/model_selection/models.py: # as blueprint above; mirror batch/models.py + +Task 2 — Schemas: + CREATE app/features/model_selection/schemas.py: + - all REQUEST models ConfigDict(strict=True); date fields Field(strict=False, ...) + - Literal[...] for model_type + ranking_metric (NOT str-Enum) + - candidate_models min_length=1 max_length=10 (or settings.model_selection_max_candidates) + - @model_validator: horizon match (LOCKED #5) + auto_predict implies auto_train_winner (LOCKED #7) + - reuse backtesting.schemas.SplitConfig (module-scope import OK; no cycle) + +Task 3 — Ranking pure logic: + CREATE app/features/model_selection/ranking.py: + - NormalizedMetrics dataclass {wape, smape, mae, rmse, bias, sample_size} + - normalize_metrics(aggregated_metrics, sample_size) -> NormalizedMetrics | None + (None when the primary metric is missing OR NaN — use math.isnan guard; np.nan can appear, + metrics.py:381; keys are mae/rmse/smape/wape/bias) + - input: list[CandidateResult] (Task-2 schema). Each entry CARRIES model_type + params through to + ModelRankEntry/WinnerSummary so the winner can be rebuilt (pseudocode L667 reads winner.params). + - filter: not failed AND numeric primary metric AND sample_size >= policy.minimum_sample_size + - rank key (default ranking_metric="wape"): (wape, smape, abs(bias), mae, model_type) [LOCKED #6] + - confidence (PIN the rel-improvement formula — denominator is the SECOND-place value): + rel_improvement = (second.wape - winner.wape) / second.wape # guard second.wape == 0 → treat as 0.0 + HIGH : >=2 valid AND rel_improvement >= policy.high_confidence_rel_improvement (default 0.10) + AND abs(winner.bias) <= policy.max_acceptable_abs_bias AND winner.sample_size sufficient + MEDIUM: a valid winner exists but HIGH not met (narrow lead OR mild warnings) and >=2 valid + LOW : exactly one valid candidate, OR availability "limited", OR abs(bias) over threshold, + OR rel_improvement < some near-tie epsilon (document the epsilon as a module constant) + - emit human-readable reasons[] strings explaining the chosen level (consumed as confidence_reasons) + - return RankingResult(winner, entries[ALL ranked-then-failed, never hidden], confidence, reasons) + +Task 4 — Business explanation pure logic: + CREATE app/features/model_selection/explanations.py: + - explain_winner(ranking, availability) -> business_summary dict + confidence_reasons + warnings + - translate WAPE/sMAPE/MAE/bias into short deterministic English; NO LLM, NO external call + +Task 5 — Pair availability: + CREATE ModelSelectionService.get_availability(db, store_id, product_id, forecast_horizon, split_config?) -> PairAvailabilityResponse: + - verify Store and Product exist (NotFoundError if absent) via data_platform ORM (module-scope import OK) + - aggregate SalesDaily for the pair (SQLAlchemy 2.0 async, mirror OpsService style): + select(func.min(SalesDaily.date), func.max(SalesDaily.date), + func.count(func.distinct(SalesDaily.date)), func.sum(SalesDaily.quantity), + func.avg(SalesDaily.quantity), + func.count().filter(SalesDaily.quantity == 0)) # FILTER aggregate; valid async idiom + .where(SalesDaily.store_id == store_id, SalesDaily.product_id == product_id) + # CAST: func.avg over Integer quantity returns Decimal; wrap average_daily_demand in float(...). + # func.count().filter(...) is a Postgres FILTER aggregate (not shown in OpsService, but supported); + # alternatively a second scalar count with .where(quantity == 0). One round-trip is fine. + - expected_calendar_days = (max_date - min_date).days + 1 + - coverage_ratio = observed_days / expected_calendar_days (guard div-by-zero / no rows) + - missing_days = expected_calendar_days - observed_days + - promotion_days: JOIN promotion ON date BETWEEN start/end AND product_id match AND + (store_id == X OR store_id IS NULL); COUNT(DISTINCT date). On any doubt → None + warning. + - status (LOCKED thresholds): + ready if observed_days >= min_train_size + horizon*n_splits AND coverage_ratio >= 0.8 + limited if observed_days >= min_train_size + horizon + unusable otherwise + - recommended_split_config: expanding, n_splits=min(5, feasible), min_train_size=30 (or adjusted), + gap=0, horizon=forecast_horizon + - NO rows for the pair -> status="unusable" with zeros/None and a warning + +Task 6 — Orchestration: + CREATE ModelSelectionService.run_selection(db, request) -> ModelSelectionRunResponse: + - persist ModelSelectionRun(selection_id=uuid4().hex, status="running", snapshots); flush + - availability = get_availability(...); persist snapshot + - if availability.status == "unusable": status="failed", error_message, flush, raise BadRequestError [LOCKED #2] + - for each candidate (LAZY import services + ModelConfig adapter): + try: cfg = flatten+validate ModelConfig; bt = await BacktestingService().run_backtest( + db, store_id, product_id, window.start, window.end, + BacktestConfig(split_config=req.split_config, model_config_main=cfg, + include_baselines=False, store_fold_details=True)) + collect aggregated_metrics, sample_size, fold dates/actuals/predictions for chart + except Exception as exc: append failed entry with reason=str(exc) [never hide — Anti-Patterns] + - ranking = rank_candidates(results, req.ranking_policy, req.ranking_metric) + - if ranking.winner is None: status="failed", persist ranking_result, flush, RETURN 200 response [LOCKED #3] + - if req.auto_train_winner: + train = await ForecastingService().train_model(db, store_id, product_id, window.start, window.end, + winner_cfg, feature_frame_version=req.feature_frame_version, feature_groups=req.feature_groups) + row.final_model_path = train.model_path + - if req.auto_predict: # requires auto_train_winner (validated) + try: pred = await ForecastingService().predict(store_id, product_id, req.forecast_horizon, row.final_model_path) + row.forecast_result = pred.model_dump(mode="json") + except : warning, leave forecast_result None + - business_summary = explain_winner(ranking, availability) + - status = "partial" if any candidate failed else "completed"; completed_at = datetime.now(UTC) + - persist all JSONB via model_dump(mode="json"); flush + refresh; return response_from_row(row) + ADD methods: get_selection(db, selection_id)->row|NotFoundError ; get_ranking ; train_winner ; predict_winner + +Task 7 — Routes: + CREATE app/features/model_selection/routes.py: + - router = APIRouter(prefix="/model-selection", tags=["model-selection"]) + - GET /availability ; POST /run (200) ; GET /{selection_id} ; GET /{selection_id}/ranking ; + POST /{selection_id}/train-winner ; POST /{selection_id}/predict + - MIRROR backtesting/routes.py error mapping EXACTLY: + service instantiated locally; try/except ValueError->BadRequestError(str(e)), + SQLAlchemyError->DatabaseError("...", details={"error": str(e)}); NotFoundError from service bubbles. + - structured logger.info events (see Integration Points) + MODIFY app/main.py: + - `from app.features.model_selection.routes import router as model_selection_router` (alpha order with siblings) + - `app.include_router(model_selection_router)` inside create_app(), near backtesting/forecasting (NO prefix arg) + +Task 8 — Tests (see Validation Loop for required names): + CREATE app/features/model_selection/tests/{conftest,test_models,test_schemas,test_ranking, + test_explanations,test_service,test_routes,test_routes_integration}.py + - unit route tests: ASGITransport + app.dependency_overrides[get_db]=AsyncMock; 4-key RFC7807 assert + - service tests: mock BacktestingService/ForecastingService (patch the lazy import targets) for + happy/partial/all-fail/auto-train/auto-predict paths + - integration tests (@pytest.mark.integration): real engine, prefix-scoped teardown in finally +``` + +### Pseudocode (CRITICAL details only) + +```python +# ranking.py — deterministic, pure +def rank_candidates(results, policy, ranking_metric="wape"): + valid, failed = [], [] + for r in results: + m = normalize_metrics(r.aggregated_metrics, r.sample_size) # keys: mae,rmse,smape,wape,bias + if m is None or m.sample_size < policy.minimum_sample_size: + failed.append(r.as_failed("missing/NaN primary metric or sample_size below minimum")) + continue + valid.append((r, m)) + if not valid: + return RankingResult(winner=None, entries=failed, confidence="low", reasons=["no valid candidate"]) + primary = lambda m: getattr(m, ranking_metric) if ranking_metric != "bias" else abs(m.bias) + ordered = sorted(valid, key=lambda p: (primary(p[1]), p[1].smape, abs(p[1].bias), p[1].mae, p[0].model_type)) + winner = ordered[0] + return build_ranking_result(ordered, failed, policy) # computes confidence vs 2nd place +``` + +```python +# service.py — orchestration (exact verified service calls) +async def run_selection(self, db, req): + from pydantic import TypeAdapter # lazy + from app.features.backtesting.schemas import BacktestConfig # lazy + from app.features.backtesting.service import BacktestingService # lazy + from app.features.forecasting.schemas import ModelConfig # lazy + from app.features.forecasting.service import ForecastingService # lazy + adapter = TypeAdapter(ModelConfig) + + row = ModelSelectionRun(selection_id=uuid.uuid4().hex, status="running", + store_id=req.store_id, product_id=req.product_id, + start_date=req.selection_window.start_date, end_date=req.selection_window.end_date, + forecast_horizon=req.forecast_horizon, ranking_metric=req.ranking_metric, + candidate_models=[c.model_dump() for c in req.candidate_models], + policy_snapshot=req.ranking_policy.model_dump(mode="json")) + db.add(row); await db.flush() + + availability = await self.get_availability(db, req.store_id, req.product_id, req.forecast_horizon, req.split_config) + row.availability_snapshot = availability.model_dump(mode="json") + if availability.status == "unusable": + row.status = "failed"; row.error_message = "Insufficient data for model selection" + await db.flush(); raise BadRequestError(message=row.error_message) # LOCKED #2 + + results = [] + for c in req.candidate_models: + try: + cfg = adapter.validate_python({"model_type": c.model_type, **c.params}) # FLATTEN + bt = await BacktestingService().run_backtest( + db, req.store_id, req.product_id, + req.selection_window.start_date, req.selection_window.end_date, + BacktestConfig(split_config=req.split_config, model_config_main=cfg, + include_baselines=False, store_fold_details=True)) # LOCKED #4 + results.append(shape_candidate(c, bt)) + except Exception as exc: + results.append(shape_failed_candidate(c, exc)) + + ranking = rank_candidates(results, req.ranking_policy, req.ranking_metric) + row.ranking_result = ranking.model_dump(mode="json") + if ranking.winner is None: + row.status = "failed"; await db.flush(); return self._response(row, ranking) # LOCKED #3 (HTTP 200) + + winner_cfg = adapter.validate_python({"model_type": ranking.winner.model_type, **ranking.winner.params}) + if req.auto_train_winner: + train = await ForecastingService().train_model( + db, req.store_id, req.product_id, req.selection_window.start_date, req.selection_window.end_date, + winner_cfg, feature_frame_version=req.feature_frame_version, feature_groups=req.feature_groups) + row.final_model_path = train.model_path + if req.auto_predict and row.final_model_path: + try: + pred = await ForecastingService().predict(req.store_id, req.product_id, req.forecast_horizon, row.final_model_path) + row.forecast_result = pred.model_dump(mode="json") + except Exception as exc: # e.g. feature-aware reject (forecasting service.py:491) + row.forecast_result = None # surface a warning in business_summary + + row.winner_model_type = ranking.winner.model_type + row.winner_metrics = ranking.winner.metrics + row.business_summary = explain_winner(ranking, availability) + row.status = "partial" if any(r.failed for r in results) else "completed" + row.completed_at = datetime.now(UTC) + await db.flush(); await db.refresh(row) + return self._response(row, ranking) +``` + +### Integration Points + +```yaml +DATABASE: + - migration: add `model_selection_run` (JSONB snapshots: candidate_models, policy_snapshot, + availability_snapshot, ranking_result, winner_metrics, forecast_result, business_summary) + - indexes: ix_model_selection_run_selection_id (unique), ix_model_selection_run_store_product_created, + ix_model_selection_run_status_created +ROUTES: + - app/main.py: import + app.include_router(model_selection_router) (router carries its own prefix) +CONFIG (optional — only if used; then ADD to .env.example with UPPER_SNAKE + a comment, and a test): + - model_selection_max_candidates: int = 10 + - model_selection_min_coverage_ratio: float = 0.8 + - model_selection_default_min_train_size: int = 30 +OBSERVABILITY (structlog events, mirror ops/backtesting naming): + - model_selection.run_received / .availability_checked / .candidate_completed / + .candidate_failed / .run_completed / .run_failed +``` + +## Validation Loop + +### Level 1 — Focused syntax & policy + +```bash +uv run ruff check app/features/model_selection app/main.py alembic/versions +uv run ruff format --check app/features/model_selection app/main.py alembic/versions +uv run mypy app/features/model_selection app/main.py +uv run pyright app/features/model_selection app/main.py +uv run pytest app/core/tests/test_strict_mode_policy.py -v +``` + +### Level 2 — Focused unit tests + +```bash +uv run pytest app/features/model_selection/tests -v -m "not integration" +``` + +Required test names: + +- `test_schema_accepts_iso_dates_under_strict_model` (JSON path: `Model.model_validate({"start_date":"2026-01-01",...})`) +- `test_schema_rejects_auto_predict_without_train_winner` +- `test_schema_rejects_horizon_mismatch_between_split_and_forecast` +- `test_rank_candidates_wape_smape_abs_bias_mae_tie_break` +- `test_rank_candidates_excludes_missing_or_nan_metrics` +- `test_rank_candidates_normalizes_five_metric_keys_including_rmse` +- `test_confidence_high_when_winner_beats_second_by_10_percent` +- `test_availability_ready_limited_unusable_thresholds` +- `test_build_model_config_flattens_params` (e.g. seasonal_naive + {"season_length":7}) +- `test_run_selection_partial_success_chooses_valid_winner` +- `test_run_selection_all_candidates_fail_returns_failed_status_not_500` (LOCKED #3) +- `test_run_selection_unusable_availability_raises_bad_request` (LOCKED #2) +- `test_run_selection_auto_train_passes_feature_frame_version_and_groups` +- `test_routes_return_problem_json_on_bad_request` (4-key RFC 7807 body) +- `test_response_uses_recommendation_confidence_key` (NOT `confidence`; maps from `RankingResult.confidence`) +- `test_winner_entry_carries_params_for_rebuild` (`ModelRankEntry.params` / `WinnerSummary.params` preserved) +- `test_chart_data_has_wape_bias_fold_stability_and_winner_actual_vs_predicted` + +### Level 3 — Migration & integration + +```bash +docker compose up -d +uv run alembic upgrade head +uv run pytest app/features/model_selection/tests -v -m integration +uv run alembic downgrade -1 && uv run alembic upgrade head # downgrade/upgrade round-trips cleanly +``` + +Integration expectations: + +- `model_selection_run` exists with the three named indexes. +- `POST /model-selection/run` persists a row; `GET /model-selection/{selection_id}` returns the same id. +- Availability detects an inserted pair with enough history (`ready`) and a too-short pair (`limited`/`unusable`). +- Partial failure persists the failed candidate reason and still ranks a valid winner. + +### Level 4 — Full backend gates (must be green before PR) + +```bash +uv run ruff check . && uv run ruff format --check . +uv run mypy app/ && uv run pyright app/ +uv run pytest -v -m "not integration" +uv run pytest -v -m integration +``` + +> Known-local-noise: mypy/pyright report pre-existing `lightgbm`/`xgboost` optional-dep import +> errors in `forecasting/`+`registry/` (untouched here; CI installs the extras). Do not "fix" them. + +### Manual API probe (seeded DB; discover real store/product ids + date window first — IDs are +not guaranteed 1-based, see memory `seeder-does-not-reset-id-sequences`) + +```bash +uv run uvicorn app.main:app --port 8123 & +curl -s "http://localhost:8123/model-selection/availability?store_id=5&product_id=8&forecast_horizon=14" | python3 -m json.tool +curl -s -X POST http://localhost:8123/model-selection/run -H "Content-Type: application/json" -d '{ + "store_id": 5, "product_id": 8, + "selection_window": {"start_date": "2026-01-01", "end_date": "2026-05-31"}, + "forecast_horizon": 14, + "split_config": {"strategy":"expanding","n_splits":5,"min_train_size":30,"gap":0,"horizon":14}, + "candidate_models": [ + {"model_type":"naive","params":{}}, + {"model_type":"seasonal_naive","params":{"season_length":7}}, + {"model_type":"moving_average","params":{"window_size":7}}, + {"model_type":"regression","params":{}}, + {"model_type":"prophet_like","params":{}} + ], + "auto_train_winner": false, "auto_predict": false +}' | python3 -m json.tool +``` + +Expected: HTTP 200; response carries `selection_id`, non-empty `ranking`, `winner.model_type`, +`recommendation_confidence`, `chart_data`. + +## Final Validation Checklist + +- [ ] New slice follows `app/features//{models,schemas,service,routes,tests}.py`. +- [ ] Router wired in `app/main.py` (import alias + `include_router`, no prefix at include). +- [ ] Migration `down_revision` chains to the live head; downgrade drops indexes then table. +- [ ] Request schemas use `ConfigDict(strict=True)` + `Field(strict=False)` for every date field; strict-mode test green. +- [ ] All 4xx responses use project exceptions (`BadRequestError`/`NotFoundError`/`DatabaseError`) → RFC 7807. +- [ ] Ranking + explanation logic is pure and unit-tested; normalizer handles all five metric keys incl. `rmse`. +- [ ] Availability covered for ready/limited/unusable + no-rows. +- [ ] `auto_train_winner` uses direct `ForecastingService.train_model` (db first, feature args keyword-only). +- [ ] `auto_predict` handles feature-aware-reject gracefully (warning, not 500). +- [ ] LOCKED decisions #1–#7 are implemented and tested. +- [ ] No frontend files, no agent mutation surface, no managed-cloud SDK. +- [ ] All four Level-4 gates pass; `gh issue view ` confirms the referenced issue is open. + +## Anti-Patterns to Avoid + +- Don't implement the React UI; don't rank models in TypeScript — backend owns ranking/confidence. +- Don't use batch item metrics for fold-level chart data (batch has none) — use direct `BacktestingService` with `store_fold_details=True`. +- Don't import sibling feature *services* at module scope — lazy in-method (matches forecasting/BatchService precedent). ORM *models* at module scope is fine (OpsService precedent). +- Don't import `validate_date_range` from analytics — reimplement locally. +- Don't pass the candidate `params` as a nested dict to `ModelConfig` — FLATTEN (`{"model_type":..., **params}`). +- Don't assume four metric keys — there are five (`rmse` included); normalize, never index a raw shape blindly. +- Don't sum `(end_date - start_date)` for promotion days (double-counts overlaps; ignores chain-wide `store_id IS NULL`). +- Don't mutate aliases automatically; don't add an agent tool. +- Don't hide failed candidates — include them with `reason`. +- Don't use an LLM for explanations — deterministic text only. +- Don't raise on all-candidates-fail (LOCKED #3 → persist failed + return 200); DO raise on unusable availability (LOCKED #2 → 400). +- Don't build SQL with string concatenation; don't weaken strict-mode or leakage tests. + +## Confidence Score + +**9.5/10** for one-pass backend implementation success. The prior draft self-rated 8/10 with +"service signatures must be rechecked at implementation time" as the top risk — that risk is now +**retired**: every `run_backtest` / `train_model` / `predict` signature, the corrected five-key +metric shape, the `ModelConfig` flattening, the strict-mode rule, the migration/JSONB/exception +patterns, and seven previously-ambiguous decisions are verified and locked here. An independent +quality-gate pass confirmed every cited signature/line-number/field-name against live source +("tried to break the cited signatures and could not") and its findings — the full response/ +intermediate contract (`CandidateResult`, `ModelRankEntry`, `RankingResult`, `WinnerSummary`, +`ChartData`, `ModelSelectionRunResponse`, …), the `recommendation_confidence` naming, the +`winner.params` carry-through, the `_response` mapping, and the rel-improvement denominator — are +now specified inline. + +Residual risks: + +- Per-candidate backtest runtime: five models × a multi-fold backtest is synchronous in-process. + On a slow host the `/run` request can be slow (acceptable for a single pair; mirrors + `/backtesting/run`). If it becomes a problem, a future PRP can move it behind the jobs slice. +- `promotion_days` derivation has real edge cases (chain-wide promos, overlapping ranges); the + PRP explicitly permits `null + warning` as a correct fallback. +- `lightgbm`/`xgboost` candidates can `ImportError` when extras are absent — they degrade to a + failed candidate with a reason (verified path), not a 500. diff --git a/alembic/env.py b/alembic/env.py index 4ce8f0e1..2cadd971 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -18,6 +18,7 @@ from app.features.data_platform import models as data_platform_models # noqa: F401 from app.features.explainability import models as explainability_models # noqa: F401 from app.features.jobs import models as jobs_models # noqa: F401 +from app.features.model_selection import models as model_selection_models # noqa: F401 from app.features.rag import models as rag_models # noqa: F401 from app.features.registry import models as registry_models # noqa: F401 from app.features.scenarios import models as scenarios_models # noqa: F401 diff --git a/alembic/versions/b667d321603c_create_model_selection_run.py b/alembic/versions/b667d321603c_create_model_selection_run.py new file mode 100644 index 00000000..e3dcaa2a --- /dev/null +++ b/alembic/versions/b667d321603c_create_model_selection_run.py @@ -0,0 +1,129 @@ +"""create_model_selection_run + +Revision ID: b667d321603c +Revises: c1d2e3f40512 +Create Date: 2026-06-01 05:58:51.986105 + +Creates the ``model_selection_run`` table for the Forecast Champion Selector +backend (issue #353). One row per ``POST /model-selection/run`` — an auditable +record of which candidate models competed for a (store, product) pair, over +which window/policy, and which model won. + +JSONB snapshot columns mirror the ``batch_job`` precedent +(``c1d2e3f40512_create_batch_tables``): every flexible payload (candidate +configs, policy, availability, ranking, per-candidate results incl. fold chart +data, winner metrics, forecast summary, business summary) is JSONB so the +eventual UI PRP can add keys without a schema migration. ``candidate_results`` +holds the full per-candidate detail (incl. fold actuals/predictions) so a +``GET`` rebuilds the same ``chart_data`` payload the originating ``/run`` +returned — without it the chart's fold-stability and actual-vs-predicted +overlays could not be reconstructed. +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "b667d321603c" +down_revision: str | None = "c1d2e3f40512" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Apply migration.""" + op.create_table( + "model_selection_run", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("selection_id", sa.String(length=32), nullable=False), + sa.Column("store_id", sa.Integer(), nullable=False), + sa.Column("product_id", sa.Integer(), nullable=False), + sa.Column("start_date", sa.Date(), nullable=False), + sa.Column("end_date", sa.Date(), nullable=False), + sa.Column("forecast_horizon", sa.Integer(), nullable=False), + sa.Column("ranking_metric", sa.String(length=20), nullable=False), + sa.Column("status", sa.String(length=20), nullable=False), + sa.Column("candidate_models", postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column("policy_snapshot", postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column("availability_snapshot", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column("ranking_result", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column("candidate_results", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column("chart_data", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column("winner_model_type", sa.String(length=40), nullable=True), + sa.Column("winner_metrics", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column("final_model_path", sa.String(length=512), nullable=True), + sa.Column("forecast_result", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column("business_summary", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column("error_message", sa.String(length=2000), nullable=True), + sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.CheckConstraint( + "status IN ('pending', 'running', 'completed', 'partial', 'failed')", + name="ck_model_selection_run_valid_status", + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_model_selection_run_selection_id"), + "model_selection_run", + ["selection_id"], + unique=True, + ) + op.create_index( + op.f("ix_model_selection_run_store_id"), + "model_selection_run", + ["store_id"], + unique=False, + ) + op.create_index( + op.f("ix_model_selection_run_product_id"), + "model_selection_run", + ["product_id"], + unique=False, + ) + op.create_index( + op.f("ix_model_selection_run_status"), + "model_selection_run", + ["status"], + unique=False, + ) + op.create_index( + "ix_model_selection_run_store_product_created", + "model_selection_run", + ["store_id", "product_id", "created_at"], + unique=False, + ) + op.create_index( + "ix_model_selection_run_status_created", + "model_selection_run", + ["status", "created_at"], + unique=False, + ) + + +def downgrade() -> None: + """Revert migration.""" + op.drop_index("ix_model_selection_run_status_created", table_name="model_selection_run") + op.drop_index( + "ix_model_selection_run_store_product_created", table_name="model_selection_run" + ) + op.drop_index(op.f("ix_model_selection_run_status"), table_name="model_selection_run") + op.drop_index(op.f("ix_model_selection_run_product_id"), table_name="model_selection_run") + op.drop_index(op.f("ix_model_selection_run_store_id"), table_name="model_selection_run") + op.drop_index(op.f("ix_model_selection_run_selection_id"), table_name="model_selection_run") + op.drop_table("model_selection_run") diff --git a/app/features/agents/tests/test_service.py b/app/features/agents/tests/test_service.py index 09413aa6..759e0284 100644 --- a/app/features/agents/tests/test_service.py +++ b/app/features/agents/tests/test_service.py @@ -3,7 +3,7 @@ import json from collections.abc import AsyncIterator from datetime import UTC, datetime, timedelta -from typing import Any +from typing import Any, cast from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -1260,7 +1260,7 @@ def test_compact_for_finalizer_strips_verbose_keys_keeps_metrics(self) -> None: } ] - compact = AgentService._compact_for_finalizer(raw) + compact = cast(list[dict[str, Any]], AgentService._compact_for_finalizer(raw)) runs = compact[0]["result"]["runs"] # Identity + metrics survive for BOTH runs (so a ranking sees 18.93). diff --git a/app/features/model_selection/__init__.py b/app/features/model_selection/__init__.py new file mode 100644 index 00000000..01931b63 --- /dev/null +++ b/app/features/model_selection/__init__.py @@ -0,0 +1,9 @@ +"""Forecast Champion Selector — backend vertical slice (issue #353). + +Validates a (store, product) pair's data availability, runs comparable +backtests for a set of candidate forecasting models, deterministically ranks +them, selects a champion with a recommendation confidence, persists an +auditable selection run, and optionally trains/predicts with the winner. + +Backend-only by design — the UI is a deliberate follow-up PRP. +""" diff --git a/app/features/model_selection/explanations.py b/app/features/model_selection/explanations.py new file mode 100644 index 00000000..907b974a --- /dev/null +++ b/app/features/model_selection/explanations.py @@ -0,0 +1,97 @@ +"""Deterministic business-explanation layer for the champion selector (#353). + +Pure functions — NO LLM, NO external call. Translates the numeric ranking + +availability into short, deterministic English a business user can read. The +output dict is persisted into ``model_selection_run.business_summary`` and +echoed on the response. +""" + +from __future__ import annotations + +from typing import Any + +from app.features.model_selection.schemas import PairAvailabilityResponse, RankingResult + + +def _metric_phrase(metrics: dict[str, float] | None) -> str: + """One-line plain-English metric summary for a ranked model.""" + if not metrics: + return "no metrics available" + return ( + f"WAPE {metrics['wape']:.1f}%, sMAPE {metrics['smape']:.1f}, " + f"MAE {metrics['mae']:.2f}, bias {metrics['bias']:.2f}" + ) + + +def explain_winner( + ranking: RankingResult, + availability: PairAvailabilityResponse | None, +) -> dict[str, Any]: + """Build the deterministic ``business_summary`` payload. + + Always returns a dict; when there is no winner the summary explains why no + model could be recommended. + """ + caveats = [ + "Backtest accuracy reflects historical fit, not a guarantee of future performance.", + "Metrics measure correlation with past demand, not causation.", + ] + + if availability is not None: + data_notes = [ + f"Observed {availability.observed_days} of " + f"{availability.expected_calendar_days} calendar days " + f"({availability.coverage_ratio:.0%} coverage).", + f"Average daily demand {availability.average_daily_demand:.2f}.", + ] + data_notes.extend(availability.warnings) + else: + data_notes = ["No availability snapshot was computed."] + + if ranking.winner is None: + return { + "headline": "No model could be recommended for this pair.", + "winner": None, + "recommendation_confidence": ranking.confidence, + "confidence_reasons": ranking.reasons, + "comparison": None, + "data_notes": data_notes, + "caveats": caveats, + } + + winner = ranking.winner + headline = f"Recommended model: {winner.model_type} ({ranking.confidence} confidence)." + + included = [e for e in ranking.entries if e.included] + runner_up = included[1] if len(included) > 1 else None + if runner_up is not None and runner_up.metrics and winner.metrics: + runner_wape = runner_up.metrics["wape"] + if runner_wape > 0: + lead = (runner_wape - winner.metrics["wape"]) / runner_wape + lead_text = f"{lead:.1%} lower WAPE than the runner-up ({runner_up.model_type})" + else: + lead_text = f"a comparable WAPE to the runner-up ({runner_up.model_type})" + comparison: dict[str, Any] = { + "runner_up_model_type": runner_up.model_type, + "runner_up_summary": _metric_phrase(runner_up.metrics), + "lead_text": lead_text, + } + else: + comparison = { + "runner_up_model_type": None, + "runner_up_summary": None, + "lead_text": "no runner-up was available for comparison", + } + + return { + "headline": headline, + "winner": { + "model_type": winner.model_type, + "summary": _metric_phrase(winner.metrics), + }, + "recommendation_confidence": ranking.confidence, + "confidence_reasons": ranking.reasons, + "comparison": comparison, + "data_notes": data_notes, + "caveats": caveats, + } diff --git a/app/features/model_selection/models.py b/app/features/model_selection/models.py new file mode 100644 index 00000000..ce7c6e20 --- /dev/null +++ b/app/features/model_selection/models.py @@ -0,0 +1,93 @@ +"""ORM models for the Forecast Champion Selector slice (issue #353). + +One table — ``model_selection_run`` — records one ``POST /model-selection/run`` +invocation as an auditable artifact. Mirrors ``app/features/batch/models.py`` +for shape: ``TimestampMixin`` + ``Base``, a string status column with an +allow-list ``CheckConstraint`` in ``__table_args__``, and JSONB columns for the +flexible audit snapshots (candidate configs, policy, availability, ranking, +per-candidate results, chart data, winner metrics, forecast summary, business +summary). +""" + +from __future__ import annotations + +import datetime as _dt +from enum import Enum +from typing import Any + +from sqlalchemy import CheckConstraint, Date, DateTime, Index, Integer, String +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import Mapped, mapped_column + +from app.core.database import Base +from app.shared.models import TimestampMixin + + +class ModelSelectionStatus(str, Enum): + """Lifecycle states of a selection run. + + Transitions: + - PENDING -> RUNNING -> {COMPLETED, PARTIAL, FAILED} + - PARTIAL fires when >=1 candidate succeeded AND >=1 candidate failed. + - FAILED fires when availability is unusable (fail-fast) OR every + candidate's backtest errored (no valid winner). + """ + + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + PARTIAL = "partial" + FAILED = "failed" + + +class ModelSelectionRun(TimestampMixin, Base): + """A single champion-selection run over one (store, product) pair. + + ``candidate_results`` carries the full per-candidate detail (incl. fold + actuals/predictions) so a ``GET`` rebuilds the same ``chart_data`` payload + the originating ``/run`` returned. ``chart_data`` caches the computed + chart-ready payload so the read path needs no recomputation. + """ + + __tablename__ = "model_selection_run" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + selection_id: Mapped[str] = mapped_column(String(32), unique=True, index=True) + store_id: Mapped[int] = mapped_column(Integer, index=True) + product_id: Mapped[int] = mapped_column(Integer, index=True) + start_date: Mapped[_dt.date] = mapped_column(Date) + end_date: Mapped[_dt.date] = mapped_column(Date) + forecast_horizon: Mapped[int] = mapped_column(Integer) + ranking_metric: Mapped[str] = mapped_column(String(20)) + status: Mapped[str] = mapped_column( + String(20), default=ModelSelectionStatus.PENDING.value, index=True + ) + candidate_models: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False) + policy_snapshot: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False) + availability_snapshot: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True) + ranking_result: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True) + candidate_results: Mapped[list[dict[str, Any]] | None] = mapped_column(JSONB, nullable=True) + chart_data: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True) + winner_model_type: Mapped[str | None] = mapped_column(String(40), nullable=True) + winner_metrics: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True) + final_model_path: Mapped[str | None] = mapped_column(String(512), nullable=True) + forecast_result: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True) + business_summary: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True) + error_message: Mapped[str | None] = mapped_column(String(2000), nullable=True) + completed_at: Mapped[_dt.datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True + ) + + __table_args__ = ( + CheckConstraint( + "status IN ('pending', 'running', 'completed', 'partial', 'failed')", + name="ck_model_selection_run_valid_status", + ), + Index( + "ix_model_selection_run_store_product_created", + "store_id", + "product_id", + "created_at", + ), + Index("ix_model_selection_run_status_created", "status", "created_at"), + ) diff --git a/app/features/model_selection/ranking.py b/app/features/model_selection/ranking.py new file mode 100644 index 00000000..ecca7587 --- /dev/null +++ b/app/features/model_selection/ranking.py @@ -0,0 +1,283 @@ +"""Pure ranking + confidence logic for the champion selector (issue #353). + +No DB, no I/O — every function here is deterministic and unit-tested directly. +The ranking key and confidence policy implement the PRP's LOCKED decision #6 +(deterministic tie-break chain) and the relative-improvement confidence model. +""" + +from __future__ import annotations + +import math +from dataclasses import dataclass + +from app.features.model_selection.schemas import ( + CandidateResult, + ChartData, + ConfidenceLevel, + FoldChart, + ModelRankEntry, + RankingPolicy, + RankingResult, +) + +# Below this relative WAPE lead over second place, the winner is a near-tie and +# confidence is capped at LOW (the lead is not meaningful). +NEAR_TIE_EPSILON = 0.02 + +# The metric keys that MUST be finite for a candidate to be rankable. rmse is +# carried for the contract but not required (it never enters the sort key). +_REQUIRED_FINITE = ("wape", "smape", "mae", "bias") + + +@dataclass(frozen=True) +class NormalizedMetrics: + """The five backtest metrics plus the derived sample size, all floats.""" + + wape: float + smape: float + mae: float + rmse: float + bias: float + sample_size: int + + def as_dict(self) -> dict[str, float]: + """Stable 6-key dict embedded in ``ModelRankEntry.metrics``.""" + return { + "wape": self.wape, + "smape": self.smape, + "mae": self.mae, + "rmse": self.rmse, + "bias": self.bias, + "sample_size": float(self.sample_size), + } + + +def _is_finite(value: float) -> bool: + return not (math.isnan(value) or math.isinf(value)) + + +def normalize_metrics( + aggregated_metrics: dict[str, float] | None, + sample_size: int, +) -> NormalizedMetrics | None: + """Coerce a raw 5-key backtest metric dict into ``NormalizedMetrics``. + + Returns ``None`` (candidate is unrankable) when the dict is missing/empty or + when any of the sort-key metrics (wape, smape, mae, bias) is NaN/inf — e.g. + a WAPE of ``inf`` from an all-zero actual window. + """ + if not aggregated_metrics: + return None + + def _g(key: str) -> float: + raw = aggregated_metrics.get(key) + return float(raw) if raw is not None else math.nan + + metrics = NormalizedMetrics( + wape=_g("wape"), + smape=_g("smape"), + mae=_g("mae"), + rmse=_g("rmse"), + bias=_g("bias"), + sample_size=sample_size, + ) + if not all(_is_finite(getattr(metrics, name)) for name in _REQUIRED_FINITE): + return None + return metrics + + +def _primary_value(metrics: NormalizedMetrics, ranking_metric: str) -> float: + """Value of the primary ranking metric (``bias`` ranks by magnitude).""" + if ranking_metric == "bias": + return abs(metrics.bias) + return float(getattr(metrics, ranking_metric)) + + +def _sort_key( + metrics: NormalizedMetrics, model_type: str, ranking_metric: str +) -> tuple[float, float, float, float, str]: + """Deterministic sort key (LOCKED #6). + + Primary = the chosen ranking metric, then the fixed tie-break chain + ``wape -> smape -> abs(bias) -> mae -> model_type`` with the primary metric + removed from the chain so it is never duplicated. + """ + chain: list[tuple[str, float]] = [ + ("wape", metrics.wape), + ("smape", metrics.smape), + ("bias", abs(metrics.bias)), + ("mae", metrics.mae), + ] + key: list[float] = [_primary_value(metrics, ranking_metric)] + key.extend(value for name, value in chain if name != ranking_metric) + return (key[0], key[1], key[2], key[3], model_type) + + +def rank_candidates( + results: list[CandidateResult], + policy: RankingPolicy, + ranking_metric: str = "wape", + availability_status: str | None = None, +) -> RankingResult: + """Rank completed candidates and pick a deterministic winner. + + Failed/filtered candidates are never hidden — they appear as excluded + ``ModelRankEntry`` rows (``rank=None``) after the ranked winners. + """ + valid: list[tuple[CandidateResult, NormalizedMetrics]] = [] + excluded: list[ModelRankEntry] = [] + + for result in results: + if result.failed: + excluded.append(_excluded_entry(result, result.error or "candidate backtest failed")) + continue + metrics = normalize_metrics(result.aggregated_metrics, result.sample_size) + if metrics is None: + excluded.append(_excluded_entry(result, "missing or non-finite primary metric")) + continue + if metrics.sample_size < policy.minimum_sample_size: + excluded.append( + _excluded_entry( + result, + f"sample_size {metrics.sample_size} below minimum {policy.minimum_sample_size}", + ) + ) + continue + valid.append((result, metrics)) + + if not valid: + return RankingResult( + winner=None, + entries=excluded, + confidence="low", + reasons=["No candidate produced a valid backtest."], + ) + + ordered = sorted(valid, key=lambda pair: _sort_key(pair[1], pair[0].model_type, ranking_metric)) + ranked_entries = [ + ModelRankEntry( + rank=index + 1, + model_type=result.model_type, + params=result.params, + included=True, + metrics=metrics.as_dict(), + ) + for index, (result, metrics) in enumerate(ordered) + ] + + confidence, reasons = _confidence(ordered, policy, availability_status) + + return RankingResult( + winner=ranked_entries[0], + entries=ranked_entries + excluded, + confidence=confidence, + reasons=reasons, + ) + + +def _excluded_entry(result: CandidateResult, reason: str) -> ModelRankEntry: + return ModelRankEntry( + rank=None, + model_type=result.model_type, + params=result.params, + included=False, + exclusion_reason=reason, + metrics=None, + ) + + +def _confidence( + ordered: list[tuple[CandidateResult, NormalizedMetrics]], + policy: RankingPolicy, + availability_status: str | None, +) -> tuple[ConfidenceLevel, list[str]]: + """Derive the recommendation confidence from the ranked candidates. + + Order of checks: a single valid candidate, limited availability, or an + over-threshold winner bias all cap confidence at LOW; a clear WAPE lead with + acceptable bias is HIGH; everything in between is MEDIUM. + """ + reasons: list[str] = [] + winner_metrics = ordered[0][1] + + if len(ordered) == 1: + reasons.append("Only one candidate produced a valid backtest.") + return "low", reasons + + second_metrics = ordered[1][1] + if second_metrics.wape > 0: + rel_improvement = (second_metrics.wape - winner_metrics.wape) / second_metrics.wape + else: + rel_improvement = 0.0 + + bias_ok = abs(winner_metrics.bias) <= policy.max_acceptable_abs_bias + + if availability_status == "limited": + reasons.append("Data availability is limited; treat the recommendation cautiously.") + return "low", reasons + if not bias_ok: + reasons.append( + f"Winner bias {winner_metrics.bias:.3f} exceeds the acceptable bound " + f"{policy.max_acceptable_abs_bias:.3f}." + ) + return "low", reasons + if rel_improvement < NEAR_TIE_EPSILON: + reasons.append(f"Winner WAPE lead over second place is {rel_improvement:.1%} — a near tie.") + return "low", reasons + if rel_improvement >= policy.high_confidence_rel_improvement: + reasons.append( + f"Winner WAPE beats second place by {rel_improvement:.1%} " + f"(>= {policy.high_confidence_rel_improvement:.0%})." + ) + return "high", reasons + + reasons.append( + f"Winner leads second place by {rel_improvement:.1%}, below the " + f"{policy.high_confidence_rel_improvement:.0%} high-confidence threshold." + ) + return "medium", reasons + + +def _fold_wape(actuals: list[float], predictions: list[float]) -> float: + """WAPE (%) for one fold; 0.0 when the actual window sums to zero.""" + denominator = sum(abs(a) for a in actuals) + if denominator == 0: + return 0.0 + numerator = sum(abs(a - p) for a, p in zip(actuals, predictions, strict=False)) + return numerator / denominator * 100.0 + + +def build_chart_data(results: list[CandidateResult], ranking: RankingResult) -> ChartData: + """Assemble the chart-ready comparison payload from candidate results. + + Keyed by ``model_type``; when a candidate list repeats a model_type the last + occurrence wins (acceptable for v1 — duplicate model_types are uncommon). + """ + by_type: dict[str, CandidateResult] = {r.model_type: r for r in results} + wape_by_model: dict[str, float] = {} + bias_by_model: dict[str, float] = {} + fold_stability: dict[str, list[float]] = {} + + for entry in ranking.entries: + if not entry.included or entry.metrics is None: + continue + wape_by_model[entry.model_type] = entry.metrics["wape"] + bias_by_model[entry.model_type] = entry.metrics["bias"] + result = by_type.get(entry.model_type) + if result is not None: + fold_stability[entry.model_type] = [ + _fold_wape(fold.actuals, fold.predictions) for fold in result.folds + ] + + winner_folds: list[FoldChart] = [] + if ranking.winner is not None: + winner_result = by_type.get(ranking.winner.model_type) + if winner_result is not None: + winner_folds = winner_result.folds + + return ChartData( + wape_by_model=wape_by_model, + bias_by_model=bias_by_model, + fold_stability=fold_stability, + winner_actual_vs_predicted=winner_folds, + ) diff --git a/app/features/model_selection/routes.py b/app/features/model_selection/routes.py new file mode 100644 index 00000000..f989aac0 --- /dev/null +++ b/app/features/model_selection/routes.py @@ -0,0 +1,174 @@ +"""FastAPI routes for the Forecast Champion Selector slice (issue #353). + +Endpoints (all under ``/model-selection``): +- GET /availability — pair data-availability assessment +- POST /run — run candidate comparison + ranking (200) +- GET /{selection_id} — fetch a persisted selection run +- GET /{selection_id}/ranking — fetch just the ranking block +- POST /{selection_id}/train-winner — train the winning model +- POST /{selection_id}/predict — forecast with the trained winner + +Error mapping mirrors ``app/features/backtesting/routes.py``: ``ValueError`` → +``BadRequestError`` (RFC 7807 400), ``SQLAlchemyError`` → ``DatabaseError`` (500). +``NotFoundError`` / ``BadRequestError`` raised inside the service are +``ForecastLabError`` subclasses and bubble straight to the global handler. +""" + +from __future__ import annotations + +from fastapi import APIRouter, Depends, Query, status +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.database import get_db +from app.core.exceptions import BadRequestError, DatabaseError +from app.core.logging import get_logger +from app.features.model_selection.schemas import ( + ModelSelectionRunRequest, + ModelSelectionRunResponse, + PairAvailabilityResponse, + PredictWinnerResponse, + RankingResult, + TrainWinnerResponse, +) +from app.features.model_selection.service import ModelSelectionService + +logger = get_logger(__name__) + +router = APIRouter(prefix="/model-selection", tags=["model-selection"]) + + +@router.get( + "/availability", + response_model=PairAvailabilityResponse, + status_code=status.HTTP_200_OK, + summary="Assess data availability for a (store, product) pair", +) +async def get_availability( + store_id: int = Query(..., ge=1, description="Store ID"), + product_id: int = Query(..., ge=1, description="Product ID"), + forecast_horizon: int = Query(14, ge=1, le=90, description="Forecast horizon in days"), + db: AsyncSession = Depends(get_db), +) -> PairAvailabilityResponse: + """Return coverage, demand, promotion, and a recommended split config.""" + service = ModelSelectionService() + try: + return await service.get_availability(db, store_id, product_id, forecast_horizon) + except ValueError as exc: + raise BadRequestError(message=str(exc)) from exc + except SQLAlchemyError as exc: + raise DatabaseError( + message="Failed to assess availability", details={"error": str(exc)} + ) from exc + + +@router.post( + "/run", + response_model=ModelSelectionRunResponse, + status_code=status.HTTP_200_OK, + summary="Run candidate model comparison and select a champion", +) +async def run_selection( + request: ModelSelectionRunRequest, + db: AsyncSession = Depends(get_db), +) -> ModelSelectionRunResponse: + """Validate availability, backtest candidates, rank, and persist the run.""" + logger.info( + "model_selection.request_received", + store_id=request.store_id, + product_id=request.product_id, + n_candidates=len(request.candidate_models), + ranking_metric=request.ranking_metric, + ) + service = ModelSelectionService() + try: + return await service.run_selection(db, request) + except ValueError as exc: + raise BadRequestError(message=str(exc)) from exc + except SQLAlchemyError as exc: + raise DatabaseError( + message="Failed to run model selection", details={"error": str(exc)} + ) from exc + + +@router.get( + "/{selection_id}", + response_model=ModelSelectionRunResponse, + status_code=status.HTTP_200_OK, + summary="Fetch a persisted selection run", +) +async def get_selection( + selection_id: str, + db: AsyncSession = Depends(get_db), +) -> ModelSelectionRunResponse: + """Return the full persisted selection run by id (404 when missing).""" + service = ModelSelectionService() + try: + return await service.get_selection(db, selection_id) + except SQLAlchemyError as exc: + raise DatabaseError( + message="Failed to fetch selection run", details={"error": str(exc)} + ) from exc + + +@router.get( + "/{selection_id}/ranking", + response_model=RankingResult, + status_code=status.HTTP_200_OK, + summary="Fetch the ranking block for a selection run", +) +async def get_ranking( + selection_id: str, + db: AsyncSession = Depends(get_db), +) -> RankingResult: + """Return just the ranking (winner, entries, confidence, reasons).""" + service = ModelSelectionService() + try: + return await service.get_ranking(db, selection_id) + except SQLAlchemyError as exc: + raise DatabaseError(message="Failed to fetch ranking", details={"error": str(exc)}) from exc + + +@router.post( + "/{selection_id}/train-winner", + response_model=TrainWinnerResponse, + status_code=status.HTTP_200_OK, + summary="Train the winning model for a selection run", +) +async def train_winner( + selection_id: str, + db: AsyncSession = Depends(get_db), +) -> TrainWinnerResponse: + """Train the champion and store its model bundle path.""" + service = ModelSelectionService() + try: + return await service.train_winner(db, selection_id) + except ValueError as exc: + raise BadRequestError(message=str(exc)) from exc + except SQLAlchemyError as exc: + raise DatabaseError( + message="Failed to train winning model", details={"error": str(exc)} + ) from exc + + +@router.post( + "/{selection_id}/predict", + response_model=PredictWinnerResponse, + status_code=status.HTTP_200_OK, + summary="Forecast with the trained winning model", +) +async def predict_winner( + selection_id: str, + db: AsyncSession = Depends(get_db), +) -> PredictWinnerResponse: + """Generate a horizon forecast from the trained champion bundle.""" + service = ModelSelectionService() + try: + forecast = await service.predict_winner(db, selection_id) + return PredictWinnerResponse(selection_id=selection_id, forecast=forecast) + except ValueError as exc: + raise BadRequestError(message=str(exc)) from exc + except SQLAlchemyError as exc: + raise DatabaseError( + message="Failed to forecast with winning model", details={"error": str(exc)} + ) from exc diff --git a/app/features/model_selection/schemas.py b/app/features/model_selection/schemas.py new file mode 100644 index 00000000..9fc10d37 --- /dev/null +++ b/app/features/model_selection/schemas.py @@ -0,0 +1,303 @@ +"""Pydantic v2 schemas for the Forecast Champion Selector slice (issue #353). + +Request bodies use ``ConfigDict(strict=True)`` per +``docs/_base/SECURITY.md`` § "Pydantic v2 strict mode on FastAPI request +bodies"; the only JSON-non-native fields (``SelectionWindow.start_date`` / +``end_date``) carry ``Field(strict=False, ...)`` so the strict-mode policy +linter (``app/core/tests/test_strict_mode_policy.py``) stays green and ISO-date +JSON strings are accepted on the ``validate_python`` path. + +Enum-like string fields use ``Literal[...]`` (NOT a ``str``-``Enum``) because +strict mode refuses to coerce a JSON string into a str-enum instance — the same +reason ``app/features/batch/schemas.py`` uses literals. + +Response/intermediate models are plain ``BaseModel`` (outputs need no strict +coercion). They form the stable backend contract the eventual UI consumes. + +``SplitConfig`` is reused directly from the backtesting slice (a schema type +with no import cycle back to this slice) to avoid configuration drift. +""" + +from __future__ import annotations + +from datetime import date, datetime +from typing import Any, Literal + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from app.features.backtesting.schemas import SplitConfig + +# Valid forecasting model_type values — the full ``ModelConfig`` union +# (``app/features/forecasting/schemas.py``). ``lightgbm``/``xgboost`` are opt-in +# extras and may degrade to a failed candidate at runtime when the extra is +# absent (handled in the service, not rejected here). +ModelType = Literal[ + "naive", + "seasonal_naive", + "moving_average", + "weighted_moving_average", + "seasonal_average", + "trend_regression_baseline", + "random_forest", + "lightgbm", + "xgboost", + "regression", + "prophet_like", +] + +RankingMetric = Literal["wape", "smape", "mae", "bias"] +SelectionStatusLiteral = Literal["pending", "running", "completed", "partial", "failed"] +ConfidenceLevel = Literal["high", "medium", "low"] +AvailabilityStatus = Literal["ready", "limited", "unusable"] + + +# ============================================================================= +# Request models (strict mode) +# ============================================================================= + + +class SelectionWindow(BaseModel): + """Inclusive date window the candidate backtests run over.""" + + model_config = ConfigDict(strict=True) + + start_date: date = Field(strict=False, description="Window start (inclusive), YYYY-MM-DD") + end_date: date = Field(strict=False, description="Window end (inclusive), YYYY-MM-DD") + + @model_validator(mode="after") + def _check_order(self) -> SelectionWindow: + """Reject an inverted/zero-length window (surfaced as RFC 7807 422).""" + if self.end_date <= self.start_date: + raise ValueError("end_date must be after start_date") + return self + + +class CandidateModelConfig(BaseModel): + """One candidate forecasting model to evaluate. + + ``params`` are the FLAT model-specific parameters (e.g. + ``{"season_length": 7}``). They are flattened into the forecasting + ``ModelConfig`` union at the service boundary; unknown params surface as a + failed candidate with a reason rather than a request rejection. + """ + + model_config = ConfigDict(strict=True) + + model_type: ModelType + params: dict[str, Any] = Field(default_factory=dict) + + +class RankingPolicy(BaseModel): + """Tunable thresholds for ranking filters + confidence. + + ``max_acceptable_abs_bias`` is an ABSOLUTE bias bound in demand units and is + therefore series-scale dependent; it defaults high enough to be effectively + disabled so confidence is driven primarily by the relative WAPE lead, the + valid-candidate count, and the sample size. Set a series-appropriate value + to enable the bias guard. + """ + + model_config = ConfigDict(strict=True) + + minimum_sample_size: int = Field( + default=0, ge=0, description="Drop candidates whose backtest sample is below this" + ) + high_confidence_rel_improvement: float = Field( + default=0.10, + ge=0.0, + le=1.0, + description="Relative WAPE lead over 2nd place required for HIGH confidence", + ) + max_acceptable_abs_bias: float = Field( + default=1_000_000_000.0, + ge=0.0, + description="Absolute winner-bias bound (demand units); high default = guard disabled", + ) + + +class ModelSelectionRunRequest(BaseModel): + """``POST /model-selection/run`` request body.""" + + model_config = ConfigDict(strict=True) + + store_id: int = Field(..., ge=1, description="Store ID") + product_id: int = Field(..., ge=1, description="Product ID") + selection_window: SelectionWindow + forecast_horizon: int = Field(..., ge=1, le=90, description="Forecast horizon in days") + ranking_metric: RankingMetric = "wape" + split_config: SplitConfig = Field(default_factory=SplitConfig) + candidate_models: list[CandidateModelConfig] = Field(min_length=1, max_length=10) + feature_frame_version: int = Field(default=1, ge=1, le=2) + feature_groups: list[str] | None = Field(default=None) + ranking_policy: RankingPolicy = Field(default_factory=RankingPolicy) + auto_train_winner: bool = Field(default=False) + auto_predict: bool = Field(default=False) + + @model_validator(mode="after") + def _check_consistency(self) -> ModelSelectionRunRequest: + """Enforce LOCKED decisions #5 and #7 plus V1/feature-group consistency.""" + if self.split_config.horizon != self.forecast_horizon: + raise ValueError( + f"split_config.horizon ({self.split_config.horizon}) must equal " + f"forecast_horizon ({self.forecast_horizon})" + ) + if self.auto_predict and not self.auto_train_winner: + raise ValueError("auto_predict requires auto_train_winner=True") + if self.feature_frame_version == 1 and self.feature_groups is not None: + raise ValueError( + "feature_groups is only valid when feature_frame_version=2; " + "omit it for V1 selection." + ) + return self + + +class AvailabilityQuery(BaseModel): + """Validated query params for ``GET /model-selection/availability``.""" + + model_config = ConfigDict(strict=True) + + store_id: int = Field(..., ge=1) + product_id: int = Field(..., ge=1) + forecast_horizon: int = Field(default=14, ge=1, le=90) + + +# ============================================================================= +# Intermediate models (service-internal; embedded in JSONB snapshots) +# ============================================================================= + + +class FoldChart(BaseModel): + """Per-fold chart points for one candidate.""" + + fold_index: int + dates: list[date] + actuals: list[float] + predictions: list[float] + + +class CandidateResult(BaseModel): + """One candidate's full backtest outcome (success or failure). + + ``params`` are carried through unchanged so the winning model can be rebuilt + from the persisted record without re-deriving them. + """ + + model_type: str + params: dict[str, Any] + failed: bool + error: str | None = None + aggregated_metrics: dict[str, float] | None = None + sample_size: int = 0 + config_hash: str | None = None + folds: list[FoldChart] = Field(default_factory=list) + + +class ModelRankEntry(BaseModel): + """One row in the ranking table — a ranked winner/runner-up or an excluded + (failed/filtered) candidate. Excluded entries keep ``rank=None``.""" + + rank: int | None + model_type: str + params: dict[str, Any] + included: bool + exclusion_reason: str | None = None + metrics: dict[str, float] | None = None + + +class RankingResult(BaseModel): + """Deterministic ranking outcome — persisted into ``ranking_result``.""" + + winner: ModelRankEntry | None + entries: list[ModelRankEntry] + confidence: ConfidenceLevel + reasons: list[str] + + +class WinnerSummary(BaseModel): + """The champion — flattened for the response top level.""" + + model_type: str + params: dict[str, Any] + metrics: dict[str, float] + rank: int + + +class ChartData(BaseModel): + """Chart-ready comparison payload (a Success-Criteria deliverable).""" + + wape_by_model: dict[str, float] + bias_by_model: dict[str, float] + fold_stability: dict[str, list[float]] + winner_actual_vs_predicted: list[FoldChart] + + +# ============================================================================= +# Response models +# ============================================================================= + + +class PairAvailabilityResponse(BaseModel): + """``GET /model-selection/availability`` response.""" + + store_id: int + product_id: int + first_sales_date: date | None + last_sales_date: date | None + observed_days: int + expected_calendar_days: int + coverage_ratio: float + missing_days: int + zero_sale_days: int + promotion_days: int | None + average_daily_demand: float + status: AvailabilityStatus + recommended_split_config: SplitConfig + warnings: list[str] = Field(default_factory=list) + + +class ForecastSummary(BaseModel): + """Forecast output rolled up for the response.""" + + points: list[dict[str, Any]] + total_demand: float + average_demand: float + horizon: int + + +class ModelSelectionRunResponse(BaseModel): + """``POST /model-selection/run`` and ``GET /model-selection/{id}`` contract.""" + + selection_id: str + store_id: int + product_id: int + status: SelectionStatusLiteral + selection_window: SelectionWindow + forecast_horizon: int + ranking_metric: str + availability: PairAvailabilityResponse | None + ranking: list[ModelRankEntry] + winner: WinnerSummary | None + recommendation_confidence: ConfidenceLevel | None + confidence_reasons: list[str] + chart_data: ChartData | None + final_model: dict[str, Any] | None + forecast: ForecastSummary | None + business_summary: dict[str, Any] | None + error_message: str | None + created_at: datetime + completed_at: datetime | None + + +class TrainWinnerResponse(BaseModel): + """``POST /model-selection/{id}/train-winner`` response.""" + + selection_id: str + model_type: str + model_path: str + + +class PredictWinnerResponse(BaseModel): + """``POST /model-selection/{id}/predict`` response.""" + + selection_id: str + forecast: ForecastSummary diff --git a/app/features/model_selection/service.py b/app/features/model_selection/service.py new file mode 100644 index 00000000..ff7111e8 --- /dev/null +++ b/app/features/model_selection/service.py @@ -0,0 +1,568 @@ +"""Service layer for the Forecast Champion Selector slice (issue #353). + +Orchestrates pair-availability → candidate backtests → deterministic ranking → +optional winner train/predict, persisting an auditable ``model_selection_run``. + +Cross-slice coupling rules (mirror ``OpsService`` + the forecasting/Batch +precedent): +- Read the data-platform ORM **models** at module scope (the sanctioned + read-only ORM surface). +- Import sibling feature **services** (``BacktestingService`` / + ``ForecastingService``) and the ``ModelConfig`` ``TypeAdapter`` LAZILY inside + the methods that use them — avoids closing an alembic cold-boot import cycle. +- Reuse the backtesting ``SplitConfig`` schema directly (no cycle). +""" + +from __future__ import annotations + +import uuid +from datetime import UTC, datetime +from typing import TYPE_CHECKING + +from sqlalchemy import and_, func, or_, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.exceptions import BadRequestError, NotFoundError +from app.core.logging import get_logger +from app.features.backtesting.schemas import SplitConfig +from app.features.data_platform.models import Product, Promotion, SalesDaily, Store +from app.features.model_selection.explanations import explain_winner +from app.features.model_selection.models import ModelSelectionRun, ModelSelectionStatus +from app.features.model_selection.ranking import build_chart_data, rank_candidates +from app.features.model_selection.schemas import ( + AvailabilityStatus, + CandidateModelConfig, + CandidateResult, + ChartData, + FoldChart, + ForecastSummary, + ModelSelectionRunRequest, + ModelSelectionRunResponse, + PairAvailabilityResponse, + RankingResult, + SelectionWindow, + TrainWinnerResponse, + WinnerSummary, +) + +if TYPE_CHECKING: + from app.features.backtesting.schemas import BacktestResponse + from app.features.forecasting.schemas import PredictResponse + +logger = get_logger(__name__) + +# Availability policy constants (module-level; not operator-configurable in v1). +MIN_COVERAGE_RATIO = 0.8 +DEFAULT_MIN_TRAIN_SIZE = 30 +MAX_RECOMMENDED_SPLITS = 5 + +_TERMINAL_WITH_WINNER = frozenset( + {ModelSelectionStatus.COMPLETED.value, ModelSelectionStatus.PARTIAL.value} +) + + +class ModelSelectionService: + """Stateless orchestrator — a fresh ``db`` session per method.""" + + # ------------------------------------------------------------------------- + # Availability + # ------------------------------------------------------------------------- + + async def get_availability( + self, + db: AsyncSession, + store_id: int, + product_id: int, + forecast_horizon: int, + split_config: SplitConfig | None = None, + ) -> PairAvailabilityResponse: + """Assess whether a (store, product) pair has enough history to model.""" + store = await db.get(Store, store_id) + if store is None: + raise NotFoundError(message=f"Store {store_id} not found") + product = await db.get(Product, product_id) + if product is None: + raise NotFoundError(message=f"Product {product_id} not found") + + n_splits = split_config.n_splits if split_config else MAX_RECOMMENDED_SPLITS + min_train = split_config.min_train_size if split_config else DEFAULT_MIN_TRAIN_SIZE + + agg = ( + await db.execute( + select( + func.min(SalesDaily.date), + func.max(SalesDaily.date), + func.count(func.distinct(SalesDaily.date)), + func.avg(SalesDaily.quantity), + func.count().filter(SalesDaily.quantity == 0), + ).where( + SalesDaily.store_id == store_id, + SalesDaily.product_id == product_id, + ) + ) + ).one() + first_date, last_date, observed_raw, avg_qty, zero_raw = agg + observed_days = int(observed_raw or 0) + zero_sale_days = int(zero_raw or 0) + average_daily_demand = float(avg_qty) if avg_qty is not None else 0.0 + + warnings: list[str] = [] + + if first_date is None or last_date is None or observed_days == 0: + expected_calendar_days = 0 + coverage_ratio = 0.0 + missing_days = 0 + promotion_days: int | None = 0 + else: + expected_calendar_days = (last_date - first_date).days + 1 + coverage_ratio = ( + observed_days / expected_calendar_days if expected_calendar_days > 0 else 0.0 + ) + missing_days = max(0, expected_calendar_days - observed_days) + promotion_days = await self._count_promotion_days(db, store_id, product_id, warnings) + + ready_threshold = min_train + forecast_horizon * n_splits + limited_threshold = min_train + forecast_horizon + status: AvailabilityStatus + if observed_days >= ready_threshold and coverage_ratio >= MIN_COVERAGE_RATIO: + status = "ready" + elif observed_days >= limited_threshold: + status = "limited" + else: + status = "unusable" + + if coverage_ratio and coverage_ratio < MIN_COVERAGE_RATIO and status != "unusable": + warnings.append( + f"Coverage {coverage_ratio:.0%} is below the {MIN_COVERAGE_RATIO:.0%} " + "ready threshold." + ) + + feasible_splits = (observed_days - min_train) // max(forecast_horizon, 1) + recommended_splits = min(20, max(2, min(MAX_RECOMMENDED_SPLITS, feasible_splits))) + recommended_split_config = SplitConfig( + strategy="expanding", + n_splits=recommended_splits, + min_train_size=min_train, + gap=0, + horizon=forecast_horizon, + ) + + return PairAvailabilityResponse( + store_id=store_id, + product_id=product_id, + first_sales_date=first_date, + last_sales_date=last_date, + observed_days=observed_days, + expected_calendar_days=expected_calendar_days, + coverage_ratio=coverage_ratio, + missing_days=missing_days, + zero_sale_days=zero_sale_days, + promotion_days=promotion_days, + average_daily_demand=average_daily_demand, + status=status, + recommended_split_config=recommended_split_config, + warnings=warnings, + ) + + async def _count_promotion_days( + self, + db: AsyncSession, + store_id: int, + product_id: int, + warnings: list[str], + ) -> int | None: + """Count distinct sales dates inside any promotion for the pair. + + Includes chain-wide promos (``promotion.store_id IS NULL``). Returns + ``None`` + a warning on any error (an acceptable fallback per the + Success Criteria) — never sums ``(end-start)`` which would double-count + overlapping ranges. + """ + try: + count = await db.scalar( + select(func.count(func.distinct(SalesDaily.date))) + .select_from(SalesDaily) + .join( + Promotion, + and_( + Promotion.product_id == SalesDaily.product_id, + or_( + Promotion.store_id == SalesDaily.store_id, + Promotion.store_id.is_(None), + ), + SalesDaily.date >= Promotion.start_date, + SalesDaily.date <= Promotion.end_date, + ), + ) + .where( + SalesDaily.store_id == store_id, + SalesDaily.product_id == product_id, + ) + ) + return int(count or 0) + except Exception as exc: # promotion_days is best-effort; degrade gracefully + warnings.append(f"promotion_days could not be derived: {exc}") + return None + + # ------------------------------------------------------------------------- + # Orchestration + # ------------------------------------------------------------------------- + + async def run_selection( + self, db: AsyncSession, request: ModelSelectionRunRequest + ) -> ModelSelectionRunResponse: + """Run the full champion-selection workflow and persist the audit row.""" + from pydantic import TypeAdapter # lazy + + from app.features.backtesting.schemas import BacktestConfig # lazy + from app.features.backtesting.service import BacktestingService # lazy + from app.features.forecasting.schemas import ModelConfig # lazy + + adapter: TypeAdapter[object] = TypeAdapter(ModelConfig) + + row = ModelSelectionRun( + selection_id=uuid.uuid4().hex, + status=ModelSelectionStatus.RUNNING.value, + store_id=request.store_id, + product_id=request.product_id, + start_date=request.selection_window.start_date, + end_date=request.selection_window.end_date, + forecast_horizon=request.forecast_horizon, + ranking_metric=request.ranking_metric, + candidate_models=[c.model_dump() for c in request.candidate_models], + policy_snapshot=request.ranking_policy.model_dump(mode="json"), + ) + db.add(row) + await db.flush() + logger.info( + "model_selection.run_received", + selection_id=row.selection_id, + store_id=request.store_id, + product_id=request.product_id, + n_candidates=len(request.candidate_models), + ) + + availability = await self.get_availability( + db, + request.store_id, + request.product_id, + request.forecast_horizon, + request.split_config, + ) + row.availability_snapshot = availability.model_dump(mode="json") + logger.info( + "model_selection.availability_checked", + selection_id=row.selection_id, + status=availability.status, + observed_days=availability.observed_days, + ) + + if availability.status == "unusable": # LOCKED #2 — fail fast (400) + message = "Insufficient data for model selection (availability unusable)." + row.status = ModelSelectionStatus.FAILED.value + row.error_message = message + await db.flush() + logger.warning( + "model_selection.run_failed", + selection_id=row.selection_id, + reason="unusable_availability", + ) + raise BadRequestError(message=message) + + results: list[CandidateResult] = [] + backtesting_service = BacktestingService() + for candidate in request.candidate_models: + try: + cfg = adapter.validate_python( + {"model_type": candidate.model_type, **candidate.params} + ) + backtest = await backtesting_service.run_backtest( + db, + request.store_id, + request.product_id, + request.selection_window.start_date, + request.selection_window.end_date, + BacktestConfig( + split_config=request.split_config, + model_config_main=cfg, # type: ignore[arg-type] + include_baselines=False, + store_fold_details=True, + ), + ) + results.append(self._shape_candidate(candidate, backtest)) + logger.info( + "model_selection.candidate_completed", + selection_id=row.selection_id, + model_type=candidate.model_type, + ) + except Exception as exc: # never hide a failed candidate + results.append(self._shape_failed_candidate(candidate, exc)) + logger.warning( + "model_selection.candidate_failed", + selection_id=row.selection_id, + model_type=candidate.model_type, + error=str(exc), + ) + + row.candidate_results = [r.model_dump(mode="json") for r in results] + ranking = rank_candidates( + results, request.ranking_policy, request.ranking_metric, availability.status + ) + row.ranking_result = ranking.model_dump(mode="json") + + if ranking.winner is None: # LOCKED #3 — persist failed, return 200 + row.status = ModelSelectionStatus.FAILED.value + row.error_message = "No candidate produced a valid backtest." + row.business_summary = explain_winner(ranking, availability) + row.completed_at = datetime.now(UTC) + await db.flush() + await db.refresh(row) + logger.warning( + "model_selection.run_failed", + selection_id=row.selection_id, + reason="no_valid_winner", + ) + return self._response(row, ranking) + + winner_cfg = adapter.validate_python( + {"model_type": ranking.winner.model_type, **ranking.winner.params} + ) + + if request.auto_train_winner: + from app.features.forecasting.service import ForecastingService # lazy + + train = await ForecastingService().train_model( + db, + request.store_id, + request.product_id, + request.selection_window.start_date, + request.selection_window.end_date, + winner_cfg, # type: ignore[arg-type] + feature_frame_version=request.feature_frame_version, + feature_groups=request.feature_groups, + ) + row.final_model_path = train.model_path + + forecast_warning: str | None = None + if request.auto_predict and row.final_model_path: + from app.features.forecasting.service import ForecastingService # lazy + + try: + prediction = await ForecastingService().predict( + request.store_id, + request.product_id, + request.forecast_horizon, + row.final_model_path, + ) + row.forecast_result = self._forecast_summary( + prediction, request.forecast_horizon + ).model_dump(mode="json") + except Exception as exc: # e.g. feature-aware predict reject — warn, don't fail + forecast_warning = f"Auto-predict skipped: {exc}" + logger.warning( + "model_selection.predict_skipped", + selection_id=row.selection_id, + error=str(exc), + ) + + row.winner_model_type = ranking.winner.model_type + row.winner_metrics = ranking.winner.metrics + row.chart_data = build_chart_data(results, ranking).model_dump(mode="json") + business = explain_winner(ranking, availability) + if forecast_warning is not None: + business["forecast_warning"] = forecast_warning + row.business_summary = business + row.status = ( + ModelSelectionStatus.PARTIAL.value + if any(r.failed for r in results) + else ModelSelectionStatus.COMPLETED.value + ) + row.completed_at = datetime.now(UTC) + await db.flush() + await db.refresh(row) + logger.info( + "model_selection.run_completed", + selection_id=row.selection_id, + status=row.status, + winner=row.winner_model_type, + ) + return self._response(row, ranking) + + # ------------------------------------------------------------------------- + # Read / re-run helpers + # ------------------------------------------------------------------------- + + async def get_selection(self, db: AsyncSession, selection_id: str) -> ModelSelectionRunResponse: + """Return a persisted selection run by id (404 when missing).""" + row = await self._load(db, selection_id) + return self._response(row, self._load_ranking(row)) + + async def get_ranking(self, db: AsyncSession, selection_id: str) -> RankingResult: + """Return just the ranking block for a selection run.""" + row = await self._load(db, selection_id) + return self._load_ranking(row) + + async def train_winner(self, db: AsyncSession, selection_id: str) -> TrainWinnerResponse: + """Train the winning model for a completed selection (V1 contract).""" + from pydantic import TypeAdapter # lazy + + from app.features.forecasting.schemas import ModelConfig # lazy + from app.features.forecasting.service import ForecastingService # lazy + + row = await self._load(db, selection_id) + ranking = self._load_ranking(row) + if ranking.winner is None: + raise BadRequestError(message="Selection has no winning model to train.") + + adapter: TypeAdapter[object] = TypeAdapter(ModelConfig) + cfg = adapter.validate_python( + {"model_type": ranking.winner.model_type, **ranking.winner.params} + ) + train = await ForecastingService().train_model( + db, + row.store_id, + row.product_id, + row.start_date, + row.end_date, + cfg, # type: ignore[arg-type] + ) + row.final_model_path = train.model_path + await db.flush() + logger.info( + "model_selection.winner_trained", + selection_id=row.selection_id, + model_type=ranking.winner.model_type, + ) + return TrainWinnerResponse( + selection_id=row.selection_id, + model_type=ranking.winner.model_type, + model_path=train.model_path, + ) + + async def predict_winner(self, db: AsyncSession, selection_id: str) -> ForecastSummary: + """Forecast with the trained winning model (requires train-winner first).""" + from app.features.forecasting.service import ForecastingService # lazy + + row = await self._load(db, selection_id) + if not row.final_model_path: + raise BadRequestError( + message="No trained model for this selection; call train-winner first." + ) + prediction = await ForecastingService().predict( + row.store_id, row.product_id, row.forecast_horizon, row.final_model_path + ) + summary = self._forecast_summary(prediction, row.forecast_horizon) + row.forecast_result = summary.model_dump(mode="json") + await db.flush() + logger.info( + "model_selection.winner_predicted", + selection_id=row.selection_id, + horizon=row.forecast_horizon, + ) + return summary + + # ------------------------------------------------------------------------- + # Pure mappers + # ------------------------------------------------------------------------- + + def _shape_candidate( + self, candidate: CandidateModelConfig, backtest: BacktestResponse + ) -> CandidateResult: + main = backtest.main_model_results + sample_size = sum(len(fold.actuals) for fold in main.fold_results) + folds = [ + FoldChart( + fold_index=fold.fold_index, + dates=fold.dates, + actuals=fold.actuals, + predictions=fold.predictions, + ) + for fold in main.fold_results + ] + return CandidateResult( + model_type=candidate.model_type, + params=candidate.params, + failed=False, + aggregated_metrics=main.aggregated_metrics, + sample_size=sample_size, + config_hash=backtest.config_hash, + folds=folds, + ) + + def _shape_failed_candidate( + self, candidate: CandidateModelConfig, exc: Exception + ) -> CandidateResult: + return CandidateResult( + model_type=candidate.model_type, + params=candidate.params, + failed=True, + error=str(exc), + aggregated_metrics=None, + sample_size=0, + folds=[], + ) + + def _forecast_summary(self, prediction: PredictResponse, horizon: int) -> ForecastSummary: + points = [point.model_dump(mode="json") for point in prediction.forecasts] + total = float(sum(point.forecast for point in prediction.forecasts)) + average = total / len(prediction.forecasts) if prediction.forecasts else 0.0 + return ForecastSummary( + points=points, total_demand=total, average_demand=average, horizon=horizon + ) + + async def _load(self, db: AsyncSession, selection_id: str) -> ModelSelectionRun: + row = await db.scalar( + select(ModelSelectionRun).where(ModelSelectionRun.selection_id == selection_id) + ) + if row is None: + raise NotFoundError(message=f"Selection run {selection_id} not found") + return row + + def _load_ranking(self, row: ModelSelectionRun) -> RankingResult: + if row.ranking_result: + return RankingResult.model_validate(row.ranking_result) + return RankingResult(winner=None, entries=[], confidence="low", reasons=[]) + + def _response( + self, row: ModelSelectionRun, ranking: RankingResult + ) -> ModelSelectionRunResponse: + availability = ( + PairAvailabilityResponse.model_validate(row.availability_snapshot) + if row.availability_snapshot + else None + ) + chart_data = ChartData.model_validate(row.chart_data) if row.chart_data else None + forecast = ( + ForecastSummary.model_validate(row.forecast_result) if row.forecast_result else None + ) + winner: WinnerSummary | None = None + if ranking.winner is not None and row.status in _TERMINAL_WITH_WINNER: + winner = WinnerSummary( + model_type=ranking.winner.model_type, + params=ranking.winner.params, + metrics=ranking.winner.metrics or {}, + rank=1, + ) + confidence = ranking.confidence if (ranking.entries or ranking.winner) else None + final_model = {"model_path": row.final_model_path} if row.final_model_path else None + return ModelSelectionRunResponse( + selection_id=row.selection_id, + store_id=row.store_id, + product_id=row.product_id, + status=row.status, # type: ignore[arg-type] + selection_window=SelectionWindow(start_date=row.start_date, end_date=row.end_date), + forecast_horizon=row.forecast_horizon, + ranking_metric=row.ranking_metric, + availability=availability, + ranking=ranking.entries, + winner=winner, + recommendation_confidence=confidence, + confidence_reasons=ranking.reasons, + chart_data=chart_data, + final_model=final_model, + forecast=forecast, + business_summary=row.business_summary, + error_message=row.error_message, + created_at=row.created_at, + completed_at=row.completed_at, + ) diff --git a/app/features/model_selection/tests/__init__.py b/app/features/model_selection/tests/__init__.py new file mode 100644 index 00000000..1dd09b80 --- /dev/null +++ b/app/features/model_selection/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for the Forecast Champion Selector slice (issue #353).""" diff --git a/app/features/model_selection/tests/conftest.py b/app/features/model_selection/tests/conftest.py new file mode 100644 index 00000000..6e3335b7 --- /dev/null +++ b/app/features/model_selection/tests/conftest.py @@ -0,0 +1,301 @@ +"""Test fixtures + factories for the model_selection slice (issue #353). + +Unit helpers build ``CandidateResult`` / fake backtest+predict responses and a +mock ``AsyncSession`` whose ``flush`` stamps ``created_at`` (so the response +mapper, which reads it, works without a real DB). Integration fixtures +(``@pytest.mark.integration``) seed a real ``docker compose`` Postgres and clean +up after themselves with prefix-scoped teardown. +""" + +from __future__ import annotations + +import uuid +from collections.abc import AsyncGenerator +from datetime import UTC, date, datetime, timedelta +from decimal import Decimal +from types import SimpleNamespace +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest +from httpx import ASGITransport, AsyncClient +from sqlalchemy import delete +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from app.core.config import get_settings +from app.core.database import get_db +from app.features.backtesting.schemas import SplitConfig +from app.features.data_platform.models import Calendar, Product, SalesDaily, Store +from app.features.model_selection.models import ModelSelectionRun +from app.features.model_selection.schemas import ( + CandidateResult, + FoldChart, + PairAvailabilityResponse, +) +from app.main import app + +# Integration test window. +TEST_START = date(2024, 1, 1) + + +# ============================================================================= +# Unit factories +# ============================================================================= + + +def make_candidate_result( + model_type: str, + *, + wape: float = 20.0, + smape: float = 15.0, + mae: float = 5.0, + rmse: float = 6.0, + bias: float = 0.5, + sample_size: int = 28, + n_folds: int = 2, + points_per_fold: int = 14, + params: dict[str, Any] | None = None, + failed: bool = False, + error: str | None = None, + aggregated_metrics: dict[str, float] | None = None, +) -> CandidateResult: + """Build a ``CandidateResult`` for ranking/chart unit tests.""" + if failed: + return CandidateResult( + model_type=model_type, + params=params or {}, + failed=True, + error=error or "boom", + aggregated_metrics=None, + sample_size=0, + folds=[], + ) + folds = [ + FoldChart( + fold_index=i, + dates=[ + TEST_START + timedelta(days=i * points_per_fold + j) for j in range(points_per_fold) + ], + actuals=[10.0 + j for j in range(points_per_fold)], + predictions=[10.5 + j for j in range(points_per_fold)], + ) + for i in range(n_folds) + ] + metrics = aggregated_metrics or { + "mae": mae, + "rmse": rmse, + "smape": smape, + "wape": wape, + "bias": bias, + } + return CandidateResult( + model_type=model_type, + params=params or {}, + failed=False, + aggregated_metrics=metrics, + sample_size=sample_size, + config_hash="cafef00d", + folds=folds, + ) + + +def make_backtest_response( + *, + wape: float = 20.0, + smape: float = 15.0, + mae: float = 5.0, + rmse: float = 6.0, + bias: float = 0.5, + n_folds: int = 2, + points_per_fold: int = 14, +) -> SimpleNamespace: + """A duck-typed stand-in for ``BacktestResponse`` (what _shape_candidate reads).""" + folds = [ + SimpleNamespace( + fold_index=i, + dates=[ + TEST_START + timedelta(days=i * points_per_fold + j) for j in range(points_per_fold) + ], + actuals=[10.0 + j for j in range(points_per_fold)], + predictions=[10.5 + j for j in range(points_per_fold)], + ) + for i in range(n_folds) + ] + main = SimpleNamespace( + fold_results=folds, + aggregated_metrics={ + "mae": mae, + "rmse": rmse, + "smape": smape, + "wape": wape, + "bias": bias, + }, + metric_std={}, + ) + return SimpleNamespace(main_model_results=main, config_hash="bt00deadbeef", backtest_id="bt") + + +def make_availability( + *, + status: str = "ready", + store_id: int = 1, + product_id: int = 1, + horizon: int = 14, +) -> PairAvailabilityResponse: + """A ready/limited/unusable availability response for service unit tests.""" + return PairAvailabilityResponse( + store_id=store_id, + product_id=product_id, + first_sales_date=TEST_START, + last_sales_date=TEST_START + timedelta(days=119), + observed_days=120, + expected_calendar_days=120, + coverage_ratio=1.0, + missing_days=0, + zero_sale_days=0, + promotion_days=0, + average_daily_demand=12.0, + status=status, # type: ignore[arg-type] + recommended_split_config=SplitConfig( + strategy="expanding", n_splits=5, min_train_size=30, gap=0, horizon=horizon + ), + warnings=[], + ) + + +def make_mock_db() -> AsyncMock: + """Mock ``AsyncSession`` whose flush stamps ``created_at`` on added rows.""" + db = AsyncMock() + added: list[Any] = [] + + def _add(obj: Any) -> None: + added.append(obj) + + async def _flush() -> None: + for obj in added: + if isinstance(obj, ModelSelectionRun) and obj.created_at is None: + obj.created_at = datetime.now(UTC) + + db.add = MagicMock(side_effect=_add) + db.flush = AsyncMock(side_effect=_flush) + db.refresh = AsyncMock() + return db + + +# ============================================================================= +# Integration fixtures — real Postgres +# ============================================================================= + + +@pytest.fixture +async def db_session() -> AsyncGenerator[AsyncSession, None]: + """Yield an async session; wipe model_selection + test data on teardown.""" + settings = get_settings() + engine = create_async_engine(settings.database_url, echo=False) + session_maker = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + async with session_maker() as session: + try: + yield session + finally: + store_ids = _registered_store_ids() + if store_ids: + await session.execute( + delete(ModelSelectionRun).where(ModelSelectionRun.store_id.in_(store_ids)) + ) + await session.execute( + delete(SalesDaily).where(SalesDaily.unit_price == Decimal("3.33")) + ) + await session.execute(delete(Product).where(Product.sku.like("TMSEL-%"))) + await session.execute(delete(Store).where(Store.code.like("TMSEL-%"))) + await session.commit() + + await engine.dispose() + + +# Track store ids created by the seeding fixtures so teardown can scope the +# model_selection_run cleanup precisely. +_SEEDED_STORE_IDS: list[int] = [] + + +def _registered_store_ids() -> list[int]: + return list(_SEEDED_STORE_IDS) + + +@pytest.fixture +async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]: + """Test client with the database dependency overridden.""" + + async def override_get_db() -> AsyncGenerator[AsyncSession, None]: + yield db_session + + app.dependency_overrides[get_db] = override_get_db + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: + yield ac + app.dependency_overrides.pop(get_db, None) + + +async def _seed_pair(db: AsyncSession, n_days: int) -> dict[str, Any]: + """Seed a store/product/calendar + a clean weekly sales series of n_days.""" + suffix = uuid.uuid4().hex[:8] + store = Store(code=f"TMSEL-{suffix}", name="MSel Store", region="R", store_type="x") + product = Product( + sku=f"TMSEL-{suffix}", + name="MSel Product", + category="C", + base_price=Decimal("3.33"), + launch_date=TEST_START, + ) + db.add_all([store, product]) + await db.commit() + await db.refresh(store) + await db.refresh(product) + _SEEDED_STORE_IDS.append(store.id) + + weekly = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0] + for i in range(n_days): + d = TEST_START + timedelta(days=i) + await db.merge( + Calendar( + date=d, + day_of_week=d.weekday(), + month=d.month, + quarter=(d.month - 1) // 3 + 1, + year=d.year, + is_holiday=False, + ) + ) + await db.commit() + + for i in range(n_days): + qty = int(weekly[i % 7]) + db.add( + SalesDaily( + date=TEST_START + timedelta(days=i), + store_id=store.id, + product_id=product.id, + quantity=qty, + unit_price=Decimal("3.33"), + total_amount=Decimal("3.33") * qty, + ) + ) + await db.commit() + return { + "store_id": store.id, + "product_id": product.id, + "start_date": TEST_START.isoformat(), + "end_date": (TEST_START + timedelta(days=n_days - 1)).isoformat(), + "n_days": n_days, + } + + +@pytest.fixture +async def ready_pair(db_session: AsyncSession) -> dict[str, Any]: + """A 120-day pair — ``ready`` for horizon=14, n_splits=5 (threshold 100).""" + return await _seed_pair(db_session, 120) + + +@pytest.fixture +async def limited_pair(db_session: AsyncSession) -> dict[str, Any]: + """A 50-day pair — ``limited`` (>= 44, < 100).""" + return await _seed_pair(db_session, 50) diff --git a/app/features/model_selection/tests/test_explanations.py b/app/features/model_selection/tests/test_explanations.py new file mode 100644 index 00000000..040b8aa3 --- /dev/null +++ b/app/features/model_selection/tests/test_explanations.py @@ -0,0 +1,44 @@ +"""Unit tests for the deterministic business-explanation layer.""" + +from __future__ import annotations + +from app.features.model_selection.explanations import explain_winner +from app.features.model_selection.ranking import rank_candidates +from app.features.model_selection.schemas import RankingPolicy +from app.features.model_selection.tests.conftest import make_availability, make_candidate_result + + +def test_explain_winner_produces_deterministic_summary() -> None: + results = [ + make_candidate_result("winner", wape=10.0), + make_candidate_result("second", wape=20.0), + ] + ranking = rank_candidates(results, RankingPolicy(), "wape", availability_status="ready") + summary = explain_winner(ranking, make_availability(status="ready")) + + assert "winner" in summary["headline"] + assert summary["winner"]["model_type"] == "winner" + assert summary["recommendation_confidence"] == ranking.confidence + assert summary["confidence_reasons"] == ranking.reasons + assert summary["comparison"]["runner_up_model_type"] == "second" + assert any("coverage" in note.lower() for note in summary["data_notes"]) + assert summary["caveats"] + + +def test_explain_winner_is_deterministic() -> None: + """Same input → byte-identical output (no LLM, no randomness).""" + results = [ + make_candidate_result("winner", wape=10.0), + make_candidate_result("second", wape=20.0), + ] + ranking = rank_candidates(results, RankingPolicy(), "wape", availability_status="ready") + availability = make_availability(status="ready") + assert explain_winner(ranking, availability) == explain_winner(ranking, availability) + + +def test_explain_winner_handles_no_winner() -> None: + results = [make_candidate_result("x", failed=True, error="boom")] + ranking = rank_candidates(results, RankingPolicy(), "wape") + summary = explain_winner(ranking, make_availability(status="limited")) + assert summary["winner"] is None + assert "No model" in summary["headline"] diff --git a/app/features/model_selection/tests/test_models.py b/app/features/model_selection/tests/test_models.py new file mode 100644 index 00000000..4f69d9e9 --- /dev/null +++ b/app/features/model_selection/tests/test_models.py @@ -0,0 +1,41 @@ +"""Tests for the ModelSelectionRun ORM model + status enum. + +The status CHECK-constraint enforcement is exercised in the integration suite +(it requires the real Postgres CHECK); here we cover the enum values and the +in-Python ORM construction. +""" + +from __future__ import annotations + +from datetime import date + +from app.features.model_selection.models import ModelSelectionRun, ModelSelectionStatus + + +def test_status_enum_values() -> None: + assert {s.value for s in ModelSelectionStatus} == { + "pending", + "running", + "completed", + "partial", + "failed", + } + + +def test_model_selection_run_construction_defaults() -> None: + row = ModelSelectionRun( + selection_id="abc123", + store_id=1, + product_id=2, + start_date=date(2026, 1, 1), + end_date=date(2026, 5, 31), + forecast_horizon=14, + ranking_metric="wape", + status=ModelSelectionStatus.RUNNING.value, + candidate_models=[{"model_type": "naive", "params": {}}], + policy_snapshot={"minimum_sample_size": 0}, + ) + assert row.selection_id == "abc123" + assert row.status == "running" + assert row.winner_model_type is None + assert row.final_model_path is None diff --git a/app/features/model_selection/tests/test_ranking.py b/app/features/model_selection/tests/test_ranking.py new file mode 100644 index 00000000..3c01b25a --- /dev/null +++ b/app/features/model_selection/tests/test_ranking.py @@ -0,0 +1,205 @@ +"""Unit tests for the pure ranking + chart logic.""" + +from __future__ import annotations + +import math + +from app.features.model_selection.ranking import ( + build_chart_data, + normalize_metrics, + rank_candidates, +) +from app.features.model_selection.schemas import RankingPolicy +from app.features.model_selection.tests.conftest import make_candidate_result + + +def test_rank_candidates_wape_smape_abs_bias_mae_tie_break() -> None: + """Default sort key is (wape, smape, abs(bias), mae, model_type) (LOCKED #6).""" + # Same wape; B wins on smape; C loses smape but would win mae (irrelevant). + results = [ + make_candidate_result("a_model", wape=10.0, smape=20.0, bias=1.0, mae=9.0), + make_candidate_result("b_model", wape=10.0, smape=15.0, bias=5.0, mae=8.0), + make_candidate_result("c_model", wape=10.0, smape=18.0, bias=0.0, mae=1.0), + ] + ranking = rank_candidates(results, RankingPolicy(), "wape") + order = [e.model_type for e in ranking.entries if e.included] + assert order == ["b_model", "c_model", "a_model"] + assert ranking.winner is not None + assert ranking.winner.model_type == "b_model" + assert ranking.winner.rank == 1 + + +def test_rank_candidates_model_type_breaks_full_tie() -> None: + """Identical metrics fall back to model_type alphabetical for determinism.""" + results = [ + make_candidate_result("zeta", wape=5.0, smape=5.0, bias=0.0, mae=1.0), + make_candidate_result("alpha", wape=5.0, smape=5.0, bias=0.0, mae=1.0), + ] + ranking = rank_candidates(results, RankingPolicy(), "wape") + assert ranking.winner is not None + assert ranking.winner.model_type == "alpha" + + +def test_rank_candidates_non_default_metric_puts_it_first() -> None: + """ranking_metric='mae' ranks by mae first.""" + results = [ + make_candidate_result("high_wape_low_mae", wape=50.0, mae=1.0), + make_candidate_result("low_wape_high_mae", wape=5.0, mae=99.0), + ] + ranking = rank_candidates(results, RankingPolicy(), "mae") + assert ranking.winner is not None + assert ranking.winner.model_type == "high_wape_low_mae" + + +def test_rank_candidates_excludes_missing_or_nan_metrics() -> None: + """A NaN/None primary metric drops the candidate to an excluded entry.""" + good = make_candidate_result("good", wape=10.0) + nan_metrics = make_candidate_result("nan_model", wape=float("nan")) + no_metrics = make_candidate_result("no_metrics", failed=False) + no_metrics.aggregated_metrics = None + ranking = rank_candidates([good, nan_metrics, no_metrics], RankingPolicy(), "wape") + + assert ranking.winner is not None + assert ranking.winner.model_type == "good" + excluded = {e.model_type: e for e in ranking.entries if not e.included} + assert set(excluded) == {"nan_model", "no_metrics"} + assert excluded["nan_model"].rank is None + assert excluded["nan_model"].exclusion_reason is not None + + +def test_rank_candidates_normalizes_five_metric_keys_including_rmse() -> None: + """normalize_metrics carries all five keys incl. rmse; entries echo them.""" + metrics = normalize_metrics( + {"mae": 1.0, "rmse": 2.0, "smape": 3.0, "wape": 4.0, "bias": 5.0}, sample_size=20 + ) + assert metrics is not None + assert metrics.rmse == 2.0 + as_dict = metrics.as_dict() + assert set(as_dict) == {"wape", "smape", "mae", "rmse", "bias", "sample_size"} + + ranking = rank_candidates([make_candidate_result("m", rmse=7.5)], RankingPolicy(), "wape") + assert ranking.entries[0].metrics is not None + assert ranking.entries[0].metrics["rmse"] == 7.5 + + +def test_normalize_metrics_rejects_inf_wape() -> None: + """An inf WAPE (all-zero actuals) is unrankable.""" + assert ( + normalize_metrics( + {"mae": 1.0, "rmse": 2.0, "smape": 3.0, "wape": math.inf, "bias": 0.0}, 10 + ) + is None + ) + + +def test_rank_candidates_excludes_below_minimum_sample_size() -> None: + """A candidate below the policy sample floor is excluded.""" + results = [ + make_candidate_result("ok", wape=10.0, sample_size=40), + make_candidate_result("tiny", wape=1.0, sample_size=5), + ] + ranking = rank_candidates(results, RankingPolicy(minimum_sample_size=30), "wape") + assert ranking.winner is not None + assert ranking.winner.model_type == "ok" + excluded = [e for e in ranking.entries if not e.included] + assert excluded[0].model_type == "tiny" + + +def test_confidence_high_when_winner_beats_second_by_10_percent() -> None: + """A >=10% relative WAPE lead with acceptable bias yields HIGH confidence.""" + results = [ + make_candidate_result("winner", wape=10.0, bias=0.1), + make_candidate_result("second", wape=20.0, bias=0.1), + ] + ranking = rank_candidates(results, RankingPolicy(), "wape", availability_status="ready") + assert ranking.winner is not None + assert ranking.winner.model_type == "winner" + assert ranking.confidence == "high" + + +def test_confidence_low_for_single_valid_candidate() -> None: + ranking = rank_candidates([make_candidate_result("solo", wape=10.0)], RankingPolicy(), "wape") + assert ranking.confidence == "low" + + +def test_confidence_low_for_near_tie() -> None: + """A sub-epsilon lead is a near tie → LOW.""" + results = [ + make_candidate_result("a", wape=10.0), + make_candidate_result("b", wape=10.05), + ] + ranking = rank_candidates(results, RankingPolicy(), "wape", availability_status="ready") + assert ranking.confidence == "low" + + +def test_confidence_medium_when_lead_below_high_threshold() -> None: + """A 5% lead (between epsilon and 10%) is MEDIUM.""" + results = [ + make_candidate_result("a", wape=9.5), + make_candidate_result("b", wape=10.0), + ] + ranking = rank_candidates(results, RankingPolicy(), "wape", availability_status="ready") + assert ranking.confidence == "medium" + + +def test_confidence_low_when_availability_limited() -> None: + """Limited availability caps confidence at LOW even with a clear lead.""" + results = [ + make_candidate_result("winner", wape=10.0), + make_candidate_result("second", wape=20.0), + ] + ranking = rank_candidates(results, RankingPolicy(), "wape", availability_status="limited") + assert ranking.confidence == "low" + + +def test_confidence_low_when_bias_over_threshold() -> None: + """A winner bias above the policy bound caps confidence at LOW.""" + results = [ + make_candidate_result("winner", wape=10.0, bias=50.0), + make_candidate_result("second", wape=20.0, bias=0.0), + ] + ranking = rank_candidates( + results, RankingPolicy(max_acceptable_abs_bias=1.0), "wape", availability_status="ready" + ) + assert ranking.confidence == "low" + + +def test_all_failed_candidates_yield_no_winner() -> None: + results = [ + make_candidate_result("x", failed=True, error="train error"), + make_candidate_result("y", failed=True, error="value error"), + ] + ranking = rank_candidates(results, RankingPolicy(), "wape") + assert ranking.winner is None + assert ranking.confidence == "low" + assert all(not e.included for e in ranking.entries) + + +def test_winner_entry_carries_params_for_rebuild() -> None: + """The winner entry preserves the original candidate params.""" + results = [ + make_candidate_result("seasonal_naive", wape=10.0, params={"season_length": 7}), + make_candidate_result("naive", wape=20.0, params={}), + ] + ranking = rank_candidates(results, RankingPolicy(), "wape") + assert ranking.winner is not None + assert ranking.winner.model_type == "seasonal_naive" + assert ranking.winner.params == {"season_length": 7} + + +def test_chart_data_has_wape_bias_fold_stability_and_winner_actual_vs_predicted() -> None: + """build_chart_data populates all four chart series.""" + results = [ + make_candidate_result("winner", wape=10.0, n_folds=3), + make_candidate_result("second", wape=20.0, n_folds=3), + ] + ranking = rank_candidates(results, RankingPolicy(), "wape") + chart = build_chart_data(results, ranking) + + assert set(chart.wape_by_model) == {"winner", "second"} + assert chart.wape_by_model["winner"] == 10.0 + assert set(chart.bias_by_model) == {"winner", "second"} + assert len(chart.fold_stability["winner"]) == 3 + assert all(isinstance(v, float) for v in chart.fold_stability["winner"]) + assert len(chart.winner_actual_vs_predicted) == 3 + assert chart.winner_actual_vs_predicted[0].actuals diff --git a/app/features/model_selection/tests/test_routes.py b/app/features/model_selection/tests/test_routes.py new file mode 100644 index 00000000..7cfb35f5 --- /dev/null +++ b/app/features/model_selection/tests/test_routes.py @@ -0,0 +1,180 @@ +"""Unit route tests — service methods mocked, exercised over the HTTP boundary. + +``get_db`` is overridden with a mock session; the service is patched at the +class level so the routes are tested in isolation. Error paths assert the +RFC 7807 problem-detail shape. +""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from datetime import UTC, date, datetime +from typing import Any +from unittest.mock import AsyncMock + +import pytest +from httpx import ASGITransport, AsyncClient + +from app.core.database import get_db +from app.core.exceptions import BadRequestError, NotFoundError +from app.features.model_selection.schemas import ( + ModelRankEntry, + ModelSelectionRunResponse, + SelectionWindow, + WinnerSummary, +) +from app.features.model_selection.service import ModelSelectionService +from app.main import app + + +@asynccontextmanager +async def _client() -> AsyncGenerator[AsyncClient, None]: + async def override_get_db() -> AsyncGenerator[AsyncMock, None]: + yield AsyncMock() + + app.dependency_overrides[get_db] = override_get_db + try: + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: + yield ac + finally: + app.dependency_overrides.pop(get_db, None) + + +def _assert_problem_detail(body: dict[str, Any], expected_status: int) -> None: + for key in ("type", "title", "status", "detail"): + assert key in body, f"missing RFC 7807 field: {key}" + assert body["status"] == expected_status + + +def _run_response() -> ModelSelectionRunResponse: + metrics = { + "wape": 10.0, + "smape": 8.0, + "mae": 4.0, + "rmse": 5.0, + "bias": 0.1, + "sample_size": 28.0, + } + return ModelSelectionRunResponse( + selection_id="sel123", + store_id=1, + product_id=1, + status="completed", + selection_window=SelectionWindow(start_date=date(2026, 1, 1), end_date=date(2026, 5, 31)), + forecast_horizon=14, + ranking_metric="wape", + availability=None, + ranking=[ + ModelRankEntry(rank=1, model_type="naive", params={}, included=True, metrics=metrics) + ], + winner=WinnerSummary(model_type="naive", params={}, metrics=metrics, rank=1), + recommendation_confidence="high", + confidence_reasons=["clear lead"], + chart_data=None, + final_model=None, + forecast=None, + business_summary=None, + error_message=None, + created_at=datetime.now(UTC), + completed_at=None, + ) + + +def _valid_run_body(**overrides: Any) -> dict[str, Any]: + body: dict[str, Any] = { + "store_id": 1, + "product_id": 1, + "selection_window": {"start_date": "2026-01-01", "end_date": "2026-05-31"}, + "forecast_horizon": 14, + "split_config": { + "strategy": "expanding", + "n_splits": 5, + "min_train_size": 30, + "gap": 0, + "horizon": 14, + }, + "candidate_models": [{"model_type": "naive", "params": {}}], + } + body.update(overrides) + return body + + +async def test_run_returns_200(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + ModelSelectionService, "run_selection", AsyncMock(return_value=_run_response()) + ) + async with _client() as ac: + response = await ac.post("/model-selection/run", json=_valid_run_body()) + assert response.status_code == 200 + body = response.json() + assert body["selection_id"] == "sel123" + assert body["recommendation_confidence"] == "high" + assert "confidence" not in body + + +async def test_run_validation_error_returns_problem_json() -> None: + """auto_predict without auto_train_winner is rejected by the validator (422).""" + async with _client() as ac: + response = await ac.post( + "/model-selection/run", + json=_valid_run_body(auto_predict=True, auto_train_winner=False), + ) + assert response.status_code == 422 + _assert_problem_detail(response.json(), 422) + + +async def test_routes_return_problem_json_on_bad_request( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr( + ModelSelectionService, + "run_selection", + AsyncMock(side_effect=BadRequestError(message="availability unusable")), + ) + async with _client() as ac: + response = await ac.post("/model-selection/run", json=_valid_run_body()) + assert response.status_code == 400 + _assert_problem_detail(response.json(), 400) + + +async def test_get_selection_not_found_returns_problem_json( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr( + ModelSelectionService, + "get_selection", + AsyncMock(side_effect=NotFoundError(message="Selection run missing not found")), + ) + async with _client() as ac: + response = await ac.get("/model-selection/missing") + assert response.status_code == 404 + _assert_problem_detail(response.json(), 404) + + +async def test_availability_returns_200(monkeypatch: pytest.MonkeyPatch) -> None: + from app.features.model_selection.tests.conftest import make_availability + + monkeypatch.setattr( + ModelSelectionService, + "get_availability", + AsyncMock(return_value=make_availability(status="ready")), + ) + async with _client() as ac: + response = await ac.get( + "/model-selection/availability", + params={"store_id": 1, "product_id": 1, "forecast_horizon": 14}, + ) + assert response.status_code == 200 + assert response.json()["status"] == "ready" + + +async def test_availability_rejects_bad_query() -> None: + """store_id < 1 fails Query validation → 422 problem+json.""" + async with _client() as ac: + response = await ac.get( + "/model-selection/availability", + params={"store_id": 0, "product_id": 1}, + ) + assert response.status_code == 422 + _assert_problem_detail(response.json(), 422) diff --git a/app/features/model_selection/tests/test_routes_integration.py b/app/features/model_selection/tests/test_routes_integration.py new file mode 100644 index 00000000..a6440f71 --- /dev/null +++ b/app/features/model_selection/tests/test_routes_integration.py @@ -0,0 +1,138 @@ +"""Integration tests for the model_selection slice against real Postgres. + +Marked ``@pytest.mark.integration`` — require ``docker compose up -d`` + an +applied ``alembic upgrade head``. +""" + +from __future__ import annotations + +from typing import Any + +import pytest +from httpx import AsyncClient +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession + +pytestmark = pytest.mark.integration + + +def _run_body( + pair: dict[str, Any], extra_candidates: list[dict[str, Any]] | None = None +) -> dict[str, Any]: + candidates = [ + {"model_type": "naive", "params": {}}, + {"model_type": "seasonal_naive", "params": {"season_length": 7}}, + {"model_type": "moving_average", "params": {"window_size": 7}}, + ] + if extra_candidates: + candidates.extend(extra_candidates) + return { + "store_id": pair["store_id"], + "product_id": pair["product_id"], + "selection_window": {"start_date": pair["start_date"], "end_date": pair["end_date"]}, + "forecast_horizon": 14, + "ranking_metric": "wape", + "split_config": { + "strategy": "expanding", + "n_splits": 5, + "min_train_size": 30, + "gap": 0, + "horizon": 14, + }, + "candidate_models": candidates, + "auto_train_winner": False, + "auto_predict": False, + } + + +async def test_table_has_named_indexes(db_session: AsyncSession) -> None: + rows = await db_session.execute( + text("SELECT indexname FROM pg_indexes WHERE tablename = 'model_selection_run'") + ) + names = {row[0] for row in rows} + assert "ix_model_selection_run_selection_id" in names + assert "ix_model_selection_run_store_product_created" in names + assert "ix_model_selection_run_status_created" in names + + +async def test_availability_ready_pair(client: AsyncClient, ready_pair: dict[str, Any]) -> None: + response = await client.get( + "/model-selection/availability", + params={ + "store_id": ready_pair["store_id"], + "product_id": ready_pair["product_id"], + "forecast_horizon": 14, + }, + ) + assert response.status_code == 200 + body = response.json() + assert body["status"] == "ready" + assert body["observed_days"] == ready_pair["n_days"] + assert body["recommended_split_config"]["horizon"] == 14 + + +async def test_availability_limited_pair(client: AsyncClient, limited_pair: dict[str, Any]) -> None: + response = await client.get( + "/model-selection/availability", + params={ + "store_id": limited_pair["store_id"], + "product_id": limited_pair["product_id"], + "forecast_horizon": 14, + }, + ) + assert response.status_code == 200 + assert response.json()["status"] == "limited" + + +async def test_availability_unknown_pair_returns_404(client: AsyncClient) -> None: + response = await client.get( + "/model-selection/availability", + params={"store_id": 999999, "product_id": 999999, "forecast_horizon": 14}, + ) + assert response.status_code == 404 + assert response.json()["status"] == 404 + + +async def test_run_persists_and_get_returns_same( + client: AsyncClient, ready_pair: dict[str, Any] +) -> None: + run = await client.post("/model-selection/run", json=_run_body(ready_pair)) + assert run.status_code == 200 + body = run.json() + assert body["status"] in {"completed", "partial"} + assert body["winner"] is not None + assert body["recommendation_confidence"] in {"high", "medium", "low"} + assert body["chart_data"] is not None + assert body["ranking"] + selection_id = body["selection_id"] + + fetched = await client.get(f"/model-selection/{selection_id}") + assert fetched.status_code == 200 + assert fetched.json()["selection_id"] == selection_id + + ranking = await client.get(f"/model-selection/{selection_id}/ranking") + assert ranking.status_code == 200 + assert ranking.json()["winner"] is not None + + +async def test_run_partial_with_bad_candidate( + client: AsyncClient, ready_pair: dict[str, Any] +) -> None: + """An invalid candidate param surfaces as a failed entry, not a 500.""" + body = _run_body( + ready_pair, + extra_candidates=[{"model_type": "moving_average", "params": {"window_size": 0}}], + ) + response = await client.post("/model-selection/run", json=body) + assert response.status_code == 200 + payload = response.json() + assert payload["status"] == "partial" + excluded = [e for e in payload["ranking"] if not e["included"]] + assert excluded + assert payload["winner"] is not None + + +async def test_get_missing_selection_returns_404(client: AsyncClient) -> None: + response = await client.get("/model-selection/does-not-exist") + assert response.status_code == 404 + assert response.json()["status"] == 404 diff --git a/app/features/model_selection/tests/test_schemas.py b/app/features/model_selection/tests/test_schemas.py new file mode 100644 index 00000000..3d34c510 --- /dev/null +++ b/app/features/model_selection/tests/test_schemas.py @@ -0,0 +1,81 @@ +"""Unit tests for model_selection request schemas (strict mode + validators).""" + +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +from app.features.model_selection.schemas import ( + ModelSelectionRunRequest, + SelectionWindow, +) + + +def _base_request_dict(**overrides: object) -> dict[str, object]: + payload: dict[str, object] = { + "store_id": 1, + "product_id": 1, + "selection_window": {"start_date": "2026-01-01", "end_date": "2026-05-31"}, + "forecast_horizon": 14, + "split_config": { + "strategy": "expanding", + "n_splits": 5, + "min_train_size": 30, + "gap": 0, + "horizon": 14, + }, + "candidate_models": [{"model_type": "naive", "params": {}}], + } + payload.update(overrides) + return payload + + +def test_schema_accepts_iso_dates_under_strict_model() -> None: + """ISO-string dates validate through the strict ``validate_python`` path.""" + window = SelectionWindow.model_validate({"start_date": "2026-01-01", "end_date": "2026-02-01"}) + assert window.start_date.isoformat() == "2026-01-01" + + request = ModelSelectionRunRequest.model_validate(_base_request_dict()) + assert request.selection_window.end_date.isoformat() == "2026-05-31" + + +def test_schema_rejects_auto_predict_without_train_winner() -> None: + """LOCKED #7 — auto_predict requires auto_train_winner.""" + with pytest.raises(ValidationError, match="auto_predict requires auto_train_winner"): + ModelSelectionRunRequest.model_validate( + _base_request_dict(auto_predict=True, auto_train_winner=False) + ) + + +def test_schema_rejects_horizon_mismatch_between_split_and_forecast() -> None: + """LOCKED #5 — split_config.horizon must equal forecast_horizon.""" + bad = _base_request_dict(forecast_horizon=14) + bad["split_config"] = { + "strategy": "expanding", + "n_splits": 5, + "min_train_size": 30, + "gap": 0, + "horizon": 7, + } + with pytest.raises(ValidationError, match="must equal"): + ModelSelectionRunRequest.model_validate(bad) + + +def test_schema_rejects_feature_groups_with_v1() -> None: + """V1 must not carry feature_groups (mirrors forecasting TrainRequest).""" + with pytest.raises(ValidationError, match="feature_groups is only valid"): + ModelSelectionRunRequest.model_validate( + _base_request_dict(feature_frame_version=1, feature_groups=["calendar"]) + ) + + +def test_selection_window_rejects_inverted_range() -> None: + """An end <= start window is rejected.""" + with pytest.raises(ValidationError, match="after start_date"): + SelectionWindow.model_validate({"start_date": "2026-02-01", "end_date": "2026-01-01"}) + + +def test_candidate_models_min_length_enforced() -> None: + """At least one candidate is required.""" + with pytest.raises(ValidationError): + ModelSelectionRunRequest.model_validate(_base_request_dict(candidate_models=[])) diff --git a/app/features/model_selection/tests/test_service.py b/app/features/model_selection/tests/test_service.py new file mode 100644 index 00000000..7d3da5f1 --- /dev/null +++ b/app/features/model_selection/tests/test_service.py @@ -0,0 +1,222 @@ +"""Unit tests for ModelSelectionService orchestration (mocked sibling services).""" + +from __future__ import annotations + +from datetime import date, timedelta +from types import SimpleNamespace +from typing import Any +from unittest.mock import AsyncMock +from uuid import uuid4 + +import pytest +from pydantic import TypeAdapter + +from app.core.exceptions import BadRequestError, NotFoundError +from app.features.forecasting.schemas import ModelConfig +from app.features.model_selection.schemas import ModelSelectionRunRequest +from app.features.model_selection.service import ModelSelectionService +from app.features.model_selection.tests.conftest import ( + make_availability, + make_backtest_response, + make_mock_db, +) + + +def _request(**overrides: Any) -> ModelSelectionRunRequest: + payload: dict[str, Any] = { + "store_id": 1, + "product_id": 1, + "selection_window": {"start_date": "2026-01-01", "end_date": "2026-05-31"}, + "forecast_horizon": 14, + "split_config": { + "strategy": "expanding", + "n_splits": 5, + "min_train_size": 30, + "gap": 0, + "horizon": 14, + }, + "candidate_models": [{"model_type": "naive", "params": {}}], + } + payload.update(overrides) + return ModelSelectionRunRequest.model_validate(payload) + + +def _patch_backtester( + monkeypatch: pytest.MonkeyPatch, *, side_effect: list[Any] +) -> SimpleNamespace: + instance = SimpleNamespace(run_backtest=AsyncMock(side_effect=side_effect)) + monkeypatch.setattr("app.features.backtesting.service.BacktestingService", lambda: instance) + return instance + + +def _patch_availability(monkeypatch: pytest.MonkeyPatch, status: str) -> None: + monkeypatch.setattr( + ModelSelectionService, + "get_availability", + AsyncMock(return_value=make_availability(status=status)), + ) + + +# ----------------------------------------------------------------------------- +# Flattening +# ----------------------------------------------------------------------------- + + +def test_build_model_config_flattens_params() -> None: + """The service's flatten-then-validate builds a typed ModelConfig.""" + adapter: TypeAdapter[Any] = TypeAdapter(ModelConfig) + cfg = adapter.validate_python({"model_type": "seasonal_naive", "season_length": 7}) + assert cfg.model_type == "seasonal_naive" + assert cfg.season_length == 7 + + +# ----------------------------------------------------------------------------- +# Availability thresholds +# ----------------------------------------------------------------------------- + + +def _availability_db(observed: int) -> AsyncMock: + """Mock DB returning a contiguous `observed`-day aggregate for one pair.""" + first = date(2024, 1, 1) if observed else None + last = date(2024, 1, 1) + timedelta(days=observed - 1) if observed else None + db = AsyncMock() + db.get = AsyncMock(return_value=SimpleNamespace(id=1)) + result = AsyncMock() + result.one = lambda: (first, last, observed, 12.0, 0) + db.execute = AsyncMock(return_value=result) + db.scalar = AsyncMock(return_value=0) + return db + + +@pytest.mark.parametrize( + ("observed", "expected"), + [(120, "ready"), (50, "limited"), (20, "unusable")], +) +async def test_availability_ready_limited_unusable_thresholds(observed: int, expected: str) -> None: + service = ModelSelectionService() + db = _availability_db(observed) + availability = await service.get_availability(db, 1, 1, forecast_horizon=14) + assert availability.status == expected + + +async def test_availability_missing_store_raises_not_found() -> None: + service = ModelSelectionService() + db = AsyncMock() + db.get = AsyncMock(return_value=None) + with pytest.raises(NotFoundError): + await service.get_availability(db, 999, 1, forecast_horizon=14) + + +# ----------------------------------------------------------------------------- +# Orchestration +# ----------------------------------------------------------------------------- + + +async def test_run_selection_partial_success_chooses_valid_winner( + monkeypatch: pytest.MonkeyPatch, +) -> None: + _patch_availability(monkeypatch, "ready") + _patch_backtester( + monkeypatch, + side_effect=[make_backtest_response(wape=10.0), ValueError("insufficient data")], + ) + request = _request( + candidate_models=[ + {"model_type": "naive", "params": {}}, + {"model_type": "seasonal_naive", "params": {"season_length": 7}}, + ] + ) + response = await ModelSelectionService().run_selection(make_mock_db(), request) + + assert response.status == "partial" + assert response.winner is not None + assert response.winner.model_type == "naive" + failed = [e for e in response.ranking if not e.included] + assert [e.model_type for e in failed] == ["seasonal_naive"] + assert failed[0].exclusion_reason is not None + + +async def test_run_selection_all_candidates_fail_returns_failed_status_not_500( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """LOCKED #3 — every candidate failing persists FAILED and returns (no raise).""" + _patch_availability(monkeypatch, "ready") + _patch_backtester(monkeypatch, side_effect=[ValueError("boom-1"), ValueError("boom-2")]) + request = _request( + candidate_models=[ + {"model_type": "naive", "params": {}}, + {"model_type": "seasonal_naive", "params": {"season_length": 7}}, + ] + ) + response = await ModelSelectionService().run_selection(make_mock_db(), request) + + assert response.status == "failed" + assert response.winner is None + assert response.selection_id + assert all(not e.included for e in response.ranking) + + +async def test_run_selection_unusable_availability_raises_bad_request( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """LOCKED #2 — unusable availability fails fast with 400.""" + _patch_availability(monkeypatch, "unusable") + with pytest.raises(BadRequestError): + await ModelSelectionService().run_selection(make_mock_db(), _request()) + + +async def test_run_selection_auto_train_passes_feature_frame_version_and_groups( + monkeypatch: pytest.MonkeyPatch, +) -> None: + _patch_availability(monkeypatch, "ready") + _patch_backtester(monkeypatch, side_effect=[make_backtest_response(wape=10.0)]) + train_mock = AsyncMock( + return_value=SimpleNamespace(model_path="artifacts/models/model_abc.joblib") + ) + monkeypatch.setattr( + "app.features.forecasting.service.ForecastingService", + lambda: SimpleNamespace(train_model=train_mock), + ) + request = _request( + feature_frame_version=2, + feature_groups=["calendar"], + auto_train_winner=True, + auto_predict=False, + ) + response = await ModelSelectionService().run_selection(make_mock_db(), request) + + assert response.final_model == {"model_path": "artifacts/models/model_abc.joblib"} + train_mock.assert_awaited_once() + assert train_mock.await_args is not None + assert train_mock.await_args.kwargs["feature_frame_version"] == 2 + assert train_mock.await_args.kwargs["feature_groups"] == ["calendar"] + + +async def test_response_uses_recommendation_confidence_key( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """The response carries ``recommendation_confidence`` (not ``confidence``).""" + _patch_availability(monkeypatch, "ready") + _patch_backtester( + monkeypatch, + side_effect=[make_backtest_response(wape=10.0), make_backtest_response(wape=20.0)], + ) + request = _request( + candidate_models=[ + {"model_type": "naive", "params": {}}, + {"model_type": "seasonal_naive", "params": {"season_length": 7}}, + ] + ) + response = await ModelSelectionService().run_selection(make_mock_db(), request) + dumped = response.model_dump() + assert "recommendation_confidence" in dumped + assert "confidence" not in dumped + assert response.recommendation_confidence in {"high", "medium", "low"} + assert response.chart_data is not None + + +async def test_get_selection_missing_raises_not_found() -> None: + db = AsyncMock() + db.scalar = AsyncMock(return_value=None) + with pytest.raises(NotFoundError): + await ModelSelectionService().get_selection(db, uuid4().hex) diff --git a/app/main.py b/app/main.py index eb4f5145..1533ce50 100644 --- a/app/main.py +++ b/app/main.py @@ -26,6 +26,7 @@ from app.features.forecasting.routes import router as forecasting_router from app.features.ingest.routes import router as ingest_router from app.features.jobs.routes import router as jobs_router +from app.features.model_selection.routes import router as model_selection_router from app.features.ops.routes import router as ops_router from app.features.rag.routes import router as rag_router from app.features.registry.routes import router as registry_router @@ -145,6 +146,7 @@ def create_app() -> FastAPI: app.include_router(forecasting_router) app.include_router(explainability_router) app.include_router(backtesting_router) + app.include_router(model_selection_router) app.include_router(registry_router) app.include_router(rag_router) app.include_router(scenarios_router)