diff --git a/.agent-plan.md b/.agent-plan.md index ffbc469..88e8379 100644 --- a/.agent-plan.md +++ b/.agent-plan.md @@ -6,38 +6,47 @@ ## Current System State -**v0.4.0 in progress — Milestones 7–9 complete (PR open).** Full simulation engine + render/bundle -layer + exposure filtering implemented. 545 tests passing. +**v0.4.0 complete — Milestones 7–10 done.** Full simulation engine + render/bundle layer + exposure filtering + CLI commands implemented. 562 tests passing. --- -## Next Up — Milestone 10: CLI `generate` command + `inspect` / `validate` stubs (v0.4.0) +## Next Up — Milestone 11: Validation harness (v0.5.0) -Goal: Wire `leadforge generate` CLI command end-to-end; implement `inspect` and `validate` output. +Goal: Implement comprehensive bundle validation — invariant checks, realism heuristics, difficulty drift detection. -- [ ] `cli/commands/generate.py` — parse flags, call `Generator.from_recipe().generate()`, call `.save()` -- [ ] `cli/commands/inspect.py` — print manifest summary for a written bundle -- [ ] `cli/commands/validate.py` — basic schema / FK / leakage checks on a written bundle -- [ ] Tests for each command +- [ ] `validation/invariants.py` — DAG acyclicity, FK integrity, determinism, exposure monotonicity +- [ ] `validation/artifact_checks.py` — file presence, hash verification, schema conformance +- [ ] `validation/realism.py` — distributional sanity checks (conversion rates, feature ranges) +- [ ] `validation/difficulty.py` — difficulty profile adherence checks +- [ ] `validation/drift.py` — cross-seed stability / drift detection +- [ ] Wire into `cli/commands/validate.py` with richer output +- [ ] Tests for each validation module --- ## Context Pointers -- Milestone 8 scope: `docs/leadforge_implementation_plan.md` §10 "Milestone 8" -- Render layer: `leadforge/render/` (snapshots, relational, tasks, manifests) -- Bundle writer: `leadforge/api/bundle.py` +- Milestone 11 scope: `docs/leadforge_implementation_plan.md` §10 "Milestone 11" +- Current validate CLI: `leadforge/cli/commands/validate.py` (basic checks implemented in M10) +- FK constraints: `leadforge/schema/relationships.py` +- Feature spec: `leadforge/schema/features.py` --- ## Completed Phases -### Milestone 9 — Exposure Filtering ✓ (v0.4.0 in PR) -- `exposure/filters.py`: `BundleFilter` frozen dataclass; `FILTERS` dict keyed by `ExposureMode`; `get_filter()` -- `exposure/redaction.py`: `write_metadata_dir()` — writes `metadata/` with `graph.json`, `graph.graphml`, `world_spec.json`, `latent_registry.json`, `mechanism_summary.json` -- `exposure/modes.py`: `apply_exposure(bundle, root, mode)` — dispatch; skips `metadata/` for `student_public` +### Milestone 10 — CLI Commands ✓ (v0.4.0) +- `cli/commands/generate.py`: fully wired — parses all flags, calls `Generator.from_recipe().generate()`, writes bundle via `.save()` +- `cli/commands/inspect.py`: reads `manifest.json` and prints summary (recipe, seed, mode, tables with row counts, task splits, metadata presence) +- `cli/commands/validate.py`: checks manifest presence, required files, table row counts, SHA-256 hashes, task split integrity, FK constraints, leakage (unexpected columns) +- 22 CLI tests (smoke, generate integration, inspect output, validate pass/fail/corrupt/missing); total 562 passing + +### Milestone 9 — Exposure Filtering ✓ (v0.4.0) +- `exposure/filters.py`: `BundleFilter` frozen dataclass; `FILTERS` dict keyed by `ExposureMode`; `get_filter()` accepts `str | ExposureMode` +- `exposure/metadata.py`: `write_metadata_dir()` — writes `metadata/` with `graph.json`, `graph.graphml`, `world_spec.json`, `latent_registry.json`, `mechanism_summary.json` +- `exposure/modes.py`: `apply_exposure(bundle, root, mode)` — dispatch; removes stale `metadata/` for `student_public` - Wired into `api/bundle.py` between dataset card and manifest steps -- 24 new tests; total 545 passing +- 22 exposure tests; total 547 passing ### Milestone 8 — Render / Bundle Layer ✓ (v0.4.0 in PR) - `render/relational.py`: `to_dataframes()` — 9-table dict of typed DataFrames from SimulationResult + PopulationResult diff --git a/leadforge/cli/commands/generate.py b/leadforge/cli/commands/generate.py index 18bdddd..cd5eeb1 100644 --- a/leadforge/cli/commands/generate.py +++ b/leadforge/cli/commands/generate.py @@ -1,7 +1,13 @@ """leadforge generate command.""" +from __future__ import annotations + +from pathlib import Path + import typer +from leadforge.core.exceptions import LeadforgeError + def generate( recipe: str = typer.Option(..., "--recipe", "-r", help="Recipe ID to use."), @@ -28,8 +34,53 @@ def generate( ), ) -> None: """Generate a synthetic CRM dataset bundle from a recipe.""" - typer.echo( - "The 'generate' command is not yet implemented. Coming in v0.2.0.", - err=True, - ) - raise typer.Exit(1) + from leadforge.api.generator import Generator + from leadforge.core.serialization import load_yaml + + override_dict: dict | None = None + if override is not None: + override_path = Path(override) + if not override_path.exists(): + typer.echo(f"Error: override file not found: {override_path}", err=True) + raise typer.Exit(1) + try: + loaded = load_yaml(override_path) + except LeadforgeError as exc: + typer.echo(f"Error: {exc}", err=True) + raise typer.Exit(1) from None + if loaded is not None and not isinstance(loaded, dict): + typer.echo( + "Error: override file must contain a YAML mapping at the top level.", + err=True, + ) + raise typer.Exit(1) + override_dict = loaded + + try: + gen = Generator.from_recipe( + recipe, + seed=seed, + exposure_mode=mode, + difficulty=difficulty, + n_accounts=n_accounts, + n_contacts=n_contacts, + n_leads=n_leads, + horizon_days=horizon_days, + override=override_dict, + ) + except (LeadforgeError, ValueError) as exc: + typer.echo(f"Error: {exc}", err=True) + raise typer.Exit(1) from None + + typer.echo(f"Generating bundle with recipe '{recipe}', seed={seed}, mode={mode} ...") + + try: + bundle = gen.generate() + except (LeadforgeError, RuntimeError) as exc: + typer.echo(f"Error during generation: {exc}", err=True) + raise typer.Exit(1) from None + + typer.echo(f"Writing bundle to {out} ...") + bundle.save(out) + + typer.echo(f"Done. Bundle written to {out}") diff --git a/leadforge/cli/commands/inspect.py b/leadforge/cli/commands/inspect.py index 36d6bf9..9ad5a5c 100644 --- a/leadforge/cli/commands/inspect.py +++ b/leadforge/cli/commands/inspect.py @@ -1,14 +1,81 @@ """leadforge inspect command.""" +from __future__ import annotations + +from pathlib import Path +from typing import Any + import typer +from leadforge.core.exceptions import LeadforgeError +from leadforge.core.serialization import load_json + def inspect( bundle_path: str = typer.Argument(..., help="Path to a generated bundle directory."), ) -> None: """Inspect a generated dataset bundle and print a summary.""" - typer.echo( - "The 'inspect' command is not yet implemented. Coming in v0.4.0.", - err=True, - ) - raise typer.Exit(1) + root = Path(bundle_path) + + if not root.exists(): + typer.echo(f"Error: path does not exist: {root}", err=True) + raise typer.Exit(1) + if not root.is_dir(): + typer.echo(f"Error: not a directory (expected a bundle dir): {root}", err=True) + raise typer.Exit(1) + + manifest_path = root / "manifest.json" + if not manifest_path.exists(): + typer.echo(f"Error: no manifest.json found in {root}", err=True) + raise typer.Exit(1) + + try: + manifest = load_json(manifest_path) + except LeadforgeError as exc: + typer.echo(f"Error: {exc}", err=True) + raise typer.Exit(1) from None + + if not isinstance(manifest, dict): + typer.echo("Error: manifest.json is not a JSON object", err=True) + raise typer.Exit(1) + + typer.echo(f"Bundle: {root}") + typer.echo(f" Recipe: {manifest.get('recipe_id', '?')}") + typer.echo(f" Seed: {manifest.get('seed', '?')}") + typer.echo(f" Mode: {manifest.get('exposure_mode', '?')}") + typer.echo(f" Difficulty: {manifest.get('difficulty', '?')}") + typer.echo(f" Horizon days: {manifest.get('horizon_days', '?')}") + typer.echo(f" Generated at: {manifest.get('generation_timestamp', '?')}") + typer.echo(f" Package: leadforge {manifest.get('package_version', '?')}") + typer.echo(f" Schema ver: {manifest.get('bundle_schema_version', '?')}") + typer.echo(f" Motif family: {manifest.get('motif_family', '?')}") + + typer.echo("") + typer.echo("Tables:") + tables = manifest.get("tables", {}) + if isinstance(tables, dict): + for name, info in tables.items(): + row_count = _safe_get(info, "row_count", "?") + typer.echo(f" {name:25s} {row_count:>8} rows") + + tasks = manifest.get("tasks", {}) + if isinstance(tasks, dict) and tasks: + typer.echo("") + typer.echo("Tasks:") + for task_id, info in tasks.items(): + train = _safe_get(info, "train_rows", "?") + valid = _safe_get(info, "valid_rows", "?") + test = _safe_get(info, "test_rows", "?") + typer.echo(f" {task_id}") + typer.echo(f" train={train} valid={valid} test={test}") + + has_metadata = (root / "metadata").is_dir() + typer.echo("") + typer.echo(f"Metadata dir: {'present' if has_metadata else 'absent'}") + + +def _safe_get(obj: Any, key: str, default: str = "?") -> Any: + """Get a key from *obj* if it's a dict, else return *default*.""" + if isinstance(obj, dict): + return obj.get(key, default) + return default diff --git a/leadforge/cli/commands/validate.py b/leadforge/cli/commands/validate.py index 461db4a..6912807 100644 --- a/leadforge/cli/commands/validate.py +++ b/leadforge/cli/commands/validate.py @@ -1,14 +1,42 @@ """leadforge validate command.""" +from __future__ import annotations + +from pathlib import Path + import typer +from leadforge.core.exceptions import LeadforgeError + def validate( bundle_path: str = typer.Argument(..., help="Path to a generated bundle directory."), ) -> None: """Run schema and artifact validation on a generated bundle.""" - typer.echo( - "The 'validate' command is not yet implemented. Coming in v0.5.0.", - err=True, - ) - raise typer.Exit(1) + from leadforge.validation.bundle_checks import validate_bundle + + root = Path(bundle_path) + + if not root.exists(): + typer.echo(f"FAIL: path does not exist: {root}", err=True) + raise typer.Exit(1) + if not root.is_dir(): + typer.echo(f"FAIL: not a directory: {root}", err=True) + raise typer.Exit(1) + if not (root / "manifest.json").exists(): + typer.echo(f"FAIL: no manifest.json in {root}", err=True) + raise typer.Exit(1) + + try: + errors = validate_bundle(root) + except LeadforgeError as exc: + typer.echo(f"FAIL: {exc}", err=True) + raise typer.Exit(1) from None + + if errors: + typer.echo(f"FAIL: {len(errors)} validation error(s):", err=True) + for e in errors: + typer.echo(f" - {e}", err=True) + raise typer.Exit(1) + + typer.echo(f"OK: bundle at {root} passed all checks.") diff --git a/leadforge/core/hashing.py b/leadforge/core/hashing.py index 680a782..ed465af 100644 --- a/leadforge/core/hashing.py +++ b/leadforge/core/hashing.py @@ -1,13 +1,17 @@ -"""Deterministic config hashing for manifest identity. +"""Deterministic config hashing and file digest helpers. A config hash uniquely identifies a (recipe, config, seed, version) tuple and is embedded in every generated manifest so that bundles can be traced back to the exact parameters that produced them. + +:func:`file_sha256` provides a reusable SHA-256 file digest used by the +manifest builder and the bundle validator. """ import hashlib import json from dataclasses import asdict +from pathlib import Path from typing import TYPE_CHECKING, Any if TYPE_CHECKING: @@ -26,6 +30,15 @@ def _canonical(obj: Any) -> Any: return obj +def file_sha256(path: Path) -> str: + """Return the hex-encoded SHA-256 digest of the file at *path*.""" + h = hashlib.sha256() + with path.open("rb") as fh: + for chunk in iter(lambda: fh.read(65536), b""): + h.update(chunk) + return h.hexdigest() + + def hash_config(config: "GenerationConfig") -> str: """Return a stable hex-encoded SHA-256 digest of *config*. diff --git a/leadforge/render/manifests.py b/leadforge/render/manifests.py index fa8dc31..6b12723 100644 --- a/leadforge/render/manifests.py +++ b/leadforge/render/manifests.py @@ -8,12 +8,13 @@ from __future__ import annotations -import hashlib import json from datetime import UTC, datetime from pathlib import Path from typing import TYPE_CHECKING, Any +from leadforge.core.hashing import file_sha256 + if TYPE_CHECKING: from leadforge.core.models import GenerationConfig from leadforge.structure.graph import WorldGraph @@ -55,7 +56,7 @@ def build_manifest( for table_name, row_count in table_row_counts.items(): rel_path = f"tables/{table_name}.parquet" abs_path = bundle_root / rel_path - sha = _sha256(abs_path) + sha = file_sha256(abs_path) tables[table_name] = {"row_count": row_count, "file": rel_path, "sha256": sha} # Build task entries. @@ -65,7 +66,7 @@ def build_manifest( for split_name, row_count in split_counts.items(): rel_path = f"tasks/{task_id}/{split_name}.parquet" abs_path = bundle_root / rel_path - sha = _sha256(abs_path) + sha = file_sha256(abs_path) entry[f"{split_name}_rows"] = row_count entry[f"{split_name}_sha256"] = sha tasks[task_id] = entry @@ -93,12 +94,3 @@ def write_manifest(manifest: dict[str, Any], bundle_root: Path) -> Path: path = bundle_root / "manifest.json" path.write_text(json.dumps(manifest, indent=2)) return path - - -def _sha256(path: Path) -> str: - """Return the hex-encoded SHA-256 digest of *path*.""" - h = hashlib.sha256() - with path.open("rb") as fh: - for chunk in iter(lambda: fh.read(65536), b""): - h.update(chunk) - return h.hexdigest() diff --git a/leadforge/validation/bundle_checks.py b/leadforge/validation/bundle_checks.py new file mode 100644 index 0000000..b22d927 --- /dev/null +++ b/leadforge/validation/bundle_checks.py @@ -0,0 +1,168 @@ +"""Bundle validation logic. + +:func:`validate_bundle` performs all structural, integrity, FK, and leakage +checks on a written bundle directory. It returns a list of human-readable +error strings (empty = pass). The CLI ``validate`` command is a thin shell +around this function. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import pandas as pd + +from leadforge.core.hashing import file_sha256 +from leadforge.core.serialization import load_json +from leadforge.schema.features import LEAD_SNAPSHOT_FEATURES +from leadforge.schema.relationships import ALL_CONSTRAINTS + + +def validate_bundle(bundle_root: Path) -> list[str]: + """Run all validation checks on the bundle at *bundle_root*. + + Returns: + A list of error strings. An empty list means the bundle is valid. + + Raises: + FileNotFoundError: if ``manifest.json`` does not exist. + ``LeadforgeError``: if ``manifest.json`` is corrupt / unparseable. + """ + manifest = load_json(bundle_root / "manifest.json") + errors: list[str] = [] + errors.extend(_check_required_files(bundle_root)) + tables, table_errors = _check_tables(bundle_root, manifest) + errors.extend(table_errors) + errors.extend(_check_task_splits(bundle_root, manifest)) + errors.extend(_check_fk_integrity(tables)) + errors.extend(_check_leakage(bundle_root, manifest)) + return errors + + +# ------------------------------------------------------------------ +# Internal check functions +# ------------------------------------------------------------------ + + +def _check_required_files(root: Path) -> list[str]: + errors: list[str] = [] + for fname in ("dataset_card.md", "feature_dictionary.csv"): + if not (root / fname).exists(): + errors.append(f"Missing required file: {fname}") + return errors + + +def _check_tables( + root: Path, manifest: dict[str, Any] +) -> tuple[dict[str, pd.DataFrame], list[str]]: + """Validate table files. Returns loaded DataFrames and errors.""" + errors: list[str] = [] + tables: dict[str, pd.DataFrame] = {} + raw_tables = manifest.get("tables", {}) + if not isinstance(raw_tables, dict): + errors.append("Malformed manifest: 'tables' is not a JSON object") + return tables, errors + for table_name, info in raw_tables.items(): + rel_path = info.get("file", f"tables/{table_name}.parquet") + abs_path = root / rel_path + if not abs_path.exists(): + errors.append(f"Missing table file: {rel_path}") + continue + + df = pd.read_parquet(abs_path) + tables[table_name] = df + + expected_rows = info.get("row_count") + if expected_rows is not None and len(df) != expected_rows: + errors.append(f"Table {table_name}: expected {expected_rows} rows, got {len(df)}") + + expected_sha = info.get("sha256") + if expected_sha is not None: + actual_sha = file_sha256(abs_path) + if actual_sha != expected_sha: + errors.append(f"Table {table_name}: SHA-256 mismatch") + + return tables, errors + + +def _check_task_splits(root: Path, manifest: dict[str, Any]) -> list[str]: + errors: list[str] = [] + raw_tasks = manifest.get("tasks", {}) + if not isinstance(raw_tasks, dict): + errors.append("Malformed manifest: 'tasks' is not a JSON object") + return errors + for task_id, task_info in raw_tasks.items(): + for split in ("train", "valid", "test"): + rel_path = f"tasks/{task_id}/{split}.parquet" + abs_path = root / rel_path + if not abs_path.exists(): + errors.append(f"Missing task file: {rel_path}") + continue + + df = pd.read_parquet(abs_path) + expected_rows = task_info.get(f"{split}_rows") + if expected_rows is not None and len(df) != expected_rows: + errors.append( + f"Task {task_id}/{split}: expected {expected_rows} rows, got {len(df)}" + ) + + expected_sha = task_info.get(f"{split}_sha256") + if expected_sha is not None: + actual_sha = file_sha256(abs_path) + if actual_sha != expected_sha: + errors.append(f"Task {task_id}/{split}: SHA-256 mismatch") + + return errors + + +def _check_fk_integrity(tables: dict[str, pd.DataFrame]) -> list[str]: + errors: list[str] = [] + for fk in ALL_CONSTRAINTS: + child_df = tables.get(fk.child_table) + parent_df = tables.get(fk.parent_table) + if child_df is None or parent_df is None: + missing = fk.child_table if child_df is None else fk.parent_table + errors.append( + f"FK check skipped: {fk.child_table}.{fk.child_column} → " + f"{fk.parent_table}.{fk.parent_column} " + f"(table '{missing}' not loaded)" + ) + continue + if fk.child_column not in child_df.columns: + continue + if fk.parent_column not in parent_df.columns: + continue + + child_vals = set(child_df[fk.child_column].dropna()) + parent_vals = set(parent_df[fk.parent_column].dropna()) + orphans = child_vals - parent_vals + if orphans: + sample = list(orphans)[:3] + errors.append( + f"FK violation: {fk.child_table}.{fk.child_column} → " + f"{fk.parent_table}.{fk.parent_column}: " + f"{len(orphans)} orphan(s), e.g. {sample}" + ) + + return errors + + +def _check_leakage(root: Path, manifest: dict[str, Any]) -> list[str]: + """Check all task splits for unexpected columns.""" + errors: list[str] = [] + raw_tasks = manifest.get("tasks", {}) + if not isinstance(raw_tasks, dict): + return errors + expected_columns = {f.name for f in LEAD_SNAPSHOT_FEATURES} + for task_id in raw_tasks: + for split in ("train", "valid", "test"): + split_path = root / f"tasks/{task_id}/{split}.parquet" + if split_path.exists(): + actual_columns = set(pd.read_parquet(split_path, columns=[]).columns) + extra = actual_columns - expected_columns + if extra: + errors.append( + f"Task {task_id}/{split}: unexpected columns (possible leakage): {extra}" + ) + return errors diff --git a/tests/test_cli.py b/tests/test_cli.py index b97c095..b8438cc 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,5 +1,11 @@ -"""CLI smoke tests.""" +"""CLI tests — smoke tests + generate/inspect/validate integration.""" +from __future__ import annotations + +import json +from pathlib import Path + +import pytest from typer.testing import CliRunner from leadforge.cli.main import app @@ -7,6 +13,11 @@ runner = CliRunner() +# --------------------------------------------------------------------------- +# Smoke tests +# --------------------------------------------------------------------------- + + def test_help_exits_clean() -> None: result = runner.invoke(app, ["--help"]) assert result.exit_code == 0 @@ -28,18 +39,289 @@ def test_list_recipes_shows_v1_recipe() -> None: assert "b2b_saas_procurement_v1" in result.output -def test_generate_stub_exits_nonzero() -> None: - result = runner.invoke( - app, ["generate", "--recipe", "x", "--seed", "1", "--mode", "y", "--out", "/tmp"] - ) - assert result.exit_code != 0 +# --------------------------------------------------------------------------- +# Helper — generate a small bundle to a temp dir +# --------------------------------------------------------------------------- + +_GENERATE_ARGS = [ + "generate", + "--recipe", + "b2b_saas_procurement_v1", + "--seed", + "42", + "--mode", + "student_public", + "--n-leads", + "30", + "--n-accounts", + "15", + "--n-contacts", + "45", +] + + +@pytest.fixture(scope="module") +def bundle_dir(tmp_path_factory: pytest.TempPathFactory) -> Path: + """Generate a small bundle once and reuse across tests in this module. + + WARNING: this fixture is module-scoped for performance (avoids re-running + the full generate pipeline per test). Tests MUST NOT mutate this directory. + Tests that need to tamper with bundle contents should ``shutil.copytree`` + into their own ``tmp_path`` first (see ``TestValidateCommand``). + """ + out = tmp_path_factory.mktemp("bundle") + result = runner.invoke(app, [*_GENERATE_ARGS, "--out", str(out)]) + assert result.exit_code == 0, f"generate failed:\n{result.output}" + return out + + +# --------------------------------------------------------------------------- +# generate command +# --------------------------------------------------------------------------- + + +class TestGenerateCommand: + def test_exits_zero(self, bundle_dir: Path) -> None: + # bundle_dir fixture already asserts exit_code == 0 + assert (bundle_dir / "manifest.json").exists() + + def test_writes_core_files(self, bundle_dir: Path) -> None: + assert (bundle_dir / "dataset_card.md").exists() + assert (bundle_dir / "feature_dictionary.csv").exists() + assert (bundle_dir / "tables").is_dir() + assert (bundle_dir / "tasks").is_dir() + + def test_manifest_has_expected_keys(self, bundle_dir: Path) -> None: + manifest = json.loads((bundle_dir / "manifest.json").read_text()) + assert manifest["recipe_id"] == "b2b_saas_procurement_v1" + assert manifest["seed"] == 42 + assert manifest["exposure_mode"] == "student_public" + assert "tables" in manifest + assert "tasks" in manifest + + def test_no_metadata_in_student_public(self, bundle_dir: Path) -> None: + assert not (bundle_dir / "metadata").exists() + + def test_invalid_recipe_fails(self, tmp_path: Path) -> None: + result = runner.invoke( + app, + [ + "generate", + "--recipe", + "nonexistent_recipe", + "--seed", + "1", + "--mode", + "student_public", + "--out", + str(tmp_path / "bad"), + ], + ) + assert result.exit_code != 0 + assert "Error" in result.output + + def test_invalid_mode_fails(self, tmp_path: Path) -> None: + result = runner.invoke( + app, + [ + "generate", + "--recipe", + "b2b_saas_procurement_v1", + "--seed", + "1", + "--mode", + "invalid_mode", + "--out", + str(tmp_path / "bad"), + ], + ) + assert result.exit_code != 0 + assert "Error" in result.output + + def test_research_instructor_mode_has_metadata(self, tmp_path: Path) -> None: + result = runner.invoke( + app, + [ + "generate", + "--recipe", + "b2b_saas_procurement_v1", + "--seed", + "7", + "--mode", + "research_instructor", + "--n-leads", + "20", + "--n-accounts", + "10", + "--n-contacts", + "30", + "--out", + str(tmp_path), + ], + ) + assert result.exit_code == 0, f"generate failed:\n{result.output}" + assert (tmp_path / "metadata").is_dir() + + def test_output_message(self, tmp_path: Path) -> None: + out = tmp_path / "msg_test" + result = runner.invoke(app, [*_GENERATE_ARGS, "--out", str(out)]) + assert result.exit_code == 0 + assert "Generating bundle" in result.output + assert "Done" in result.output + + def test_override_flag(self, tmp_path: Path) -> None: + """--override with a valid YAML file should work.""" + override_file = tmp_path / "override.yaml" + override_file.write_text("n_leads: 25\n") + out = tmp_path / "override_out" + result = runner.invoke( + app, + [ + "generate", + "--recipe", + "b2b_saas_procurement_v1", + "--seed", + "1", + "--mode", + "student_public", + "--override", + str(override_file), + "--out", + str(out), + ], + ) + assert result.exit_code == 0, f"generate failed:\n{result.output}" + assert (out / "manifest.json").exists() + + def test_override_missing_file_fails(self, tmp_path: Path) -> None: + result = runner.invoke( + app, + [ + "generate", + "--recipe", + "b2b_saas_procurement_v1", + "--seed", + "1", + "--mode", + "student_public", + "--override", + str(tmp_path / "nope.yaml"), + "--out", + str(tmp_path / "out"), + ], + ) + assert result.exit_code != 0 + assert "not found" in result.output + + def test_difficulty_flag(self, tmp_path: Path) -> None: + out = tmp_path / "diff_out" + result = runner.invoke( + app, + [ + "generate", + "--recipe", + "b2b_saas_procurement_v1", + "--seed", + "1", + "--mode", + "student_public", + "--difficulty", + "intro", + "--n-leads", + "20", + "--n-accounts", + "10", + "--n-contacts", + "30", + "--out", + str(out), + ], + ) + assert result.exit_code == 0, f"generate failed:\n{result.output}" + manifest = json.loads((out / "manifest.json").read_text()) + assert manifest["difficulty"] == "intro" + + +# --------------------------------------------------------------------------- +# inspect command +# --------------------------------------------------------------------------- + + +class TestInspectCommand: + def test_inspect_output(self, bundle_dir: Path) -> None: + """Single invocation, multiple assertions.""" + result = runner.invoke(app, ["inspect", str(bundle_dir)]) + assert result.exit_code == 0 + output = result.output + assert "b2b_saas_procurement_v1" in output + assert "42" in output + assert "accounts" in output + assert "leads" in output + assert "converted_within_90_days" in output + assert "Metadata dir:" in output + + def test_missing_bundle_fails(self, tmp_path: Path) -> None: + result = runner.invoke(app, ["inspect", str(tmp_path / "nonexistent")]) + assert result.exit_code != 0 + + def test_file_instead_of_dir_fails(self, bundle_dir: Path) -> None: + """Passing a file path instead of a directory should error clearly.""" + result = runner.invoke(app, ["inspect", str(bundle_dir / "manifest.json")]) + assert result.exit_code != 0 + assert "not a directory" in result.output + + +# --------------------------------------------------------------------------- +# validate command +# --------------------------------------------------------------------------- + + +class TestValidateCommand: + def test_valid_bundle_passes(self, bundle_dir: Path) -> None: + result = runner.invoke(app, ["validate", str(bundle_dir)]) + assert result.exit_code == 0 + assert "OK" in result.output + + def test_missing_bundle_fails(self, tmp_path: Path) -> None: + result = runner.invoke(app, ["validate", str(tmp_path / "nonexistent")]) + assert result.exit_code != 0 + + def test_file_instead_of_dir_fails(self, bundle_dir: Path) -> None: + result = runner.invoke(app, ["validate", str(bundle_dir / "manifest.json")]) + assert result.exit_code != 0 + assert "not a directory" in result.output + + def test_corrupt_manifest_fails(self, tmp_path: Path, bundle_dir: Path) -> None: + """A bundle with a tampered row count should fail validation.""" + import shutil + + corrupt = tmp_path / "corrupt_bundle" + shutil.copytree(bundle_dir, corrupt) + + manifest = json.loads((corrupt / "manifest.json").read_text()) + # Tamper with a table row count + first_table = next(iter(manifest["tables"])) + manifest["tables"][first_table]["row_count"] = 999999 + (corrupt / "manifest.json").write_text(json.dumps(manifest, indent=2)) + + result = runner.invoke(app, ["validate", str(corrupt)]) + assert result.exit_code != 0 + assert "FAIL" in result.output + def test_missing_table_file_fails(self, tmp_path: Path, bundle_dir: Path) -> None: + """Removing a table Parquet file should cause validation failure.""" + import shutil -def test_inspect_stub_exits_nonzero() -> None: - result = runner.invoke(app, ["inspect", "/nonexistent"]) - assert result.exit_code != 0 + corrupt = tmp_path / "missing_table_bundle" + shutil.copytree(bundle_dir, corrupt) + # Remove one table file + manifest = json.loads((corrupt / "manifest.json").read_text()) + first_table = next(iter(manifest["tables"])) + (corrupt / f"tables/{first_table}.parquet").unlink() -def test_validate_stub_exits_nonzero() -> None: - result = runner.invoke(app, ["validate", "/nonexistent"]) - assert result.exit_code != 0 + result = runner.invoke(app, ["validate", str(corrupt)]) + assert result.exit_code != 0 + assert "FAIL" in result.output + # Should also report skipped FK checks for the missing table + assert "FK check skipped" in result.output diff --git a/tests/validation/__init__.py b/tests/validation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/validation/test_bundle_checks.py b/tests/validation/test_bundle_checks.py new file mode 100644 index 0000000..27c820f --- /dev/null +++ b/tests/validation/test_bundle_checks.py @@ -0,0 +1,80 @@ +"""Tests for leadforge.validation.bundle_checks.""" + +from __future__ import annotations + +import json +import shutil +from pathlib import Path + +import pytest + +from leadforge.api.generator import Generator +from leadforge.validation.bundle_checks import validate_bundle + +# --------------------------------------------------------------------------- +# Fixture — generate a small bundle once +# --------------------------------------------------------------------------- + +_SMALL = {"n_leads": 20, "n_accounts": 10, "n_contacts": 30} + + +@pytest.fixture(scope="module") +def valid_bundle(tmp_path_factory: pytest.TempPathFactory) -> Path: + """Generate a valid bundle for reuse. Do not mutate.""" + out = tmp_path_factory.mktemp("valid_bundle") + gen = Generator.from_recipe("b2b_saas_procurement_v1", seed=99, exposure_mode="student_public") + gen.generate(**_SMALL).save(str(out)) + return out + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestValidBundle: + def test_passes(self, valid_bundle: Path) -> None: + assert validate_bundle(valid_bundle) == [] + + +class TestCorruptBundle: + def test_row_count_mismatch(self, tmp_path: Path, valid_bundle: Path) -> None: + corrupt = tmp_path / "bad" + shutil.copytree(valid_bundle, corrupt) + manifest = json.loads((corrupt / "manifest.json").read_text()) + first_table = next(iter(manifest["tables"])) + manifest["tables"][first_table]["row_count"] = 999999 + (corrupt / "manifest.json").write_text(json.dumps(manifest, indent=2)) + + errors = validate_bundle(corrupt) + assert any("expected 999999 rows" in e for e in errors) + + def test_missing_table_reports_fk_skip(self, tmp_path: Path, valid_bundle: Path) -> None: + corrupt = tmp_path / "missing" + shutil.copytree(valid_bundle, corrupt) + manifest = json.loads((corrupt / "manifest.json").read_text()) + first_table = next(iter(manifest["tables"])) + (corrupt / f"tables/{first_table}.parquet").unlink() + + errors = validate_bundle(corrupt) + assert any("Missing table file" in e for e in errors) + assert any("FK check skipped" in e for e in errors) + + def test_sha256_mismatch(self, tmp_path: Path, valid_bundle: Path) -> None: + corrupt = tmp_path / "sha" + shutil.copytree(valid_bundle, corrupt) + manifest = json.loads((corrupt / "manifest.json").read_text()) + first_table = next(iter(manifest["tables"])) + manifest["tables"][first_table]["sha256"] = "0" * 64 + (corrupt / "manifest.json").write_text(json.dumps(manifest, indent=2)) + + errors = validate_bundle(corrupt) + assert any("SHA-256 mismatch" in e for e in errors) + + def test_missing_required_file(self, tmp_path: Path, valid_bundle: Path) -> None: + corrupt = tmp_path / "nocard" + shutil.copytree(valid_bundle, corrupt) + (corrupt / "dataset_card.md").unlink() + + errors = validate_bundle(corrupt) + assert any("dataset_card.md" in e for e in errors)