From a27d9a4b415cf7119377040f5ee20df25cf0722e Mon Sep 17 00:00:00 2001 From: Shay Palachy Date: Wed, 29 Apr 2026 08:52:31 +0300 Subject: [PATCH 1/5] =?UTF-8?q?feat:=20Milestone=2010=20=E2=80=94=20CLI=20?= =?UTF-8?q?generate/inspect/validate=20commands?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - generate: parses all flags, calls Generator.from_recipe().generate(), writes bundle via .save(); supports --override, --difficulty, population size overrides - inspect: reads manifest.json, prints summary (recipe, seed, mode, tables with row counts, task splits, metadata presence) - validate: checks manifest presence, required files, table row counts, SHA-256 hashes, task split integrity, FK constraints, leakage detection (unexpected columns in task splits) - 22 CLI tests (smoke, generate integration for both modes, inspect output assertions, validate pass/fail/corrupt/missing-file scenarios) - 562 total tests passing Co-Authored-By: Claude Opus 4.6 --- .agent-plan.md | 41 +++-- leadforge/cli/commands/generate.py | 41 ++++- leadforge/cli/commands/inspect.py | 49 +++++- leadforge/cli/commands/validate.py | 140 +++++++++++++++++- tests/test_cli.py | 230 +++++++++++++++++++++++++++-- 5 files changed, 455 insertions(+), 46 deletions(-) diff --git a/.agent-plan.md b/.agent-plan.md index ffbc469..88e8379 100644 --- a/.agent-plan.md +++ b/.agent-plan.md @@ -6,38 +6,47 @@ ## Current System State -**v0.4.0 in progress — Milestones 7–9 complete (PR open).** Full simulation engine + render/bundle -layer + exposure filtering implemented. 545 tests passing. +**v0.4.0 complete — Milestones 7–10 done.** Full simulation engine + render/bundle layer + exposure filtering + CLI commands implemented. 562 tests passing. --- -## Next Up — Milestone 10: CLI `generate` command + `inspect` / `validate` stubs (v0.4.0) +## Next Up — Milestone 11: Validation harness (v0.5.0) -Goal: Wire `leadforge generate` CLI command end-to-end; implement `inspect` and `validate` output. +Goal: Implement comprehensive bundle validation — invariant checks, realism heuristics, difficulty drift detection. -- [ ] `cli/commands/generate.py` — parse flags, call `Generator.from_recipe().generate()`, call `.save()` -- [ ] `cli/commands/inspect.py` — print manifest summary for a written bundle -- [ ] `cli/commands/validate.py` — basic schema / FK / leakage checks on a written bundle -- [ ] Tests for each command +- [ ] `validation/invariants.py` — DAG acyclicity, FK integrity, determinism, exposure monotonicity +- [ ] `validation/artifact_checks.py` — file presence, hash verification, schema conformance +- [ ] `validation/realism.py` — distributional sanity checks (conversion rates, feature ranges) +- [ ] `validation/difficulty.py` — difficulty profile adherence checks +- [ ] `validation/drift.py` — cross-seed stability / drift detection +- [ ] Wire into `cli/commands/validate.py` with richer output +- [ ] Tests for each validation module --- ## Context Pointers -- Milestone 8 scope: `docs/leadforge_implementation_plan.md` §10 "Milestone 8" -- Render layer: `leadforge/render/` (snapshots, relational, tasks, manifests) -- Bundle writer: `leadforge/api/bundle.py` +- Milestone 11 scope: `docs/leadforge_implementation_plan.md` §10 "Milestone 11" +- Current validate CLI: `leadforge/cli/commands/validate.py` (basic checks implemented in M10) +- FK constraints: `leadforge/schema/relationships.py` +- Feature spec: `leadforge/schema/features.py` --- ## Completed Phases -### Milestone 9 — Exposure Filtering ✓ (v0.4.0 in PR) -- `exposure/filters.py`: `BundleFilter` frozen dataclass; `FILTERS` dict keyed by `ExposureMode`; `get_filter()` -- `exposure/redaction.py`: `write_metadata_dir()` — writes `metadata/` with `graph.json`, `graph.graphml`, `world_spec.json`, `latent_registry.json`, `mechanism_summary.json` -- `exposure/modes.py`: `apply_exposure(bundle, root, mode)` — dispatch; skips `metadata/` for `student_public` +### Milestone 10 — CLI Commands ✓ (v0.4.0) +- `cli/commands/generate.py`: fully wired — parses all flags, calls `Generator.from_recipe().generate()`, writes bundle via `.save()` +- `cli/commands/inspect.py`: reads `manifest.json` and prints summary (recipe, seed, mode, tables with row counts, task splits, metadata presence) +- `cli/commands/validate.py`: checks manifest presence, required files, table row counts, SHA-256 hashes, task split integrity, FK constraints, leakage (unexpected columns) +- 22 CLI tests (smoke, generate integration, inspect output, validate pass/fail/corrupt/missing); total 562 passing + +### Milestone 9 — Exposure Filtering ✓ (v0.4.0) +- `exposure/filters.py`: `BundleFilter` frozen dataclass; `FILTERS` dict keyed by `ExposureMode`; `get_filter()` accepts `str | ExposureMode` +- `exposure/metadata.py`: `write_metadata_dir()` — writes `metadata/` with `graph.json`, `graph.graphml`, `world_spec.json`, `latent_registry.json`, `mechanism_summary.json` +- `exposure/modes.py`: `apply_exposure(bundle, root, mode)` — dispatch; removes stale `metadata/` for `student_public` - Wired into `api/bundle.py` between dataset card and manifest steps -- 24 new tests; total 545 passing +- 22 exposure tests; total 547 passing ### Milestone 8 — Render / Bundle Layer ✓ (v0.4.0 in PR) - `render/relational.py`: `to_dataframes()` — 9-table dict of typed DataFrames from SimulationResult + PopulationResult diff --git a/leadforge/cli/commands/generate.py b/leadforge/cli/commands/generate.py index 18bdddd..99eb5de 100644 --- a/leadforge/cli/commands/generate.py +++ b/leadforge/cli/commands/generate.py @@ -1,5 +1,9 @@ """leadforge generate command.""" +from __future__ import annotations + +from pathlib import Path + import typer @@ -28,8 +32,37 @@ def generate( ), ) -> None: """Generate a synthetic CRM dataset bundle from a recipe.""" - typer.echo( - "The 'generate' command is not yet implemented. Coming in v0.2.0.", - err=True, + from leadforge.api.generator import Generator + from leadforge.core.serialization import load_yaml + + override_dict: dict | None = None + if override is not None: + override_dict = load_yaml(Path(override)) + + gen = Generator.from_recipe( + recipe, + seed=seed, + exposure_mode=mode, + difficulty=difficulty, + n_accounts=n_accounts, + n_contacts=n_contacts, + n_leads=n_leads, + horizon_days=horizon_days, + override=override_dict, ) - raise typer.Exit(1) + + generate_kwargs: dict[str, int] = {} + if n_accounts is not None: + generate_kwargs["n_accounts"] = n_accounts + if n_contacts is not None: + generate_kwargs["n_contacts"] = n_contacts + if n_leads is not None: + generate_kwargs["n_leads"] = n_leads + + typer.echo(f"Generating bundle with recipe '{recipe}', seed={seed}, mode={mode} ...") + bundle = gen.generate(**generate_kwargs) + + typer.echo(f"Writing bundle to {out} ...") + bundle.save(out) + + typer.echo(f"Done. Bundle written to {out}") diff --git a/leadforge/cli/commands/inspect.py b/leadforge/cli/commands/inspect.py index 36d6bf9..a631c41 100644 --- a/leadforge/cli/commands/inspect.py +++ b/leadforge/cli/commands/inspect.py @@ -1,5 +1,10 @@ """leadforge inspect command.""" +from __future__ import annotations + +import json +from pathlib import Path + import typer @@ -7,8 +12,42 @@ def inspect( bundle_path: str = typer.Argument(..., help="Path to a generated bundle directory."), ) -> None: """Inspect a generated dataset bundle and print a summary.""" - typer.echo( - "The 'inspect' command is not yet implemented. Coming in v0.4.0.", - err=True, - ) - raise typer.Exit(1) + root = Path(bundle_path) + manifest_path = root / "manifest.json" + if not manifest_path.exists(): + typer.echo(f"Error: no manifest.json found in {root}", err=True) + raise typer.Exit(1) + + manifest = json.loads(manifest_path.read_text()) + + typer.echo(f"Bundle: {root}") + typer.echo(f" Recipe: {manifest.get('recipe_id', '?')}") + typer.echo(f" Seed: {manifest.get('seed', '?')}") + typer.echo(f" Mode: {manifest.get('exposure_mode', '?')}") + typer.echo(f" Difficulty: {manifest.get('difficulty', '?')}") + typer.echo(f" Horizon days: {manifest.get('horizon_days', '?')}") + typer.echo(f" Generated at: {manifest.get('generation_timestamp', '?')}") + typer.echo(f" Package: leadforge {manifest.get('package_version', '?')}") + typer.echo(f" Schema ver: {manifest.get('bundle_schema_version', '?')}") + typer.echo(f" Motif family: {manifest.get('motif_family', '?')}") + + typer.echo("") + typer.echo("Tables:") + tables = manifest.get("tables", {}) + for name, info in tables.items(): + typer.echo(f" {name:25s} {info.get('row_count', '?'):>8} rows") + + tasks = manifest.get("tasks", {}) + if tasks: + typer.echo("") + typer.echo("Tasks:") + for task_id, info in tasks.items(): + train = info.get("train_rows", "?") + valid = info.get("valid_rows", "?") + test = info.get("test_rows", "?") + typer.echo(f" {task_id}") + typer.echo(f" train={train} valid={valid} test={test}") + + has_metadata = (root / "metadata").is_dir() + typer.echo("") + typer.echo(f"Metadata dir: {'present' if has_metadata else 'absent'}") diff --git a/leadforge/cli/commands/validate.py b/leadforge/cli/commands/validate.py index 461db4a..8bf0cf1 100644 --- a/leadforge/cli/commands/validate.py +++ b/leadforge/cli/commands/validate.py @@ -1,5 +1,11 @@ """leadforge validate command.""" +from __future__ import annotations + +import hashlib +import json +from pathlib import Path + import typer @@ -7,8 +13,132 @@ def validate( bundle_path: str = typer.Argument(..., help="Path to a generated bundle directory."), ) -> None: """Run schema and artifact validation on a generated bundle.""" - typer.echo( - "The 'validate' command is not yet implemented. Coming in v0.5.0.", - err=True, - ) - raise typer.Exit(1) + root = Path(bundle_path) + errors: list[str] = [] + + # ------------------------------------------------------------------ + # 1. Manifest presence and parse + # ------------------------------------------------------------------ + manifest_path = root / "manifest.json" + if not manifest_path.exists(): + typer.echo(f"FAIL: no manifest.json in {root}", err=True) + raise typer.Exit(1) + + manifest = json.loads(manifest_path.read_text()) + + # ------------------------------------------------------------------ + # 2. Required top-level files + # ------------------------------------------------------------------ + for fname in ("dataset_card.md", "feature_dictionary.csv"): + if not (root / fname).exists(): + errors.append(f"Missing required file: {fname}") + + # ------------------------------------------------------------------ + # 3. Table files exist + row counts + SHA-256 hashes + # ------------------------------------------------------------------ + import pandas as pd + + tables: dict[str, pd.DataFrame] = {} + for table_name, info in manifest.get("tables", {}).items(): + rel_path = info.get("file", f"tables/{table_name}.parquet") + abs_path = root / rel_path + if not abs_path.exists(): + errors.append(f"Missing table file: {rel_path}") + continue + + df = pd.read_parquet(abs_path) + tables[table_name] = df + + expected_rows = info.get("row_count") + if expected_rows is not None and len(df) != expected_rows: + errors.append(f"Table {table_name}: expected {expected_rows} rows, got {len(df)}") + + expected_sha = info.get("sha256") + if expected_sha is not None: + actual_sha = _sha256(abs_path) + if actual_sha != expected_sha: + errors.append(f"Table {table_name}: SHA-256 mismatch") + + # ------------------------------------------------------------------ + # 4. Task split files exist + row counts + hashes + # ------------------------------------------------------------------ + for task_id, task_info in manifest.get("tasks", {}).items(): + for split in ("train", "valid", "test"): + rel_path = f"tasks/{task_id}/{split}.parquet" + abs_path = root / rel_path + if not abs_path.exists(): + errors.append(f"Missing task file: {rel_path}") + continue + + df = pd.read_parquet(abs_path) + expected_rows = task_info.get(f"{split}_rows") + if expected_rows is not None and len(df) != expected_rows: + errors.append( + f"Task {task_id}/{split}: expected {expected_rows} rows, got {len(df)}" + ) + + expected_sha = task_info.get(f"{split}_sha256") + if expected_sha is not None: + actual_sha = _sha256(abs_path) + if actual_sha != expected_sha: + errors.append(f"Task {task_id}/{split}: SHA-256 mismatch") + + # ------------------------------------------------------------------ + # 5. FK integrity + # ------------------------------------------------------------------ + from leadforge.schema.relationships import ALL_CONSTRAINTS + + for fk in ALL_CONSTRAINTS: + child_df = tables.get(fk.child_table) + parent_df = tables.get(fk.parent_table) + if child_df is None or parent_df is None: + continue + if fk.child_column not in child_df.columns: + continue + if fk.parent_column not in parent_df.columns: + continue + + child_vals = set(child_df[fk.child_column].dropna()) + parent_vals = set(parent_df[fk.parent_column].dropna()) + orphans = child_vals - parent_vals + if orphans: + sample = list(orphans)[:3] + errors.append( + f"FK violation: {fk.child_table}.{fk.child_column} → " + f"{fk.parent_table}.{fk.parent_column}: " + f"{len(orphans)} orphan(s), e.g. {sample}" + ) + + # ------------------------------------------------------------------ + # 6. Leakage check — no post-anchor features in task splits + # ------------------------------------------------------------------ + from leadforge.schema.features import LEAD_SNAPSHOT_FEATURES + + expected_columns = {f.name for f in LEAD_SNAPSHOT_FEATURES} + for task_id in manifest.get("tasks", {}): + train_path = root / f"tasks/{task_id}/train.parquet" + if train_path.exists(): + actual_columns = set(pd.read_parquet(train_path).columns) + extra = actual_columns - expected_columns + if extra: + errors.append(f"Task {task_id}: unexpected columns (possible leakage): {extra}") + + # ------------------------------------------------------------------ + # Report + # ------------------------------------------------------------------ + if errors: + typer.echo(f"FAIL: {len(errors)} validation error(s):", err=True) + for e in errors: + typer.echo(f" - {e}", err=True) + raise typer.Exit(1) + + typer.echo(f"OK: bundle at {root} passed all checks.") + + +def _sha256(path: Path) -> str: + """Return hex-encoded SHA-256 digest of *path*.""" + h = hashlib.sha256() + with path.open("rb") as fh: + for chunk in iter(lambda: fh.read(65536), b""): + h.update(chunk) + return h.hexdigest() diff --git a/tests/test_cli.py b/tests/test_cli.py index b97c095..c64859f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,5 +1,11 @@ -"""CLI smoke tests.""" +"""CLI tests — smoke tests + generate/inspect/validate integration.""" +from __future__ import annotations + +import json +from pathlib import Path + +import pytest from typer.testing import CliRunner from leadforge.cli.main import app @@ -7,6 +13,11 @@ runner = CliRunner() +# --------------------------------------------------------------------------- +# Smoke tests +# --------------------------------------------------------------------------- + + def test_help_exits_clean() -> None: result = runner.invoke(app, ["--help"]) assert result.exit_code == 0 @@ -28,18 +39,205 @@ def test_list_recipes_shows_v1_recipe() -> None: assert "b2b_saas_procurement_v1" in result.output -def test_generate_stub_exits_nonzero() -> None: - result = runner.invoke( - app, ["generate", "--recipe", "x", "--seed", "1", "--mode", "y", "--out", "/tmp"] - ) - assert result.exit_code != 0 - - -def test_inspect_stub_exits_nonzero() -> None: - result = runner.invoke(app, ["inspect", "/nonexistent"]) - assert result.exit_code != 0 - - -def test_validate_stub_exits_nonzero() -> None: - result = runner.invoke(app, ["validate", "/nonexistent"]) - assert result.exit_code != 0 +# --------------------------------------------------------------------------- +# Helper — generate a small bundle to a temp dir +# --------------------------------------------------------------------------- + +_GENERATE_ARGS = [ + "generate", + "--recipe", + "b2b_saas_procurement_v1", + "--seed", + "42", + "--mode", + "student_public", + "--n-leads", + "30", + "--n-accounts", + "15", + "--n-contacts", + "45", +] + + +@pytest.fixture(scope="module") +def bundle_dir(tmp_path_factory: pytest.TempPathFactory) -> Path: + """Generate a small bundle once and reuse across tests in this module.""" + out = tmp_path_factory.mktemp("bundle") + result = runner.invoke(app, [*_GENERATE_ARGS, "--out", str(out)]) + assert result.exit_code == 0, f"generate failed:\n{result.output}" + return out + + +# --------------------------------------------------------------------------- +# generate command +# --------------------------------------------------------------------------- + + +class TestGenerateCommand: + def test_exits_zero(self, bundle_dir: Path) -> None: + # bundle_dir fixture already asserts exit_code == 0 + assert (bundle_dir / "manifest.json").exists() + + def test_writes_core_files(self, bundle_dir: Path) -> None: + assert (bundle_dir / "dataset_card.md").exists() + assert (bundle_dir / "feature_dictionary.csv").exists() + assert (bundle_dir / "tables").is_dir() + assert (bundle_dir / "tasks").is_dir() + + def test_manifest_has_expected_keys(self, bundle_dir: Path) -> None: + manifest = json.loads((bundle_dir / "manifest.json").read_text()) + assert manifest["recipe_id"] == "b2b_saas_procurement_v1" + assert manifest["seed"] == 42 + assert manifest["exposure_mode"] == "student_public" + assert "tables" in manifest + assert "tasks" in manifest + + def test_no_metadata_in_student_public(self, bundle_dir: Path) -> None: + assert not (bundle_dir / "metadata").exists() + + def test_invalid_recipe_fails(self, tmp_path: Path) -> None: + result = runner.invoke( + app, + [ + "generate", + "--recipe", + "nonexistent_recipe", + "--seed", + "1", + "--mode", + "student_public", + "--out", + str(tmp_path / "bad"), + ], + ) + assert result.exit_code != 0 + + def test_invalid_mode_fails(self, tmp_path: Path) -> None: + result = runner.invoke( + app, + [ + "generate", + "--recipe", + "b2b_saas_procurement_v1", + "--seed", + "1", + "--mode", + "invalid_mode", + "--out", + str(tmp_path / "bad"), + ], + ) + assert result.exit_code != 0 + + def test_research_instructor_mode_has_metadata(self, tmp_path: Path) -> None: + result = runner.invoke( + app, + [ + "generate", + "--recipe", + "b2b_saas_procurement_v1", + "--seed", + "7", + "--mode", + "research_instructor", + "--n-leads", + "20", + "--n-accounts", + "10", + "--n-contacts", + "30", + "--out", + str(tmp_path), + ], + ) + assert result.exit_code == 0, f"generate failed:\n{result.output}" + assert (tmp_path / "metadata").is_dir() + + def test_output_message(self, tmp_path: Path) -> None: + out = tmp_path / "msg_test" + result = runner.invoke(app, [*_GENERATE_ARGS, "--out", str(out)]) + assert result.exit_code == 0 + assert "Generating bundle" in result.output + assert "Done" in result.output + + +# --------------------------------------------------------------------------- +# inspect command +# --------------------------------------------------------------------------- + + +class TestInspectCommand: + def test_exits_zero(self, bundle_dir: Path) -> None: + result = runner.invoke(app, ["inspect", str(bundle_dir)]) + assert result.exit_code == 0 + + def test_shows_recipe(self, bundle_dir: Path) -> None: + result = runner.invoke(app, ["inspect", str(bundle_dir)]) + assert "b2b_saas_procurement_v1" in result.output + + def test_shows_seed(self, bundle_dir: Path) -> None: + result = runner.invoke(app, ["inspect", str(bundle_dir)]) + assert "42" in result.output + + def test_shows_tables(self, bundle_dir: Path) -> None: + result = runner.invoke(app, ["inspect", str(bundle_dir)]) + assert "accounts" in result.output + assert "leads" in result.output + + def test_shows_tasks(self, bundle_dir: Path) -> None: + result = runner.invoke(app, ["inspect", str(bundle_dir)]) + assert "converted_within_90_days" in result.output + + def test_missing_bundle_fails(self, tmp_path: Path) -> None: + result = runner.invoke(app, ["inspect", str(tmp_path / "nonexistent")]) + assert result.exit_code != 0 + + +# --------------------------------------------------------------------------- +# validate command +# --------------------------------------------------------------------------- + + +class TestValidateCommand: + def test_valid_bundle_passes(self, bundle_dir: Path) -> None: + result = runner.invoke(app, ["validate", str(bundle_dir)]) + assert result.exit_code == 0 + assert "OK" in result.output + + def test_missing_bundle_fails(self, tmp_path: Path) -> None: + result = runner.invoke(app, ["validate", str(tmp_path / "nonexistent")]) + assert result.exit_code != 0 + + def test_corrupt_manifest_fails(self, tmp_path: Path, bundle_dir: Path) -> None: + """A bundle with a tampered row count should fail validation.""" + import shutil + + corrupt = tmp_path / "corrupt_bundle" + shutil.copytree(bundle_dir, corrupt) + + manifest = json.loads((corrupt / "manifest.json").read_text()) + # Tamper with a table row count + first_table = next(iter(manifest["tables"])) + manifest["tables"][first_table]["row_count"] = 999999 + (corrupt / "manifest.json").write_text(json.dumps(manifest, indent=2)) + + result = runner.invoke(app, ["validate", str(corrupt)]) + assert result.exit_code != 0 + assert "FAIL" in result.output + + def test_missing_table_file_fails(self, tmp_path: Path, bundle_dir: Path) -> None: + """Removing a table Parquet file should cause validation failure.""" + import shutil + + corrupt = tmp_path / "missing_table_bundle" + shutil.copytree(bundle_dir, corrupt) + + # Remove one table file + manifest = json.loads((corrupt / "manifest.json").read_text()) + first_table = next(iter(manifest["tables"])) + (corrupt / f"tables/{first_table}.parquet").unlink() + + result = runner.invoke(app, ["validate", str(corrupt)]) + assert result.exit_code != 0 + assert "FAIL" in result.output From 4fe7c7ece8419ceff303c72d0ea569bebd5f2f75 Mon Sep 17 00:00:00 2001 From: Shay Palachy Date: Wed, 29 Apr 2026 09:03:08 +0300 Subject: [PATCH 2/5] fix: resolve mypy arg-type error in generate command Pass n_accounts/n_contacts/n_leads as explicit keyword args to gen.generate() instead of **dict unpacking, which mypy could not reconcile with the typed signature. Co-Authored-By: Claude Opus 4.6 --- leadforge/cli/commands/generate.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/leadforge/cli/commands/generate.py b/leadforge/cli/commands/generate.py index 99eb5de..53bce2f 100644 --- a/leadforge/cli/commands/generate.py +++ b/leadforge/cli/commands/generate.py @@ -51,16 +51,12 @@ def generate( override=override_dict, ) - generate_kwargs: dict[str, int] = {} - if n_accounts is not None: - generate_kwargs["n_accounts"] = n_accounts - if n_contacts is not None: - generate_kwargs["n_contacts"] = n_contacts - if n_leads is not None: - generate_kwargs["n_leads"] = n_leads - typer.echo(f"Generating bundle with recipe '{recipe}', seed={seed}, mode={mode} ...") - bundle = gen.generate(**generate_kwargs) + bundle = gen.generate( + n_accounts=n_accounts, + n_contacts=n_contacts, + n_leads=n_leads, + ) typer.echo(f"Writing bundle to {out} ...") bundle.save(out) From cb0d87b1bf5e951cb439a175b4256c31084fd330 Mon Sep 17 00:00:00 2001 From: Shay Palachy Date: Wed, 29 Apr 2026 10:01:14 +0300 Subject: [PATCH 3/5] =?UTF-8?q?fix:=20address=20PR=20#16=20review=20feedba?= =?UTF-8?q?ck=20=E2=80=94=20use=20load=5Fjson,=20centralize=20file=5Fsha25?= =?UTF-8?q?6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - COPILOT-1/4: inspect.py and validate.py now use load_json() from core.serialization with LeadforgeError → clean typer.Exit(1) - COPILOT-2/3: confirmed Typer CliRunner always mixes stderr into result.output; assertions are already correct (resolved as-is) - COPILOT-5: opened #17 for Parquet metadata optimization (out-of-scope for v1 bundle sizes) - COPILOT-6: moved _sha256 to core.hashing.file_sha256(); both render/manifests.py and cli/commands/validate.py now import from there Co-Authored-By: Claude Opus 4.6 --- leadforge/cli/commands/inspect.py | 10 ++++++++-- leadforge/cli/commands/validate.py | 25 +++++++++++-------------- leadforge/core/hashing.py | 15 ++++++++++++++- leadforge/render/manifests.py | 16 ++++------------ 4 files changed, 37 insertions(+), 29 deletions(-) diff --git a/leadforge/cli/commands/inspect.py b/leadforge/cli/commands/inspect.py index a631c41..7341e65 100644 --- a/leadforge/cli/commands/inspect.py +++ b/leadforge/cli/commands/inspect.py @@ -2,11 +2,13 @@ from __future__ import annotations -import json from pathlib import Path import typer +from leadforge.core.exceptions import LeadforgeError +from leadforge.core.serialization import load_json + def inspect( bundle_path: str = typer.Argument(..., help="Path to a generated bundle directory."), @@ -18,7 +20,11 @@ def inspect( typer.echo(f"Error: no manifest.json found in {root}", err=True) raise typer.Exit(1) - manifest = json.loads(manifest_path.read_text()) + try: + manifest = load_json(manifest_path) + except LeadforgeError as exc: + typer.echo(f"Error: {exc}", err=True) + raise typer.Exit(1) from None typer.echo(f"Bundle: {root}") typer.echo(f" Recipe: {manifest.get('recipe_id', '?')}") diff --git a/leadforge/cli/commands/validate.py b/leadforge/cli/commands/validate.py index 8bf0cf1..e20a506 100644 --- a/leadforge/cli/commands/validate.py +++ b/leadforge/cli/commands/validate.py @@ -2,12 +2,14 @@ from __future__ import annotations -import hashlib -import json from pathlib import Path import typer +from leadforge.core.exceptions import LeadforgeError +from leadforge.core.hashing import file_sha256 +from leadforge.core.serialization import load_json + def validate( bundle_path: str = typer.Argument(..., help="Path to a generated bundle directory."), @@ -24,7 +26,11 @@ def validate( typer.echo(f"FAIL: no manifest.json in {root}", err=True) raise typer.Exit(1) - manifest = json.loads(manifest_path.read_text()) + try: + manifest = load_json(manifest_path) + except LeadforgeError as exc: + typer.echo(f"FAIL: {exc}", err=True) + raise typer.Exit(1) from None # ------------------------------------------------------------------ # 2. Required top-level files @@ -55,7 +61,7 @@ def validate( expected_sha = info.get("sha256") if expected_sha is not None: - actual_sha = _sha256(abs_path) + actual_sha = file_sha256(abs_path) if actual_sha != expected_sha: errors.append(f"Table {table_name}: SHA-256 mismatch") @@ -79,7 +85,7 @@ def validate( expected_sha = task_info.get(f"{split}_sha256") if expected_sha is not None: - actual_sha = _sha256(abs_path) + actual_sha = file_sha256(abs_path) if actual_sha != expected_sha: errors.append(f"Task {task_id}/{split}: SHA-256 mismatch") @@ -133,12 +139,3 @@ def validate( raise typer.Exit(1) typer.echo(f"OK: bundle at {root} passed all checks.") - - -def _sha256(path: Path) -> str: - """Return hex-encoded SHA-256 digest of *path*.""" - h = hashlib.sha256() - with path.open("rb") as fh: - for chunk in iter(lambda: fh.read(65536), b""): - h.update(chunk) - return h.hexdigest() diff --git a/leadforge/core/hashing.py b/leadforge/core/hashing.py index 680a782..ed465af 100644 --- a/leadforge/core/hashing.py +++ b/leadforge/core/hashing.py @@ -1,13 +1,17 @@ -"""Deterministic config hashing for manifest identity. +"""Deterministic config hashing and file digest helpers. A config hash uniquely identifies a (recipe, config, seed, version) tuple and is embedded in every generated manifest so that bundles can be traced back to the exact parameters that produced them. + +:func:`file_sha256` provides a reusable SHA-256 file digest used by the +manifest builder and the bundle validator. """ import hashlib import json from dataclasses import asdict +from pathlib import Path from typing import TYPE_CHECKING, Any if TYPE_CHECKING: @@ -26,6 +30,15 @@ def _canonical(obj: Any) -> Any: return obj +def file_sha256(path: Path) -> str: + """Return the hex-encoded SHA-256 digest of the file at *path*.""" + h = hashlib.sha256() + with path.open("rb") as fh: + for chunk in iter(lambda: fh.read(65536), b""): + h.update(chunk) + return h.hexdigest() + + def hash_config(config: "GenerationConfig") -> str: """Return a stable hex-encoded SHA-256 digest of *config*. diff --git a/leadforge/render/manifests.py b/leadforge/render/manifests.py index fa8dc31..6b12723 100644 --- a/leadforge/render/manifests.py +++ b/leadforge/render/manifests.py @@ -8,12 +8,13 @@ from __future__ import annotations -import hashlib import json from datetime import UTC, datetime from pathlib import Path from typing import TYPE_CHECKING, Any +from leadforge.core.hashing import file_sha256 + if TYPE_CHECKING: from leadforge.core.models import GenerationConfig from leadforge.structure.graph import WorldGraph @@ -55,7 +56,7 @@ def build_manifest( for table_name, row_count in table_row_counts.items(): rel_path = f"tables/{table_name}.parquet" abs_path = bundle_root / rel_path - sha = _sha256(abs_path) + sha = file_sha256(abs_path) tables[table_name] = {"row_count": row_count, "file": rel_path, "sha256": sha} # Build task entries. @@ -65,7 +66,7 @@ def build_manifest( for split_name, row_count in split_counts.items(): rel_path = f"tasks/{task_id}/{split_name}.parquet" abs_path = bundle_root / rel_path - sha = _sha256(abs_path) + sha = file_sha256(abs_path) entry[f"{split_name}_rows"] = row_count entry[f"{split_name}_sha256"] = sha tasks[task_id] = entry @@ -93,12 +94,3 @@ def write_manifest(manifest: dict[str, Any], bundle_root: Path) -> Path: path = bundle_root / "manifest.json" path.write_text(json.dumps(manifest, indent=2)) return path - - -def _sha256(path: Path) -> str: - """Return the hex-encoded SHA-256 digest of *path*.""" - h = hashlib.sha256() - with path.open("rb") as fh: - for chunk in iter(lambda: fh.read(65536), b""): - h.update(chunk) - return h.hexdigest() From c0e97c08596e49186bbff38670665a51aefb0e52 Mon Sep 17 00:00:00 2001 From: Shay Palachy Date: Wed, 29 Apr 2026 10:11:30 +0300 Subject: [PATCH 4/5] =?UTF-8?q?fix:=20harden=20CLI=20commands=20=E2=80=94?= =?UTF-8?q?=20extract=20validation,=20add=20error=20handling,=20expand=20t?= =?UTF-8?q?ests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses all 13 items from self-review: generate.py: 1. Remove double population-count pass (was in both from_recipe and generate) 2. Wrap from_recipe/generate/load_yaml with try/except → clean CLI errors 3. Validate --override path exists before loading inspect.py: 4. Defensive manifest parsing — handle non-dict tables/tasks entries 5. Check bundle_path is a directory, not a file validate.py → validation/bundle_checks.py: 6. FK checks now report "FK check skipped" when a table is missing instead of silently passing 7. Leakage check no longer re-reads train.parquet (reads schema only) 8. Leakage check now covers all splits (train, valid, test) 13. Extract all validation logic to validation/bundle_checks.py; CLI command is now a thin shell tests: 9. Add warning comment to module-scoped bundle_dir fixture 10. Add tests for --override (valid file + missing file) 11. Add test for --difficulty flag 12. Consolidate 5 inspect invocations into 1 test with multiple asserts + Add file-instead-of-dir tests for inspect and validate + Assert FK-skip message in missing-table validate test + New test_bundle_checks.py with 5 unit tests for validation module 568 tests passing; ruff + mypy clean. Co-Authored-By: Claude Opus 4.6 --- leadforge/cli/commands/generate.py | 49 +++++--- leadforge/cli/commands/inspect.py | 34 +++++- leadforge/cli/commands/validate.py | 119 ++----------------- leadforge/validation/bundle_checks.py | 157 +++++++++++++++++++++++++ tests/test_cli.py | 122 ++++++++++++++++--- tests/validation/__init__.py | 0 tests/validation/test_bundle_checks.py | 80 +++++++++++++ 7 files changed, 410 insertions(+), 151 deletions(-) create mode 100644 leadforge/validation/bundle_checks.py create mode 100644 tests/validation/__init__.py create mode 100644 tests/validation/test_bundle_checks.py diff --git a/leadforge/cli/commands/generate.py b/leadforge/cli/commands/generate.py index 53bce2f..e70c769 100644 --- a/leadforge/cli/commands/generate.py +++ b/leadforge/cli/commands/generate.py @@ -6,6 +6,8 @@ import typer +from leadforge.core.exceptions import LeadforgeError + def generate( recipe: str = typer.Option(..., "--recipe", "-r", help="Recipe ID to use."), @@ -37,26 +39,39 @@ def generate( override_dict: dict | None = None if override is not None: - override_dict = load_yaml(Path(override)) + override_path = Path(override) + if not override_path.exists(): + typer.echo(f"Error: override file not found: {override_path}", err=True) + raise typer.Exit(1) + try: + override_dict = load_yaml(override_path) + except LeadforgeError as exc: + typer.echo(f"Error: {exc}", err=True) + raise typer.Exit(1) from None - gen = Generator.from_recipe( - recipe, - seed=seed, - exposure_mode=mode, - difficulty=difficulty, - n_accounts=n_accounts, - n_contacts=n_contacts, - n_leads=n_leads, - horizon_days=horizon_days, - override=override_dict, - ) + try: + gen = Generator.from_recipe( + recipe, + seed=seed, + exposure_mode=mode, + difficulty=difficulty, + n_accounts=n_accounts, + n_contacts=n_contacts, + n_leads=n_leads, + horizon_days=horizon_days, + override=override_dict, + ) + except (LeadforgeError, ValueError) as exc: + typer.echo(f"Error: {exc}", err=True) + raise typer.Exit(1) from None typer.echo(f"Generating bundle with recipe '{recipe}', seed={seed}, mode={mode} ...") - bundle = gen.generate( - n_accounts=n_accounts, - n_contacts=n_contacts, - n_leads=n_leads, - ) + + try: + bundle = gen.generate() + except (LeadforgeError, RuntimeError) as exc: + typer.echo(f"Error during generation: {exc}", err=True) + raise typer.Exit(1) from None typer.echo(f"Writing bundle to {out} ...") bundle.save(out) diff --git a/leadforge/cli/commands/inspect.py b/leadforge/cli/commands/inspect.py index 7341e65..9ad5a5c 100644 --- a/leadforge/cli/commands/inspect.py +++ b/leadforge/cli/commands/inspect.py @@ -3,6 +3,7 @@ from __future__ import annotations from pathlib import Path +from typing import Any import typer @@ -15,6 +16,14 @@ def inspect( ) -> None: """Inspect a generated dataset bundle and print a summary.""" root = Path(bundle_path) + + if not root.exists(): + typer.echo(f"Error: path does not exist: {root}", err=True) + raise typer.Exit(1) + if not root.is_dir(): + typer.echo(f"Error: not a directory (expected a bundle dir): {root}", err=True) + raise typer.Exit(1) + manifest_path = root / "manifest.json" if not manifest_path.exists(): typer.echo(f"Error: no manifest.json found in {root}", err=True) @@ -26,6 +35,10 @@ def inspect( typer.echo(f"Error: {exc}", err=True) raise typer.Exit(1) from None + if not isinstance(manifest, dict): + typer.echo("Error: manifest.json is not a JSON object", err=True) + raise typer.Exit(1) + typer.echo(f"Bundle: {root}") typer.echo(f" Recipe: {manifest.get('recipe_id', '?')}") typer.echo(f" Seed: {manifest.get('seed', '?')}") @@ -40,20 +53,29 @@ def inspect( typer.echo("") typer.echo("Tables:") tables = manifest.get("tables", {}) - for name, info in tables.items(): - typer.echo(f" {name:25s} {info.get('row_count', '?'):>8} rows") + if isinstance(tables, dict): + for name, info in tables.items(): + row_count = _safe_get(info, "row_count", "?") + typer.echo(f" {name:25s} {row_count:>8} rows") tasks = manifest.get("tasks", {}) - if tasks: + if isinstance(tasks, dict) and tasks: typer.echo("") typer.echo("Tasks:") for task_id, info in tasks.items(): - train = info.get("train_rows", "?") - valid = info.get("valid_rows", "?") - test = info.get("test_rows", "?") + train = _safe_get(info, "train_rows", "?") + valid = _safe_get(info, "valid_rows", "?") + test = _safe_get(info, "test_rows", "?") typer.echo(f" {task_id}") typer.echo(f" train={train} valid={valid} test={test}") has_metadata = (root / "metadata").is_dir() typer.echo("") typer.echo(f"Metadata dir: {'present' if has_metadata else 'absent'}") + + +def _safe_get(obj: Any, key: str, default: str = "?") -> Any: + """Get a key from *obj* if it's a dict, else return *default*.""" + if isinstance(obj, dict): + return obj.get(key, default) + return default diff --git a/leadforge/cli/commands/validate.py b/leadforge/cli/commands/validate.py index e20a506..6912807 100644 --- a/leadforge/cli/commands/validate.py +++ b/leadforge/cli/commands/validate.py @@ -7,131 +7,32 @@ import typer from leadforge.core.exceptions import LeadforgeError -from leadforge.core.hashing import file_sha256 -from leadforge.core.serialization import load_json def validate( bundle_path: str = typer.Argument(..., help="Path to a generated bundle directory."), ) -> None: """Run schema and artifact validation on a generated bundle.""" + from leadforge.validation.bundle_checks import validate_bundle + root = Path(bundle_path) - errors: list[str] = [] - # ------------------------------------------------------------------ - # 1. Manifest presence and parse - # ------------------------------------------------------------------ - manifest_path = root / "manifest.json" - if not manifest_path.exists(): + if not root.exists(): + typer.echo(f"FAIL: path does not exist: {root}", err=True) + raise typer.Exit(1) + if not root.is_dir(): + typer.echo(f"FAIL: not a directory: {root}", err=True) + raise typer.Exit(1) + if not (root / "manifest.json").exists(): typer.echo(f"FAIL: no manifest.json in {root}", err=True) raise typer.Exit(1) try: - manifest = load_json(manifest_path) + errors = validate_bundle(root) except LeadforgeError as exc: typer.echo(f"FAIL: {exc}", err=True) raise typer.Exit(1) from None - # ------------------------------------------------------------------ - # 2. Required top-level files - # ------------------------------------------------------------------ - for fname in ("dataset_card.md", "feature_dictionary.csv"): - if not (root / fname).exists(): - errors.append(f"Missing required file: {fname}") - - # ------------------------------------------------------------------ - # 3. Table files exist + row counts + SHA-256 hashes - # ------------------------------------------------------------------ - import pandas as pd - - tables: dict[str, pd.DataFrame] = {} - for table_name, info in manifest.get("tables", {}).items(): - rel_path = info.get("file", f"tables/{table_name}.parquet") - abs_path = root / rel_path - if not abs_path.exists(): - errors.append(f"Missing table file: {rel_path}") - continue - - df = pd.read_parquet(abs_path) - tables[table_name] = df - - expected_rows = info.get("row_count") - if expected_rows is not None and len(df) != expected_rows: - errors.append(f"Table {table_name}: expected {expected_rows} rows, got {len(df)}") - - expected_sha = info.get("sha256") - if expected_sha is not None: - actual_sha = file_sha256(abs_path) - if actual_sha != expected_sha: - errors.append(f"Table {table_name}: SHA-256 mismatch") - - # ------------------------------------------------------------------ - # 4. Task split files exist + row counts + hashes - # ------------------------------------------------------------------ - for task_id, task_info in manifest.get("tasks", {}).items(): - for split in ("train", "valid", "test"): - rel_path = f"tasks/{task_id}/{split}.parquet" - abs_path = root / rel_path - if not abs_path.exists(): - errors.append(f"Missing task file: {rel_path}") - continue - - df = pd.read_parquet(abs_path) - expected_rows = task_info.get(f"{split}_rows") - if expected_rows is not None and len(df) != expected_rows: - errors.append( - f"Task {task_id}/{split}: expected {expected_rows} rows, got {len(df)}" - ) - - expected_sha = task_info.get(f"{split}_sha256") - if expected_sha is not None: - actual_sha = file_sha256(abs_path) - if actual_sha != expected_sha: - errors.append(f"Task {task_id}/{split}: SHA-256 mismatch") - - # ------------------------------------------------------------------ - # 5. FK integrity - # ------------------------------------------------------------------ - from leadforge.schema.relationships import ALL_CONSTRAINTS - - for fk in ALL_CONSTRAINTS: - child_df = tables.get(fk.child_table) - parent_df = tables.get(fk.parent_table) - if child_df is None or parent_df is None: - continue - if fk.child_column not in child_df.columns: - continue - if fk.parent_column not in parent_df.columns: - continue - - child_vals = set(child_df[fk.child_column].dropna()) - parent_vals = set(parent_df[fk.parent_column].dropna()) - orphans = child_vals - parent_vals - if orphans: - sample = list(orphans)[:3] - errors.append( - f"FK violation: {fk.child_table}.{fk.child_column} → " - f"{fk.parent_table}.{fk.parent_column}: " - f"{len(orphans)} orphan(s), e.g. {sample}" - ) - - # ------------------------------------------------------------------ - # 6. Leakage check — no post-anchor features in task splits - # ------------------------------------------------------------------ - from leadforge.schema.features import LEAD_SNAPSHOT_FEATURES - - expected_columns = {f.name for f in LEAD_SNAPSHOT_FEATURES} - for task_id in manifest.get("tasks", {}): - train_path = root / f"tasks/{task_id}/train.parquet" - if train_path.exists(): - actual_columns = set(pd.read_parquet(train_path).columns) - extra = actual_columns - expected_columns - if extra: - errors.append(f"Task {task_id}: unexpected columns (possible leakage): {extra}") - - # ------------------------------------------------------------------ - # Report - # ------------------------------------------------------------------ if errors: typer.echo(f"FAIL: {len(errors)} validation error(s):", err=True) for e in errors: diff --git a/leadforge/validation/bundle_checks.py b/leadforge/validation/bundle_checks.py new file mode 100644 index 0000000..aa458ca --- /dev/null +++ b/leadforge/validation/bundle_checks.py @@ -0,0 +1,157 @@ +"""Bundle validation logic. + +:func:`validate_bundle` performs all structural, integrity, FK, and leakage +checks on a written bundle directory. It returns a list of human-readable +error strings (empty = pass). The CLI ``validate`` command is a thin shell +around this function. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import pandas as pd + +from leadforge.core.hashing import file_sha256 +from leadforge.core.serialization import load_json +from leadforge.schema.features import LEAD_SNAPSHOT_FEATURES +from leadforge.schema.relationships import ALL_CONSTRAINTS + + +def validate_bundle(bundle_root: Path) -> list[str]: + """Run all validation checks on the bundle at *bundle_root*. + + Returns: + A list of error strings. An empty list means the bundle is valid. + + Raises: + FileNotFoundError: if ``manifest.json`` does not exist. + ``LeadforgeError``: if ``manifest.json`` is corrupt / unparseable. + """ + manifest = load_json(bundle_root / "manifest.json") + errors: list[str] = [] + errors.extend(_check_required_files(bundle_root)) + tables, table_errors = _check_tables(bundle_root, manifest) + errors.extend(table_errors) + errors.extend(_check_task_splits(bundle_root, manifest)) + errors.extend(_check_fk_integrity(tables)) + errors.extend(_check_leakage(bundle_root, manifest)) + return errors + + +# ------------------------------------------------------------------ +# Internal check functions +# ------------------------------------------------------------------ + + +def _check_required_files(root: Path) -> list[str]: + errors: list[str] = [] + for fname in ("dataset_card.md", "feature_dictionary.csv"): + if not (root / fname).exists(): + errors.append(f"Missing required file: {fname}") + return errors + + +def _check_tables( + root: Path, manifest: dict[str, Any] +) -> tuple[dict[str, pd.DataFrame], list[str]]: + """Validate table files. Returns loaded DataFrames and errors.""" + errors: list[str] = [] + tables: dict[str, pd.DataFrame] = {} + for table_name, info in manifest.get("tables", {}).items(): + rel_path = info.get("file", f"tables/{table_name}.parquet") + abs_path = root / rel_path + if not abs_path.exists(): + errors.append(f"Missing table file: {rel_path}") + continue + + df = pd.read_parquet(abs_path) + tables[table_name] = df + + expected_rows = info.get("row_count") + if expected_rows is not None and len(df) != expected_rows: + errors.append(f"Table {table_name}: expected {expected_rows} rows, got {len(df)}") + + expected_sha = info.get("sha256") + if expected_sha is not None: + actual_sha = file_sha256(abs_path) + if actual_sha != expected_sha: + errors.append(f"Table {table_name}: SHA-256 mismatch") + + return tables, errors + + +def _check_task_splits(root: Path, manifest: dict[str, Any]) -> list[str]: + errors: list[str] = [] + for task_id, task_info in manifest.get("tasks", {}).items(): + for split in ("train", "valid", "test"): + rel_path = f"tasks/{task_id}/{split}.parquet" + abs_path = root / rel_path + if not abs_path.exists(): + errors.append(f"Missing task file: {rel_path}") + continue + + df = pd.read_parquet(abs_path) + expected_rows = task_info.get(f"{split}_rows") + if expected_rows is not None and len(df) != expected_rows: + errors.append( + f"Task {task_id}/{split}: expected {expected_rows} rows, got {len(df)}" + ) + + expected_sha = task_info.get(f"{split}_sha256") + if expected_sha is not None: + actual_sha = file_sha256(abs_path) + if actual_sha != expected_sha: + errors.append(f"Task {task_id}/{split}: SHA-256 mismatch") + + return errors + + +def _check_fk_integrity(tables: dict[str, pd.DataFrame]) -> list[str]: + errors: list[str] = [] + for fk in ALL_CONSTRAINTS: + child_df = tables.get(fk.child_table) + parent_df = tables.get(fk.parent_table) + if child_df is None or parent_df is None: + missing = fk.child_table if child_df is None else fk.parent_table + errors.append( + f"FK check skipped: {fk.child_table}.{fk.child_column} → " + f"{fk.parent_table}.{fk.parent_column} " + f"(table '{missing}' not loaded)" + ) + continue + if fk.child_column not in child_df.columns: + continue + if fk.parent_column not in parent_df.columns: + continue + + child_vals = set(child_df[fk.child_column].dropna()) + parent_vals = set(parent_df[fk.parent_column].dropna()) + orphans = child_vals - parent_vals + if orphans: + sample = list(orphans)[:3] + errors.append( + f"FK violation: {fk.child_table}.{fk.child_column} → " + f"{fk.parent_table}.{fk.parent_column}: " + f"{len(orphans)} orphan(s), e.g. {sample}" + ) + + return errors + + +def _check_leakage(root: Path, manifest: dict[str, Any]) -> list[str]: + """Check all task splits for unexpected columns.""" + errors: list[str] = [] + expected_columns = {f.name for f in LEAD_SNAPSHOT_FEATURES} + for task_id in manifest.get("tasks", {}): + for split in ("train", "valid", "test"): + split_path = root / f"tasks/{task_id}/{split}.parquet" + if split_path.exists(): + actual_columns = set(pd.read_parquet(split_path, columns=[]).columns) + extra = actual_columns - expected_columns + if extra: + errors.append( + f"Task {task_id}/{split}: unexpected columns (possible leakage): {extra}" + ) + return errors diff --git a/tests/test_cli.py b/tests/test_cli.py index c64859f..b8438cc 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -62,7 +62,13 @@ def test_list_recipes_shows_v1_recipe() -> None: @pytest.fixture(scope="module") def bundle_dir(tmp_path_factory: pytest.TempPathFactory) -> Path: - """Generate a small bundle once and reuse across tests in this module.""" + """Generate a small bundle once and reuse across tests in this module. + + WARNING: this fixture is module-scoped for performance (avoids re-running + the full generate pipeline per test). Tests MUST NOT mutate this directory. + Tests that need to tamper with bundle contents should ``shutil.copytree`` + into their own ``tmp_path`` first (see ``TestValidateCommand``). + """ out = tmp_path_factory.mktemp("bundle") result = runner.invoke(app, [*_GENERATE_ARGS, "--out", str(out)]) assert result.exit_code == 0, f"generate failed:\n{result.output}" @@ -112,6 +118,7 @@ def test_invalid_recipe_fails(self, tmp_path: Path) -> None: ], ) assert result.exit_code != 0 + assert "Error" in result.output def test_invalid_mode_fails(self, tmp_path: Path) -> None: result = runner.invoke( @@ -129,6 +136,7 @@ def test_invalid_mode_fails(self, tmp_path: Path) -> None: ], ) assert result.exit_code != 0 + assert "Error" in result.output def test_research_instructor_mode_has_metadata(self, tmp_path: Path) -> None: result = runner.invoke( @@ -161,6 +169,78 @@ def test_output_message(self, tmp_path: Path) -> None: assert "Generating bundle" in result.output assert "Done" in result.output + def test_override_flag(self, tmp_path: Path) -> None: + """--override with a valid YAML file should work.""" + override_file = tmp_path / "override.yaml" + override_file.write_text("n_leads: 25\n") + out = tmp_path / "override_out" + result = runner.invoke( + app, + [ + "generate", + "--recipe", + "b2b_saas_procurement_v1", + "--seed", + "1", + "--mode", + "student_public", + "--override", + str(override_file), + "--out", + str(out), + ], + ) + assert result.exit_code == 0, f"generate failed:\n{result.output}" + assert (out / "manifest.json").exists() + + def test_override_missing_file_fails(self, tmp_path: Path) -> None: + result = runner.invoke( + app, + [ + "generate", + "--recipe", + "b2b_saas_procurement_v1", + "--seed", + "1", + "--mode", + "student_public", + "--override", + str(tmp_path / "nope.yaml"), + "--out", + str(tmp_path / "out"), + ], + ) + assert result.exit_code != 0 + assert "not found" in result.output + + def test_difficulty_flag(self, tmp_path: Path) -> None: + out = tmp_path / "diff_out" + result = runner.invoke( + app, + [ + "generate", + "--recipe", + "b2b_saas_procurement_v1", + "--seed", + "1", + "--mode", + "student_public", + "--difficulty", + "intro", + "--n-leads", + "20", + "--n-accounts", + "10", + "--n-contacts", + "30", + "--out", + str(out), + ], + ) + assert result.exit_code == 0, f"generate failed:\n{result.output}" + manifest = json.loads((out / "manifest.json").read_text()) + assert manifest["difficulty"] == "intro" + # --------------------------------------------------------------------------- # inspect command @@ -168,31 +248,28 @@ def test_output_message(self, tmp_path: Path) -> None: class TestInspectCommand: - def test_exits_zero(self, bundle_dir: Path) -> None: + def test_inspect_output(self, bundle_dir: Path) -> None: + """Single invocation, multiple assertions.""" result = runner.invoke(app, ["inspect", str(bundle_dir)]) assert result.exit_code == 0 - - def test_shows_recipe(self, bundle_dir: Path) -> None: - result = runner.invoke(app, ["inspect", str(bundle_dir)]) - assert "b2b_saas_procurement_v1" in result.output - - def test_shows_seed(self, bundle_dir: Path) -> None: - result = runner.invoke(app, ["inspect", str(bundle_dir)]) - assert "42" in result.output - - def test_shows_tables(self, bundle_dir: Path) -> None: - result = runner.invoke(app, ["inspect", str(bundle_dir)]) - assert "accounts" in result.output - assert "leads" in result.output - - def test_shows_tasks(self, bundle_dir: Path) -> None: - result = runner.invoke(app, ["inspect", str(bundle_dir)]) - assert "converted_within_90_days" in result.output + output = result.output + assert "b2b_saas_procurement_v1" in output + assert "42" in output + assert "accounts" in output + assert "leads" in output + assert "converted_within_90_days" in output + assert "Metadata dir:" in output def test_missing_bundle_fails(self, tmp_path: Path) -> None: result = runner.invoke(app, ["inspect", str(tmp_path / "nonexistent")]) assert result.exit_code != 0 + def test_file_instead_of_dir_fails(self, bundle_dir: Path) -> None: + """Passing a file path instead of a directory should error clearly.""" + result = runner.invoke(app, ["inspect", str(bundle_dir / "manifest.json")]) + assert result.exit_code != 0 + assert "not a directory" in result.output + # --------------------------------------------------------------------------- # validate command @@ -209,6 +286,11 @@ def test_missing_bundle_fails(self, tmp_path: Path) -> None: result = runner.invoke(app, ["validate", str(tmp_path / "nonexistent")]) assert result.exit_code != 0 + def test_file_instead_of_dir_fails(self, bundle_dir: Path) -> None: + result = runner.invoke(app, ["validate", str(bundle_dir / "manifest.json")]) + assert result.exit_code != 0 + assert "not a directory" in result.output + def test_corrupt_manifest_fails(self, tmp_path: Path, bundle_dir: Path) -> None: """A bundle with a tampered row count should fail validation.""" import shutil @@ -241,3 +323,5 @@ def test_missing_table_file_fails(self, tmp_path: Path, bundle_dir: Path) -> Non result = runner.invoke(app, ["validate", str(corrupt)]) assert result.exit_code != 0 assert "FAIL" in result.output + # Should also report skipped FK checks for the missing table + assert "FK check skipped" in result.output diff --git a/tests/validation/__init__.py b/tests/validation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/validation/test_bundle_checks.py b/tests/validation/test_bundle_checks.py new file mode 100644 index 0000000..27c820f --- /dev/null +++ b/tests/validation/test_bundle_checks.py @@ -0,0 +1,80 @@ +"""Tests for leadforge.validation.bundle_checks.""" + +from __future__ import annotations + +import json +import shutil +from pathlib import Path + +import pytest + +from leadforge.api.generator import Generator +from leadforge.validation.bundle_checks import validate_bundle + +# --------------------------------------------------------------------------- +# Fixture — generate a small bundle once +# --------------------------------------------------------------------------- + +_SMALL = {"n_leads": 20, "n_accounts": 10, "n_contacts": 30} + + +@pytest.fixture(scope="module") +def valid_bundle(tmp_path_factory: pytest.TempPathFactory) -> Path: + """Generate a valid bundle for reuse. Do not mutate.""" + out = tmp_path_factory.mktemp("valid_bundle") + gen = Generator.from_recipe("b2b_saas_procurement_v1", seed=99, exposure_mode="student_public") + gen.generate(**_SMALL).save(str(out)) + return out + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestValidBundle: + def test_passes(self, valid_bundle: Path) -> None: + assert validate_bundle(valid_bundle) == [] + + +class TestCorruptBundle: + def test_row_count_mismatch(self, tmp_path: Path, valid_bundle: Path) -> None: + corrupt = tmp_path / "bad" + shutil.copytree(valid_bundle, corrupt) + manifest = json.loads((corrupt / "manifest.json").read_text()) + first_table = next(iter(manifest["tables"])) + manifest["tables"][first_table]["row_count"] = 999999 + (corrupt / "manifest.json").write_text(json.dumps(manifest, indent=2)) + + errors = validate_bundle(corrupt) + assert any("expected 999999 rows" in e for e in errors) + + def test_missing_table_reports_fk_skip(self, tmp_path: Path, valid_bundle: Path) -> None: + corrupt = tmp_path / "missing" + shutil.copytree(valid_bundle, corrupt) + manifest = json.loads((corrupt / "manifest.json").read_text()) + first_table = next(iter(manifest["tables"])) + (corrupt / f"tables/{first_table}.parquet").unlink() + + errors = validate_bundle(corrupt) + assert any("Missing table file" in e for e in errors) + assert any("FK check skipped" in e for e in errors) + + def test_sha256_mismatch(self, tmp_path: Path, valid_bundle: Path) -> None: + corrupt = tmp_path / "sha" + shutil.copytree(valid_bundle, corrupt) + manifest = json.loads((corrupt / "manifest.json").read_text()) + first_table = next(iter(manifest["tables"])) + manifest["tables"][first_table]["sha256"] = "0" * 64 + (corrupt / "manifest.json").write_text(json.dumps(manifest, indent=2)) + + errors = validate_bundle(corrupt) + assert any("SHA-256 mismatch" in e for e in errors) + + def test_missing_required_file(self, tmp_path: Path, valid_bundle: Path) -> None: + corrupt = tmp_path / "nocard" + shutil.copytree(valid_bundle, corrupt) + (corrupt / "dataset_card.md").unlink() + + errors = validate_bundle(corrupt) + assert any("dataset_card.md" in e for e in errors) From 99ff9ecd9c4e95a911ab7b986e05e424f07d6711 Mon Sep 17 00:00:00 2001 From: Shay Palachy Date: Wed, 29 Apr 2026 10:15:03 +0300 Subject: [PATCH 5/5] fix: validate override YAML type, guard malformed manifest tables/tasks - generate: reject --override files that don't contain a YAML mapping - bundle_checks: guard tables/tasks manifest fields with isinstance checks; report "Malformed manifest" instead of crashing on non-dict Co-Authored-By: Claude Opus 4.6 --- leadforge/cli/commands/generate.py | 9 ++++++++- leadforge/validation/bundle_checks.py | 17 ++++++++++++++--- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/leadforge/cli/commands/generate.py b/leadforge/cli/commands/generate.py index e70c769..cd5eeb1 100644 --- a/leadforge/cli/commands/generate.py +++ b/leadforge/cli/commands/generate.py @@ -44,10 +44,17 @@ def generate( typer.echo(f"Error: override file not found: {override_path}", err=True) raise typer.Exit(1) try: - override_dict = load_yaml(override_path) + loaded = load_yaml(override_path) except LeadforgeError as exc: typer.echo(f"Error: {exc}", err=True) raise typer.Exit(1) from None + if loaded is not None and not isinstance(loaded, dict): + typer.echo( + "Error: override file must contain a YAML mapping at the top level.", + err=True, + ) + raise typer.Exit(1) + override_dict = loaded try: gen = Generator.from_recipe( diff --git a/leadforge/validation/bundle_checks.py b/leadforge/validation/bundle_checks.py index aa458ca..b22d927 100644 --- a/leadforge/validation/bundle_checks.py +++ b/leadforge/validation/bundle_checks.py @@ -59,7 +59,11 @@ def _check_tables( """Validate table files. Returns loaded DataFrames and errors.""" errors: list[str] = [] tables: dict[str, pd.DataFrame] = {} - for table_name, info in manifest.get("tables", {}).items(): + raw_tables = manifest.get("tables", {}) + if not isinstance(raw_tables, dict): + errors.append("Malformed manifest: 'tables' is not a JSON object") + return tables, errors + for table_name, info in raw_tables.items(): rel_path = info.get("file", f"tables/{table_name}.parquet") abs_path = root / rel_path if not abs_path.exists(): @@ -84,7 +88,11 @@ def _check_tables( def _check_task_splits(root: Path, manifest: dict[str, Any]) -> list[str]: errors: list[str] = [] - for task_id, task_info in manifest.get("tasks", {}).items(): + raw_tasks = manifest.get("tasks", {}) + if not isinstance(raw_tasks, dict): + errors.append("Malformed manifest: 'tasks' is not a JSON object") + return errors + for task_id, task_info in raw_tasks.items(): for split in ("train", "valid", "test"): rel_path = f"tasks/{task_id}/{split}.parquet" abs_path = root / rel_path @@ -143,8 +151,11 @@ def _check_fk_integrity(tables: dict[str, pd.DataFrame]) -> list[str]: def _check_leakage(root: Path, manifest: dict[str, Any]) -> list[str]: """Check all task splits for unexpected columns.""" errors: list[str] = [] + raw_tasks = manifest.get("tasks", {}) + if not isinstance(raw_tasks, dict): + return errors expected_columns = {f.name for f in LEAD_SNAPSHOT_FEATURES} - for task_id in manifest.get("tasks", {}): + for task_id in raw_tasks: for split in ("train", "valid", "test"): split_path = root / f"tasks/{task_id}/{split}.parquet" if split_path.exists():