From d5adc5b5b4d4a78d41c9e1a4908918a3b1c12465 Mon Sep 17 00:00:00 2001 From: Shay Palachy Date: Tue, 28 Apr 2026 11:09:09 +0300 Subject: [PATCH 1/5] =?UTF-8?q?feat:=20Milestone=208=20=E2=80=94=20render/?= =?UTF-8?q?bundle=20layer=20and=20end-to-end=20Generator.generate()?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements the full observation model that transforms hidden simulation output into the canonical output bundle structure. - render/relational.py: to_dataframes() — 9 typed DataFrames from SimulationResult + PopulationResult - render/snapshots.py: build_snapshot() — 30-column leakage-free lead snapshot; touch/session/ activity aggregates, account/contact field joins, vectorised pandas operations - render/tasks.py: write_task_splits() — deterministic 70/15/15 train/valid/test Parquet split + task_manifest.json; seeded random shuffle guarantees reproducibility - render/manifests.py: build_manifest() / write_manifest() — manifest.json with full provenance, row counts, and SHA-256 file hashes - api/bundle.py: write_bundle() — orchestrates all render steps in order - core/models.py: WorldBundle enriched with population/simulation_result/world_graph fields; WorldBundle.save(path) delegates to write_bundle() via lazy import (avoids circular import) - api/generator.py: Generator.generate() fully implemented end-to-end - tests/render/test_render.py: 31 tests covering all render modules and full bundle write smoke test - tests/api/test_generator.py: replaced stale NotImplementedError test with WorldBundle assertion 521 tests passing; ruff + mypy clean. Co-Authored-By: Claude Sonnet 4.6 --- .agent-plan.md | 35 ++- leadforge/api/bundle.py | 86 +++++++ leadforge/api/generator.py | 59 ++++- leadforge/core/models.py | 39 ++- leadforge/render/manifests.py | 104 ++++++++ leadforge/render/relational.py | 77 ++++++ leadforge/render/snapshots.py | 210 +++++++++++++++ leadforge/render/tasks.py | 78 ++++++ tests/api/test_generator.py | 12 +- tests/render/__init__.py | 0 tests/render/test_render.py | 450 +++++++++++++++++++++++++++++++++ 11 files changed, 1128 insertions(+), 22 deletions(-) create mode 100644 leadforge/api/bundle.py create mode 100644 leadforge/render/manifests.py create mode 100644 leadforge/render/relational.py create mode 100644 leadforge/render/snapshots.py create mode 100644 leadforge/render/tasks.py create mode 100644 tests/render/__init__.py create mode 100644 tests/render/test_render.py diff --git a/.agent-plan.md b/.agent-plan.md index 4ed3a55..04dc996 100644 --- a/.agent-plan.md +++ b/.agent-plan.md @@ -6,35 +6,42 @@ ## Current System State -**v0.4.0 in progress — Milestone 7 complete (PR open).** Full simulation engine implemented: -per-lead mutable state, 90-day daily-step loop, touch/session/sales-activity emission, -HazardTransition stage advancement, ConversionHazard final-close, and post-conversion -opportunity/customer/subscription creation. 490 tests passing. +**v0.4.0 in progress — Milestones 7–8 complete (PRs open).** Full simulation engine + render/bundle +layer implemented. 521 tests passing. --- -## Active Task Breakdown — Milestone 8: Observation Model (v0.4.0) +## Next Up — Milestone 9: Exposure Filtering (v0.4.0) -Goal: Transform the hidden simulated world into realistic CRM-like observations. +Goal: Apply `student_public` / `research_instructor` exposure-mode filtering during bundle write. -- [ ] **1. Snapshot builder** (`render/snapshots.py`) — lead-anchored flat feature snapshot -- [ ] **2. Relational export** (`render/relational.py`) — DataFrame per table from SimulationResult -- [ ] **3. Task export** (`render/tasks.py`) — train/valid/test Parquet split for `converted_within_90_days` -- [ ] **4. Manifest builder** (`render/manifests.py`) — bundle manifest.json -- [ ] **5. Bundle writer** (`api/bundle.py`) — `WorldBundle.save(path)` +- [ ] `exposure/modes.py` — `ExposureMode`-aware filter dispatch +- [ ] `exposure/filters.py` — column/table redaction rules per mode +- [ ] `exposure/redaction.py` — latent-column scrubbing for `student_public` +- [ ] Wire into `api/bundle.py` write pipeline --- ## Context Pointers -- Milestone 7 scope: `docs/leadforge_implementation_plan.md` §10 "Milestone 7" -- Simulation spec: `docs/leadforge_architecture_spec.md` §11 "Simulation engine" -- Mechanism layer: `leadforge/mechanisms/` (all M6 files) +- Milestone 8 scope: `docs/leadforge_implementation_plan.md` §10 "Milestone 8" +- Render layer: `leadforge/render/` (snapshots, relational, tasks, manifests) +- Bundle writer: `leadforge/api/bundle.py` --- ## Completed Phases +### Milestone 8 — Render / Bundle Layer ✓ (v0.4.0 in PR) +- `render/relational.py`: `to_dataframes()` — 9-table dict of typed DataFrames from SimulationResult + PopulationResult +- `render/snapshots.py`: `build_snapshot()` — 30-column leakage-free lead snapshot with touch/session/activity aggregates, account/contact field joins +- `render/tasks.py`: `write_task_splits()` — deterministic 70/15/15 train/valid/test Parquet split + `task_manifest.json` +- `render/manifests.py`: `build_manifest()` / `write_manifest()` — manifest.json with provenance, row counts, SHA-256 hashes +- `api/bundle.py`: `write_bundle()` — orchestrates all render steps; writes full bundle to disk +- `core/models.py`: `WorldBundle.save(path)` — delegates to `write_bundle()` via lazy import +- `api/generator.py`: `Generator.generate()` — fully implemented end-to-end flow +- 31 new render tests; total 521 passing + ### Milestone 7 — Simulation Engine ✓ (v0.4.0 in PR) - `simulation/state.py`: `LeadSimState` — per-lead mutable state (stage, dwell, converted, churned, sql_day) - `simulation/engine.py`: `simulate_world()` — 90-day daily-step loop; `SimulationResult` output type diff --git a/leadforge/api/bundle.py b/leadforge/api/bundle.py new file mode 100644 index 0000000..753e29d --- /dev/null +++ b/leadforge/api/bundle.py @@ -0,0 +1,86 @@ +"""Bundle writer — assembles and serialises the full output bundle. + +:func:`write_bundle` is called by :meth:`WorldBundle.save` and orchestrates +all rendering steps: + +1. Write relational Parquet tables (``tables/``). +2. Build the lead snapshot and write task splits (``tasks/``). +3. Write ``dataset_card.md`` and ``feature_dictionary.csv``. +4. Build and write ``manifest.json``. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +from leadforge.render.manifests import build_manifest, write_manifest +from leadforge.render.relational import to_dataframes +from leadforge.render.snapshots import build_snapshot +from leadforge.render.tasks import write_task_splits +from leadforge.schema.dictionaries import write_feature_dictionary +from leadforge.schema.tables import write_parquet + +if TYPE_CHECKING: + from leadforge.core.models import WorldBundle + + +def write_bundle(bundle: WorldBundle, path: str) -> None: + """Write *bundle* to disk at *path*. + + Args: + bundle: Fully populated :class:`~leadforge.core.models.WorldBundle`. + path: Destination directory (created if absent). + + Raises: + RuntimeError: if any of ``bundle.simulation_result``, + ``bundle.population``, or ``bundle.world_graph`` are ``None``. + """ + if bundle.simulation_result is None or bundle.population is None or bundle.world_graph is None: + raise RuntimeError("WorldBundle is not fully populated. Call Generator.generate() first.") + + root = Path(path) + root.mkdir(parents=True, exist_ok=True) + + config = bundle.spec.config + result = bundle.simulation_result + population = bundle.population + world_graph = bundle.world_graph + + # ------------------------------------------------------------------ + # 1. Relational tables → tables/ + # ------------------------------------------------------------------ + tables_dir = root / "tables" + tables_dir.mkdir(exist_ok=True) + + dfs = to_dataframes(result, population) + table_row_counts: dict[str, int] = {} + for table_name, df in dfs.items(): + write_parquet(df, tables_dir / f"{table_name}.parquet") + table_row_counts[table_name] = len(df) + + # ------------------------------------------------------------------ + # 2. Snapshot + task splits → tasks/ + # ------------------------------------------------------------------ + snapshot = build_snapshot(result, population, horizon_days=config.horizon_days) + task_row_counts = write_task_splits(snapshot, root / "tasks", seed=config.seed) + + # ------------------------------------------------------------------ + # 3. Dataset card and feature dictionary + # ------------------------------------------------------------------ + from leadforge.narrative.dataset_card import render_dataset_card + + (root / "dataset_card.md").write_text(render_dataset_card(bundle.spec)) + write_feature_dictionary(root / "feature_dictionary.csv") + + # ------------------------------------------------------------------ + # 4. Manifest + # ------------------------------------------------------------------ + manifest = build_manifest( + config=config, + world_graph=world_graph, + table_row_counts=table_row_counts, + task_row_counts={"converted_within_90_days": task_row_counts}, + bundle_root=root, + ) + write_manifest(manifest, root) diff --git a/leadforge/api/generator.py b/leadforge/api/generator.py index 0b1cfc8..5d499de 100644 --- a/leadforge/api/generator.py +++ b/leadforge/api/generator.py @@ -115,8 +115,61 @@ def generate( difficulty: str | DifficultyProfile = DifficultyProfile.intermediate, **kwargs: Any, ) -> WorldBundle: - """Run the world simulation and return a bundle. + """Run the full world simulation and return an in-memory bundle. - Not yet implemented — available in v0.3.0+. + Overrides in *n_accounts*, *n_contacts*, *n_leads*, and *difficulty* + take effect for this call only — they do not mutate the Generator. + + Args: + n_accounts: Override account count. + n_contacts: Override contact count. + n_leads: Override lead count. + difficulty: Difficulty profile name or enum value. + **kwargs: Reserved for future use. + + Returns: + A fully populated :class:`~leadforge.core.models.WorldBundle`. + Call :meth:`~leadforge.core.models.WorldBundle.save` to write it + to disk. """ - raise NotImplementedError("Generator.generate() is not yet implemented. Coming in v0.3.0.") + import dataclasses + + from leadforge.simulation.engine import simulate_world + from leadforge.simulation.population import build_population + from leadforge.structure.sampler import sample_hidden_graph + + config = self._world_spec.config + + # Apply per-call overrides without mutating the shared config. + overrides: dict[str, Any] = {} + if n_accounts is not None: + overrides["n_accounts"] = n_accounts + if n_contacts is not None: + overrides["n_contacts"] = n_contacts + if n_leads is not None: + overrides["n_leads"] = n_leads + if not isinstance(difficulty, DifficultyProfile): + difficulty = DifficultyProfile(difficulty) + if difficulty != config.difficulty: + overrides["difficulty"] = difficulty + if overrides: + config = dataclasses.replace(config, **overrides) + + narrative = self._world_spec.narrative + if narrative is None: + raise RuntimeError( + "No narrative loaded. Initialise the Generator via " + "Generator.from_recipe() to resolve the narrative." + ) + + world_graph = sample_hidden_graph(config.seed) + population = build_population(config, narrative, world_graph) + result = simulate_world(config, population, world_graph) + + spec = WorldSpec(config=config, narrative=narrative) + return WorldBundle( + spec=spec, + population=population, + simulation_result=result, + world_graph=world_graph, + ) diff --git a/leadforge/core/models.py b/leadforge/core/models.py index 1aec86f..9d355eb 100644 --- a/leadforge/core/models.py +++ b/leadforge/core/models.py @@ -90,7 +90,44 @@ class WorldSpec: class WorldBundle: """In-memory result of one complete generation run. - Populated in Milestone 7+ (simulation and rendering). + Holds all generated artefacts and provides :meth:`save` to write the + full output bundle to disk. + + Attributes: + spec: Fully resolved world specification (config + narrative). + population: Generated accounts, contacts, leads, and latent state. + simulation_result: Simulated event tables and final lead outcomes. + world_graph: Sampled hidden world graph used during simulation. """ spec: WorldSpec = field(default_factory=WorldSpec) + population: Any = None # PopulationResult | None + simulation_result: Any = None # SimulationResult | None + world_graph: Any = None # WorldGraph | None + + def save(self, path: str) -> None: + """Write the full output bundle to *path*. + + Creates the directory if it does not exist. The bundle layout + matches the canonical structure defined in ``CLAUDE.md``:: + + path/ + manifest.json + dataset_card.md + feature_dictionary.csv + tables/ # one .parquet per relational table + tasks/converted_within_90_days/{train,valid,test}.parquet + tasks/converted_within_90_days/task_manifest.json + + Args: + path: Destination directory (created if absent). + + Raises: + RuntimeError: if :attr:`simulation_result`, :attr:`population`, + or :attr:`world_graph` have not been populated (i.e. if + :meth:`~leadforge.api.generator.Generator.generate` was not + called). + """ + from leadforge.api.bundle import write_bundle + + write_bundle(self, path) diff --git a/leadforge/render/manifests.py b/leadforge/render/manifests.py new file mode 100644 index 0000000..d4bea48 --- /dev/null +++ b/leadforge/render/manifests.py @@ -0,0 +1,104 @@ +"""Bundle manifest builder. + +:func:`build_manifest` constructs the ``manifest.json`` dict that is written +at the root of every output bundle. The manifest is the authoritative record +of provenance: it identifies the recipe, seed, version, and every file in the +bundle along with its SHA-256 hash and row count. +""" + +from __future__ import annotations + +import hashlib +import json +from datetime import UTC, datetime +from pathlib import Path +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from leadforge.core.models import GenerationConfig + from leadforge.structure.graph import WorldGraph + +# Bump this whenever the bundle layout or manifest schema changes. +BUNDLE_SCHEMA_VERSION = "1" + + +def build_manifest( + config: GenerationConfig, + world_graph: WorldGraph, + table_row_counts: dict[str, int], + task_row_counts: dict[str, dict[str, int]], + bundle_root: Path, + generation_timestamp: str | None = None, +) -> dict[str, Any]: + """Build the bundle manifest dict. + + SHA-256 hashes are computed by reading the written Parquet files from + *bundle_root*, so all table and task files must already exist on disk + before calling this function. + + Args: + config: The resolved generation configuration. + world_graph: The sampled hidden world graph (provides motif_family). + table_row_counts: Mapping of table name → row count. + task_row_counts: Mapping of task_id → {split_name → row count}. + bundle_root: Root directory of the written bundle. + generation_timestamp: ISO-8601 UTC timestamp string. Defaults to now. + + Returns: + A JSON-serialisable dict ready to be written as ``manifest.json``. + """ + if generation_timestamp is None: + generation_timestamp = datetime.now(UTC).isoformat(timespec="seconds") + + # Build table entries with row counts and file hashes. + tables: dict[str, Any] = {} + for table_name, row_count in table_row_counts.items(): + rel_path = f"tables/{table_name}.parquet" + abs_path = bundle_root / rel_path + sha = _sha256(abs_path) if abs_path.exists() else "" + tables[table_name] = {"row_count": row_count, "file": rel_path, "sha256": sha} + + # Build task entries. + tasks: dict[str, Any] = {} + for task_id, split_counts in task_row_counts.items(): + entry: dict[str, Any] = {} + for split_name, row_count in split_counts.items(): + rel_path = f"tasks/{task_id}/{split_name}.parquet" + abs_path = bundle_root / rel_path + sha = _sha256(abs_path) if abs_path.exists() else "" + entry[f"{split_name}_rows"] = row_count + entry[f"{split_name}_sha256"] = sha + tasks[task_id] = entry + + return { + "bundle_schema_version": BUNDLE_SCHEMA_VERSION, + "package_version": config.package_version, + "recipe_id": config.recipe_id, + "seed": config.seed, + "generation_timestamp": generation_timestamp, + "exposure_mode": config.exposure_mode.value, + "difficulty": config.difficulty.value, + "n_accounts": config.n_accounts, + "n_contacts": config.n_contacts, + "n_leads": config.n_leads, + "horizon_days": config.horizon_days, + "motif_family": world_graph.motif_family, + "tables": tables, + "tasks": tasks, + } + + +def write_manifest(manifest: dict[str, Any], bundle_root: Path) -> Path: + """Serialise *manifest* to ``bundle_root/manifest.json`` and return the path.""" + path = bundle_root / "manifest.json" + path.write_text(json.dumps(manifest, indent=2)) + return path + + +def _sha256(path: Path) -> str: + """Return the hex-encoded SHA-256 digest of *path*.""" + h = hashlib.sha256() + with path.open("rb") as fh: + for chunk in iter(lambda: fh.read(65536), b""): + h.update(chunk) + return h.hexdigest() diff --git a/leadforge/render/relational.py b/leadforge/render/relational.py new file mode 100644 index 0000000..3604dce --- /dev/null +++ b/leadforge/render/relational.py @@ -0,0 +1,77 @@ +"""Relational export — convert SimulationResult to typed DataFrames. + +:func:`to_dataframes` is the single entry point. It produces one +``pd.DataFrame`` per relational table, with dtypes matching the +:attr:`~leadforge.schema.entities.AccountRow.DTYPE_MAP` of each entity +class. The resulting dict is consumed by the bundle writer to produce +the ``tables/`` directory in the output bundle. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pandas as pd + +from leadforge.schema.entities import ( + AccountRow, + ContactRow, + CustomerRow, + LeadRow, + OpportunityRow, + SalesActivityRow, + SessionRow, + SubscriptionRow, + TouchRow, +) + +if TYPE_CHECKING: + from leadforge.simulation.engine import SimulationResult + from leadforge.simulation.population import PopulationResult + +# Mapping from table name to (entity_class, attribute_on_SimulationResult_or_population) +# Population tables come from PopulationResult; event tables from SimulationResult. +_TABLE_SOURCES: dict[str, tuple[type, str, str]] = { + # (entity_class, source: "population"|"simulation", attr_name) + AccountRow.TABLE_NAME: (AccountRow, "population", "accounts"), + ContactRow.TABLE_NAME: (ContactRow, "population", "contacts"), + LeadRow.TABLE_NAME: (LeadRow, "simulation", "leads"), + TouchRow.TABLE_NAME: (TouchRow, "simulation", "touches"), + SessionRow.TABLE_NAME: (SessionRow, "simulation", "sessions"), + SalesActivityRow.TABLE_NAME: (SalesActivityRow, "simulation", "sales_activities"), + OpportunityRow.TABLE_NAME: (OpportunityRow, "simulation", "opportunities"), + CustomerRow.TABLE_NAME: (CustomerRow, "simulation", "customers"), + SubscriptionRow.TABLE_NAME: (SubscriptionRow, "simulation", "subscriptions"), +} + + +def to_dataframes( + result: SimulationResult, + population: PopulationResult, +) -> dict[str, pd.DataFrame]: + """Convert simulation output to one typed DataFrame per relational table. + + Args: + result: Output of :func:`~leadforge.simulation.engine.simulate_world`. + population: Output of + :func:`~leadforge.simulation.population.build_population`. + + Returns: + Dict mapping table name → ``pd.DataFrame`` with dtypes matching the + entity class's ``DTYPE_MAP``. Empty tables are returned as zero-row + DataFrames with the correct schema. + """ + dfs: dict[str, pd.DataFrame] = {} + for table_name, (cls, source, attr) in _TABLE_SOURCES.items(): + obj = population if source == "population" else result + rows = getattr(obj, attr, []) + if rows: + df = pd.DataFrame([row.to_dict() for row in rows]) + # Apply canonical dtypes — use nullable pandas types where possible. + for col, dtype in cls.DTYPE_MAP.items(): # type: ignore[attr-defined] + if col in df.columns: + df[col] = df[col].astype(dtype) + else: + df = cls.empty_dataframe() # type: ignore[attr-defined] + dfs[table_name] = df + return dfs diff --git a/leadforge/render/snapshots.py b/leadforge/render/snapshots.py new file mode 100644 index 0000000..4c33968 --- /dev/null +++ b/leadforge/render/snapshots.py @@ -0,0 +1,210 @@ +"""Lead snapshot builder — flatten the simulated world into an ML-ready table. + +:func:`build_snapshot` produces one row per lead, containing the features +defined in :data:`~leadforge.schema.features.LEAD_SNAPSHOT_FEATURES`. All +columns are anchored at or before the snapshot date (lead creation + horizon), +preserving the leakage-free guarantee. + +The snapshot is the source table for the primary task export +(``converted_within_90_days``). +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pandas as pd + +from leadforge.schema.features import LEAD_SNAPSHOT_FEATURES + +if TYPE_CHECKING: + from leadforge.simulation.engine import SimulationResult + from leadforge.simulation.population import PopulationResult + +# Ordered column list derived from the canonical feature spec. +_SNAPSHOT_COLUMNS = [f.name for f in LEAD_SNAPSHOT_FEATURES] +_SNAPSHOT_DTYPES = {f.name: f.dtype for f in LEAD_SNAPSHOT_FEATURES} + + +def build_snapshot( + result: SimulationResult, + population: PopulationResult, + horizon_days: int = 90, +) -> pd.DataFrame: + """Build the lead snapshot DataFrame from simulation output. + + One row is produced per lead. Features are computed by aggregating + touches, sessions, and sales activities that occurred during the + simulation horizon. The snapshot anchor date is + ``lead_created_at + timedelta(days=horizon_days)``. + + Args: + result: Output of :func:`~leadforge.simulation.engine.simulate_world`. + population: Output of + :func:`~leadforge.simulation.population.build_population`. + horizon_days: Simulation horizon length. Defaults to 90. + + Returns: + A ``pd.DataFrame`` with the columns specified in + :data:`~leadforge.schema.features.LEAD_SNAPSHOT_FEATURES` and dtypes + matching the feature spec. Row order matches ``result.leads``. + """ + account_by_id = {a.account_id: a for a in population.accounts} + contact_by_id = {c.contact_id: c for c in population.contacts} + + # ------------------------------------------------------------------- + # Aggregate event tables by lead_id using pandas for efficiency. + # ------------------------------------------------------------------- + + # Touch aggregates + if result.touches: + td = pd.DataFrame([t.to_dict() for t in result.touches]) + touch_agg = ( + td.groupby("lead_id") + .agg( + touch_count=("touch_id", "count"), + inbound_touch_count=( + "touch_direction", + lambda s: int((s == "inbound").sum()), + ), + outbound_touch_count=( + "touch_direction", + lambda s: int((s == "outbound").sum()), + ), + last_touch_timestamp=("touch_timestamp", "max"), + ) + .reset_index() + ) + else: + touch_agg = pd.DataFrame( + columns=[ + "lead_id", + "touch_count", + "inbound_touch_count", + "outbound_touch_count", + "last_touch_timestamp", + ] + ) + + # Session aggregates + if result.sessions: + sd = pd.DataFrame([s.to_dict() for s in result.sessions]) + sess_agg = ( + sd.groupby("lead_id") + .agg( + session_count=("session_id", "count"), + pricing_page_views=("pricing_page_views", "sum"), + demo_page_views=("demo_page_views", "sum"), + total_session_duration_seconds=("session_duration_seconds", "sum"), + ) + .reset_index() + ) + else: + sess_agg = pd.DataFrame( + columns=[ + "lead_id", + "session_count", + "pricing_page_views", + "demo_page_views", + "total_session_duration_seconds", + ] + ) + + # Sales activity aggregates + if result.sales_activities: + ad = pd.DataFrame([a.to_dict() for a in result.sales_activities]) + act_agg = ad.groupby("lead_id").agg(activity_count=("activity_id", "count")).reset_index() + else: + act_agg = pd.DataFrame(columns=["lead_id", "activity_count"]) + + # Opportunity join: find open (unclosed) opportunity per lead. + if result.opportunities: + od = pd.DataFrame([o.to_dict() for o in result.opportunities]) + open_opps = od[od["close_outcome"].isna()][["lead_id", "estimated_acv"]] + # One open opp per lead (first if multiple, which shouldn't happen in v1). + open_opps = open_opps.groupby("lead_id").first().reset_index() + open_opps = open_opps.rename(columns={"estimated_acv": "opportunity_estimated_acv"}) + open_opps["has_open_opportunity"] = True + else: + open_opps = pd.DataFrame( + columns=["lead_id", "has_open_opportunity", "opportunity_estimated_acv"] + ) + + # ------------------------------------------------------------------- + # Build base lead DataFrame and join aggregates. + # ------------------------------------------------------------------- + lead_df = pd.DataFrame([lead.to_dict() for lead in result.leads]) + + # Compute snapshot anchor date (per lead, vectorised). + lead_df["anchor_date"] = pd.to_datetime(lead_df["lead_created_at"]) + pd.Timedelta( + days=horizon_days + ) + + # Join aggregates (left join preserves all leads). + lead_df = lead_df.merge(touch_agg, on="lead_id", how="left") + lead_df = lead_df.merge(sess_agg, on="lead_id", how="left") + lead_df = lead_df.merge(act_agg, on="lead_id", how="left") + lead_df = lead_df.merge(open_opps, on="lead_id", how="left") + + # Fill missing aggregates with zero / False. + lead_df["touch_count"] = lead_df["touch_count"].fillna(0).astype("Int64") + lead_df["inbound_touch_count"] = lead_df["inbound_touch_count"].fillna(0).astype("Int64") + lead_df["outbound_touch_count"] = lead_df["outbound_touch_count"].fillna(0).astype("Int64") + lead_df["session_count"] = lead_df["session_count"].fillna(0).astype("Int64") + lead_df["pricing_page_views"] = lead_df["pricing_page_views"].fillna(0).astype("Int64") + lead_df["demo_page_views"] = lead_df["demo_page_views"].fillna(0).astype("Int64") + lead_df["total_session_duration_seconds"] = ( + lead_df["total_session_duration_seconds"].fillna(0).astype("Int64") + ) + lead_df["activity_count"] = lead_df["activity_count"].fillna(0).astype("Int64") + mask = lead_df["has_open_opportunity"].notna() + lead_df["has_open_opportunity"] = ( + lead_df["has_open_opportunity"].where(mask, other=False).astype("boolean") + ) + lead_df["opportunity_estimated_acv"] = lead_df["opportunity_estimated_acv"].astype("Float64") + + # Compute days_since_last_touch (Float64, NaN when no touches). + has_touch = lead_df["last_touch_timestamp"].notna() + lead_df["days_since_last_touch"] = pd.NA + if has_touch.any(): + last_ts = pd.to_datetime(lead_df.loc[has_touch, "last_touch_timestamp"]) + lead_df.loc[has_touch, "days_since_last_touch"] = ( + lead_df.loc[has_touch, "anchor_date"] - last_ts + ).dt.days + lead_df["days_since_last_touch"] = lead_df["days_since_last_touch"].astype("Float64") + + # ------------------------------------------------------------------- + # Join account and contact features. + # ------------------------------------------------------------------- + def _account_field(row: pd.Series, field: str) -> object: + acct = account_by_id.get(row["account_id"]) + return getattr(acct, field, pd.NA) if acct else pd.NA + + def _contact_field(row: pd.Series, field: str) -> object: + cont = contact_by_id.get(row["contact_id"]) + return getattr(cont, field, pd.NA) if cont else pd.NA + + for field in ( + "industry", + "region", + "employee_band", + "estimated_revenue_band", + "process_maturity_band", + ): + lead_df[field] = lead_df.apply(_account_field, axis=1, field=field) + + for field in ("role_function", "seniority", "buyer_role"): + lead_df[field] = lead_df.apply(_contact_field, axis=1, field=field) + + # ------------------------------------------------------------------- + # Select and order columns per canonical feature spec; apply dtypes. + # ------------------------------------------------------------------- + snapshot = lead_df[_SNAPSHOT_COLUMNS].copy() + for col, dtype in _SNAPSHOT_DTYPES.items(): + if col in snapshot.columns: + try: + snapshot[col] = snapshot[col].astype(dtype) + except (ValueError, TypeError): + pass # column already has compatible dtype + + return snapshot diff --git a/leadforge/render/tasks.py b/leadforge/render/tasks.py new file mode 100644 index 0000000..01b89b9 --- /dev/null +++ b/leadforge/render/tasks.py @@ -0,0 +1,78 @@ +"""Task export — deterministic train/valid/test split and Parquet output. + +:func:`write_task_splits` takes the lead snapshot DataFrame, shuffles it +deterministically, splits it according to the task manifest ratios, and +writes the three Parquet files plus a ``task_manifest.json`` into the +tasks directory. +""" + +from __future__ import annotations + +import json +import random +from pathlib import Path + +import pandas as pd + +from leadforge.schema.tasks import CONVERTED_WITHIN_90_DAYS, TaskManifest + + +def write_task_splits( + snapshot: pd.DataFrame, + out_dir: Path, + *, + seed: int, + task: TaskManifest = CONVERTED_WITHIN_90_DAYS, +) -> dict[str, int]: + """Shuffle, split, and write snapshot Parquet files for *task*. + + Files written under ``out_dir / task.task_id /``:: + + train.parquet + valid.parquet + test.parquet + task_manifest.json + + Args: + snapshot: Lead snapshot DataFrame from + :func:`~leadforge.render.snapshots.build_snapshot`. + out_dir: Parent directory for task outputs (typically + ``bundle_root / "tasks"``). + seed: Seed used for deterministic row shuffle. + task: Task manifest describing the split ratios and label column. + + Returns: + Dict mapping split name (``"train"``, ``"valid"``, ``"test"``) to + the number of rows written. + """ + task_dir = out_dir / task.task_id + task_dir.mkdir(parents=True, exist_ok=True) + + # Deterministic shuffle via seeded RNG (index permutation). + rng = random.Random(seed) # noqa: S311 + indices = list(range(len(snapshot))) + rng.shuffle(indices) + shuffled = snapshot.iloc[indices].reset_index(drop=True) + + n = len(shuffled) + n_train = int(n * task.split.train) + n_valid = int(n * task.split.valid) + # Test gets the remainder to avoid off-by-one from integer rounding. + + splits: dict[str, pd.DataFrame] = { + "train": shuffled.iloc[:n_train], + "valid": shuffled.iloc[n_train : n_train + n_valid], + "test": shuffled.iloc[n_train + n_valid :], + } + + row_counts: dict[str, int] = {} + for split_name, df in splits.items(): + path = task_dir / f"{split_name}.parquet" + df.to_parquet(path, index=False, engine="pyarrow") + row_counts[split_name] = len(df) + + # Write task_manifest.json alongside the Parquet files. + manifest_path = task_dir / "task_manifest.json" + manifest_path.write_text(json.dumps(task.to_dict(), indent=2)) + + return row_counts diff --git a/tests/api/test_generator.py b/tests/api/test_generator.py index b7c2b1b..74caa9c 100644 --- a/tests/api/test_generator.py +++ b/tests/api/test_generator.py @@ -60,10 +60,14 @@ def test_from_recipe_invalid_id_raises() -> None: Generator.from_recipe("does_not_exist") -def test_generate_not_implemented() -> None: - gen = Generator.from_recipe("b2b_saas_procurement_v1") - with pytest.raises(NotImplementedError): - gen.generate() +def test_generate_returns_world_bundle() -> None: + from leadforge.core.models import WorldBundle + + gen = Generator.from_recipe("b2b_saas_procurement_v1", seed=42) + bundle = gen.generate(n_leads=30, n_accounts=15, n_contacts=45) + assert isinstance(bundle, WorldBundle) + assert bundle.simulation_result is not None + assert bundle.population is not None def test_from_recipe_config_has_package_version() -> None: diff --git a/tests/render/__init__.py b/tests/render/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/render/test_render.py b/tests/render/test_render.py new file mode 100644 index 0000000..fcd9f9f --- /dev/null +++ b/tests/render/test_render.py @@ -0,0 +1,450 @@ +"""Tests for the render layer: relational.py, snapshots.py, tasks.py, manifests.py.""" + +from __future__ import annotations + +import json + +import pandas as pd +import pytest + +from leadforge.core.models import GenerationConfig +from leadforge.schema.features import LEAD_SNAPSHOT_FEATURES +from leadforge.simulation.engine import simulate_world +from leadforge.simulation.population import build_population +from leadforge.structure.sampler import sample_hidden_graph + +# --------------------------------------------------------------------------- +# Shared fixtures +# --------------------------------------------------------------------------- + +_SNAPSHOT_COLUMNS = [f.name for f in LEAD_SNAPSHOT_FEATURES] +_SNAPSHOT_DTYPES = {f.name: f.dtype for f in LEAD_SNAPSHOT_FEATURES} + + +def _make_config(seed: int = 42, n_leads: int = 80) -> GenerationConfig: + return GenerationConfig(seed=seed, n_accounts=30, n_contacts=90, n_leads=n_leads) + + +def _make_narrative(): + from leadforge.api.generator import Generator + + gen = Generator.from_recipe("b2b_saas_procurement_v1", seed=42) + assert gen.world_spec.narrative is not None + return gen.world_spec.narrative + + +@pytest.fixture(scope="module") +def sim_outputs(): + """Run a small simulation once; share across all tests in this module.""" + config = _make_config() + narrative = _make_narrative() + graph = sample_hidden_graph(42) + population = build_population(config, narrative, graph) + result = simulate_world(config, population, graph) + return config, population, result, graph + + +# --------------------------------------------------------------------------- +# render/relational.py +# --------------------------------------------------------------------------- + + +class TestToDataframes: + def test_returns_all_table_names(self, sim_outputs): + _, population, result, _ = sim_outputs + from leadforge.render.relational import to_dataframes + + dfs = to_dataframes(result, population) + expected = { + "accounts", + "contacts", + "leads", + "touches", + "sessions", + "sales_activities", + "opportunities", + "customers", + "subscriptions", + } + assert set(dfs.keys()) == expected + + def test_lead_count_matches(self, sim_outputs): + config, population, result, _ = sim_outputs + from leadforge.render.relational import to_dataframes + + dfs = to_dataframes(result, population) + assert len(dfs["leads"]) == config.n_leads + + def test_account_and_contact_counts(self, sim_outputs): + config, population, result, _ = sim_outputs + from leadforge.render.relational import to_dataframes + + dfs = to_dataframes(result, population) + assert len(dfs["accounts"]) == config.n_accounts + assert len(dfs["contacts"]) == config.n_contacts + + def test_dataframes_are_dataframes(self, sim_outputs): + _, population, result, _ = sim_outputs + from leadforge.render.relational import to_dataframes + + dfs = to_dataframes(result, population) + for name, df in dfs.items(): + assert isinstance(df, pd.DataFrame), f"{name} is not a DataFrame" + + def test_empty_tables_have_schema(self, sim_outputs): + """Tables with zero rows must still expose the correct column names.""" + _, population, result, _ = sim_outputs + from leadforge.render.relational import to_dataframes + from leadforge.schema.entities import CustomerRow + + dfs = to_dataframes(result, population) + # customers may or may not be empty, but its columns must be a superset + # of the entity's DTYPE_MAP keys. + assert set(CustomerRow.DTYPE_MAP.keys()).issubset(set(dfs["customers"].columns)) + + def test_deterministic_under_same_seed(self): + """Same seed → identical relational DataFrames.""" + from leadforge.render.relational import to_dataframes + + def _run(seed): + cfg = _make_config(seed=seed) + narr = _make_narrative() + g = sample_hidden_graph(seed) + pop = build_population(cfg, narr, g) + res = simulate_world(cfg, pop, g) + return to_dataframes(res, pop) + + dfs1 = _run(77) + dfs2 = _run(77) + for tbl in ("leads", "accounts", "touches"): + pd.testing.assert_frame_equal(dfs1[tbl], dfs2[tbl], check_like=False) + + +# --------------------------------------------------------------------------- +# render/snapshots.py +# --------------------------------------------------------------------------- + + +class TestBuildSnapshot: + def test_row_count_equals_lead_count(self, sim_outputs): + config, population, result, _ = sim_outputs + from leadforge.render.snapshots import build_snapshot + + snap = build_snapshot(result, population, horizon_days=config.horizon_days) + assert len(snap) == config.n_leads + + def test_all_snapshot_columns_present(self, sim_outputs): + _, population, result, _ = sim_outputs + from leadforge.render.snapshots import build_snapshot + + snap = build_snapshot(result, population) + for col in _SNAPSHOT_COLUMNS: + assert col in snap.columns, f"Missing column: {col}" + + def test_no_extra_columns(self, sim_outputs): + _, population, result, _ = sim_outputs + from leadforge.render.snapshots import build_snapshot + + snap = build_snapshot(result, population) + assert set(snap.columns) == set(_SNAPSHOT_COLUMNS) + + def test_target_column_is_boolean(self, sim_outputs): + _, population, result, _ = sim_outputs + from leadforge.render.snapshots import build_snapshot + + snap = build_snapshot(result, population) + assert snap["converted_within_90_days"].dtype.name == "boolean" + + def test_touch_counts_non_negative(self, sim_outputs): + _, population, result, _ = sim_outputs + from leadforge.render.snapshots import build_snapshot + + snap = build_snapshot(result, population) + assert (snap["touch_count"].dropna() >= 0).all() + assert (snap["inbound_touch_count"].dropna() >= 0).all() + assert (snap["outbound_touch_count"].dropna() >= 0).all() + + def test_inbound_plus_outbound_le_total(self, sim_outputs): + """inbound + outbound ≤ touch_count (can be less if other directions exist).""" + _, population, result, _ = sim_outputs + from leadforge.render.snapshots import build_snapshot + + snap = build_snapshot(result, population) + valid = snap[["touch_count", "inbound_touch_count", "outbound_touch_count"]].dropna() + combined = valid["inbound_touch_count"] + valid["outbound_touch_count"] + assert (combined <= valid["touch_count"]).all() + + def test_days_since_last_touch_finite_when_touches_exist(self, sim_outputs): + _, population, result, _ = sim_outputs + from leadforge.render.snapshots import build_snapshot + + snap = build_snapshot(result, population) + has_touch = snap["touch_count"] > 0 + if has_touch.any(): + assert snap.loc[has_touch, "days_since_last_touch"].notna().all() + + def test_no_leakage_target_not_derived_from_future(self, sim_outputs): + """converted_within_90_days must match SimulationResult's own flag.""" + _, population, result, _ = sim_outputs + from leadforge.render.snapshots import build_snapshot + + snap = build_snapshot(result, population) + lead_flags = {row.lead_id: row.converted_within_90_days for row in result.leads} + # Map lead_id → snapshot label + snap_flags = dict(zip(snap["lead_id"], snap["converted_within_90_days"], strict=False)) + for lid, flag in lead_flags.items(): + assert snap_flags[lid] == flag, f"Mismatch on {lid}" + + def test_deterministic_under_same_seed(self): + """Same seed → identical snapshots.""" + from leadforge.render.snapshots import build_snapshot + + def _snap(seed): + cfg = _make_config(seed=seed) + narr = _make_narrative() + g = sample_hidden_graph(seed) + pop = build_population(cfg, narr, g) + res = simulate_world(cfg, pop, g) + return build_snapshot(res, pop, horizon_days=cfg.horizon_days) + + s1 = _snap(13) + s2 = _snap(13) + pd.testing.assert_frame_equal(s1, s2, check_like=False) + + +# --------------------------------------------------------------------------- +# render/tasks.py +# --------------------------------------------------------------------------- + + +class TestWriteTaskSplits: + def test_three_files_written(self, sim_outputs, tmp_path): + config, population, result, _ = sim_outputs + from leadforge.render.snapshots import build_snapshot + from leadforge.render.tasks import write_task_splits + + snap = build_snapshot(result, population, horizon_days=config.horizon_days) + write_task_splits(snap, tmp_path, seed=config.seed) + + task_dir = tmp_path / "converted_within_90_days" + for split in ("train", "valid", "test"): + assert (task_dir / f"{split}.parquet").exists(), f"{split}.parquet missing" + + def test_task_manifest_written(self, sim_outputs, tmp_path): + config, population, result, _ = sim_outputs + from leadforge.render.snapshots import build_snapshot + from leadforge.render.tasks import write_task_splits + + snap = build_snapshot(result, population, horizon_days=config.horizon_days) + write_task_splits(snap, tmp_path, seed=config.seed) + + manifest_path = tmp_path / "converted_within_90_days" / "task_manifest.json" + assert manifest_path.exists() + data = json.loads(manifest_path.read_text()) + assert "task_id" in data + + def test_row_counts_sum_to_total(self, sim_outputs, tmp_path): + config, population, result, _ = sim_outputs + from leadforge.render.snapshots import build_snapshot + from leadforge.render.tasks import write_task_splits + + snap = build_snapshot(result, population, horizon_days=config.horizon_days) + counts = write_task_splits(snap, tmp_path, seed=config.seed) + + assert counts["train"] + counts["valid"] + counts["test"] == len(snap) + + def test_split_ratios_approx(self, sim_outputs, tmp_path): + """Train ≈ 70%, valid ≈ 15%, test ≈ 15% (±5% tolerance for small samples).""" + config, population, result, _ = sim_outputs + from leadforge.render.snapshots import build_snapshot + from leadforge.render.tasks import write_task_splits + + snap = build_snapshot(result, population, horizon_days=config.horizon_days) + counts = write_task_splits(snap, tmp_path, seed=config.seed) + n = len(snap) + assert counts["train"] / n == pytest.approx(0.70, abs=0.05) + assert counts["valid"] / n == pytest.approx(0.15, abs=0.05) + assert counts["test"] / n == pytest.approx(0.15, abs=0.05) + + def test_splits_are_disjoint(self, sim_outputs, tmp_path): + config, population, result, _ = sim_outputs + from leadforge.render.snapshots import build_snapshot + from leadforge.render.tasks import write_task_splits + + snap = build_snapshot(result, population, horizon_days=config.horizon_days) + write_task_splits(snap, tmp_path, seed=config.seed) + + task_dir = tmp_path / "converted_within_90_days" + dfs = {s: pd.read_parquet(task_dir / f"{s}.parquet") for s in ("train", "valid", "test")} + ids = {s: set(dfs[s]["lead_id"]) for s in dfs} + assert ids["train"].isdisjoint(ids["valid"]) + assert ids["train"].isdisjoint(ids["test"]) + assert ids["valid"].isdisjoint(ids["test"]) + + def test_deterministic_under_same_seed(self, sim_outputs, tmp_path): + config, population, result, _ = sim_outputs + from leadforge.render.snapshots import build_snapshot + from leadforge.render.tasks import write_task_splits + + snap = build_snapshot(result, population, horizon_days=config.horizon_days) + + p1 = tmp_path / "run1" + p2 = tmp_path / "run2" + c1 = write_task_splits(snap, p1, seed=config.seed) + c2 = write_task_splits(snap, p2, seed=config.seed) + assert c1 == c2 + + task_id = "converted_within_90_days" + for split in ("train", "valid", "test"): + df1 = pd.read_parquet(p1 / task_id / f"{split}.parquet") + df2 = pd.read_parquet(p2 / task_id / f"{split}.parquet") + pd.testing.assert_frame_equal(df1, df2) + + +# --------------------------------------------------------------------------- +# render/manifests.py +# --------------------------------------------------------------------------- + + +class TestBuildManifest: + def _make_manifest(self, sim_outputs, tmp_path): + config, population, result, world_graph = sim_outputs + from leadforge.render.manifests import build_manifest + from leadforge.render.relational import to_dataframes + from leadforge.render.snapshots import build_snapshot + from leadforge.render.tasks import write_task_splits + from leadforge.schema.tables import write_parquet + + tables_dir = tmp_path / "tables" + tables_dir.mkdir() + dfs = to_dataframes(result, population) + table_row_counts = {} + for name, df in dfs.items(): + write_parquet(df, tables_dir / f"{name}.parquet") + table_row_counts[name] = len(df) + + snap = build_snapshot(result, population, horizon_days=config.horizon_days) + task_counts = write_task_splits(snap, tmp_path / "tasks", seed=config.seed) + + manifest = build_manifest( + config=config, + world_graph=world_graph, + table_row_counts=table_row_counts, + task_row_counts={"converted_within_90_days": task_counts}, + bundle_root=tmp_path, + ) + return manifest + + def test_required_top_level_keys(self, sim_outputs, tmp_path): + manifest = self._make_manifest(sim_outputs, tmp_path) + required = { + "bundle_schema_version", + "package_version", + "recipe_id", + "seed", + "generation_timestamp", + "exposure_mode", + "difficulty", + "n_accounts", + "n_contacts", + "n_leads", + "horizon_days", + "motif_family", + "tables", + "tasks", + } + assert required.issubset(set(manifest.keys())) + + def test_table_row_counts_match(self, sim_outputs, tmp_path): + config, _, _, _ = sim_outputs + manifest = self._make_manifest(sim_outputs, tmp_path) + assert manifest["tables"]["leads"]["row_count"] == config.n_leads + assert manifest["tables"]["accounts"]["row_count"] == config.n_accounts + assert manifest["tables"]["contacts"]["row_count"] == config.n_contacts + + def test_sha256_populated(self, sim_outputs, tmp_path): + manifest = self._make_manifest(sim_outputs, tmp_path) + for tbl, entry in manifest["tables"].items(): + assert isinstance(entry["sha256"], str), f"{tbl} sha256 is not a string" + assert len(entry["sha256"]) == 64, f"{tbl} sha256 has wrong length" + + def test_task_split_counts_present(self, sim_outputs, tmp_path): + manifest = self._make_manifest(sim_outputs, tmp_path) + task = manifest["tasks"]["converted_within_90_days"] + assert "train_rows" in task + assert "valid_rows" in task + assert "test_rows" in task + + def test_seed_and_recipe_recorded(self, sim_outputs, tmp_path): + config, _, _, _ = sim_outputs + manifest = self._make_manifest(sim_outputs, tmp_path) + assert manifest["seed"] == config.seed + assert manifest["recipe_id"] == config.recipe_id + + def test_manifest_is_json_serialisable(self, sim_outputs, tmp_path): + manifest = self._make_manifest(sim_outputs, tmp_path) + dumped = json.dumps(manifest) + reloaded = json.loads(dumped) + assert reloaded["seed"] == manifest["seed"] + + +# --------------------------------------------------------------------------- +# api/bundle.py — integration smoke test +# --------------------------------------------------------------------------- + + +class TestWriteBundle: + def test_full_bundle_written(self, sim_outputs, tmp_path): + config, population, result, world_graph = sim_outputs + from leadforge.api.bundle import write_bundle + from leadforge.core.models import WorldBundle, WorldSpec + + bundle = WorldBundle( + spec=WorldSpec(config=config), + population=population, + simulation_result=result, + world_graph=world_graph, + ) + write_bundle(bundle, str(tmp_path)) + + assert (tmp_path / "manifest.json").exists() + assert (tmp_path / "dataset_card.md").exists() + assert (tmp_path / "feature_dictionary.csv").exists() + assert (tmp_path / "tables").is_dir() + assert (tmp_path / "tasks" / "converted_within_90_days").is_dir() + + def test_manifest_is_valid_json(self, sim_outputs, tmp_path): + config, population, result, world_graph = sim_outputs + from leadforge.api.bundle import write_bundle + from leadforge.core.models import WorldBundle, WorldSpec + + bundle = WorldBundle( + spec=WorldSpec(config=config), + population=population, + simulation_result=result, + world_graph=world_graph, + ) + write_bundle(bundle, str(tmp_path)) + + data = json.loads((tmp_path / "manifest.json").read_text()) + assert data["seed"] == config.seed + + def test_unpopulated_bundle_raises(self): + from leadforge.api.bundle import write_bundle + from leadforge.core.models import WorldBundle + + with pytest.raises(RuntimeError, match="not fully populated"): + write_bundle(WorldBundle(), "/tmp/leadforge_test_empty") + + def test_generator_generate_and_save(self, tmp_path): + """End-to-end: Generator.from_recipe → generate → save.""" + from leadforge.api.generator import Generator + + gen = Generator.from_recipe("b2b_saas_procurement_v1", seed=7) + bundle = gen.generate(n_leads=60, n_accounts=20, n_contacts=60) + bundle.save(str(tmp_path)) + + assert (tmp_path / "manifest.json").exists() + manifest = json.loads((tmp_path / "manifest.json").read_text()) + assert manifest["n_leads"] == 60 + assert manifest["seed"] == 7 From e750de534793991751bfa69929b6ea17b25515c9 Mon Sep 17 00:00:00 2001 From: Shay Palachy Date: Tue, 28 Apr 2026 17:51:37 +0300 Subject: [PATCH 2/5] fix: address self-review on PR #13 render/bundle layer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Ten issues identified and fixed: snapshots.py: - Replace row-wise apply() for account/contact joins with vectorised merge; eliminates ~40k Python function calls for a 5k-lead dataset - Remove silent except (ValueError, TypeError): pass on dtype casts; data-incompatible casts must fail loudly - Eliminate duplicate hardcoded fallback column lists in if/else empty-list guards; use entity empty_dataframe() instead so column names are always authoritative relational.py: - Add EntityRowProtocol to schema/entities.py and use it as the type for _TABLE_SOURCES values; removes both # type: ignore[attr-defined] comments - Replace magic "population"/"simulation" strings with Literal["population", "simulation"] alias; typos are now caught at type-check time tasks.py: - Replace raw random.Random(seed) with RNGRoot(seed).child("task_split_shuffle") to honour the project's single-seeded-root design - Remove dangling comment left after previous ruff fix bundle.py: - Move render_dataset_card import to module level (was buried inside function body with no circular-import justification) tests/render/test_render.py: - test_unpopulated_bundle_raises: use tmp_path fixture instead of hardcoded /tmp - test_inbound_plus_outbound_le_total → _equals_total: assert == not <= (only two directions exist in v1; < would indicate miscategorised touches) - Replace vague "no leakage" test with test_no_post_anchor_columns_in_snapshot: asserts that conversion_timestamp, closed_at, close_outcome are absent - Add test_fk_integrity: calls validate_fk() on all ALL_CONSTRAINTS against the produced DataFrames — the core correctness property of relational export 523 tests passing; ruff + mypy clean. Co-Authored-By: Claude Sonnet 4.6 --- leadforge/api/bundle.py | 3 +- leadforge/render/relational.py | 16 ++-- leadforge/render/snapshots.py | 165 +++++++++++++++------------------ leadforge/render/tasks.py | 8 +- leadforge/schema/entities.py | 21 ++++- tests/render/test_render.py | 45 +++++++-- 6 files changed, 140 insertions(+), 118 deletions(-) diff --git a/leadforge/api/bundle.py b/leadforge/api/bundle.py index 753e29d..164781f 100644 --- a/leadforge/api/bundle.py +++ b/leadforge/api/bundle.py @@ -14,6 +14,7 @@ from pathlib import Path from typing import TYPE_CHECKING +from leadforge.narrative.dataset_card import render_dataset_card from leadforge.render.manifests import build_manifest, write_manifest from leadforge.render.relational import to_dataframes from leadforge.render.snapshots import build_snapshot @@ -68,8 +69,6 @@ def write_bundle(bundle: WorldBundle, path: str) -> None: # ------------------------------------------------------------------ # 3. Dataset card and feature dictionary # ------------------------------------------------------------------ - from leadforge.narrative.dataset_card import render_dataset_card - (root / "dataset_card.md").write_text(render_dataset_card(bundle.spec)) write_feature_dictionary(root / "feature_dictionary.csv") diff --git a/leadforge/render/relational.py b/leadforge/render/relational.py index 3604dce..c7fb328 100644 --- a/leadforge/render/relational.py +++ b/leadforge/render/relational.py @@ -9,7 +9,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal import pandas as pd @@ -17,6 +17,7 @@ AccountRow, ContactRow, CustomerRow, + EntityRowProtocol, LeadRow, OpportunityRow, SalesActivityRow, @@ -29,10 +30,10 @@ from leadforge.simulation.engine import SimulationResult from leadforge.simulation.population import PopulationResult -# Mapping from table name to (entity_class, attribute_on_SimulationResult_or_population) -# Population tables come from PopulationResult; event tables from SimulationResult. -_TABLE_SOURCES: dict[str, tuple[type, str, str]] = { - # (entity_class, source: "population"|"simulation", attr_name) +_Source = Literal["population", "simulation"] + +# Maps table name → (entity class, data source, attribute name on source object). +_TABLE_SOURCES: dict[str, tuple[type[EntityRowProtocol], _Source, str]] = { AccountRow.TABLE_NAME: (AccountRow, "population", "accounts"), ContactRow.TABLE_NAME: (ContactRow, "population", "contacts"), LeadRow.TABLE_NAME: (LeadRow, "simulation", "leads"), @@ -67,11 +68,10 @@ def to_dataframes( rows = getattr(obj, attr, []) if rows: df = pd.DataFrame([row.to_dict() for row in rows]) - # Apply canonical dtypes — use nullable pandas types where possible. - for col, dtype in cls.DTYPE_MAP.items(): # type: ignore[attr-defined] + for col, dtype in cls.DTYPE_MAP.items(): if col in df.columns: df[col] = df[col].astype(dtype) else: - df = cls.empty_dataframe() # type: ignore[attr-defined] + df = cls.empty_dataframe() dfs[table_name] = df return dfs diff --git a/leadforge/render/snapshots.py b/leadforge/render/snapshots.py index 4c33968..cae25f5 100644 --- a/leadforge/render/snapshots.py +++ b/leadforge/render/snapshots.py @@ -15,6 +15,12 @@ import pandas as pd +from leadforge.schema.entities import ( + OpportunityRow, + SalesActivityRow, + SessionRow, + TouchRow, +) from leadforge.schema.features import LEAD_SNAPSHOT_FEATURES if TYPE_CHECKING: @@ -25,6 +31,17 @@ _SNAPSHOT_COLUMNS = [f.name for f in LEAD_SNAPSHOT_FEATURES] _SNAPSHOT_DTYPES = {f.name: f.dtype for f in LEAD_SNAPSHOT_FEATURES} +# Account and contact columns needed in the snapshot (subset of their full DTYPE_MAP). +_ACCOUNT_JOIN_COLS = [ + "account_id", + "industry", + "region", + "employee_band", + "estimated_revenue_band", + "process_maturity_band", +] +_CONTACT_JOIN_COLS = ["contact_id", "role_function", "seniority", "buyer_role"] + def build_snapshot( result: SimulationResult, @@ -49,86 +66,70 @@ def build_snapshot( :data:`~leadforge.schema.features.LEAD_SNAPSHOT_FEATURES` and dtypes matching the feature spec. Row order matches ``result.leads``. """ - account_by_id = {a.account_id: a for a in population.accounts} - contact_by_id = {c.contact_id: c for c in population.contacts} - # ------------------------------------------------------------------- # Aggregate event tables by lead_id using pandas for efficiency. + # Empty event lists fall back to the entity's canonical empty DataFrame + # so groupby always produces the correct output column names. # ------------------------------------------------------------------- # Touch aggregates - if result.touches: - td = pd.DataFrame([t.to_dict() for t in result.touches]) - touch_agg = ( - td.groupby("lead_id") - .agg( - touch_count=("touch_id", "count"), - inbound_touch_count=( - "touch_direction", - lambda s: int((s == "inbound").sum()), - ), - outbound_touch_count=( - "touch_direction", - lambda s: int((s == "outbound").sum()), - ), - last_touch_timestamp=("touch_timestamp", "max"), - ) - .reset_index() - ) - else: - touch_agg = pd.DataFrame( - columns=[ - "lead_id", - "touch_count", - "inbound_touch_count", - "outbound_touch_count", - "last_touch_timestamp", - ] + td = ( + pd.DataFrame([t.to_dict() for t in result.touches]) + if result.touches + else TouchRow.empty_dataframe() + ) + touch_agg = ( + td.groupby("lead_id") + .agg( + touch_count=("touch_id", "count"), + inbound_touch_count=( + "touch_direction", + lambda s: int((s == "inbound").sum()), + ), + outbound_touch_count=( + "touch_direction", + lambda s: int((s == "outbound").sum()), + ), + last_touch_timestamp=("touch_timestamp", "max"), ) + .reset_index() + ) # Session aggregates - if result.sessions: - sd = pd.DataFrame([s.to_dict() for s in result.sessions]) - sess_agg = ( - sd.groupby("lead_id") - .agg( - session_count=("session_id", "count"), - pricing_page_views=("pricing_page_views", "sum"), - demo_page_views=("demo_page_views", "sum"), - total_session_duration_seconds=("session_duration_seconds", "sum"), - ) - .reset_index() - ) - else: - sess_agg = pd.DataFrame( - columns=[ - "lead_id", - "session_count", - "pricing_page_views", - "demo_page_views", - "total_session_duration_seconds", - ] + sd = ( + pd.DataFrame([s.to_dict() for s in result.sessions]) + if result.sessions + else SessionRow.empty_dataframe() + ) + sess_agg = ( + sd.groupby("lead_id") + .agg( + session_count=("session_id", "count"), + pricing_page_views=("pricing_page_views", "sum"), + demo_page_views=("demo_page_views", "sum"), + total_session_duration_seconds=("session_duration_seconds", "sum"), ) + .reset_index() + ) # Sales activity aggregates - if result.sales_activities: - ad = pd.DataFrame([a.to_dict() for a in result.sales_activities]) - act_agg = ad.groupby("lead_id").agg(activity_count=("activity_id", "count")).reset_index() - else: - act_agg = pd.DataFrame(columns=["lead_id", "activity_count"]) + ad = ( + pd.DataFrame([a.to_dict() for a in result.sales_activities]) + if result.sales_activities + else SalesActivityRow.empty_dataframe() + ) + act_agg = ad.groupby("lead_id").agg(activity_count=("activity_id", "count")).reset_index() # Opportunity join: find open (unclosed) opportunity per lead. - if result.opportunities: - od = pd.DataFrame([o.to_dict() for o in result.opportunities]) - open_opps = od[od["close_outcome"].isna()][["lead_id", "estimated_acv"]] - # One open opp per lead (first if multiple, which shouldn't happen in v1). - open_opps = open_opps.groupby("lead_id").first().reset_index() - open_opps = open_opps.rename(columns={"estimated_acv": "opportunity_estimated_acv"}) - open_opps["has_open_opportunity"] = True - else: - open_opps = pd.DataFrame( - columns=["lead_id", "has_open_opportunity", "opportunity_estimated_acv"] - ) + od = ( + pd.DataFrame([o.to_dict() for o in result.opportunities]) + if result.opportunities + else OpportunityRow.empty_dataframe() + ) + open_opps = od[od["close_outcome"].isna()][["lead_id", "estimated_acv"]] + open_opps = open_opps.groupby("lead_id").first().reset_index() + open_opps = open_opps.rename(columns={"estimated_acv": "opportunity_estimated_acv"}) + open_opps["has_open_opportunity"] = True # ------------------------------------------------------------------- # Build base lead DataFrame and join aggregates. @@ -174,27 +175,12 @@ def build_snapshot( lead_df["days_since_last_touch"] = lead_df["days_since_last_touch"].astype("Float64") # ------------------------------------------------------------------- - # Join account and contact features. + # Join account and contact features via vectorised merge (not apply). # ------------------------------------------------------------------- - def _account_field(row: pd.Series, field: str) -> object: - acct = account_by_id.get(row["account_id"]) - return getattr(acct, field, pd.NA) if acct else pd.NA - - def _contact_field(row: pd.Series, field: str) -> object: - cont = contact_by_id.get(row["contact_id"]) - return getattr(cont, field, pd.NA) if cont else pd.NA - - for field in ( - "industry", - "region", - "employee_band", - "estimated_revenue_band", - "process_maturity_band", - ): - lead_df[field] = lead_df.apply(_account_field, axis=1, field=field) - - for field in ("role_function", "seniority", "buyer_role"): - lead_df[field] = lead_df.apply(_contact_field, axis=1, field=field) + acct_df = pd.DataFrame([a.to_dict() for a in population.accounts])[_ACCOUNT_JOIN_COLS] + cont_df = pd.DataFrame([c.to_dict() for c in population.contacts])[_CONTACT_JOIN_COLS] + lead_df = lead_df.merge(acct_df, on="account_id", how="left") + lead_df = lead_df.merge(cont_df, on="contact_id", how="left") # ------------------------------------------------------------------- # Select and order columns per canonical feature spec; apply dtypes. @@ -202,9 +188,6 @@ def _contact_field(row: pd.Series, field: str) -> object: snapshot = lead_df[_SNAPSHOT_COLUMNS].copy() for col, dtype in _SNAPSHOT_DTYPES.items(): if col in snapshot.columns: - try: - snapshot[col] = snapshot[col].astype(dtype) - except (ValueError, TypeError): - pass # column already has compatible dtype + snapshot[col] = snapshot[col].astype(dtype) return snapshot diff --git a/leadforge/render/tasks.py b/leadforge/render/tasks.py index 01b89b9..145005e 100644 --- a/leadforge/render/tasks.py +++ b/leadforge/render/tasks.py @@ -9,11 +9,11 @@ from __future__ import annotations import json -import random from pathlib import Path import pandas as pd +from leadforge.core.rng import RNGRoot from leadforge.schema.tasks import CONVERTED_WITHIN_90_DAYS, TaskManifest @@ -48,8 +48,8 @@ def write_task_splits( task_dir = out_dir / task.task_id task_dir.mkdir(parents=True, exist_ok=True) - # Deterministic shuffle via seeded RNG (index permutation). - rng = random.Random(seed) # noqa: S311 + # Deterministic shuffle via the project's RNG substream system. + rng = RNGRoot(seed).child("task_split_shuffle") indices = list(range(len(snapshot))) rng.shuffle(indices) shuffled = snapshot.iloc[indices].reset_index(drop=True) @@ -57,7 +57,7 @@ def write_task_splits( n = len(shuffled) n_train = int(n * task.split.train) n_valid = int(n * task.split.valid) - # Test gets the remainder to avoid off-by-one from integer rounding. + # test split gets the remainder to avoid off-by-one from integer rounding. splits: dict[str, pd.DataFrame] = { "train": shuffled.iloc[:n_train], diff --git a/leadforge/schema/entities.py b/leadforge/schema/entities.py index 2a60a7c..d97230a 100644 --- a/leadforge/schema/entities.py +++ b/leadforge/schema/entities.py @@ -18,11 +18,26 @@ from __future__ import annotations from dataclasses import dataclass, fields -from typing import Any, ClassVar +from typing import Any, ClassVar, Protocol import pandas as pd +class EntityRowProtocol(Protocol): + """Structural protocol shared by all entity row dataclasses. + + Allows typed dispatch in render code without coupling to concrete classes. + """ + + TABLE_NAME: ClassVar[str] + DTYPE_MAP: ClassVar[dict[str, str]] + + def to_dict(self) -> dict[str, Any]: ... + + @classmethod + def empty_dataframe(cls) -> pd.DataFrame: ... + + def _empty_df(dtype_map: dict[str, str]) -> pd.DataFrame: """Return a zero-row DataFrame with columns ordered as *dtype_map*.""" return pd.DataFrame({col: pd.array([], dtype=dtype) for col, dtype in dtype_map.items()}) @@ -360,7 +375,7 @@ def empty_dataframe(cls) -> pd.DataFrame: # Registry # --------------------------------------------------------------------------- -ALL_ROW_TYPES: tuple[type, ...] = ( +ALL_ROW_TYPES: tuple[type[EntityRowProtocol], ...] = ( AccountRow, ContactRow, LeadRow, @@ -372,4 +387,4 @@ def empty_dataframe(cls) -> pd.DataFrame: SubscriptionRow, ) -TABLE_NAMES: tuple[str, ...] = tuple(cls.TABLE_NAME for cls in ALL_ROW_TYPES) # type: ignore[attr-defined] +TABLE_NAMES: tuple[str, ...] = tuple(cls.TABLE_NAME for cls in ALL_ROW_TYPES) diff --git a/tests/render/test_render.py b/tests/render/test_render.py index fcd9f9f..f1c2abc 100644 --- a/tests/render/test_render.py +++ b/tests/render/test_render.py @@ -98,10 +98,26 @@ def test_empty_tables_have_schema(self, sim_outputs): from leadforge.schema.entities import CustomerRow dfs = to_dataframes(result, population) - # customers may or may not be empty, but its columns must be a superset - # of the entity's DTYPE_MAP keys. assert set(CustomerRow.DTYPE_MAP.keys()).issubset(set(dfs["customers"].columns)) + def test_fk_integrity(self, sim_outputs): + """All FK constraints must hold on the produced DataFrames.""" + _, population, result, _ = sim_outputs + from leadforge.render.relational import to_dataframes + from leadforge.schema.relationships import ALL_CONSTRAINTS, validate_fk + + dfs = to_dataframes(result, population) + for constraint in ALL_CONSTRAINTS: + child_df = dfs.get(constraint.child_table) + parent_df = dfs.get(constraint.parent_table) + if child_df is None or parent_df is None or child_df.empty: + continue + validate_fk( + child_values=child_df[constraint.child_column].dropna().tolist(), + parent_values=set(parent_df[constraint.parent_column].tolist()), + constraint=constraint, + ) + def test_deterministic_under_same_seed(self): """Same seed → identical relational DataFrames.""" from leadforge.render.relational import to_dataframes @@ -164,15 +180,15 @@ def test_touch_counts_non_negative(self, sim_outputs): assert (snap["inbound_touch_count"].dropna() >= 0).all() assert (snap["outbound_touch_count"].dropna() >= 0).all() - def test_inbound_plus_outbound_le_total(self, sim_outputs): - """inbound + outbound ≤ touch_count (can be less if other directions exist).""" + def test_inbound_plus_outbound_equals_total(self, sim_outputs): + """inbound + outbound must equal touch_count exactly (only two directions in v1).""" _, population, result, _ = sim_outputs from leadforge.render.snapshots import build_snapshot snap = build_snapshot(result, population) valid = snap[["touch_count", "inbound_touch_count", "outbound_touch_count"]].dropna() combined = valid["inbound_touch_count"] + valid["outbound_touch_count"] - assert (combined <= valid["touch_count"]).all() + assert (combined == valid["touch_count"]).all() def test_days_since_last_touch_finite_when_touches_exist(self, sim_outputs): _, population, result, _ = sim_outputs @@ -183,14 +199,23 @@ def test_days_since_last_touch_finite_when_touches_exist(self, sim_outputs): if has_touch.any(): assert snap.loc[has_touch, "days_since_last_touch"].notna().all() - def test_no_leakage_target_not_derived_from_future(self, sim_outputs): - """converted_within_90_days must match SimulationResult's own flag.""" + def test_no_post_anchor_columns_in_snapshot(self, sim_outputs): + """Columns that represent post-anchor truth must not appear in the snapshot.""" + _, population, result, _ = sim_outputs + from leadforge.render.snapshots import build_snapshot + + snap = build_snapshot(result, population) + # These exist in LeadRow / OpportunityRow but must be excluded (leakage rule). + forbidden = {"conversion_timestamp", "closed_at", "close_outcome"} + assert forbidden.isdisjoint(set(snap.columns)) + + def test_target_matches_simulation_result(self, sim_outputs): + """converted_within_90_days in snapshot must match SimulationResult's flag.""" _, population, result, _ = sim_outputs from leadforge.render.snapshots import build_snapshot snap = build_snapshot(result, population) lead_flags = {row.lead_id: row.converted_within_90_days for row in result.leads} - # Map lead_id → snapshot label snap_flags = dict(zip(snap["lead_id"], snap["converted_within_90_days"], strict=False)) for lid, flag in lead_flags.items(): assert snap_flags[lid] == flag, f"Mismatch on {lid}" @@ -429,12 +454,12 @@ def test_manifest_is_valid_json(self, sim_outputs, tmp_path): data = json.loads((tmp_path / "manifest.json").read_text()) assert data["seed"] == config.seed - def test_unpopulated_bundle_raises(self): + def test_unpopulated_bundle_raises(self, tmp_path): from leadforge.api.bundle import write_bundle from leadforge.core.models import WorldBundle with pytest.raises(RuntimeError, match="not fully populated"): - write_bundle(WorldBundle(), "/tmp/leadforge_test_empty") + write_bundle(WorldBundle(), str(tmp_path)) def test_generator_generate_and_save(self, tmp_path): """End-to-end: Generator.from_recipe → generate → save.""" From 47dc5e23b1c11ebae5b039af6e3e0be7aa030e1c Mon Sep 17 00:00:00 2001 From: Shay Palachy Date: Tue, 28 Apr 2026 18:40:47 +0300 Subject: [PATCH 3/5] fix: second self-review pass on PR #13 Nine issues fixed: snapshots.py: - Derive _ACCOUNT_JOIN_COLS / _CONTACT_JOIN_COLS from LEAD_SNAPSHOT_FEATURES categories instead of maintaining a separate hand-written list; new account or contact features in the spec are now included automatically - Consolidate dtype application: replace eight individual fillna+astype lines with a single bulk fillna block, then rely on the final dtype loop as the sole cast; removes duplicate astype calls on the same columns - Vectorise days_since_last_touch: pd.to_datetime returns NaT for nulls and (Timestamp - NaT) yields NaN, so the has_touch.any() branch is unnecessary; one expression replaces six lines relational.py: - Replace opaque 3-tuple with _TableSource NamedTuple (cls, origin, attr); each field is named at the call site - Drop getattr default []: a typo in _TABLE_SOURCES now raises AttributeError immediately instead of silently returning empty DataFrames manifests.py: - Remove `if abs_path.exists() else ""` guard on _sha256 calls; a missing Parquet file is a bug that should raise, not silently produce a corrupt empty-hash manifest bundle.py: - Replace hardcoded "converted_within_90_days" string with CONVERTED_WITHIN_90_DAYS.task_id; one source of truth for the task ID tasks.py: - Move detached comment inline next to the "test" slice it describes tests/render/test_render.py: - _make_narrative() now accepts and forwards seed so determinism test helpers (_snap, _run) no longer lie about which seed they're using 523 tests passing; ruff + mypy clean. Co-Authored-By: Claude Sonnet 4.6 --- leadforge/api/bundle.py | 3 +- leadforge/render/manifests.py | 4 +- leadforge/render/relational.py | 41 ++++++++++++--------- leadforge/render/snapshots.py | 67 +++++++++++++++------------------- leadforge/render/tasks.py | 3 +- tests/render/test_render.py | 10 ++--- 6 files changed, 64 insertions(+), 64 deletions(-) diff --git a/leadforge/api/bundle.py b/leadforge/api/bundle.py index 164781f..2cfdb4c 100644 --- a/leadforge/api/bundle.py +++ b/leadforge/api/bundle.py @@ -21,6 +21,7 @@ from leadforge.render.tasks import write_task_splits from leadforge.schema.dictionaries import write_feature_dictionary from leadforge.schema.tables import write_parquet +from leadforge.schema.tasks import CONVERTED_WITHIN_90_DAYS if TYPE_CHECKING: from leadforge.core.models import WorldBundle @@ -79,7 +80,7 @@ def write_bundle(bundle: WorldBundle, path: str) -> None: config=config, world_graph=world_graph, table_row_counts=table_row_counts, - task_row_counts={"converted_within_90_days": task_row_counts}, + task_row_counts={CONVERTED_WITHIN_90_DAYS.task_id: task_row_counts}, bundle_root=root, ) write_manifest(manifest, root) diff --git a/leadforge/render/manifests.py b/leadforge/render/manifests.py index d4bea48..dc49b23 100644 --- a/leadforge/render/manifests.py +++ b/leadforge/render/manifests.py @@ -55,7 +55,7 @@ def build_manifest( for table_name, row_count in table_row_counts.items(): rel_path = f"tables/{table_name}.parquet" abs_path = bundle_root / rel_path - sha = _sha256(abs_path) if abs_path.exists() else "" + sha = _sha256(abs_path) tables[table_name] = {"row_count": row_count, "file": rel_path, "sha256": sha} # Build task entries. @@ -65,7 +65,7 @@ def build_manifest( for split_name, row_count in split_counts.items(): rel_path = f"tasks/{task_id}/{split_name}.parquet" abs_path = bundle_root / rel_path - sha = _sha256(abs_path) if abs_path.exists() else "" + sha = _sha256(abs_path) entry[f"{split_name}_rows"] = row_count entry[f"{split_name}_sha256"] = sha tasks[task_id] = entry diff --git a/leadforge/render/relational.py b/leadforge/render/relational.py index c7fb328..fb21578 100644 --- a/leadforge/render/relational.py +++ b/leadforge/render/relational.py @@ -9,7 +9,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal, NamedTuple import pandas as pd @@ -32,17 +32,24 @@ _Source = Literal["population", "simulation"] -# Maps table name → (entity class, data source, attribute name on source object). -_TABLE_SOURCES: dict[str, tuple[type[EntityRowProtocol], _Source, str]] = { - AccountRow.TABLE_NAME: (AccountRow, "population", "accounts"), - ContactRow.TABLE_NAME: (ContactRow, "population", "contacts"), - LeadRow.TABLE_NAME: (LeadRow, "simulation", "leads"), - TouchRow.TABLE_NAME: (TouchRow, "simulation", "touches"), - SessionRow.TABLE_NAME: (SessionRow, "simulation", "sessions"), - SalesActivityRow.TABLE_NAME: (SalesActivityRow, "simulation", "sales_activities"), - OpportunityRow.TABLE_NAME: (OpportunityRow, "simulation", "opportunities"), - CustomerRow.TABLE_NAME: (CustomerRow, "simulation", "customers"), - SubscriptionRow.TABLE_NAME: (SubscriptionRow, "simulation", "subscriptions"), + +class _TableSource(NamedTuple): + cls: type[EntityRowProtocol] + origin: _Source # which object holds the rows + attr: str # attribute name on that object + + +# Maps table name → source descriptor. +_TABLE_SOURCES: dict[str, _TableSource] = { + AccountRow.TABLE_NAME: _TableSource(AccountRow, "population", "accounts"), + ContactRow.TABLE_NAME: _TableSource(ContactRow, "population", "contacts"), + LeadRow.TABLE_NAME: _TableSource(LeadRow, "simulation", "leads"), + TouchRow.TABLE_NAME: _TableSource(TouchRow, "simulation", "touches"), + SessionRow.TABLE_NAME: _TableSource(SessionRow, "simulation", "sessions"), + SalesActivityRow.TABLE_NAME: _TableSource(SalesActivityRow, "simulation", "sales_activities"), + OpportunityRow.TABLE_NAME: _TableSource(OpportunityRow, "simulation", "opportunities"), + CustomerRow.TABLE_NAME: _TableSource(CustomerRow, "simulation", "customers"), + SubscriptionRow.TABLE_NAME: _TableSource(SubscriptionRow, "simulation", "subscriptions"), } @@ -63,15 +70,15 @@ def to_dataframes( DataFrames with the correct schema. """ dfs: dict[str, pd.DataFrame] = {} - for table_name, (cls, source, attr) in _TABLE_SOURCES.items(): - obj = population if source == "population" else result - rows = getattr(obj, attr, []) + for table_name, src in _TABLE_SOURCES.items(): + obj = population if src.origin == "population" else result + rows = getattr(obj, src.attr) # AttributeError surfaces missing attrs immediately if rows: df = pd.DataFrame([row.to_dict() for row in rows]) - for col, dtype in cls.DTYPE_MAP.items(): + for col, dtype in src.cls.DTYPE_MAP.items(): if col in df.columns: df[col] = df[col].astype(dtype) else: - df = cls.empty_dataframe() + df = src.cls.empty_dataframe() dfs[table_name] = df return dfs diff --git a/leadforge/render/snapshots.py b/leadforge/render/snapshots.py index cae25f5..f43fbd6 100644 --- a/leadforge/render/snapshots.py +++ b/leadforge/render/snapshots.py @@ -27,20 +27,27 @@ from leadforge.simulation.engine import SimulationResult from leadforge.simulation.population import PopulationResult -# Ordered column list derived from the canonical feature spec. +# Ordered column list and dtypes derived from the canonical feature spec. _SNAPSHOT_COLUMNS = [f.name for f in LEAD_SNAPSHOT_FEATURES] _SNAPSHOT_DTYPES = {f.name: f.dtype for f in LEAD_SNAPSHOT_FEATURES} -# Account and contact columns needed in the snapshot (subset of their full DTYPE_MAP). -_ACCOUNT_JOIN_COLS = [ - "account_id", - "industry", - "region", - "employee_band", - "estimated_revenue_band", - "process_maturity_band", +# Join columns derived from the feature spec — single source of truth. +# Adding a new account/contact feature to LEAD_SNAPSHOT_FEATURES automatically +# includes it here without any manual list maintenance. +_ACCOUNT_JOIN_COLS = [f.name for f in LEAD_SNAPSHOT_FEATURES if f.category == "account"] +_CONTACT_JOIN_COLS = [f.name for f in LEAD_SNAPSHOT_FEATURES if f.category == "contact"] + +# Aggregated count columns that need zero-filling after left-merge. +_INT_AGG_COLS = [ + "touch_count", + "inbound_touch_count", + "outbound_touch_count", + "session_count", + "pricing_page_views", + "demo_page_views", + "total_session_duration_seconds", + "activity_count", ] -_CONTACT_JOIN_COLS = ["contact_id", "role_function", "seniority", "buyer_role"] def build_snapshot( @@ -147,35 +154,21 @@ def build_snapshot( lead_df = lead_df.merge(act_agg, on="lead_id", how="left") lead_df = lead_df.merge(open_opps, on="lead_id", how="left") - # Fill missing aggregates with zero / False. - lead_df["touch_count"] = lead_df["touch_count"].fillna(0).astype("Int64") - lead_df["inbound_touch_count"] = lead_df["inbound_touch_count"].fillna(0).astype("Int64") - lead_df["outbound_touch_count"] = lead_df["outbound_touch_count"].fillna(0).astype("Int64") - lead_df["session_count"] = lead_df["session_count"].fillna(0).astype("Int64") - lead_df["pricing_page_views"] = lead_df["pricing_page_views"].fillna(0).astype("Int64") - lead_df["demo_page_views"] = lead_df["demo_page_views"].fillna(0).astype("Int64") - lead_df["total_session_duration_seconds"] = ( - lead_df["total_session_duration_seconds"].fillna(0).astype("Int64") - ) - lead_df["activity_count"] = lead_df["activity_count"].fillna(0).astype("Int64") - mask = lead_df["has_open_opportunity"].notna() - lead_df["has_open_opportunity"] = ( - lead_df["has_open_opportunity"].where(mask, other=False).astype("boolean") - ) - lead_df["opportunity_estimated_acv"] = lead_df["opportunity_estimated_acv"].astype("Float64") - - # Compute days_since_last_touch (Float64, NaN when no touches). - has_touch = lead_df["last_touch_timestamp"].notna() - lead_df["days_since_last_touch"] = pd.NA - if has_touch.any(): - last_ts = pd.to_datetime(lead_df.loc[has_touch, "last_touch_timestamp"]) - lead_df.loc[has_touch, "days_since_last_touch"] = ( - lead_df.loc[has_touch, "anchor_date"] - last_ts - ).dt.days - lead_df["days_since_last_touch"] = lead_df["days_since_last_touch"].astype("Float64") + # Fill missing event aggregate counts with zero; has_open_opportunity with False. + # opportunity_estimated_acv and days_since_last_touch intentionally stay NaN. + lead_df[_INT_AGG_COLS] = lead_df[_INT_AGG_COLS].fillna(0) + opp_mask = lead_df["has_open_opportunity"].notna() + lead_df["has_open_opportunity"] = lead_df["has_open_opportunity"].where(opp_mask, other=False) + + # Compute days_since_last_touch fully vectorised. + # pd.to_datetime returns NaT for nulls; (Timestamp - NaT) yields NaN naturally. + last_ts = pd.to_datetime(lead_df["last_touch_timestamp"]) + lead_df["days_since_last_touch"] = (lead_df["anchor_date"] - last_ts).dt.days # ------------------------------------------------------------------- # Join account and contact features via vectorised merge (not apply). + # Columns are derived from LEAD_SNAPSHOT_FEATURES categories so this + # list stays in sync automatically when the feature spec changes. # ------------------------------------------------------------------- acct_df = pd.DataFrame([a.to_dict() for a in population.accounts])[_ACCOUNT_JOIN_COLS] cont_df = pd.DataFrame([c.to_dict() for c in population.contacts])[_CONTACT_JOIN_COLS] @@ -183,7 +176,7 @@ def build_snapshot( lead_df = lead_df.merge(cont_df, on="contact_id", how="left") # ------------------------------------------------------------------- - # Select and order columns per canonical feature spec; apply dtypes. + # Select, order, and cast columns — single authoritative dtype pass. # ------------------------------------------------------------------- snapshot = lead_df[_SNAPSHOT_COLUMNS].copy() for col, dtype in _SNAPSHOT_DTYPES.items(): diff --git a/leadforge/render/tasks.py b/leadforge/render/tasks.py index 145005e..7d6013c 100644 --- a/leadforge/render/tasks.py +++ b/leadforge/render/tasks.py @@ -57,12 +57,11 @@ def write_task_splits( n = len(shuffled) n_train = int(n * task.split.train) n_valid = int(n * task.split.valid) - # test split gets the remainder to avoid off-by-one from integer rounding. splits: dict[str, pd.DataFrame] = { "train": shuffled.iloc[:n_train], "valid": shuffled.iloc[n_train : n_train + n_valid], - "test": shuffled.iloc[n_train + n_valid :], + "test": shuffled.iloc[n_train + n_valid :], # remainder avoids rounding off-by-one } row_counts: dict[str, int] = {} diff --git a/tests/render/test_render.py b/tests/render/test_render.py index f1c2abc..18df3cd 100644 --- a/tests/render/test_render.py +++ b/tests/render/test_render.py @@ -25,10 +25,10 @@ def _make_config(seed: int = 42, n_leads: int = 80) -> GenerationConfig: return GenerationConfig(seed=seed, n_accounts=30, n_contacts=90, n_leads=n_leads) -def _make_narrative(): +def _make_narrative(seed: int = 42): from leadforge.api.generator import Generator - gen = Generator.from_recipe("b2b_saas_procurement_v1", seed=42) + gen = Generator.from_recipe("b2b_saas_procurement_v1", seed=seed) assert gen.world_spec.narrative is not None return gen.world_spec.narrative @@ -37,7 +37,7 @@ def _make_narrative(): def sim_outputs(): """Run a small simulation once; share across all tests in this module.""" config = _make_config() - narrative = _make_narrative() + narrative = _make_narrative(config.seed) graph = sample_hidden_graph(42) population = build_population(config, narrative, graph) result = simulate_world(config, population, graph) @@ -124,7 +124,7 @@ def test_deterministic_under_same_seed(self): def _run(seed): cfg = _make_config(seed=seed) - narr = _make_narrative() + narr = _make_narrative(seed) g = sample_hidden_graph(seed) pop = build_population(cfg, narr, g) res = simulate_world(cfg, pop, g) @@ -226,7 +226,7 @@ def test_deterministic_under_same_seed(self): def _snap(seed): cfg = _make_config(seed=seed) - narr = _make_narrative() + narr = _make_narrative(seed) g = sample_hidden_graph(seed) pop = build_population(cfg, narr, g) res = simulate_world(cfg, pop, g) From 42e4ae1038b9a3dcc3333a5ef99ec8c50e335bd3 Mon Sep 17 00:00:00 2001 From: Shay Palachy Date: Tue, 28 Apr 2026 19:30:26 +0300 Subject: [PATCH 4/5] fix: address Copilot review comments on PR #13 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit COPILOT-1 (tasks.py — random.Random): already fixed in prior commit; now uses RNGRoot(seed).child("task_split_shuffle"). COPILOT-2 (manifests.py — empty SHA on missing table file): already fixed; _sha256() is called unconditionally, raises FileNotFoundError if absent. COPILOT-3 (manifests.py — same for task split files, outdated): same fix covers both table and task SHA calls. COPILOT-4 (manifests.py — misleading module docstring): fixed by narrowing the docstring to accurately describe what the manifest covers (Parquet data files only — relational tables and task splits). Expanding to hash dataset_card.md / feature_dictionary.csv / task_manifest.json is not required by the architecture spec and is deferred. Co-Authored-By: Claude Sonnet 4.6 --- leadforge/render/manifests.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/leadforge/render/manifests.py b/leadforge/render/manifests.py index dc49b23..fa8dc31 100644 --- a/leadforge/render/manifests.py +++ b/leadforge/render/manifests.py @@ -1,9 +1,9 @@ """Bundle manifest builder. :func:`build_manifest` constructs the ``manifest.json`` dict that is written -at the root of every output bundle. The manifest is the authoritative record -of provenance: it identifies the recipe, seed, version, and every file in the -bundle along with its SHA-256 hash and row count. +at the root of every output bundle. The manifest records provenance (recipe, +seed, version, generation timestamp) and integrity metadata (row counts and +SHA-256 hashes) for the Parquet data files: relational tables and task splits. """ from __future__ import annotations From 5629f12f55731fa00f49dfa64bc0ef172d65d5d0 Mon Sep 17 00:00:00 2001 From: Shay Palachy Date: Tue, 28 Apr 2026 23:00:33 +0300 Subject: [PATCH 5/5] fix: address Copilot round-2 review on PR #13 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit COPILOT-1 (generator.py — difficulty default silently overrides recipe): Default difficulty to _MISSING instead of DifficultyProfile.intermediate; the override is only applied when the caller explicitly passes a value. Generator.from_recipe(..., difficulty="advanced").generate() now keeps the advanced profile rather than silently reverting to intermediate. Added two regression tests. COPILOT-2 (models.py — WorldBundle fields typed Any): Import SimulationResult, PopulationResult, WorldGraph under TYPE_CHECKING and annotate the three fields with their concrete union types. With from __future__ import annotations the imports are lazy at runtime so there is no circular import risk. FAIL-1 (PR agent context refresh startup_failure): Infrastructure-level runner failure unrelated to code; resolved as irrelevant. Co-Authored-By: Claude Sonnet 4.6 --- leadforge/api/generator.py | 15 +++++++++------ leadforge/core/models.py | 9 ++++++--- tests/api/test_generator.py | 21 +++++++++++++++++++++ 3 files changed, 36 insertions(+), 9 deletions(-) diff --git a/leadforge/api/generator.py b/leadforge/api/generator.py index 5d499de..39a719c 100644 --- a/leadforge/api/generator.py +++ b/leadforge/api/generator.py @@ -112,19 +112,21 @@ def generate( n_accounts: int | None = None, n_contacts: int | None = None, n_leads: int | None = None, - difficulty: str | DifficultyProfile = DifficultyProfile.intermediate, + difficulty: str | DifficultyProfile = _MISSING, # type: ignore[assignment] **kwargs: Any, ) -> WorldBundle: """Run the full world simulation and return an in-memory bundle. Overrides in *n_accounts*, *n_contacts*, *n_leads*, and *difficulty* take effect for this call only — they do not mutate the Generator. + When *difficulty* is omitted the Generator's configured difficulty is used. Args: n_accounts: Override account count. n_contacts: Override contact count. n_leads: Override lead count. - difficulty: Difficulty profile name or enum value. + difficulty: Difficulty profile name or enum value. Defaults to + the difficulty set on the Generator (i.e. from the recipe). **kwargs: Reserved for future use. Returns: @@ -148,10 +150,11 @@ def generate( overrides["n_contacts"] = n_contacts if n_leads is not None: overrides["n_leads"] = n_leads - if not isinstance(difficulty, DifficultyProfile): - difficulty = DifficultyProfile(difficulty) - if difficulty != config.difficulty: - overrides["difficulty"] = difficulty + if difficulty is not _MISSING: + if not isinstance(difficulty, DifficultyProfile): + difficulty = DifficultyProfile(difficulty) # type: ignore[arg-type] + if difficulty != config.difficulty: + overrides["difficulty"] = difficulty if overrides: config = dataclasses.replace(config, **overrides) diff --git a/leadforge/core/models.py b/leadforge/core/models.py index 9d355eb..d4fe66b 100644 --- a/leadforge/core/models.py +++ b/leadforge/core/models.py @@ -11,6 +11,9 @@ if TYPE_CHECKING: from leadforge.narrative.spec import NarrativeSpec + from leadforge.simulation.engine import SimulationResult + from leadforge.simulation.population import PopulationResult + from leadforge.structure.graph import WorldGraph def _require_positive_int(value: Any, name: str) -> None: @@ -101,9 +104,9 @@ class WorldBundle: """ spec: WorldSpec = field(default_factory=WorldSpec) - population: Any = None # PopulationResult | None - simulation_result: Any = None # SimulationResult | None - world_graph: Any = None # WorldGraph | None + population: PopulationResult | None = None + simulation_result: SimulationResult | None = None + world_graph: WorldGraph | None = None def save(self, path: str) -> None: """Write the full output bundle to *path*. diff --git a/tests/api/test_generator.py b/tests/api/test_generator.py index 74caa9c..05e896e 100644 --- a/tests/api/test_generator.py +++ b/tests/api/test_generator.py @@ -70,6 +70,27 @@ def test_generate_returns_world_bundle() -> None: assert bundle.population is not None +def test_generate_respects_recipe_difficulty_when_not_overridden() -> None: + """Calling generate() without difficulty must not silently override the recipe's setting.""" + from leadforge.core.enums import DifficultyProfile + + gen = Generator.from_recipe("b2b_saas_procurement_v1", difficulty="advanced") + assert gen.config.difficulty == DifficultyProfile.advanced + bundle = gen.generate(n_leads=20, n_accounts=10, n_contacts=30) + assert bundle.spec.config.difficulty == DifficultyProfile.advanced + + +def test_generate_explicit_difficulty_overrides_recipe() -> None: + """An explicit difficulty kwarg must override the recipe setting for that call only.""" + from leadforge.core.enums import DifficultyProfile + + gen = Generator.from_recipe("b2b_saas_procurement_v1", difficulty="advanced") + bundle = gen.generate(n_leads=20, n_accounts=10, n_contacts=30, difficulty="intro") + assert bundle.spec.config.difficulty == DifficultyProfile.intro + # Generator itself is unchanged. + assert gen.config.difficulty == DifficultyProfile.advanced + + def test_from_recipe_config_has_package_version() -> None: gen = Generator.from_recipe("b2b_saas_procurement_v1") assert gen.config.package_version # non-empty string