Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
665 changes: 665 additions & 0 deletions PRPs/PRP-reliability-E4-shared-model-taxonomy.md

Large diffs are not rendered by default.

19 changes: 7 additions & 12 deletions app/features/batch/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

Submits one ``batch_job`` and N ``batch_job_item`` rows in one transaction,
then loops a partial-index-backed picker (``FOR UPDATE SKIP LOCKED``) and
delegates each item to ``JobService.create_job`` via a lazy in-method
import. The metrics JSONB is pinned to the exact five-key shape
delegates each item to ``JobService.create_job``. The metrics JSONB is
pinned to the exact five-key shape
``{wape, smape, mae, bias, sample_size}`` — every downstream PRP
(parallel-execution, priority-queue, export-and-retry,
champion-and-heatmap) consumes this shape directly. ``sample_size`` is
Expand Down Expand Up @@ -48,8 +48,11 @@

# data_platform is the de-facto shared ORM layer (see the
# data-platform-shared-orm-layer memory) — module-scope import for scope
# expansion is permitted; cross-slice *service* calls stay lazy.
# expansion is permitted.
from app.features.data_platform.models import Product, SalesDaily, Store
from app.features.jobs.models import JobStatus
from app.features.jobs.schemas import JobCreate
from app.features.jobs.service import JobService

if TYPE_CHECKING:
from app.features.jobs.schemas import JobResponse
Expand Down Expand Up @@ -284,15 +287,7 @@ async def _pick_next(self, db: AsyncSession, batch_id: str) -> BatchJobItem | No
return (await db.execute(stmt)).scalar_one_or_none()

async def _execute_item(self, db: AsyncSession, item: BatchJobItem) -> None:
"""Run one item: delegate to ``JobService.create_job`` and capture metrics.

Lazy cross-slice imports break the alembic cold-boot cycle
(precedent: ``app/features/forecasting/service.py:786-787``).
"""
from app.features.jobs.models import JobStatus
from app.features.jobs.schemas import JobCreate
from app.features.jobs.service import JobService

"""Run one item: delegate to ``JobService.create_job`` and capture metrics."""
item.status = BatchItemStatus.RUNNING.value
item.started_at = datetime.now(UTC)
await db.commit()
Expand Down
49 changes: 10 additions & 39 deletions app/features/forecasting/feature_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,52 +27,23 @@
RegressionForecaster,
XGBoostForecaster,
)
from app.features.forecasting.schemas import FeatureImportanceItem, ModelFamily
from app.features.forecasting.schemas import FeatureImportanceItem

# Back-compat re-export (#268): forecasting/service.py and tests import these
# from here. ``_MODEL_FAMILY_MAP`` is re-exported because the drift-lock test
# (test_feature_metadata.py) reads it to compare against the ModelType Literal.
from app.shared.model_taxonomy import (
_MODEL_FAMILY_MAP as _MODEL_FAMILY_MAP, # pyright: ignore[reportPrivateUsage]
)
from app.shared.model_taxonomy import ModelFamily as ModelFamily
from app.shared.model_taxonomy import model_family_for as model_family_for

if TYPE_CHECKING:
from app.features.forecasting.models import BaseForecaster

logger = structlog.get_logger(__name__)


# Canonical map: model_type string → ModelFamily. Unknown types log a warning
# and classify as BASELINE (forward-compatible for new families before this
# map is updated). Keep in sync with the ``ModelType`` Literal in
# ``forecasting/models.py`` (line 1133-1135).
_MODEL_FAMILY_MAP: dict[str, ModelFamily] = {
"naive": ModelFamily.BASELINE,
"seasonal_naive": ModelFamily.BASELINE,
"moving_average": ModelFamily.BASELINE,
"weighted_moving_average": ModelFamily.BASELINE,
"seasonal_average": ModelFamily.BASELINE,
"trend_regression_baseline": ModelFamily.ADDITIVE,
"random_forest": ModelFamily.TREE,
"regression": ModelFamily.TREE,
"lightgbm": ModelFamily.TREE,
"xgboost": ModelFamily.TREE,
"prophet_like": ModelFamily.ADDITIVE,
}


def model_family_for(model_type: str) -> ModelFamily:
"""Return the :class:`ModelFamily` for a given ``model_type`` string.

Unknown types log a warning and return :attr:`ModelFamily.BASELINE` so a
new model registered in :mod:`forecasting.models` before this map is
updated does not raise — it just shows up in the dashboard as a baseline
until the map catches up.
"""
family = _MODEL_FAMILY_MAP.get(model_type)
if family is None:
logger.warning(
"forecasting.unknown_model_family",
model_type=model_type,
fallback=ModelFamily.BASELINE.value,
)
return ModelFamily.BASELINE
return family


class FeatureImportanceUnavailableError(ValueError):
"""The estimator does not expose a usable feature-importance vector.

Expand Down
24 changes: 7 additions & 17 deletions app/features/forecasting/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@

import hashlib
from datetime import date as date_type
from enum import Enum
from typing import Literal

from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator

from app.shared.feature_frames import FeatureGroup

# Back-compat re-export (#268): downstream modules and tests import ModelFamily
# from this module; the redundant alias makes the re-export explicit for
# mypy/pyright/ruff.
from app.shared.model_taxonomy import ModelFamily as ModelFamily

# =============================================================================
# Model Configuration Schemas
# =============================================================================
Expand Down Expand Up @@ -610,25 +614,11 @@ class PredictResponse(BaseModel):


# =============================================================================
# Model Family + Feature Metadata Schemas (MLZOO-D / PRP-31)
# Feature Metadata Schemas (MLZOO-D / PRP-31; ModelFamily moved to
# app/shared/model_taxonomy — #268)
# =============================================================================


class ModelFamily(str, Enum):
"""Classifier for advanced-model UI surfacing.

Derived from ``model_type``; not persisted in the DB. Surfaced on
``RunResponse`` via a computed field and consumed by the dashboard for the
family Badge and the feature-importance panel routing. Unknown model types
classify as ``BASELINE`` (forward-compatible for new families before the
map in ``feature_metadata.py`` is updated).
"""

BASELINE = "baseline" # naive, seasonal_naive, moving_average
TREE = "tree" # regression (HistGBR), lightgbm, xgboost
ADDITIVE = "additive" # prophet_like (Ridge pipeline)


class FeatureImportanceItem(BaseModel):
"""One row of model-derived feature importance, ready for the dashboard."""

Expand Down
24 changes: 11 additions & 13 deletions app/features/forecasting/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,6 @@
PredictResponse,
TrainResponse,
)

# NOTE: ``RegistryService`` / ``JobService`` and their status enums are imported
# LAZILY inside the feature-metadata methods below. Importing them at module
# scope would close a cycle with ``registry.schemas`` (which eagerly imports
# ``ModelFamily`` from the forecasting slice for the ``model_family`` computed
# field on ``RunResponse``). The explainability slice avoids the same trap by
# importing only ``registry.models`` (a read-only ORM contract); we keep the
# import-graph one-way by deferring our service-level imports.
from app.features.forecasting.v2_loaders import (
assemble_v2_historical_sidecar,
load_exogenous_history,
Expand All @@ -68,6 +60,8 @@
load_replenishment_history,
load_returns_history,
)
from app.features.registry.schemas import RunStatus
from app.features.registry.service import RegistryService
from app.shared.feature_frames import (
DEFAULT_V2_GROUPS,
HISTORY_TAIL_DAYS,
Expand All @@ -82,6 +76,14 @@
v2_pinned_constants,
)

# NOTE: ``JobService`` and the job enums are imported LAZILY inside
# ``get_feature_metadata_for_job``: ``jobs/service.py`` lazily imports
# ``ForecastingService`` back at call time (lines ~435, ~545), so the pair is
# mutually dependent and at least one side must stay call-time-lazy to keep
# cold-boot clean. The registry imports above are EAGER since #268 moved
# ``ModelFamily`` to ``app/shared/model_taxonomy`` — registry no longer
# imports this slice.

if TYPE_CHECKING:
pass

Expand Down Expand Up @@ -963,10 +965,6 @@ async def get_feature_metadata_for_run(
estimator does not expose
``feature_importances_`` (``HistGradientBoostingRegressor``).
"""
# Lazy cross-slice imports — see module-level NOTE.
from app.features.registry.schemas import RunStatus
from app.features.registry.service import RegistryService

run = await RegistryService().get_run(db, run_id)
if run is None:
raise NotFoundError(message=f"Model run not found: {run_id}")
Expand Down Expand Up @@ -1049,7 +1047,7 @@ async def get_feature_metadata_for_job(
``load_model_bundle`` can no longer find, or when the
``ml-*`` extra is missing at unpickle time.
"""
# Lazy cross-slice imports — see module-level NOTE.
# Lazy by design — see the jobs↔forecasting NOTE below the module imports.
from app.features.jobs.models import JobStatus, JobType
from app.features.jobs.service import JobService

Expand Down
9 changes: 3 additions & 6 deletions app/features/model_selection/capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
``MODEL_FAMILY_MAP`` / labels never drift from the Python authority.

Capability provenance (BACKEND-OWNED, verified 2026-06-01):
- ``family`` — ``forecasting.feature_metadata.model_family_for`` (lazy
cross-slice import inside the builder, per the slice's import discipline).
- ``family`` — ``app.shared.model_taxonomy.model_family_for`` (shared
taxonomy, #268).
- ``feature_aware`` — the set whose forecasters set ``requires_features=True``
(RandomForest/Regression/LightGBM/XGBoost/ProphetLike), i.e. exactly the set
``ForecastingService.predict()`` rejects (``forecasting/service.py``).
Expand All @@ -28,6 +28,7 @@
CandidateModelInfo,
ModelCatalogResponse,
)
from app.shared.model_taxonomy import model_family_for

# Models gated behind the matching opt-in extra (may be absent at runtime).
_REQUIRES_EXTRA: frozenset[str] = frozenset({"lightgbm", "xgboost"})
Expand Down Expand Up @@ -130,10 +131,6 @@ def build_model_catalog() -> ModelCatalogResponse:
from the module-level sets. Returns the full catalog plus the documented
default candidate set.
"""
# Lazy cross-slice import (mirror service.py) — avoids closing an alembic
# cold-boot import cycle through the forecasting slice.
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
from app.features.forecasting.feature_metadata import model_family_for

models: list[CandidateModelInfo] = []
for model_type, meta in _CATALOG.items():
feature_aware = model_type in _FEATURE_AWARE
Expand Down
15 changes: 7 additions & 8 deletions app/features/model_selection/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@
TrainWinnerResponse,
WinnerSummary,
)
from app.features.registry.schemas import (
AliasCreate,
RunCreate,
RunStatus,
RunUpdate,
)
from app.features.registry.service import RegistryService

if TYPE_CHECKING:
from app.features.backtesting.schemas import BacktestResponse
Expand Down Expand Up @@ -1074,14 +1081,6 @@ async def promote(
audit record on ``model_selection_run``. Promotion is NEVER automatic and
performs NO comparison.
"""
from app.features.registry.schemas import ( # lazy
AliasCreate,
RunCreate,
RunStatus,
RunUpdate,
)
from app.features.registry.service import RegistryService # lazy

row = await self._load(db, selection_id)
if not row.final_model_path or not row.trained_model_type:
raise UnprocessableEntityError(message="Train the model before promoting.")
Expand Down
4 changes: 3 additions & 1 deletion app/features/model_selection/tests/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,8 +623,10 @@ def _patch_registry(monkeypatch: pytest.MonkeyPatch) -> dict[str, AsyncMock]:
create_run = AsyncMock(return_value=run_resp)
update_run = AsyncMock(return_value=run_resp)
create_alias = AsyncMock(return_value=alias_resp)
# Patch the binding promote() actually uses — module-scope since #268
# promoted the registry imports out of the method body.
monkeypatch.setattr(
"app.features.registry.service.RegistryService",
"app.features.model_selection.service.RegistryService",
lambda: SimpleNamespace(
create_run=create_run, update_run=update_run, create_alias=create_alias
),
Expand Down
27 changes: 8 additions & 19 deletions app/features/registry/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,11 @@

from pydantic import BaseModel, ConfigDict, Field, computed_field, field_validator

# Pydantic v2 resolves a ``@computed_field``'s return-type annotation at
# validation time, so ``ModelFamily`` must be a real runtime import here.
# To avoid the cycle this introduces with the forecasting slice (whose
# ``service.py`` imports ``RegistryService``), the forecasting slice's
# cross-slice imports of ``RegistryService`` / ``JobService`` / status enums
# are LAZY (inside the methods that use them). See
# ``app/features/forecasting/service.py`` for the matching contract.
from app.features.forecasting.schemas import ModelFamily
# ``ModelFamily`` / ``model_family_for`` live in ``app/shared/model_taxonomy``
# (#268) so this module never imports from another feature slice. Pydantic v2
# resolves a ``@computed_field``'s return-type annotation at schema-build time,
# so ``ModelFamily`` must be a real runtime import (never TYPE_CHECKING-gated).
from app.shared.model_taxonomy import ModelFamily, model_family_for


class RunStatus(str, Enum):
Expand Down Expand Up @@ -131,9 +128,8 @@ class RunResponse(BaseModel):

``model_family`` is a computed field derived from ``model_type`` at
serialization time — no DB column, no Alembic migration, no backfill.
See ``app/features/forecasting/feature_metadata.py:model_family_for`` for
the canonical map. Unknown model types log a warning and return
``ModelFamily.BASELINE``.
See ``app/shared/model_taxonomy.py`` for the canonical map. Unknown model
types log a warning and return ``ModelFamily.BASELINE``.
"""

model_config = ConfigDict(from_attributes=True, populate_by_name=True)
Expand Down Expand Up @@ -166,14 +162,7 @@ class RunResponse(BaseModel):
@computed_field # type: ignore[prop-decorator]
@property
def model_family(self) -> ModelFamily:
"""Computed family label derived from ``model_type``.

Imported lazily to avoid a hard cycle between
``registry.schemas`` and ``forecasting.feature_metadata`` at module
import time.
"""
from app.features.forecasting.feature_metadata import model_family_for

"""Computed family label derived from ``model_type``."""
return model_family_for(self.model_type)

@computed_field # type: ignore[prop-decorator]
Expand Down
Loading