Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .agent-plan.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ Documentation + CI:
- [x] `leadforge/narrative/dataset_card.py` — renders task name and label window from `world_spec.config` instead of hard-coded literals
- [x] 10 new tests (3 dataset card + 7 config resolution); total 750 passing

### Parquet metadata row counts (PR #37, closes #17)

- [x] `leadforge/validation/bundle_checks.py` — `_check_task_splits()` uses `pq.read_metadata().num_rows` instead of `pd.read_parquet()`; `_check_leakage()` uses `pq.read_schema().names` instead of `pd.read_parquet(columns=[])`
- [x] 3 new tests: metadata/data row count consistency, task split row count mismatch detection, leakage column detection via schema
- [x] All 757 tests pass; lint clean

### Pipeline refactors (PR #34, closes #31 + #32)

- [x] `leadforge/core/rng.py` — `numpy_child()` method on `RNGRoot` returning `np.random.RandomState`
Expand Down
15 changes: 9 additions & 6 deletions leadforge/validation/bundle_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Any

import pandas as pd
import pyarrow.parquet as pq

from leadforge.core.hashing import file_sha256
from leadforge.core.serialization import load_json
Expand Down Expand Up @@ -112,12 +113,14 @@ def _check_task_splits(root: Path, manifest: dict[str, Any]) -> list[str]:
errors.append(f"Missing task file: {rel_path}")
continue

df = pd.read_parquet(abs_path)
expected_rows = task_info.get(f"{split}_rows")
if expected_rows is not None and len(df) != expected_rows:
errors.append(
f"Task {task_id}/{split}: expected {expected_rows} rows, got {len(df)}"
)
if expected_rows is not None:
meta = pq.read_metadata(abs_path)
if meta.num_rows != expected_rows:
errors.append(
f"Task {task_id}/{split}: expected"
f" {expected_rows} rows, got {meta.num_rows}"
)

expected_sha = task_info.get(f"{split}_sha256")
if expected_sha is not None:
Expand Down Expand Up @@ -171,7 +174,7 @@ def _check_leakage(root: Path, manifest: dict[str, Any]) -> list[str]:
for split in ("train", "valid", "test"):
split_path = root / f"tasks/{task_id}/{split}.parquet"
if split_path.exists():
actual_columns = set(pd.read_parquet(split_path, columns=[]).columns)
actual_columns = set(pq.read_schema(split_path).names)
extra = actual_columns - expected_columns
if extra:
errors.append(
Expand Down
53 changes: 53 additions & 0 deletions tests/validation/test_bundle_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,33 @@ def test_passes(self, valid_bundle: Path) -> None:
assert validate_bundle(valid_bundle) == []


class TestMetadataRowCounts:
"""Verify that task split checks use Parquet metadata, not full reads."""

def test_task_splits_does_not_call_read_parquet(
self, valid_bundle: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
import pandas as _pd

from leadforge.validation import bundle_checks

def _boom(*args: object, **kwargs: object) -> None:
raise AssertionError("pd.read_parquet should not be called")

manifest = json.loads((valid_bundle / "manifest.json").read_text())
fake_pd = type(
"_FakePd",
(),
{
"read_parquet": staticmethod(_boom),
"DataFrame": _pd.DataFrame,
},
)
monkeypatch.setattr(bundle_checks, "pd", fake_pd)
errors = bundle_checks._check_task_splits(valid_bundle, manifest)
assert errors == []


class TestCorruptBundle:
def test_row_count_mismatch(self, tmp_path: Path, valid_bundle: Path) -> None:
corrupt = tmp_path / "bad"
Expand Down Expand Up @@ -71,6 +98,32 @@ def test_sha256_mismatch(self, tmp_path: Path, valid_bundle: Path) -> None:
errors = validate_bundle(corrupt)
assert any("SHA-256 mismatch" in e for e in errors)

def test_task_split_row_count_mismatch(self, tmp_path: Path, valid_bundle: Path) -> None:
corrupt = tmp_path / "bad_task"
shutil.copytree(valid_bundle, corrupt)
manifest = json.loads((corrupt / "manifest.json").read_text())
first_task = next(iter(manifest["tasks"]))
manifest["tasks"][first_task]["train_rows"] = 999999
(corrupt / "manifest.json").write_text(json.dumps(manifest, indent=2))

errors = validate_bundle(corrupt)
assert any("expected 999999 rows" in e for e in errors)

def test_leakage_detects_extra_columns(self, tmp_path: Path, valid_bundle: Path) -> None:
import pandas as pd

corrupt = tmp_path / "leak"
shutil.copytree(valid_bundle, corrupt)
manifest = json.loads((corrupt / "manifest.json").read_text())
first_task = next(iter(manifest["tasks"]))
train_path = corrupt / f"tasks/{first_task}/train.parquet"
df = pd.read_parquet(train_path)
df["__sneaky_leak__"] = 1
df.to_parquet(train_path, index=False)

errors = validate_bundle(corrupt, include_realism=False)
assert any("__sneaky_leak__" in e for e in errors)

def test_missing_required_file(self, tmp_path: Path, valid_bundle: Path) -> None:
corrupt = tmp_path / "nocard"
shutil.copytree(valid_bundle, corrupt)
Expand Down
Loading