Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 157 additions & 0 deletions app/features/model_selection/capabilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
"""Pure model-capability catalog for the champion selector (issue #356, Slice A).

No DB, no I/O — :func:`build_model_catalog` is deterministic and unit-tested
directly (mirrors ``ranking.py`` / ``explanations.py``). It surfaces the
forecasting model union as a frontend-consumable catalog so the React
``MODEL_FAMILY_MAP`` / labels never drift from the Python authority.

Capability provenance (BACKEND-OWNED, verified 2026-06-01):
- ``family`` — ``forecasting.feature_metadata.model_family_for`` (lazy
cross-slice import inside the builder, per the slice's import discipline).
- ``feature_aware`` — the set whose forecasters set ``requires_features=True``
(RandomForest/Regression/LightGBM/XGBoost/ProphetLike), i.e. exactly the set
``ForecastingService.predict()`` rejects (``forecasting/service.py``).
- ``requires_extra`` — ``lightgbm``/``xgboost`` (opt-in extras that may
``ImportError`` when the extra is not installed).
- ``supports_auto_predict`` — ``not feature_aware`` (feature-aware winners
forecast through ``POST /scenarios/simulate``, not the plain predict path).
- ``default_params`` — the FLAT model-tuning defaults pinned from the live
``forecasting.schemas.ModelConfig`` members (the internal ``schema_version``
and ``feature_config_hash`` meta fields are intentionally omitted).
"""

from __future__ import annotations

from dataclasses import dataclass, field

from app.features.model_selection.schemas import (
CandidateModelInfo,
ModelCatalogResponse,
)

# Models gated behind the matching opt-in extra (may be absent at runtime).
_REQUIRES_EXTRA: frozenset[str] = frozenset({"lightgbm", "xgboost"})

# Feature-aware models — their forecasters set ``requires_features=True`` and
# ``ForecastingService.predict()`` rejects them (they need an exogenous feature
# frame). Verified against ``forecasting/models.py`` requires_features flags.
_FEATURE_AWARE: frozenset[str] = frozenset(
{"regression", "prophet_like", "lightgbm", "xgboost", "random_forest"}
)

# The default candidate set the backend ``POST /run`` contract documents — the
# UI pre-selects exactly these.
DEFAULT_CANDIDATE_MODEL_TYPES: list[str] = [
"naive",
"seasonal_naive",
"moving_average",
"regression",
"prophet_like",
]


@dataclass(frozen=True)
class _CatalogEntry:
"""Slice-local presentation metadata for one model_type."""

label: str
description: str
default_params: dict[str, object] = field(default_factory=lambda: {})


# Ordered map: model_type → presentation metadata. The KEYS must equal the
# ``ModelType`` Literal in ``schemas.py`` exactly (asserted in
# ``test_capabilities.py``). ``default_params`` are the flat model-tuning
# defaults from the forecasting ``ModelConfig`` members (schema_version /
# feature_config_hash meta fields omitted), pinned 2026-06-01.
_CATALOG: dict[str, _CatalogEntry] = {
"naive": _CatalogEntry(
label="Naive",
description="Repeats the last observed value.",
),
"seasonal_naive": _CatalogEntry(
label="Seasonal Naive",
description="Repeats the value from one season ago.",
default_params={"season_length": 7},
),
"moving_average": _CatalogEntry(
label="Moving Average",
description="Averages the last N observed values.",
default_params={"window_size": 7},
),
"weighted_moving_average": _CatalogEntry(
label="Weighted Moving Average",
description="Recency-weighted average of the last N values.",
default_params={"window_size": 7, "weight_strategy": "linear", "decay": 0.7},
),
"seasonal_average": _CatalogEntry(
label="Seasonal Average",
description="Averages the same season-position across recent cycles.",
default_params={"season_length": 7, "lookback_cycles": 4, "trim_outliers": False},
),
"trend_regression_baseline": _CatalogEntry(
label="Trend Regression Baseline",
description="Ridge trend with optional day-of-week / month terms.",
default_params={"alpha": 1.0, "include_dow": True, "include_month": True},
),
"random_forest": _CatalogEntry(
label="Random Forest",
description="Feature-aware random-forest regressor over lag/calendar features.",
default_params={"n_estimators": 100, "max_depth": 10, "min_samples_leaf": 2},
),
"lightgbm": _CatalogEntry(
label="LightGBM",
description="Gradient-boosted trees (opt-in extra) over engineered features.",
default_params={"n_estimators": 100, "max_depth": 6, "learning_rate": 0.1},
),
"xgboost": _CatalogEntry(
label="XGBoost",
description="Extreme gradient boosting (opt-in extra) over engineered features.",
default_params={"n_estimators": 100, "max_depth": 6, "learning_rate": 0.1},
),
"regression": _CatalogEntry(
label="Gradient Boosting Regression",
description="Histogram gradient-boosting over lag, calendar, and exogenous features.",
default_params={"max_iter": 200, "learning_rate": 0.05, "max_depth": 6},
),
"prophet_like": _CatalogEntry(
label="Prophet-like Additive",
description="Additive trend/seasonality Ridge over engineered features.",
default_params={"alpha": 1.0},
),
}


def build_model_catalog() -> ModelCatalogResponse:
"""Build the backend-owned candidate-model catalog (pure, no I/O).

Iterates the slice-local ``_CATALOG`` in declaration order, deriving each
entry's ``family`` from the forecasting authority and its capability flags
from the module-level sets. Returns the full catalog plus the documented
default candidate set.
"""
# Lazy cross-slice import (mirror service.py) — avoids closing an alembic
# cold-boot import cycle through the forecasting slice.
from app.features.forecasting.feature_metadata import model_family_for

models: list[CandidateModelInfo] = []
for model_type, meta in _CATALOG.items():
feature_aware = model_type in _FEATURE_AWARE
models.append(
CandidateModelInfo(
model_type=model_type,
label=meta.label,
# ``ModelFamily`` is a ``str, Enum`` whose ``.value`` is already
# typed as the ``baseline|tree|additive`` literal the schema wants.
family=model_family_for(model_type).value,
feature_aware=feature_aware,
requires_extra=model_type in _REQUIRES_EXTRA,
default_params=dict(meta.default_params),
supports_auto_predict=not feature_aware,
description=meta.description,
)
)
return ModelCatalogResponse(
models=models,
default_candidate_model_types=list(DEFAULT_CANDIDATE_MODEL_TYPES),
)
17 changes: 17 additions & 0 deletions app/features/model_selection/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from app.core.exceptions import BadRequestError, DatabaseError
from app.core.logging import get_logger
from app.features.model_selection.schemas import (
ModelCatalogResponse,
ModelSelectionRunRequest,
ModelSelectionRunResponse,
PairAvailabilityResponse,
Expand Down Expand Up @@ -62,6 +63,22 @@ async def get_availability(
) from exc


@router.get(
"/models",
response_model=ModelCatalogResponse,
status_code=status.HTTP_200_OK,
summary="List the backend-owned candidate-model capability catalog",
)
async def get_model_catalog() -> ModelCatalogResponse:
"""Return the static candidate-model catalog (no DB, no query params).

Declared BEFORE ``GET /{selection_id}`` so Starlette matches the literal
``/models`` path and does not capture it as ``selection_id="models"``.
"""
service = ModelSelectionService()
return service.get_model_catalog()


@router.post(
"/run",
response_model=ModelSelectionRunResponse,
Expand Down
26 changes: 26 additions & 0 deletions app/features/model_selection/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,32 @@ class ModelSelectionRunResponse(BaseModel):
completed_at: datetime | None


class CandidateModelInfo(BaseModel):
"""One selectable forecasting model in the capability catalog.

Output-only (plain ``BaseModel`` — no strict coercion needed). The
capability flags are BACKEND-OWNED: they derive from the forecasting
authority (``model_family_for`` + each forecaster's ``requires_features``)
so the frontend never re-derives families/feature-awareness in TypeScript.
"""

model_type: str
label: str
family: Literal["baseline", "tree", "additive"]
feature_aware: bool
requires_extra: bool # lightgbm/xgboost — opt-in extra may be absent at runtime
default_params: dict[str, Any]
supports_auto_predict: bool # False for feature-aware models (predict() rejects them)
description: str


class ModelCatalogResponse(BaseModel):
"""``GET /model-selection/models`` — backend-owned candidate catalog."""

models: list[CandidateModelInfo]
default_candidate_model_types: list[str]


class TrainWinnerResponse(BaseModel):
"""``POST /model-selection/{id}/train-winner`` response."""

Expand Down
14 changes: 14 additions & 0 deletions app/features/model_selection/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from app.core.logging import get_logger
from app.features.backtesting.schemas import SplitConfig
from app.features.data_platform.models import Product, Promotion, SalesDaily, Store
from app.features.model_selection.capabilities import build_model_catalog
from app.features.model_selection.explanations import explain_winner
from app.features.model_selection.models import ModelSelectionRun, ModelSelectionStatus
from app.features.model_selection.ranking import build_chart_data, rank_candidates
Expand All @@ -36,6 +37,7 @@
ChartData,
FoldChart,
ForecastSummary,
ModelCatalogResponse,
ModelSelectionRunRequest,
ModelSelectionRunResponse,
PairAvailabilityResponse,
Expand Down Expand Up @@ -64,6 +66,18 @@
class ModelSelectionService:
"""Stateless orchestrator — a fresh ``db`` session per method."""

# -------------------------------------------------------------------------
# Capability catalog
# -------------------------------------------------------------------------

def get_model_catalog(self) -> ModelCatalogResponse:
"""Return the backend-owned candidate-model catalog (static, no I/O).

Thin pass-through to the pure :func:`capabilities.build_model_catalog`;
kept on the service for symmetry with ``get_availability`` / ``run``.
"""
return build_model_catalog()

# -------------------------------------------------------------------------
# Availability
# -------------------------------------------------------------------------
Expand Down
102 changes: 102 additions & 0 deletions app/features/model_selection/tests/test_capabilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""Unit tests for the pure model-capability catalog (issue #356, Slice A).

No DB, no I/O — exercises ``build_model_catalog`` directly, mirroring
``test_ranking.py``. These pin the BACKEND-OWNED capability contract the
frontend consumes read-only.
"""

from __future__ import annotations

import typing

from app.features.model_selection.capabilities import (
DEFAULT_CANDIDATE_MODEL_TYPES,
build_model_catalog,
)
from app.features.model_selection.schemas import ModelType

_EXPECTED_MODEL_TYPES = set(typing.get_args(ModelType))


def test_catalog_model_types_match_literal() -> None:
"""The catalog covers EXACTLY the ``ModelType`` Literal — no drift."""
catalog = build_model_catalog()
catalog_types = {m.model_type for m in catalog.models}
assert catalog_types == _EXPECTED_MODEL_TYPES
# 11 models, no duplicates.
assert len(catalog.models) == len(_EXPECTED_MODEL_TYPES) == 11


def test_catalog_families_are_valid_literals() -> None:
"""Every family is one of the three lowercase literals from forecasting."""
catalog = build_model_catalog()
for model in catalog.models:
assert model.family in {"baseline", "tree", "additive"}


def test_requires_extra_flags_lightgbm_xgboost_only() -> None:
"""Only the opt-in extras (lightgbm/xgboost) carry requires_extra=True."""
catalog = build_model_catalog()
extras = {m.model_type for m in catalog.models if m.requires_extra}
assert extras == {"lightgbm", "xgboost"}


def test_feature_aware_set_matches_predict_reject_set() -> None:
"""feature_aware == the forecasters with requires_features=True."""
catalog = build_model_catalog()
feature_aware = {m.model_type for m in catalog.models if m.feature_aware}
assert feature_aware == {
"regression",
"prophet_like",
"lightgbm",
"xgboost",
"random_forest",
}


def test_feature_aware_models_do_not_support_auto_predict() -> None:
"""supports_auto_predict is the strict negation of feature_aware."""
catalog = build_model_catalog()
for model in catalog.models:
assert model.supports_auto_predict == (not model.feature_aware)


def test_default_candidate_model_types_are_the_default_five() -> None:
"""The pre-selected defaults match the backend /run contract example."""
catalog = build_model_catalog()
assert catalog.default_candidate_model_types == [
"naive",
"seasonal_naive",
"moving_average",
"regression",
"prophet_like",
]
# The exported constant and the response agree.
assert DEFAULT_CANDIDATE_MODEL_TYPES == catalog.default_candidate_model_types
# Every default is a real catalog entry.
catalog_types = {m.model_type for m in catalog.models}
assert set(catalog.default_candidate_model_types) <= catalog_types


def test_default_params_match_forecasting_defaults() -> None:
"""default_params are pinned to the live forecasting ModelConfig defaults."""
by_type = {m.model_type: m.default_params for m in build_model_catalog().models}
assert by_type["naive"] == {}
assert by_type["seasonal_naive"] == {"season_length": 7}
assert by_type["moving_average"] == {"window_size": 7}
assert by_type["regression"] == {
"max_iter": 200,
"learning_rate": 0.05,
"max_depth": 6,
}
# No internal/meta fields leak into the catalog.
for params in by_type.values():
assert "schema_version" not in params
assert "feature_config_hash" not in params


def test_labels_and_descriptions_are_non_empty() -> None:
"""Each entry carries human-facing label + description copy."""
for model in build_model_catalog().models:
assert model.label.strip()
assert model.description.strip()
Loading
Loading