Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .agent-plan.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,15 @@ Documentation + CI:
- [x] `tests/scripts/test_build_v6_snapshot.py` — `seed=` kwarg throughout
- [x] All 740 tests pass; lint + format clean

### Sampler RNG refactor (PR #35, closes #9)

- [x] `leadforge/structure/sampler.py` — `sample_hidden_graph()` accepts `RNGRoot` instead of bare `seed: int`; strict type check (no backward-compat shim — internal function only)
- [x] `leadforge/api/generator.py` — creates `RNGRoot(config.seed)` once, passes to `sample_hidden_graph()`
- [x] `scripts/spike_category_signal.py` — same
- [x] All test files updated: `test_sampler.py`, `test_population.py`, `test_engine.py`, `test_render.py`, `test_snapshot_windowed.py`
- [x] Input validation tests: `test_int_seed_raises_type_error`, `test_none_raises_type_error`
- [x] All 740 tests pass; lint + format clean

---

## Deferred Items
Expand Down
3 changes: 2 additions & 1 deletion leadforge/api/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@ def generate(
"Generator.from_recipe() to resolve the narrative."
)

world_graph = sample_hidden_graph(config.seed)
rng_root = RNGRoot(config.seed)
world_graph = sample_hidden_graph(rng_root)

# Load category-latent correlations from difficulty profile if available.
from leadforge.api.recipes import Recipe
Expand Down
26 changes: 15 additions & 11 deletions leadforge/structure/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,19 @@


def sample_hidden_graph(
seed: int,
rng_root: RNGRoot,
motif_family_name: str | None = None,
) -> WorldGraph:
"""Draw a validated hidden world graph.

The function is fully deterministic given ``(seed, motif_family_name)``.
The function is fully deterministic given ``(rng_root, motif_family_name)``.

Args:
seed: Integer seed passed to :class:`~leadforge.core.rng.RNGRoot`.
All stochastic choices (motif selection if *motif_family_name*
is ``None``, rewiring decisions, weight jitter) derive from a
named child stream of this root so the sampler integrates with
the repo's RNG convention.
rng_root: An :class:`~leadforge.core.rng.RNGRoot` instance. All
stochastic choices (motif selection if *motif_family_name* is
``None``, rewiring decisions, weight jitter) derive from a named
child stream of this root so the sampler integrates with the
repo's RNG convention.
motif_family_name: If provided, pin the motif family by name
(must be one of :data:`~leadforge.structure.motifs.MOTIF_FAMILY_NAMES`).
If ``None``, a family is chosen uniformly at random from the
Expand All @@ -46,15 +46,19 @@ def sample_hidden_graph(
A validated :class:`~leadforge.structure.graph.WorldGraph`.

Raises:
ValueError: If *seed* is a ``bool`` or a negative integer.
TypeError: If *rng_root* is not an :class:`RNGRoot` instance.
KeyError: If *motif_family_name* is not a known motif family name.
RuntimeError: If :data:`_MAX_ATTEMPTS` rewiring attempts all
produce graphs that fail structural validation (should not
happen in practice with well-formed motifs).
"""
if isinstance(seed, bool) or not isinstance(seed, int) or seed < 0:
raise ValueError(f"seed must be a non-negative int, got {seed!r}")
np_seed = RNGRoot(seed).child("hidden_graph").getrandbits(64)
if not isinstance(rng_root, RNGRoot):
raise TypeError(
f"sample_hidden_graph() requires an RNGRoot instance as the first "
f"argument, got {type(rng_root).__name__!r}"
)

np_seed = rng_root.child("hidden_graph").getrandbits(64)
rng = np.random.default_rng(np_seed)

motif = _select_motif(motif_family_name, rng)
Expand Down
4 changes: 3 additions & 1 deletion scripts/spike_category_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

from leadforge.api.generator import Generator
from leadforge.core.rng import RNGRoot
from leadforge.render.snapshots import build_snapshot
from leadforge.simulation.engine import simulate_world
from leadforge.simulation.population import PopulationResult, build_population
Expand Down Expand Up @@ -162,7 +163,8 @@ def run_pipeline(label: str, gen: Generator, scale: float | None = None) -> None
if narrative is None:
raise RuntimeError("No narrative loaded")

world_graph = sample_hidden_graph(config.seed)
rng_root = RNGRoot(config.seed)
world_graph = sample_hidden_graph(rng_root)
print(f" Motif family: {world_graph.motif_family}")

pop = build_population(config, narrative, world_graph)
Expand Down
7 changes: 4 additions & 3 deletions tests/render/test_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest

from leadforge.core.models import GenerationConfig
from leadforge.core.rng import RNGRoot
from leadforge.schema.features import LEAD_SNAPSHOT_FEATURES
from leadforge.simulation.engine import simulate_world
from leadforge.simulation.population import build_population
Expand Down Expand Up @@ -38,7 +39,7 @@ def sim_outputs():
"""Run a small simulation once; share across all tests in this module."""
config = _make_config()
narrative = _make_narrative(config.seed)
graph = sample_hidden_graph(42)
graph = sample_hidden_graph(RNGRoot(42))
population = build_population(config, narrative, graph)
result = simulate_world(config, population, graph)
return config, population, result, graph
Expand Down Expand Up @@ -125,7 +126,7 @@ def test_deterministic_under_same_seed(self):
def _run(seed):
cfg = _make_config(seed=seed)
narr = _make_narrative(seed)
g = sample_hidden_graph(seed)
g = sample_hidden_graph(RNGRoot(seed))
pop = build_population(cfg, narr, g)
res = simulate_world(cfg, pop, g)
return to_dataframes(res, pop)
Expand Down Expand Up @@ -227,7 +228,7 @@ def test_deterministic_under_same_seed(self):
def _snap(seed):
cfg = _make_config(seed=seed)
narr = _make_narrative(seed)
g = sample_hidden_graph(seed)
g = sample_hidden_graph(RNGRoot(seed))
pop = build_population(cfg, narr, g)
res = simulate_world(cfg, pop, g)
return build_snapshot(res, pop, horizon_days=cfg.horizon_days)
Expand Down
5 changes: 3 additions & 2 deletions tests/render/test_snapshot_windowed.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytest

from leadforge.core.models import GenerationConfig
from leadforge.core.rng import RNGRoot
from leadforge.render.snapshots import build_snapshot
from leadforge.simulation.engine import simulate_world
from leadforge.simulation.population import build_population
Expand All @@ -29,7 +30,7 @@ def sim_data():
"""Run a small simulation once; share across all tests in this module."""
config = GenerationConfig(seed=42, n_accounts=30, n_contacts=90, n_leads=80)
narrative = _make_narrative(config.seed)
graph = sample_hidden_graph(42)
graph = sample_hidden_graph(RNGRoot(42))
population = build_population(config, narrative, graph)
result = simulate_world(config, population, graph)
return config, population, result
Expand Down Expand Up @@ -180,7 +181,7 @@ def test_same_seed_same_output(self):
def _snap(seed):
cfg = GenerationConfig(seed=seed, n_accounts=15, n_contacts=45, n_leads=40)
narr = _make_narrative(seed)
g = sample_hidden_graph(seed)
g = sample_hidden_graph(RNGRoot(seed))
pop = build_population(cfg, narr, g)
res = simulate_world(cfg, pop, g)
return build_snapshot(res, pop, snapshot_day=14)
Expand Down
9 changes: 5 additions & 4 deletions tests/simulation/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

from leadforge.core.models import GenerationConfig
from leadforge.core.rng import RNGRoot
from leadforge.schema.entities import (
CustomerRow,
LeadRow,
Expand Down Expand Up @@ -41,7 +42,7 @@ def _make_narrative():
def _run_sim(seed: int = 42, n_leads: int = 50, motif: str | None = None) -> SimulationResult:
config = _make_config(seed=seed, n_leads=n_leads)
narrative = _make_narrative()
graph = sample_hidden_graph(seed, motif_family_name=motif)
graph = sample_hidden_graph(RNGRoot(seed), motif_family_name=motif)
pop = build_population(config, narrative, graph)
return simulate_world(config, pop, graph)

Expand Down Expand Up @@ -272,7 +273,7 @@ def test_subscription_per_customer(self) -> None:
def test_customer_account_fk(self) -> None:
config = _make_config(n_leads=50)
narrative = _make_narrative()
graph = sample_hidden_graph(42)
graph = sample_hidden_graph(RNGRoot(42))
pop = build_population(config, narrative, graph)
result = simulate_world(config, pop, graph)
acct_ids = {a.account_id for a in pop.accounts}
Expand All @@ -298,7 +299,7 @@ class TestEventIntegrity:
def test_touch_lead_fk(self) -> None:
config = _make_config(n_leads=50)
narrative = _make_narrative()
graph = sample_hidden_graph(42)
graph = sample_hidden_graph(RNGRoot(42))
pop = build_population(config, narrative, graph)
result = simulate_world(config, pop, graph)
lead_ids = {row.lead_id for row in result.leads}
Expand All @@ -308,7 +309,7 @@ def test_touch_lead_fk(self) -> None:
def test_session_lead_fk(self) -> None:
config = _make_config(n_leads=50)
narrative = _make_narrative()
graph = sample_hidden_graph(42)
graph = sample_hidden_graph(RNGRoot(42))
pop = build_population(config, narrative, graph)
result = simulate_world(config, pop, graph)
lead_ids = {row.lead_id for row in result.leads}
Expand Down
21 changes: 11 additions & 10 deletions tests/simulation/test_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from leadforge.core.exceptions import InvalidConfigError
from leadforge.core.ids import ID_PREFIXES, make_id
from leadforge.core.models import GenerationConfig
from leadforge.core.rng import RNGRoot
from leadforge.narrative.spec import NarrativeSpec
from leadforge.simulation.population import (
_N_REPS,
Expand Down Expand Up @@ -37,7 +38,7 @@ def _make_result(seed: int = _SEED, motif: str | None = None) -> PopulationResul
gen = Generator.from_recipe("b2b_saas_procurement_v1", seed=seed)
narrative = gen.world_spec.narrative
assert narrative is not None
graph = sample_hidden_graph(seed=seed, motif_family_name=motif)
graph = sample_hidden_graph(RNGRoot(seed), motif_family_name=motif)
return build_population(config, narrative, graph)


Expand Down Expand Up @@ -281,14 +282,14 @@ def test_fit_dominant_raises_account_fit_mean() -> None:
narrative = gen.world_spec.narrative
assert narrative is not None

g_fit = sample_hidden_graph(seed=seed, motif_family_name="fit_dominant")
g_fit = sample_hidden_graph(RNGRoot(seed), motif_family_name="fit_dominant")
r_fit = build_population(config, narrative, g_fit)
fit_means.append(
sum(t["latent_account_fit"] for t in r_fit.latent_state.account_latents.values())
/ config.n_accounts
)

g_fric = sample_hidden_graph(seed=seed, motif_family_name="buying_committee_friction")
g_fric = sample_hidden_graph(RNGRoot(seed), motif_family_name="buying_committee_friction")
r_fric = build_population(config, narrative, g_fric)
friction_means.append(
sum(t["latent_account_fit"] for t in r_fric.latent_state.account_latents.values())
Expand All @@ -312,14 +313,14 @@ def test_buying_committee_friction_lowers_contact_authority() -> None:
narrative = gen.world_spec.narrative
assert narrative is not None

g_bc = sample_hidden_graph(seed=seed, motif_family_name="buying_committee_friction")
g_bc = sample_hidden_graph(RNGRoot(seed), motif_family_name="buying_committee_friction")
r_bc = build_population(config, narrative, g_bc)
bc_means.append(
sum(t["latent_contact_authority"] for t in r_bc.latent_state.contact_latents.values())
/ config.n_contacts
)

g_fd = sample_hidden_graph(seed=seed, motif_family_name="fit_dominant")
g_fd = sample_hidden_graph(RNGRoot(seed), motif_family_name="fit_dominant")
r_fd = build_population(config, narrative, g_fd)
fd_means.append(
sum(t["latent_contact_authority"] for t in r_fd.latent_state.contact_latents.values())
Expand Down Expand Up @@ -371,7 +372,7 @@ def _base_narrative() -> NarrativeSpec:

def _build_with_narrative(narrative: NarrativeSpec) -> PopulationResult:
config = GenerationConfig(seed=0, n_accounts=10, n_contacts=20, n_leads=30)
graph = sample_hidden_graph(seed=0)
graph = sample_hidden_graph(RNGRoot(0))
return build_population(config, narrative, graph)


Expand Down Expand Up @@ -425,7 +426,7 @@ def test_category_latent_correlations_shift_latents() -> None:
gen = Generator.from_recipe("b2b_saas_procurement_v1", seed=42)
narrative = gen.world_spec.narrative
assert narrative is not None
graph = sample_hidden_graph(seed=42)
graph = sample_hidden_graph(RNGRoot(42))

# Build without correlations.
baseline = build_population(config, narrative, graph)
Expand Down Expand Up @@ -463,7 +464,7 @@ def test_category_latent_correlations_clamped() -> None:
gen = Generator.from_recipe("b2b_saas_procurement_v1", seed=42)
narrative = gen.world_spec.narrative
assert narrative is not None
graph = sample_hidden_graph(seed=42)
graph = sample_hidden_graph(RNGRoot(42))

correlations = {
"seniority": {
Expand All @@ -488,7 +489,7 @@ def test_category_latent_correlations_deterministic() -> None:
gen = Generator.from_recipe("b2b_saas_procurement_v1", seed=77)
narrative = gen.world_spec.narrative
assert narrative is not None
graph = sample_hidden_graph(seed=77)
graph = sample_hidden_graph(RNGRoot(77))

correlations = {
"estimated_revenue_band": {
Expand Down Expand Up @@ -516,7 +517,7 @@ def test_lead_source_boost_not_stacked_per_contact() -> None:
gen = Generator.from_recipe("b2b_saas_procurement_v1", seed=42)
narrative = gen.world_spec.narrative
assert narrative is not None
graph = sample_hidden_graph(seed=42)
graph = sample_hidden_graph(RNGRoot(42))
pop = build_population(config, narrative, graph)

# Find a contact with multiple leads of the same source.
Expand Down
Loading
Loading