diff --git a/app/features/model_selection/capabilities.py b/app/features/model_selection/capabilities.py new file mode 100644 index 00000000..5c513496 --- /dev/null +++ b/app/features/model_selection/capabilities.py @@ -0,0 +1,157 @@ +"""Pure model-capability catalog for the champion selector (issue #356, Slice A). + +No DB, no I/O — :func:`build_model_catalog` is deterministic and unit-tested +directly (mirrors ``ranking.py`` / ``explanations.py``). It surfaces the +forecasting model union as a frontend-consumable catalog so the React +``MODEL_FAMILY_MAP`` / labels never drift from the Python authority. + +Capability provenance (BACKEND-OWNED, verified 2026-06-01): +- ``family`` — ``forecasting.feature_metadata.model_family_for`` (lazy + cross-slice import inside the builder, per the slice's import discipline). +- ``feature_aware`` — the set whose forecasters set ``requires_features=True`` + (RandomForest/Regression/LightGBM/XGBoost/ProphetLike), i.e. exactly the set + ``ForecastingService.predict()`` rejects (``forecasting/service.py``). +- ``requires_extra`` — ``lightgbm``/``xgboost`` (opt-in extras that may + ``ImportError`` when the extra is not installed). +- ``supports_auto_predict`` — ``not feature_aware`` (feature-aware winners + forecast through ``POST /scenarios/simulate``, not the plain predict path). +- ``default_params`` — the FLAT model-tuning defaults pinned from the live + ``forecasting.schemas.ModelConfig`` members (the internal ``schema_version`` + and ``feature_config_hash`` meta fields are intentionally omitted). +""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +from app.features.model_selection.schemas import ( + CandidateModelInfo, + ModelCatalogResponse, +) + +# Models gated behind the matching opt-in extra (may be absent at runtime). +_REQUIRES_EXTRA: frozenset[str] = frozenset({"lightgbm", "xgboost"}) + +# Feature-aware models — their forecasters set ``requires_features=True`` and +# ``ForecastingService.predict()`` rejects them (they need an exogenous feature +# frame). Verified against ``forecasting/models.py`` requires_features flags. +_FEATURE_AWARE: frozenset[str] = frozenset( + {"regression", "prophet_like", "lightgbm", "xgboost", "random_forest"} +) + +# The default candidate set the backend ``POST /run`` contract documents — the +# UI pre-selects exactly these. +DEFAULT_CANDIDATE_MODEL_TYPES: list[str] = [ + "naive", + "seasonal_naive", + "moving_average", + "regression", + "prophet_like", +] + + +@dataclass(frozen=True) +class _CatalogEntry: + """Slice-local presentation metadata for one model_type.""" + + label: str + description: str + default_params: dict[str, object] = field(default_factory=lambda: {}) + + +# Ordered map: model_type → presentation metadata. The KEYS must equal the +# ``ModelType`` Literal in ``schemas.py`` exactly (asserted in +# ``test_capabilities.py``). ``default_params`` are the flat model-tuning +# defaults from the forecasting ``ModelConfig`` members (schema_version / +# feature_config_hash meta fields omitted), pinned 2026-06-01. +_CATALOG: dict[str, _CatalogEntry] = { + "naive": _CatalogEntry( + label="Naive", + description="Repeats the last observed value.", + ), + "seasonal_naive": _CatalogEntry( + label="Seasonal Naive", + description="Repeats the value from one season ago.", + default_params={"season_length": 7}, + ), + "moving_average": _CatalogEntry( + label="Moving Average", + description="Averages the last N observed values.", + default_params={"window_size": 7}, + ), + "weighted_moving_average": _CatalogEntry( + label="Weighted Moving Average", + description="Recency-weighted average of the last N values.", + default_params={"window_size": 7, "weight_strategy": "linear", "decay": 0.7}, + ), + "seasonal_average": _CatalogEntry( + label="Seasonal Average", + description="Averages the same season-position across recent cycles.", + default_params={"season_length": 7, "lookback_cycles": 4, "trim_outliers": False}, + ), + "trend_regression_baseline": _CatalogEntry( + label="Trend Regression Baseline", + description="Ridge trend with optional day-of-week / month terms.", + default_params={"alpha": 1.0, "include_dow": True, "include_month": True}, + ), + "random_forest": _CatalogEntry( + label="Random Forest", + description="Feature-aware random-forest regressor over lag/calendar features.", + default_params={"n_estimators": 100, "max_depth": 10, "min_samples_leaf": 2}, + ), + "lightgbm": _CatalogEntry( + label="LightGBM", + description="Gradient-boosted trees (opt-in extra) over engineered features.", + default_params={"n_estimators": 100, "max_depth": 6, "learning_rate": 0.1}, + ), + "xgboost": _CatalogEntry( + label="XGBoost", + description="Extreme gradient boosting (opt-in extra) over engineered features.", + default_params={"n_estimators": 100, "max_depth": 6, "learning_rate": 0.1}, + ), + "regression": _CatalogEntry( + label="Gradient Boosting Regression", + description="Histogram gradient-boosting over lag, calendar, and exogenous features.", + default_params={"max_iter": 200, "learning_rate": 0.05, "max_depth": 6}, + ), + "prophet_like": _CatalogEntry( + label="Prophet-like Additive", + description="Additive trend/seasonality Ridge over engineered features.", + default_params={"alpha": 1.0}, + ), +} + + +def build_model_catalog() -> ModelCatalogResponse: + """Build the backend-owned candidate-model catalog (pure, no I/O). + + Iterates the slice-local ``_CATALOG`` in declaration order, deriving each + entry's ``family`` from the forecasting authority and its capability flags + from the module-level sets. Returns the full catalog plus the documented + default candidate set. + """ + # Lazy cross-slice import (mirror service.py) — avoids closing an alembic + # cold-boot import cycle through the forecasting slice. + from app.features.forecasting.feature_metadata import model_family_for + + models: list[CandidateModelInfo] = [] + for model_type, meta in _CATALOG.items(): + feature_aware = model_type in _FEATURE_AWARE + models.append( + CandidateModelInfo( + model_type=model_type, + label=meta.label, + # ``ModelFamily`` is a ``str, Enum`` whose ``.value`` is already + # typed as the ``baseline|tree|additive`` literal the schema wants. + family=model_family_for(model_type).value, + feature_aware=feature_aware, + requires_extra=model_type in _REQUIRES_EXTRA, + default_params=dict(meta.default_params), + supports_auto_predict=not feature_aware, + description=meta.description, + ) + ) + return ModelCatalogResponse( + models=models, + default_candidate_model_types=list(DEFAULT_CANDIDATE_MODEL_TYPES), + ) diff --git a/app/features/model_selection/routes.py b/app/features/model_selection/routes.py index f989aac0..f4f833c7 100644 --- a/app/features/model_selection/routes.py +++ b/app/features/model_selection/routes.py @@ -24,6 +24,7 @@ from app.core.exceptions import BadRequestError, DatabaseError from app.core.logging import get_logger from app.features.model_selection.schemas import ( + ModelCatalogResponse, ModelSelectionRunRequest, ModelSelectionRunResponse, PairAvailabilityResponse, @@ -62,6 +63,22 @@ async def get_availability( ) from exc +@router.get( + "/models", + response_model=ModelCatalogResponse, + status_code=status.HTTP_200_OK, + summary="List the backend-owned candidate-model capability catalog", +) +async def get_model_catalog() -> ModelCatalogResponse: + """Return the static candidate-model catalog (no DB, no query params). + + Declared BEFORE ``GET /{selection_id}`` so Starlette matches the literal + ``/models`` path and does not capture it as ``selection_id="models"``. + """ + service = ModelSelectionService() + return service.get_model_catalog() + + @router.post( "/run", response_model=ModelSelectionRunResponse, diff --git a/app/features/model_selection/schemas.py b/app/features/model_selection/schemas.py index 9fc10d37..d3bc45dd 100644 --- a/app/features/model_selection/schemas.py +++ b/app/features/model_selection/schemas.py @@ -288,6 +288,32 @@ class ModelSelectionRunResponse(BaseModel): completed_at: datetime | None +class CandidateModelInfo(BaseModel): + """One selectable forecasting model in the capability catalog. + + Output-only (plain ``BaseModel`` — no strict coercion needed). The + capability flags are BACKEND-OWNED: they derive from the forecasting + authority (``model_family_for`` + each forecaster's ``requires_features``) + so the frontend never re-derives families/feature-awareness in TypeScript. + """ + + model_type: str + label: str + family: Literal["baseline", "tree", "additive"] + feature_aware: bool + requires_extra: bool # lightgbm/xgboost — opt-in extra may be absent at runtime + default_params: dict[str, Any] + supports_auto_predict: bool # False for feature-aware models (predict() rejects them) + description: str + + +class ModelCatalogResponse(BaseModel): + """``GET /model-selection/models`` — backend-owned candidate catalog.""" + + models: list[CandidateModelInfo] + default_candidate_model_types: list[str] + + class TrainWinnerResponse(BaseModel): """``POST /model-selection/{id}/train-winner`` response.""" diff --git a/app/features/model_selection/service.py b/app/features/model_selection/service.py index ff7111e8..b8536068 100644 --- a/app/features/model_selection/service.py +++ b/app/features/model_selection/service.py @@ -26,6 +26,7 @@ from app.core.logging import get_logger from app.features.backtesting.schemas import SplitConfig from app.features.data_platform.models import Product, Promotion, SalesDaily, Store +from app.features.model_selection.capabilities import build_model_catalog from app.features.model_selection.explanations import explain_winner from app.features.model_selection.models import ModelSelectionRun, ModelSelectionStatus from app.features.model_selection.ranking import build_chart_data, rank_candidates @@ -36,6 +37,7 @@ ChartData, FoldChart, ForecastSummary, + ModelCatalogResponse, ModelSelectionRunRequest, ModelSelectionRunResponse, PairAvailabilityResponse, @@ -64,6 +66,18 @@ class ModelSelectionService: """Stateless orchestrator — a fresh ``db`` session per method.""" + # ------------------------------------------------------------------------- + # Capability catalog + # ------------------------------------------------------------------------- + + def get_model_catalog(self) -> ModelCatalogResponse: + """Return the backend-owned candidate-model catalog (static, no I/O). + + Thin pass-through to the pure :func:`capabilities.build_model_catalog`; + kept on the service for symmetry with ``get_availability`` / ``run``. + """ + return build_model_catalog() + # ------------------------------------------------------------------------- # Availability # ------------------------------------------------------------------------- diff --git a/app/features/model_selection/tests/test_capabilities.py b/app/features/model_selection/tests/test_capabilities.py new file mode 100644 index 00000000..3ff73804 --- /dev/null +++ b/app/features/model_selection/tests/test_capabilities.py @@ -0,0 +1,102 @@ +"""Unit tests for the pure model-capability catalog (issue #356, Slice A). + +No DB, no I/O — exercises ``build_model_catalog`` directly, mirroring +``test_ranking.py``. These pin the BACKEND-OWNED capability contract the +frontend consumes read-only. +""" + +from __future__ import annotations + +import typing + +from app.features.model_selection.capabilities import ( + DEFAULT_CANDIDATE_MODEL_TYPES, + build_model_catalog, +) +from app.features.model_selection.schemas import ModelType + +_EXPECTED_MODEL_TYPES = set(typing.get_args(ModelType)) + + +def test_catalog_model_types_match_literal() -> None: + """The catalog covers EXACTLY the ``ModelType`` Literal — no drift.""" + catalog = build_model_catalog() + catalog_types = {m.model_type for m in catalog.models} + assert catalog_types == _EXPECTED_MODEL_TYPES + # 11 models, no duplicates. + assert len(catalog.models) == len(_EXPECTED_MODEL_TYPES) == 11 + + +def test_catalog_families_are_valid_literals() -> None: + """Every family is one of the three lowercase literals from forecasting.""" + catalog = build_model_catalog() + for model in catalog.models: + assert model.family in {"baseline", "tree", "additive"} + + +def test_requires_extra_flags_lightgbm_xgboost_only() -> None: + """Only the opt-in extras (lightgbm/xgboost) carry requires_extra=True.""" + catalog = build_model_catalog() + extras = {m.model_type for m in catalog.models if m.requires_extra} + assert extras == {"lightgbm", "xgboost"} + + +def test_feature_aware_set_matches_predict_reject_set() -> None: + """feature_aware == the forecasters with requires_features=True.""" + catalog = build_model_catalog() + feature_aware = {m.model_type for m in catalog.models if m.feature_aware} + assert feature_aware == { + "regression", + "prophet_like", + "lightgbm", + "xgboost", + "random_forest", + } + + +def test_feature_aware_models_do_not_support_auto_predict() -> None: + """supports_auto_predict is the strict negation of feature_aware.""" + catalog = build_model_catalog() + for model in catalog.models: + assert model.supports_auto_predict == (not model.feature_aware) + + +def test_default_candidate_model_types_are_the_default_five() -> None: + """The pre-selected defaults match the backend /run contract example.""" + catalog = build_model_catalog() + assert catalog.default_candidate_model_types == [ + "naive", + "seasonal_naive", + "moving_average", + "regression", + "prophet_like", + ] + # The exported constant and the response agree. + assert DEFAULT_CANDIDATE_MODEL_TYPES == catalog.default_candidate_model_types + # Every default is a real catalog entry. + catalog_types = {m.model_type for m in catalog.models} + assert set(catalog.default_candidate_model_types) <= catalog_types + + +def test_default_params_match_forecasting_defaults() -> None: + """default_params are pinned to the live forecasting ModelConfig defaults.""" + by_type = {m.model_type: m.default_params for m in build_model_catalog().models} + assert by_type["naive"] == {} + assert by_type["seasonal_naive"] == {"season_length": 7} + assert by_type["moving_average"] == {"window_size": 7} + assert by_type["regression"] == { + "max_iter": 200, + "learning_rate": 0.05, + "max_depth": 6, + } + # No internal/meta fields leak into the catalog. + for params in by_type.values(): + assert "schema_version" not in params + assert "feature_config_hash" not in params + + +def test_labels_and_descriptions_are_non_empty() -> None: + """Each entry carries human-facing label + description copy.""" + for model in build_model_catalog().models: + assert model.label.strip() + assert model.description.strip() diff --git a/app/features/model_selection/tests/test_routes.py b/app/features/model_selection/tests/test_routes.py index 7cfb35f5..2effbc62 100644 --- a/app/features/model_selection/tests/test_routes.py +++ b/app/features/model_selection/tests/test_routes.py @@ -178,3 +178,53 @@ async def test_availability_rejects_bad_query() -> None: ) assert response.status_code == 422 _assert_problem_detail(response.json(), 422) + + +async def test_get_models_returns_catalog_200() -> None: + """GET /model-selection/models returns the static catalog (no mock needed).""" + async with _client() as ac: + response = await ac.get("/model-selection/models") + assert response.status_code == 200 + body = response.json() + assert isinstance(body["models"], list) + assert len(body["models"]) == 11 + # Each entry carries the backend-owned capability contract. + first = body["models"][0] + for key in ( + "model_type", + "label", + "family", + "feature_aware", + "requires_extra", + "default_params", + "supports_auto_predict", + "description", + ): + assert key in first, f"missing catalog field: {key}" + assert body["default_candidate_model_types"] == [ + "naive", + "seasonal_naive", + "moving_average", + "regression", + "prophet_like", + ] + + +async def test_models_route_not_captured_by_selection_id( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Literal /models must NOT be matched as GET /{selection_id}. + + If route ordering regressed, the request would hit ``get_selection`` (here + forced to 404) instead of the catalog handler. We assert the catalog shape + comes back, proving the literal-before-path-param ordering holds. + """ + monkeypatch.setattr( + ModelSelectionService, + "get_selection", + AsyncMock(side_effect=NotFoundError(message="selection run models not found")), + ) + async with _client() as ac: + response = await ac.get("/model-selection/models") + assert response.status_code == 200 + assert "models" in response.json() diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 1ef34bf1..2dc4042f 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -26,6 +26,7 @@ const BacktestPage = lazy(() => import('@/pages/visualize/backtest')) const DemandPlannerPage = lazy(() => import('@/pages/visualize/demand')) const WhatIfPlannerPage = lazy(() => import('@/pages/visualize/planner')) const BatchRunnerPage = lazy(() => import('@/pages/visualize/batch')) +const ChampionSelectorPage = lazy(() => import('@/pages/visualize/champion')) const ChatPage = lazy(() => import('@/pages/chat')) const KnowledgePage = lazy(() => import('@/pages/knowledge')) const GuidePage = lazy(() => import('@/pages/guide')) @@ -186,6 +187,14 @@ function App() { } /> + }> + + + } + /> = {}): PairAvailability { + return { + store_id: 7, + product_id: 12, + first_sales_date: '2026-01-01', + last_sales_date: '2026-05-31', + observed_days: 150, + expected_calendar_days: 151, + coverage_ratio: 0.99, + missing_days: 1, + zero_sale_days: 4, + promotion_days: 3, + average_daily_demand: 9.2, + status: 'ready', + recommended_split_config: { + strategy: 'expanding', + n_splits: 5, + min_train_size: 30, + gap: 0, + horizon: 14, + }, + warnings: [], + ...overrides, + } +} + +describe('AvailabilityPanel', () => { + it('renders status badge + metric tiles for a ready pair', () => { + render( + , + ) + expect(screen.getByTestId('availability-panel')).toBeTruthy() + expect(screen.getByTestId('availability-status-badge').textContent).toContain('Ready') + expect(screen.getByText('Observed days')).toBeTruthy() + expect(screen.getByText('Avg daily demand')).toBeTruthy() + }) + + it('renders the not-enough-data empty state for an unusable pair', () => { + render( + , + ) + expect(screen.queryByTestId('availability-panel')).toBeNull() + expect(screen.getByText('Not enough data to model this pair')).toBeTruthy() + }) + + it('renders an em dash when promotion_days is null', () => { + render( + , + ) + expect(screen.getByText('—')).toBeTruthy() + }) + + it('shows a loading state while assessing', () => { + render() + expect(screen.getByText('Assessing data availability…')).toBeTruthy() + }) +}) diff --git a/frontend/src/components/champion-selector/availability-panel.tsx b/frontend/src/components/champion-selector/availability-panel.tsx new file mode 100644 index 00000000..3dfa7370 --- /dev/null +++ b/frontend/src/components/champion-selector/availability-panel.tsx @@ -0,0 +1,146 @@ +import { AlertTriangle, DatabaseZap } from 'lucide-react' +import { EmptyState } from '@/components/common/error-display' +import { LoadingState } from '@/components/common/loading-state' +import { Badge } from '@/components/ui/badge' +import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card' +import { formatNumber, formatPercent } from '@/lib/api' +import type { AvailabilityStatus, PairAvailability } from '@/types/api' + +interface AvailabilityPanelProps { + availability?: PairAvailability + isLoading: boolean + isError: boolean +} + +const STATUS_VARIANT: Record< + AvailabilityStatus, + 'default' | 'secondary' | 'destructive' +> = { + ready: 'default', + limited: 'secondary', + unusable: 'destructive', +} + +const STATUS_LABEL: Record = { + ready: 'Ready', + limited: 'Limited', + unusable: 'Unusable', +} + +function Metric({ label, value }: { label: string; value: string }) { + return ( +
+

{label}

+

{value}

+
+ ) +} + +/** + * Renders the (store, product) data-availability triage for the Champion + * Selector. Slice A surfaces the backend assessment only — no run, no charts. + */ +export function AvailabilityPanel({ + availability, + isLoading, + isError, +}: AvailabilityPanelProps) { + if (isLoading) { + return + } + + if (isError) { + return ( + } + /> + ) + } + + if (!availability) { + return ( + } + /> + ) + } + + // Not-enough-data state: an unusable pair or one with zero observed history. + if (availability.status === 'unusable' || availability.observed_days === 0) { + return ( + } + /> + ) + } + + const split = availability.recommended_split_config + + return ( + + +
+ Data availability + + {STATUS_LABEL[availability.status]} + +
+
+ +
+ + + + + +
+ +
+

+ Recommended split +

+

+ {split.strategy} · {split.n_splits} splits · min train{' '} + {split.min_train_size}d · gap {split.gap}d · horizon {split.horizon}d +

+
+ + {availability.warnings.length > 0 && ( +
    + {availability.warnings.map((warning, index) => ( +
  • + + {warning} +
  • + ))} +
+ )} +
+
+ ) +} diff --git a/frontend/src/components/champion-selector/backtest-settings-form.test.tsx b/frontend/src/components/champion-selector/backtest-settings-form.test.tsx new file mode 100644 index 00000000..b9df7a2b --- /dev/null +++ b/frontend/src/components/champion-selector/backtest-settings-form.test.tsx @@ -0,0 +1,120 @@ +import { afterEach, beforeAll, describe, expect, it, vi } from 'vitest' +import { cleanup, fireEvent, render, screen } from '@testing-library/react' +import { BacktestSettingsForm } from './backtest-settings-form' +import { splitConfigErrors } from './split-config' +import type { SplitConfig } from '@/types/api' + +// Radix Collapsible/Select need a couple of layout APIs jsdom lacks. +beforeAll(() => { + class ResizeObserverStub { + observe() {} + unobserve() {} + disconnect() {} + } + vi.stubGlobal('ResizeObserver', ResizeObserverStub) + if (!Element.prototype.hasPointerCapture) { + Element.prototype.hasPointerCapture = () => false + } + if (!Element.prototype.scrollIntoView) { + Element.prototype.scrollIntoView = () => {} + } +}) + +afterEach(cleanup) + +const VALID: SplitConfig = { + strategy: 'expanding', + n_splits: 5, + min_train_size: 30, + gap: 0, + horizon: 14, +} + +describe('splitConfigErrors', () => { + it('accepts a valid config', () => { + expect(splitConfigErrors(VALID)).toEqual([]) + }) + + it('flags out-of-range n_splits and gap >= horizon', () => { + const errors = splitConfigErrors({ ...VALID, n_splits: 1, gap: 14 }) + expect(errors.some((e) => e.includes('Splits'))).toBe(true) + expect(errors.some((e) => e.includes('Gap must be smaller'))).toBe(true) + }) +}) + +describe('BacktestSettingsForm', () => { + it('reveals the advanced split inputs when toggled', () => { + render( + {}} + onRankingMetricChange={() => {}} + />, + ) + // Hidden until the collapsible opens. + expect(screen.queryByTestId('settings-n-splits')).toBeNull() + fireEvent.click(screen.getByTestId('advanced-toggle')) + expect(screen.getByTestId('settings-n-splits')).toBeTruthy() + expect(screen.getByTestId('settings-gap')).toBeTruthy() + }) + + it('renders validation errors for an invalid config', () => { + render( + {}} + onRankingMetricChange={() => {}} + />, + ) + expect(screen.getByTestId('settings-errors')).toBeTruthy() + expect(screen.getByText(/Splits must be between 2 and 20/)).toBeTruthy() + }) + + it('"Use recommended split" emits the recommended config (horizon synced)', () => { + const onChange = vi.fn() + const recommended: SplitConfig = { + strategy: 'sliding', + n_splits: 8, + min_train_size: 45, + gap: 1, + horizon: 7, // intentionally different — must be overridden to forecastHorizon + } + render( + {}} + recommended={recommended} + />, + ) + fireEvent.click(screen.getByTestId('use-recommended-split')) + expect(onChange).toHaveBeenCalledWith({ + strategy: 'sliding', + n_splits: 8, + min_train_size: 45, + gap: 1, + horizon: 14, // synced to forecastHorizon + }) + }) + + it('keeps the horizon input read-only and equal to the forecast horizon', () => { + render( + {}} + onRankingMetricChange={() => {}} + />, + ) + const horizon = screen.getByTestId('settings-horizon') as HTMLInputElement + expect(horizon.value).toBe('21') + expect(horizon.readOnly).toBe(true) + }) +}) diff --git a/frontend/src/components/champion-selector/backtest-settings-form.tsx b/frontend/src/components/champion-selector/backtest-settings-form.tsx new file mode 100644 index 00000000..fdaca7f3 --- /dev/null +++ b/frontend/src/components/champion-selector/backtest-settings-form.tsx @@ -0,0 +1,206 @@ +import { useState } from 'react' +import { ChevronDown, Settings2, Wand2 } from 'lucide-react' +import { Button } from '@/components/ui/button' +import { + Collapsible, + CollapsibleContent, + CollapsibleTrigger, +} from '@/components/ui/collapsible' +import { Input } from '@/components/ui/input' +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from '@/components/ui/select' +import { cn } from '@/lib/utils' +import type { RankingMetric, SplitConfig, SplitStrategy } from '@/types/api' +import { BIAS_EXPLANATION, RANKING_TIE_BREAK } from './copy' +import { splitConfigErrors } from './split-config' + +interface BacktestSettingsFormProps { + value: SplitConfig + rankingMetric: RankingMetric + forecastHorizon: number + onChange: (next: SplitConfig) => void + onRankingMetricChange: (metric: RankingMetric) => void + recommended?: SplitConfig +} + +const RANKING_METRICS: { value: RankingMetric; label: string }[] = [ + { value: 'wape', label: 'WAPE (default)' }, + { value: 'smape', label: 'sMAPE' }, + { value: 'mae', label: 'MAE' }, + { value: 'bias', label: 'Bias' }, +] + +function Field({ + label, + children, + hint, +}: { + label: string + children: React.ReactNode + hint?: string +}) { + return ( +
+ {label} + {children} + {hint &&

{hint}

} +
+ ) +} + +/** + * Simple/advanced backtest-settings form. The horizon is DERIVED from + * `forecastHorizon` (kept equal so the assembled run request is always valid) + * and shown read-only. The advanced toggle reveals the split-CV knobs. + */ +export function BacktestSettingsForm({ + value, + rankingMetric, + forecastHorizon, + onChange, + onRankingMetricChange, + recommended, +}: BacktestSettingsFormProps) { + const [advancedOpen, setAdvancedOpen] = useState(false) + const errors = splitConfigErrors(value) + + function patch(partial: Partial) { + onChange({ ...value, ...partial, horizon: forecastHorizon }) + } + + return ( +
+
+ + + + + + +
+ + {recommended && ( + + )} + + + + + + +
+ + + + + + patch({ n_splits: Number(event.target.value) || 0 }) + } + /> + + + + patch({ min_train_size: Number(event.target.value) || 0 }) + } + /> + + + + patch({ gap: Number(event.target.value) || 0 }) + } + /> + +
+
+
+ + {errors.length > 0 && ( +
    + {errors.map((error) => ( +
  • + {error} +
  • + ))} +
+ )} +
+ ) +} diff --git a/frontend/src/components/champion-selector/candidate-model-picker.test.tsx b/frontend/src/components/champion-selector/candidate-model-picker.test.tsx new file mode 100644 index 00000000..8c7d171d --- /dev/null +++ b/frontend/src/components/champion-selector/candidate-model-picker.test.tsx @@ -0,0 +1,99 @@ +import { afterEach, describe, expect, it, vi } from 'vitest' +import { cleanup, fireEvent, render, screen } from '@testing-library/react' +import { CandidateModelPicker, MAX_CANDIDATES } from './candidate-model-picker' +import type { CandidateModelInfo, ModelCatalogResponse } from '@/types/api' + +afterEach(cleanup) + +function model( + model_type: string, + overrides: Partial = {}, +): CandidateModelInfo { + return { + model_type, + label: model_type, + family: 'baseline', + feature_aware: false, + requires_extra: false, + default_params: {}, + supports_auto_predict: true, + description: `desc ${model_type}`, + ...overrides, + } +} + +const CATALOG: ModelCatalogResponse = { + models: [ + model('naive'), + model('regression', { family: 'tree', feature_aware: true }), + model('lightgbm', { family: 'tree', feature_aware: true, requires_extra: true }), + model('xgboost', { family: 'tree', feature_aware: true, requires_extra: true }), + ], + default_candidate_model_types: ['naive', 'regression'], +} + +describe('CandidateModelPicker', () => { + it('toggling a model calls onChange with the new selection', () => { + const onChange = vi.fn() + render( + , + ) + fireEvent.click(screen.getByTestId('candidate-checkbox-regression')) + expect(onChange).toHaveBeenCalledWith(['naive', 'regression']) + }) + + it('deselects an already-selected model', () => { + const onChange = vi.fn() + render( + , + ) + fireEvent.click(screen.getByTestId('candidate-checkbox-naive')) + expect(onChange).toHaveBeenCalledWith(['regression']) + }) + + it('flags opt-in-extra models with an "extra" badge', () => { + render( + {}} + isLoading={false} + />, + ) + expect(screen.getByTestId('candidate-extra-badge-lightgbm')).toBeTruthy() + expect(screen.getByTestId('candidate-extra-badge-xgboost')).toBeTruthy() + // A baseline model carries no extra badge. + expect(screen.queryByTestId('candidate-extra-badge-naive')).toBeNull() + }) + + it('caps the selection at MAX_CANDIDATES and disables unselected models', () => { + const many = Array.from({ length: MAX_CANDIDATES }, (_, i) => `m${i}`) + const onChange = vi.fn() + const bigCatalog: ModelCatalogResponse = { + models: [...many.map((m) => model(m)), model('extra_model')], + default_candidate_model_types: [], + } + render( + , + ) + expect(screen.getByTestId('candidate-cap-badge')).toBeTruthy() + // Clicking an unselected model at the cap must NOT add it. + fireEvent.click(screen.getByTestId('candidate-checkbox-extra_model')) + expect(onChange).not.toHaveBeenCalled() + }) +}) diff --git a/frontend/src/components/champion-selector/candidate-model-picker.tsx b/frontend/src/components/champion-selector/candidate-model-picker.tsx new file mode 100644 index 00000000..6a3b4366 --- /dev/null +++ b/frontend/src/components/champion-selector/candidate-model-picker.tsx @@ -0,0 +1,129 @@ +import { LoadingState } from '@/components/common/loading-state' +import { Badge } from '@/components/ui/badge' +import { Checkbox } from '@/components/ui/checkbox' +import { cn } from '@/lib/utils' +import type { CandidateModelInfo, ModelCatalogResponse, ModelFamily } from '@/types/api' + +/** Backend caps `candidate_models` at 10 (ModelSelectionRunRequest.max_length). */ +export const MAX_CANDIDATES = 10 + +interface CandidateModelPickerProps { + catalog?: ModelCatalogResponse + selected: string[] + onChange: (types: string[]) => void + isLoading: boolean +} + +const FAMILY_ORDER: ModelFamily[] = ['baseline', 'additive', 'tree'] +const FAMILY_LABEL: Record = { + baseline: 'Baseline', + additive: 'Additive', + tree: 'Tree-based', +} + +/** + * Candidate-model multi-select fed by the BACKEND catalog (never the hardcoded + * `model-type-utils`). Mirrors the batch-matrix-picker conventions: a checkbox + * per model grouped by family, opt-in-extra + feature-aware badges, and a + * selection cap of 10. + */ +export function CandidateModelPicker({ + catalog, + selected, + onChange, + isLoading, +}: CandidateModelPickerProps) { + if (isLoading) { + return + } + if (!catalog || catalog.models.length === 0) { + return ( +

No models available.

+ ) + } + + const selectedSet = new Set(selected) + const atCap = selected.length >= MAX_CANDIDATES + + function toggle(modelType: string) { + if (selectedSet.has(modelType)) { + onChange(selected.filter((type) => type !== modelType)) + } else if (!atCap) { + onChange([...selected, modelType]) + } + } + + const byFamily = new Map() + for (const model of catalog.models) { + const list = byFamily.get(model.family) ?? [] + list.push(model) + byFamily.set(model.family, list) + } + + return ( +
+
+ + {selected.length} of {MAX_CANDIDATES} selected + + {atCap && ( + + Max {MAX_CANDIDATES} reached + + )} +
+ + {FAMILY_ORDER.filter((family) => byFamily.has(family)).map((family) => ( +
+

+ {FAMILY_LABEL[family]} +

+
+ {(byFamily.get(family) ?? []).map((model) => { + const isSelected = selectedSet.has(model.model_type) + const disabled = !isSelected && atCap + return ( + + ) + })} +
+
+ ))} +
+ ) +} diff --git a/frontend/src/components/champion-selector/copy.ts b/frontend/src/components/champion-selector/copy.ts new file mode 100644 index 00000000..bafbfd53 --- /dev/null +++ b/frontend/src/components/champion-selector/copy.ts @@ -0,0 +1,20 @@ +/** + * Shared, LOCKED copy for the Champion Selector workflow (Slices A/B/C). + * + * Kept in a `.ts` (not `.tsx`) module so the `react-refresh/only-export-components` + * lint rule never trips on these non-component exports. Slices B and C import + * the SAME constants so the bias wording / tie-break explanation never drift. + */ + +/** LOCKED #7 — the canonical bias explanation reused everywhere bias is shown. */ +export const BIAS_EXPLANATION = + 'Positive bias means the model under-forecasts (risk of stockouts); ' + + 'negative bias means it over-forecasts (risk of overstock).' + +/** LOCKED #8 — the deterministic ranking tie-break chain. */ +export const RANKING_TIE_BREAK = + 'Ranked by WAPE, then sMAPE, then |bias|, then MAE.' + +/** Copy for the disabled Slice-A "Run comparison" CTA. */ +export const RUN_COMPARISON_PENDING = + 'Model comparison runs in the next update.' diff --git a/frontend/src/components/champion-selector/run-request.test.ts b/frontend/src/components/champion-selector/run-request.test.ts new file mode 100644 index 00000000..59f4ad0e --- /dev/null +++ b/frontend/src/components/champion-selector/run-request.test.ts @@ -0,0 +1,63 @@ +import { describe, expect, it } from 'vitest' +import { assembleRunRequest } from './run-request' +import type { SplitConfig } from '@/types/api' + +const SPLIT: SplitConfig = { + strategy: 'expanding', + n_splits: 5, + min_train_size: 30, + gap: 0, + horizon: 7, // intentionally stale — must be overridden to forecastHorizon +} + +describe('assembleRunRequest', () => { + it('pins auto_train_winner and auto_predict to false (Slice A invariant)', () => { + const req = assembleRunRequest({ + storeId: 7, + productId: 12, + startDate: '2026-01-01', + endDate: '2026-05-31', + forecastHorizon: 14, + rankingMetric: 'wape', + splitConfig: SPLIT, + selectedModels: ['naive', 'regression'], + }) + expect(req.auto_train_winner).toBe(false) + expect(req.auto_predict).toBe(false) + }) + + it('forces split_config.horizon === forecast_horizon', () => { + const req = assembleRunRequest({ + storeId: 1, + productId: 2, + startDate: '2026-01-01', + endDate: '2026-03-31', + forecastHorizon: 21, + rankingMetric: 'wape', + splitConfig: SPLIT, + selectedModels: ['naive'], + }) + expect(req.forecast_horizon).toBe(21) + expect(req.split_config.horizon).toBe(21) + }) + + it('maps selected model types into flat candidate configs and stays V1', () => { + const req = assembleRunRequest({ + storeId: 1, + productId: 2, + startDate: '2026-01-01', + endDate: '2026-03-31', + forecastHorizon: 14, + rankingMetric: 'smape', + splitConfig: SPLIT, + selectedModels: ['naive', 'seasonal_naive'], + }) + expect(req.candidate_models).toEqual([ + { model_type: 'naive', params: {} }, + { model_type: 'seasonal_naive', params: {} }, + ]) + expect(req.feature_frame_version).toBe(1) + expect(req.feature_groups).toBeNull() + expect(req.ranking_metric).toBe('smape') + }) +}) diff --git a/frontend/src/components/champion-selector/run-request.ts b/frontend/src/components/champion-selector/run-request.ts new file mode 100644 index 00000000..253da365 --- /dev/null +++ b/frontend/src/components/champion-selector/run-request.ts @@ -0,0 +1,50 @@ +import type { + ModelSelectionRunRequest, + RankingMetric, + SplitConfig, +} from '@/types/api' + +export interface AssembleRunRequestInput { + storeId: number + productId: number + startDate: string // YYYY-MM-DD + endDate: string // YYYY-MM-DD + forecastHorizon: number + rankingMetric: RankingMetric + splitConfig: SplitConfig + selectedModels: string[] +} + +/** + * Assemble the typed `ModelSelectionRunRequest` from the Champion Selector + * form state. Pure + side-effect-free so it can be unit-tested. + * + * Slice A pins `auto_train_winner` and `auto_predict` to `false`: the async run + * path (Slice B) treats both as NO-OPS, and Slice C owns explicit + * train/predict. `split_config.horizon` is forced equal to `forecast_horizon` + * (the backend `ModelSelectionRunRequest` validator requires it). The request + * is assembled but NOT sent in Slice A — the "Run comparison" CTA is disabled. + */ +export function assembleRunRequest( + input: AssembleRunRequestInput, +): ModelSelectionRunRequest { + return { + store_id: input.storeId, + product_id: input.productId, + selection_window: { + start_date: input.startDate, + end_date: input.endDate, + }, + forecast_horizon: input.forecastHorizon, + ranking_metric: input.rankingMetric, + split_config: { ...input.splitConfig, horizon: input.forecastHorizon }, + candidate_models: input.selectedModels.map((model_type) => ({ + model_type, + params: {}, + })), + feature_frame_version: 1, + feature_groups: null, + auto_train_winner: false, + auto_predict: false, + } +} diff --git a/frontend/src/components/champion-selector/searchable-entity-select.test.tsx b/frontend/src/components/champion-selector/searchable-entity-select.test.tsx new file mode 100644 index 00000000..99b476a7 --- /dev/null +++ b/frontend/src/components/champion-selector/searchable-entity-select.test.tsx @@ -0,0 +1,78 @@ +import { afterEach, beforeAll, describe, expect, it, vi } from 'vitest' +import { cleanup, fireEvent, render, screen } from '@testing-library/react' +import { SearchableEntitySelect, type SearchableEntityItem } from './searchable-entity-select' + +// Radix Popover positions its content with Popper, which needs ResizeObserver +// + a couple of layout APIs jsdom lacks. Polyfill them locally (the repo has no +// vitest setup file) so the popover can open in the test environment. +beforeAll(() => { + class ResizeObserverStub { + observe() {} + unobserve() {} + disconnect() {} + } + vi.stubGlobal('ResizeObserver', ResizeObserverStub) + if (!Element.prototype.hasPointerCapture) { + Element.prototype.hasPointerCapture = () => false + } + if (!Element.prototype.scrollIntoView) { + Element.prototype.scrollIntoView = () => {} + } +}) + +afterEach(cleanup) + +const ITEMS: SearchableEntityItem[] = [ + { id: 7, primary: 'S001 · Downtown', secondary: 'North' }, + { id: 12, primary: 'S002 · Airport', secondary: 'West' }, + { id: 99, primary: 'S003 · Suburb', secondary: 'East' }, +] + +describe('SearchableEntitySelect', () => { + it('shows the placeholder when nothing is selected', () => { + render( + {}} + placeholder="Pick a store…" + />, + ) + expect(screen.getByText('Pick a store…')).toBeTruthy() + }) + + it('filters the list client-side and selects an option on click', () => { + const onChange = vi.fn() + render( + , + ) + fireEvent.click(screen.getByTestId('searchable-entity-select')) + + // All three options visible before filtering. + expect(screen.getByTestId('searchable-entity-select-option-7')).toBeTruthy() + expect(screen.getByTestId('searchable-entity-select-option-12')).toBeTruthy() + expect(screen.getByTestId('searchable-entity-select-option-99')).toBeTruthy() + + // Filter narrows to the Airport row (matches the primary text). + fireEvent.change(screen.getByTestId('searchable-entity-select-filter'), { + target: { value: 'airport' }, + }) + expect(screen.queryByTestId('searchable-entity-select-option-7')).toBeNull() + expect(screen.getByTestId('searchable-entity-select-option-12')).toBeTruthy() + + fireEvent.click(screen.getByTestId('searchable-entity-select-option-12')) + expect(onChange).toHaveBeenCalledWith(12) + }) + + it('filters on the secondary descriptor too', () => { + render( + {}} />, + ) + fireEvent.click(screen.getByTestId('searchable-entity-select')) + fireEvent.change(screen.getByTestId('searchable-entity-select-filter'), { + target: { value: 'east' }, + }) + expect(screen.getByTestId('searchable-entity-select-option-99')).toBeTruthy() + expect(screen.queryByTestId('searchable-entity-select-option-7')).toBeNull() + }) +}) diff --git a/frontend/src/components/champion-selector/searchable-entity-select.tsx b/frontend/src/components/champion-selector/searchable-entity-select.tsx new file mode 100644 index 00000000..f4dcf51b --- /dev/null +++ b/frontend/src/components/champion-selector/searchable-entity-select.tsx @@ -0,0 +1,144 @@ +import { useState } from 'react' +import { Check, ChevronsUpDown, Search } from 'lucide-react' +import { cn } from '@/lib/utils' +import { Button } from '@/components/ui/button' +import { Input } from '@/components/ui/input' +import { + Popover, + PopoverContent, + PopoverTrigger, +} from '@/components/ui/popover' + +export interface SearchableEntityItem { + id: number + primary: string + secondary?: string +} + +interface SearchableEntitySelectProps { + items: SearchableEntityItem[] + value: number | null + onChange: (id: number) => void + placeholder?: string + loading?: boolean + emptyLabel?: string + /** Forwarded to the trigger button + filter input for scoped test queries. */ + testId?: string +} + +/** + * A combobox built from existing primitives (Popover + Input + a filtered + * ` + + +
+ + setFilter(event.target.value)} + placeholder="Filter…" + data-testid={`${testId}-filter`} + className="h-8 border-0 px-0 shadow-none focus-visible:ring-0" + /> +
+
+ {filtered.length === 0 ? ( +

+ {emptyLabel} +

+ ) : ( + filtered.map((item) => ( + + )) + )} +
+
+ + ) +} diff --git a/frontend/src/components/champion-selector/split-config.ts b/frontend/src/components/champion-selector/split-config.ts new file mode 100644 index 00000000..ecc98f35 --- /dev/null +++ b/frontend/src/components/champion-selector/split-config.ts @@ -0,0 +1,24 @@ +import type { SplitConfig } from '@/types/api' + +/** + * Inline-validate a `SplitConfig` against the backend SplitConfig bounds + * (`app/features/backtesting/schemas.py`). Kept in a `.ts` module (not the + * form `.tsx`) so the `react-refresh/only-export-components` lint rule stays + * happy. Returns a list of human-facing error strings (empty = valid). + */ +export function splitConfigErrors(config: SplitConfig): string[] { + const errors: string[] = [] + if (config.n_splits < 2 || config.n_splits > 20) { + errors.push('Splits must be between 2 and 20.') + } + if (config.min_train_size < 7) { + errors.push('Minimum train size must be at least 7 days.') + } + if (config.gap < 0 || config.gap > 30) { + errors.push('Gap must be between 0 and 30 days.') + } + if (config.gap >= config.horizon) { + errors.push('Gap must be smaller than the horizon.') + } + return errors +} diff --git a/frontend/src/hooks/index.ts b/frontend/src/hooks/index.ts index 1c47074d..eebde40d 100644 --- a/frontend/src/hooks/index.ts +++ b/frontend/src/hooks/index.ts @@ -7,6 +7,7 @@ export * from './use-inventory' export * from './use-lifecycle-curve' export * from './use-runs' export * from './use-jobs' +export * from './use-model-selection' export * from './use-ops' export * from './use-scenarios' export * from './use-rag-sources' diff --git a/frontend/src/hooks/use-model-selection.test.ts b/frontend/src/hooks/use-model-selection.test.ts new file mode 100644 index 00000000..a1187321 --- /dev/null +++ b/frontend/src/hooks/use-model-selection.test.ts @@ -0,0 +1,126 @@ +/** + * Unit tests for the model-selection query hooks (Champion Selector, Slice A). + * + * Stubs `fetch` to assert the catalog + availability GET URLs and the + * availability `enabled` gating. No real backend is exercised. + */ +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' +import { renderHook, waitFor } from '@testing-library/react' +import { afterEach, describe, expect, it, vi } from 'vitest' +import { createElement, type ReactNode } from 'react' + +import { useModelCatalog, usePairAvailability } from './use-model-selection' +import type { ModelCatalogResponse, PairAvailability } from '@/types/api' + +function makeWrapper(client: QueryClient) { + return function Wrapper({ children }: { children: ReactNode }) { + return createElement(QueryClientProvider, { client }, children) + } +} + +function makeClient() { + return new QueryClient({ defaultOptions: { queries: { retry: false } } }) +} + +const CATALOG: ModelCatalogResponse = { + models: [ + { + model_type: 'naive', + label: 'Naive', + family: 'baseline', + feature_aware: false, + requires_extra: false, + default_params: {}, + supports_auto_predict: true, + description: 'Repeats the last observed value.', + }, + ], + default_candidate_model_types: ['naive', 'seasonal_naive', 'moving_average'], +} + +const AVAILABILITY: PairAvailability = { + store_id: 7, + product_id: 12, + first_sales_date: '2026-01-01', + last_sales_date: '2026-05-31', + observed_days: 150, + expected_calendar_days: 151, + coverage_ratio: 0.99, + missing_days: 1, + zero_sale_days: 4, + promotion_days: 3, + average_daily_demand: 9.2, + status: 'ready', + recommended_split_config: { + strategy: 'expanding', + n_splits: 5, + min_train_size: 30, + gap: 0, + horizon: 14, + }, + warnings: [], +} + +afterEach(() => { + vi.unstubAllGlobals() +}) + +describe('useModelCatalog', () => { + it('GETs /model-selection/models and returns the parsed catalog', async () => { + const fetchMock = vi.fn().mockResolvedValue( + new Response(JSON.stringify(CATALOG), { + status: 200, + headers: { 'content-type': 'application/json' }, + }), + ) + vi.stubGlobal('fetch', fetchMock) + + const { result } = renderHook(() => useModelCatalog(), { + wrapper: makeWrapper(makeClient()), + }) + + await waitFor(() => expect(result.current.isSuccess).toBe(true)) + expect(fetchMock).toHaveBeenCalledTimes(1) + expect(fetchMock.mock.calls[0]![0]).toContain('/model-selection/models') + expect(result.current.data?.models[0]?.model_type).toBe('naive') + }) +}) + +describe('usePairAvailability', () => { + it('GETs /model-selection/availability with the three query params', async () => { + const fetchMock = vi.fn().mockResolvedValue( + new Response(JSON.stringify(AVAILABILITY), { + status: 200, + headers: { 'content-type': 'application/json' }, + }), + ) + vi.stubGlobal('fetch', fetchMock) + + const { result } = renderHook( + () => usePairAvailability({ storeId: 7, productId: 12, forecastHorizon: 14 }), + { wrapper: makeWrapper(makeClient()) }, + ) + + await waitFor(() => expect(result.current.isSuccess).toBe(true)) + const url = String(fetchMock.mock.calls[0]![0]) + expect(url).toContain('/model-selection/availability') + expect(url).toContain('store_id=7') + expect(url).toContain('product_id=12') + expect(url).toContain('forecast_horizon=14') + expect(result.current.data?.status).toBe('ready') + }) + + it('does NOT fetch while the pair is incomplete (enabled gating)', async () => { + const fetchMock = vi.fn() + vi.stubGlobal('fetch', fetchMock) + + renderHook( + () => usePairAvailability({ storeId: null, productId: 12, forecastHorizon: 14 }), + { wrapper: makeWrapper(makeClient()) }, + ) + + // Give TanStack a tick; the disabled query must never call fetch. + await new Promise((resolve) => setTimeout(resolve, 20)) + expect(fetchMock).not.toHaveBeenCalled() + }) +}) diff --git a/frontend/src/hooks/use-model-selection.ts b/frontend/src/hooks/use-model-selection.ts new file mode 100644 index 00000000..726f8072 --- /dev/null +++ b/frontend/src/hooks/use-model-selection.ts @@ -0,0 +1,57 @@ +import { useQuery } from '@tanstack/react-query' +import { api } from '@/lib/api' +import type { ModelCatalogResponse, PairAvailability } from '@/types/api' + +/** + * Model-selection query hooks (Champion Selector, Slice A). + * + * Read-only: the catalog and pair-availability GETs. The run mutation, + * progress, and results hooks are owned by Slice B; train/predict by Slice C. + */ + +/** + * Fetch the backend-owned candidate-model capability catalog. + * + * The catalog is static, so it is cached aggressively (no refetch churn). + */ +export function useModelCatalog() { + return useQuery({ + queryKey: ['model-selection', 'models'], + queryFn: () => api('/model-selection/models'), + staleTime: 1000 * 60 * 60, // 1h — the catalog rarely changes within a session + }) +} + +interface UsePairAvailabilityParams { + storeId: number | null + productId: number | null + forecastHorizon: number + enabled?: boolean +} + +/** + * Assess data availability for a (store, product) pair at a given horizon. + * + * Gated like `useStore`: only fires once a real pair is chosen. `storeId` / + * `productId` are nullable so the page can pass its raw selection state without + * coercing un-selected values to a bogus `0`/`1`. + */ +export function usePairAvailability({ + storeId, + productId, + forecastHorizon, + enabled = true, +}: UsePairAvailabilityParams) { + return useQuery({ + queryKey: ['model-selection', 'availability', storeId, productId, forecastHorizon], + queryFn: () => + api('/model-selection/availability', { + params: { + store_id: storeId, + product_id: productId, + forecast_horizon: forecastHorizon, + }, + }), + enabled: enabled && !!storeId && storeId > 0 && !!productId && productId > 0, + }) +} diff --git a/frontend/src/lib/constants.ts b/frontend/src/lib/constants.ts index 6a6de39f..95cb28b8 100644 --- a/frontend/src/lib/constants.ts +++ b/frontend/src/lib/constants.ts @@ -25,6 +25,7 @@ export const ROUTES = { DEMAND: '/visualize/demand', PLANNER: '/visualize/planner', BATCH: '/visualize/batch', + CHAMPION: '/visualize/champion', }, KNOWLEDGE: '/knowledge', CHAT: '/chat', @@ -55,6 +56,7 @@ export const NAV_ITEMS = [ { label: 'Forecast', href: ROUTES.VISUALIZE.FORECAST }, { label: 'Backtest Results', href: ROUTES.VISUALIZE.BACKTEST }, { label: 'Batch Runner', href: ROUTES.VISUALIZE.BATCH }, + { label: 'Champion Selector', href: ROUTES.VISUALIZE.CHAMPION }, ], }, { label: 'Knowledge', href: ROUTES.KNOWLEDGE }, diff --git a/frontend/src/pages/visualize/champion.test.tsx b/frontend/src/pages/visualize/champion.test.tsx new file mode 100644 index 00000000..123d4862 --- /dev/null +++ b/frontend/src/pages/visualize/champion.test.tsx @@ -0,0 +1,118 @@ +import { afterEach, beforeAll, describe, expect, it, vi } from 'vitest' +import { cleanup, render, screen, waitFor } from '@testing-library/react' +import type { ModelCatalogResponse } from '@/types/api' + +// Radix primitives need a couple of layout APIs jsdom lacks. +beforeAll(() => { + class ResizeObserverStub { + observe() {} + unobserve() {} + disconnect() {} + } + vi.stubGlobal('ResizeObserver', ResizeObserverStub) + if (!Element.prototype.hasPointerCapture) { + Element.prototype.hasPointerCapture = () => false + } + if (!Element.prototype.scrollIntoView) { + Element.prototype.scrollIntoView = () => {} + } +}) + +const CATALOG: ModelCatalogResponse = { + models: [ + { + model_type: 'naive', + label: 'Naive', + family: 'baseline', + feature_aware: false, + requires_extra: false, + default_params: {}, + supports_auto_predict: true, + description: 'Repeats the last observed value.', + }, + { + model_type: 'regression', + label: 'Gradient Boosting Regression', + family: 'tree', + feature_aware: true, + requires_extra: false, + default_params: {}, + supports_auto_predict: false, + description: 'Histogram gradient boosting.', + }, + ], + default_candidate_model_types: ['naive', 'regression'], +} + +vi.mock('@/hooks/use-stores', () => ({ + useStores: () => ({ + data: { stores: [{ id: 7, code: 'S001', name: 'Downtown', region: 'North', store_type: 'flagship' }] }, + isLoading: false, + }), +})) +vi.mock('@/hooks/use-products', () => ({ + useProducts: () => ({ + data: { products: [{ id: 12, sku: 'SKU1', name: 'Widget', category: 'tools' }] }, + isLoading: false, + }), +})) +vi.mock('@/hooks/use-model-selection', () => ({ + useModelCatalog: () => ({ + data: CATALOG, + isLoading: false, + isError: false, + error: null, + refetch: () => {}, + }), + usePairAvailability: () => ({ + data: undefined, + isLoading: false, + isError: false, + }), +})) + +import ChampionSelectorPage from './champion' + +afterEach(cleanup) + +describe('ChampionSelectorPage', () => { + it('renders the selection shell', () => { + render() + expect(screen.getByText('Champion Selector')).toBeTruthy() + expect(screen.getByText('1 · Pick a store & product')).toBeTruthy() + expect(screen.getByText('2 · Data availability')).toBeTruthy() + expect(screen.getByText('3 · Candidate models')).toBeTruthy() + expect(screen.getByText('4 · Backtest settings')).toBeTruthy() + }) + + it('drives candidate cards from the backend catalog', () => { + render() + expect(screen.getByTestId('candidate-model-naive')).toBeTruthy() + expect(screen.getByTestId('candidate-model-regression')).toBeTruthy() + }) + + it('pre-selects the catalog default candidate models', async () => { + render() + // The seeding effect selects the default two models. + await waitFor(() => + expect(screen.getByText('2 of 10 selected')).toBeTruthy(), + ) + }) + + it('renders the availability empty state until a pair is chosen', () => { + render() + expect(screen.getByText('Pick a store and product')).toBeTruthy() + }) + + it('keeps the Run comparison CTA disabled and issues no POST', () => { + const fetchMock = vi.fn() + vi.stubGlobal('fetch', fetchMock) + render() + const cta = screen.getByTestId('run-comparison-cta') as HTMLButtonElement + expect(cta.disabled).toBe(true) + // The page itself issues no network calls (the hooks are mocked); in + // particular it never POSTs to /model-selection/run. + expect(fetchMock).not.toHaveBeenCalled() + vi.unstubAllGlobals() + }) +}) diff --git a/frontend/src/pages/visualize/champion.tsx b/frontend/src/pages/visualize/champion.tsx new file mode 100644 index 00000000..d3e3106f --- /dev/null +++ b/frontend/src/pages/visualize/champion.tsx @@ -0,0 +1,294 @@ +import { useMemo, useState } from 'react' +import { format } from 'date-fns' +import { DateRange } from 'react-day-picker' +import { Trophy } from 'lucide-react' +import { useStores } from '@/hooks/use-stores' +import { useProducts } from '@/hooks/use-products' +import { useModelCatalog, usePairAvailability } from '@/hooks/use-model-selection' +import { DateRangePicker } from '@/components/common/date-range-picker' +import { ErrorDisplay } from '@/components/common/error-display' +import { AvailabilityPanel } from '@/components/champion-selector/availability-panel' +import { BacktestSettingsForm } from '@/components/champion-selector/backtest-settings-form' +import { splitConfigErrors } from '@/components/champion-selector/split-config' +import { CandidateModelPicker } from '@/components/champion-selector/candidate-model-picker' +import { SearchableEntitySelect } from '@/components/champion-selector/searchable-entity-select' +import { RUN_COMPARISON_PENDING } from '@/components/champion-selector/copy' +import { assembleRunRequest } from '@/components/champion-selector/run-request' +import { Button } from '@/components/ui/button' +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card' +import { Input } from '@/components/ui/input' +import type { + ModelSelectionRunRequest, + SplitConfig, +} from '@/types/api' + +const DEFAULT_HORIZON = 14 + +const DEFAULT_SPLIT: SplitConfig = { + strategy: 'expanding', + n_splits: 5, + min_train_size: 30, + gap: 0, + horizon: DEFAULT_HORIZON, +} + +/** + * Forecast Champion Selector — Slice A. + * + * Configuration + availability triage only. It assembles a typed + * `ModelSelectionRunRequest` in component state and surfaces a DISABLED + * "Run comparison" CTA — the comparison RUN itself (and all results/training) + * lands in Slices B/C. This page calls only the two read GETs (catalog + + * availability); it never POSTs. + */ +export default function ChampionSelectorPage() { + const [storeId, setStoreId] = useState(null) + const [productId, setProductId] = useState(null) + const [dateRange, setDateRange] = useState() + const [forecastHorizon, setForecastHorizon] = useState(DEFAULT_HORIZON) + const [splitConfig, setSplitConfig] = useState(DEFAULT_SPLIT) + const [rankingMetric, setRankingMetric] = useState< + ModelSelectionRunRequest['ranking_metric'] + >('wape') + // `null` means "the user hasn't edited the selection yet" — fall back to the + // catalog's default candidate set (derived below, no effect needed). + const [editedModels, setEditedModels] = useState(null) + + // /dimensions/{stores,products} both cap page_size at 100 (client-filtered). + const storesQuery = useStores({ page: 1, pageSize: 100 }) + const productsQuery = useProducts({ page: 1, pageSize: 100 }) + const catalogQuery = useModelCatalog() + + const validPair = !!storeId && !!productId + const availabilityQuery = usePairAvailability({ + storeId, + productId, + forecastHorizon, + enabled: validPair, + }) + + // Pre-select the backend default candidate set until the user edits it — + // derived during render rather than seeded via an effect. + const selectedModels = + editedModels ?? catalogQuery.data?.default_candidate_model_types ?? [] + + // split_config.horizon must equal forecast_horizon (the backend validator). + // Force it during render so no effect is needed to keep them in sync. + const effectiveSplit: SplitConfig = useMemo( + () => ({ ...splitConfig, horizon: forecastHorizon }), + [splitConfig, forecastHorizon], + ) + + const storeItems = useMemo( + () => + (storesQuery.data?.stores ?? []).map((store) => ({ + id: store.id, + primary: `${store.code} · ${store.name}`, + secondary: [store.region, store.store_type].filter(Boolean).join(' · '), + })), + [storesQuery.data], + ) + const productItems = useMemo( + () => + (productsQuery.data?.products ?? []).map((product) => ({ + id: product.id, + primary: `${product.sku} · ${product.name}`, + secondary: product.category ?? undefined, + })), + [productsQuery.data], + ) + + const formReady = + validPair && + !!dateRange?.from && + !!dateRange?.to && + forecastHorizon >= 1 && + forecastHorizon <= 90 && + selectedModels.length >= 1 && + splitConfigErrors(effectiveSplit).length === 0 + + // The assembled request — typed but NOT sent in Slice A (the CTA is disabled). + // `auto_train_winner`/`auto_predict` are pinned false by `assembleRunRequest`. + // Built defensively so it is valid the moment Slice B wires the mutation. + const runRequest: ModelSelectionRunRequest | null = + formReady && dateRange?.from && dateRange?.to + ? assembleRunRequest({ + storeId: storeId!, + productId: productId!, + startDate: format(dateRange.from, 'yyyy-MM-dd'), + endDate: format(dateRange.to, 'yyyy-MM-dd'), + forecastHorizon, + rankingMetric, + splitConfig: effectiveSplit, + selectedModels, + }) + : null + + return ( +
+
+

+ + Champion Selector +

+

+ Configure a store, product, time period, horizon and candidate models, + and check whether the pair has enough history to model. Running the + comparison arrives in a later update. +

+
+ + {/* Selection */} + + + 1 · Pick a store & product + + Search by code/SKU or name. The availability check runs automatically + once a valid pair and horizon are chosen. + + + +
+
+ Store + +
+
+ Product + +
+
+ Time period + +
+
+ + Forecast horizon (days) + + + setForecastHorizon(Number(event.target.value) || 0) + } + /> +
+
+
+
+ + {/* Availability */} + + + 2 · Data availability + + Whether this pair has enough observed history for a reliable + comparison, plus the recommended split. + + + + + + + + {/* Candidate models */} + + + 3 · Candidate models + + Pick the models to compare (up to 10). The default five are + pre-selected; opt-in extras are flagged. + + + + {catalogQuery.isError ? ( + catalogQuery.refetch()} + /> + ) : ( + + )} + + + + {/* Backtest settings */} + + + 4 · Backtest settings + + The ranking metric and cross-validation split. Start with the + recommended split or fine-tune under Advanced. + + + + + + + + {/* Run CTA (disabled until Slice B) */} + + +
+ {formReady + ? `Ready to compare ${selectedModels.length} model${ + selectedModels.length === 1 ? '' : 's' + }. ${RUN_COMPARISON_PENDING}` + : 'Pick a store, product, time period, horizon and at least one model to continue.'} +
+ +
+
+ + {/* Dev-only assurance that a valid request is assembled (not sent). */} + {runRequest && ( +

+ {JSON.stringify(runRequest)} +

+ )} +
+ ) +} diff --git a/frontend/src/types/api.ts b/frontend/src/types/api.ts index df2289f4..d6e0584f 100644 --- a/frontend/src/types/api.ts +++ b/frontend/src/types/api.ts @@ -1188,3 +1188,161 @@ export interface ForecastExplanation { as_of_date: string // ISO date generated_at: string // ISO datetime } + +// ============================================================================= +// Model Selection (Champion Selector) — backend slice app/features/model_selection +// ============================================================================= +// +// The FULL workflow contract is declared here so Slices B/C add BEHAVIOR, not +// type definitions. Slice A CONSUMES only `ModelCatalogResponse`, +// `PairAvailability`, and `SplitConfig` (read-only). Everything tagged +// DECLARED-FOR-LATER is wired by Slice B (async run + results) and Slice C +// (train / predict / business summary / override / promotion). + +export type ModelSelectionStatus = + | 'pending' + | 'running' + | 'completed' + | 'partial' + | 'failed' +export type RankingMetric = 'wape' | 'smape' | 'mae' | 'bias' +export type AvailabilityStatus = 'ready' | 'limited' | 'unusable' +// `ConfidenceLevel` ('high' | 'medium' | 'low') is reused from the +// Explainability section above — the backend uses the same enum. + +// Backtest split config — mirrors `app/features/backtesting/schemas.py` +// `SplitConfig` EXACTLY (bounds enforced client-side so the assembled run +// request is always valid for Slice B). +export type SplitStrategy = 'expanding' | 'sliding' +export interface SplitConfig { + strategy: SplitStrategy // def 'expanding' + n_splits: number // 2..20, def 5 + min_train_size: number // >= 7, def 30 + gap: number // 0..30, def 0 + horizon: number // 1..90, def 14; must be > gap; kept === forecast_horizon +} + +// --- CONSUMED in Slice A --------------------------------------------------- + +export interface CandidateModelInfo { + model_type: string + label: string + family: ModelFamily + feature_aware: boolean + /** lightgbm/xgboost — opt-in extra may be absent at runtime. */ + requires_extra: boolean + default_params: Record + /** false for feature-aware models (the predict path rejects them). */ + supports_auto_predict: boolean + description: string +} + +export interface ModelCatalogResponse { + models: CandidateModelInfo[] + default_candidate_model_types: string[] +} + +export interface PairAvailability { + store_id: number + product_id: number + first_sales_date: string | null + last_sales_date: string | null + observed_days: number + expected_calendar_days: number + coverage_ratio: number + missing_days: number + zero_sale_days: number + promotion_days: number | null + average_daily_demand: number + status: AvailabilityStatus + recommended_split_config: SplitConfig + warnings: string[] +} + +// --- DECLARED-FOR-LATER (Slices B/C wire behavior on these) ---------------- + +export interface SelectionWindow { + start_date: string // ISO date (inclusive) + end_date: string // ISO date (inclusive) +} + +export interface CandidateModelConfig { + model_type: string + params: Record +} + +export interface RankingPolicy { + minimum_sample_size: number + high_confidence_rel_improvement: number + max_acceptable_abs_bias: number +} + +export interface ModelSelectionRunRequest { + store_id: number + product_id: number + selection_window: SelectionWindow + forecast_horizon: number + ranking_metric: RankingMetric + split_config: SplitConfig + candidate_models: CandidateModelConfig[] + feature_frame_version: number // 1 | 2 (Slice A always 1) + feature_groups: string[] | null // only valid when feature_frame_version === 2 + ranking_policy?: RankingPolicy + // Slice A sets BOTH false. The async run path (Slice B `POST /runs`) treats + // them as NO-OPS, and Slice C owns explicit train/predict — so these two + // fields stay false throughout the UI flow and are never surfaced as toggles. + auto_train_winner: boolean + auto_predict: boolean +} + +export interface ModelRankEntry { + rank: number | null + model_type: string + params: Record + included: boolean + exclusion_reason: string | null + metrics: Record | null +} + +export interface WinnerSummary { + model_type: string + params: Record + metrics: Record + rank: number +} + +export interface ModelSelectionChartData { + wape_by_model: Record + bias_by_model: Record + fold_stability: Record + winner_actual_vs_predicted: unknown[] +} + +export interface ModelSelectionForecastSummary { + points: Record[] + total_demand: number + average_demand: number + horizon: number +} + +export interface ModelSelectionRunResponse { + selection_id: string + store_id: number + product_id: number + status: ModelSelectionStatus + selection_window: SelectionWindow + forecast_horizon: number + ranking_metric: string + availability: PairAvailability | null + ranking: ModelRankEntry[] + winner: WinnerSummary | null + recommendation_confidence: ConfidenceLevel | null + confidence_reasons: string[] + chart_data: ModelSelectionChartData | null + final_model: Record | null + forecast: ModelSelectionForecastSummary | null + business_summary: Record | null + error_message: string | null + created_at: string // ISO datetime + completed_at: string | null +}