From 77c11a872feab6cc81cf3ab8eb90db2675524d62 Mon Sep 17 00:00:00 2001 From: Shay Palachy Date: Fri, 1 May 2026 09:22:02 +0300 Subject: [PATCH 1/2] refactor: wire sample_hidden_graph RNG through RNGRoot Change sample_hidden_graph() to accept an RNGRoot instance instead of a bare int seed, aligning it with the cross-component RNG convention established in PR #34. The old int-seed interface is preserved with a deprecation warning for backward compatibility. Closes #9 Co-Authored-By: Claude Opus 4.6 --- .agent-plan.md | 9 +++++ leadforge/api/generator.py | 2 +- leadforge/structure/sampler.py | 46 ++++++++++++++++++++------ scripts/spike_category_signal.py | 3 +- tests/render/test_render.py | 7 ++-- tests/render/test_snapshot_windowed.py | 5 +-- tests/simulation/test_engine.py | 9 ++--- tests/simulation/test_population.py | 21 ++++++------ tests/structure/test_sampler.py | 39 ++++++++++++---------- 9 files changed, 91 insertions(+), 50 deletions(-) diff --git a/.agent-plan.md b/.agent-plan.md index f8d147d..751966c 100644 --- a/.agent-plan.md +++ b/.agent-plan.md @@ -132,6 +132,15 @@ Documentation + CI: - [x] `tests/scripts/test_build_v6_snapshot.py` — `seed=` kwarg throughout - [x] All 740 tests pass; lint + format clean +### Sampler RNG refactor (PR #35, closes #9) + +- [x] `leadforge/structure/sampler.py` — `sample_hidden_graph()` accepts `RNGRoot` instead of bare `seed: int`; deprecated `seed` parameter with backward-compat warnings +- [x] `leadforge/api/generator.py` — passes `RNGRoot(config.seed)` to `sample_hidden_graph()` +- [x] `scripts/spike_category_signal.py` — same +- [x] All test files updated: `test_sampler.py`, `test_population.py`, `test_engine.py`, `test_render.py`, `test_snapshot_windowed.py` +- [x] 2 new deprecation warning tests in `test_sampler.py` +- [x] All 740 tests pass; lint + format clean + --- ## Deferred Items diff --git a/leadforge/api/generator.py b/leadforge/api/generator.py index e30bb99..8e78e44 100644 --- a/leadforge/api/generator.py +++ b/leadforge/api/generator.py @@ -165,7 +165,7 @@ def generate( "Generator.from_recipe() to resolve the narrative." ) - world_graph = sample_hidden_graph(config.seed) + world_graph = sample_hidden_graph(RNGRoot(config.seed)) # Load category-latent correlations from difficulty profile if available. from leadforge.api.recipes import Recipe diff --git a/leadforge/structure/sampler.py b/leadforge/structure/sampler.py index 3aee3c5..a2c7a92 100644 --- a/leadforge/structure/sampler.py +++ b/leadforge/structure/sampler.py @@ -8,6 +8,8 @@ from __future__ import annotations +import warnings + import numpy as np from leadforge.core.rng import RNGRoot @@ -24,37 +26,59 @@ def sample_hidden_graph( - seed: int, + rng_root: RNGRoot | int | None = None, motif_family_name: str | None = None, + *, + seed: int | None = None, ) -> WorldGraph: """Draw a validated hidden world graph. - The function is fully deterministic given ``(seed, motif_family_name)``. + The function is fully deterministic given ``(rng_root, motif_family_name)``. Args: - seed: Integer seed passed to :class:`~leadforge.core.rng.RNGRoot`. - All stochastic choices (motif selection if *motif_family_name* - is ``None``, rewiring decisions, weight jitter) derive from a - named child stream of this root so the sampler integrates with - the repo's RNG convention. + rng_root: An :class:`~leadforge.core.rng.RNGRoot` instance. All + stochastic choices (motif selection if *motif_family_name* is + ``None``, rewiring decisions, weight jitter) derive from a named + child stream of this root so the sampler integrates with the + repo's RNG convention. motif_family_name: If provided, pin the motif family by name (must be one of :data:`~leadforge.structure.motifs.MOTIF_FAMILY_NAMES`). If ``None``, a family is chosen uniformly at random from the five v1 families. + seed: **Deprecated.** Pass an ``RNGRoot`` as the first argument + instead. When *seed* is given and *rng_root* is not, an + ``RNGRoot(seed)`` is constructed automatically. Returns: A validated :class:`~leadforge.structure.graph.WorldGraph`. Raises: - ValueError: If *seed* is a ``bool`` or a negative integer. KeyError: If *motif_family_name* is not a known motif family name. RuntimeError: If :data:`_MAX_ATTEMPTS` rewiring attempts all produce graphs that fail structural validation (should not happen in practice with well-formed motifs). """ - if isinstance(seed, bool) or not isinstance(seed, int) or seed < 0: - raise ValueError(f"seed must be a non-negative int, got {seed!r}") - np_seed = RNGRoot(seed).child("hidden_graph").getrandbits(64) + # ---- backward-compat: accept bare int seed ---- + if isinstance(rng_root, int): + warnings.warn( + "Passing an int seed as the first argument to sample_hidden_graph() " + "is deprecated. Pass an RNGRoot instance instead.", + DeprecationWarning, + stacklevel=2, + ) + rng_root = RNGRoot(rng_root) + elif seed is not None: + warnings.warn( + "The 'seed' keyword argument to sample_hidden_graph() is deprecated. " + "Pass an RNGRoot instance as the first argument instead.", + DeprecationWarning, + stacklevel=2, + ) + rng_root = RNGRoot(seed) + elif rng_root is None: + raise TypeError("sample_hidden_graph() requires an RNGRoot instance as the first argument") + + np_seed = rng_root.child("hidden_graph").getrandbits(64) rng = np.random.default_rng(np_seed) motif = _select_motif(motif_family_name, rng) diff --git a/scripts/spike_category_signal.py b/scripts/spike_category_signal.py index 27ba972..07c3d43 100644 --- a/scripts/spike_category_signal.py +++ b/scripts/spike_category_signal.py @@ -25,6 +25,7 @@ sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from leadforge.api.generator import Generator +from leadforge.core.rng import RNGRoot from leadforge.render.snapshots import build_snapshot from leadforge.simulation.engine import simulate_world from leadforge.simulation.population import PopulationResult, build_population @@ -162,7 +163,7 @@ def run_pipeline(label: str, gen: Generator, scale: float | None = None) -> None if narrative is None: raise RuntimeError("No narrative loaded") - world_graph = sample_hidden_graph(config.seed) + world_graph = sample_hidden_graph(RNGRoot(config.seed)) print(f" Motif family: {world_graph.motif_family}") pop = build_population(config, narrative, world_graph) diff --git a/tests/render/test_render.py b/tests/render/test_render.py index 18df3cd..fa4da40 100644 --- a/tests/render/test_render.py +++ b/tests/render/test_render.py @@ -8,6 +8,7 @@ import pytest from leadforge.core.models import GenerationConfig +from leadforge.core.rng import RNGRoot from leadforge.schema.features import LEAD_SNAPSHOT_FEATURES from leadforge.simulation.engine import simulate_world from leadforge.simulation.population import build_population @@ -38,7 +39,7 @@ def sim_outputs(): """Run a small simulation once; share across all tests in this module.""" config = _make_config() narrative = _make_narrative(config.seed) - graph = sample_hidden_graph(42) + graph = sample_hidden_graph(RNGRoot(42)) population = build_population(config, narrative, graph) result = simulate_world(config, population, graph) return config, population, result, graph @@ -125,7 +126,7 @@ def test_deterministic_under_same_seed(self): def _run(seed): cfg = _make_config(seed=seed) narr = _make_narrative(seed) - g = sample_hidden_graph(seed) + g = sample_hidden_graph(RNGRoot(seed)) pop = build_population(cfg, narr, g) res = simulate_world(cfg, pop, g) return to_dataframes(res, pop) @@ -227,7 +228,7 @@ def test_deterministic_under_same_seed(self): def _snap(seed): cfg = _make_config(seed=seed) narr = _make_narrative(seed) - g = sample_hidden_graph(seed) + g = sample_hidden_graph(RNGRoot(seed)) pop = build_population(cfg, narr, g) res = simulate_world(cfg, pop, g) return build_snapshot(res, pop, horizon_days=cfg.horizon_days) diff --git a/tests/render/test_snapshot_windowed.py b/tests/render/test_snapshot_windowed.py index f2a15f6..e0b2654 100644 --- a/tests/render/test_snapshot_windowed.py +++ b/tests/render/test_snapshot_windowed.py @@ -10,6 +10,7 @@ import pytest from leadforge.core.models import GenerationConfig +from leadforge.core.rng import RNGRoot from leadforge.render.snapshots import build_snapshot from leadforge.simulation.engine import simulate_world from leadforge.simulation.population import build_population @@ -29,7 +30,7 @@ def sim_data(): """Run a small simulation once; share across all tests in this module.""" config = GenerationConfig(seed=42, n_accounts=30, n_contacts=90, n_leads=80) narrative = _make_narrative(config.seed) - graph = sample_hidden_graph(42) + graph = sample_hidden_graph(RNGRoot(42)) population = build_population(config, narrative, graph) result = simulate_world(config, population, graph) return config, population, result @@ -180,7 +181,7 @@ def test_same_seed_same_output(self): def _snap(seed): cfg = GenerationConfig(seed=seed, n_accounts=15, n_contacts=45, n_leads=40) narr = _make_narrative(seed) - g = sample_hidden_graph(seed) + g = sample_hidden_graph(RNGRoot(seed)) pop = build_population(cfg, narr, g) res = simulate_world(cfg, pop, g) return build_snapshot(res, pop, snapshot_day=14) diff --git a/tests/simulation/test_engine.py b/tests/simulation/test_engine.py index cc26231..4825169 100644 --- a/tests/simulation/test_engine.py +++ b/tests/simulation/test_engine.py @@ -5,6 +5,7 @@ import pytest from leadforge.core.models import GenerationConfig +from leadforge.core.rng import RNGRoot from leadforge.schema.entities import ( CustomerRow, LeadRow, @@ -41,7 +42,7 @@ def _make_narrative(): def _run_sim(seed: int = 42, n_leads: int = 50, motif: str | None = None) -> SimulationResult: config = _make_config(seed=seed, n_leads=n_leads) narrative = _make_narrative() - graph = sample_hidden_graph(seed, motif_family_name=motif) + graph = sample_hidden_graph(RNGRoot(seed), motif_family_name=motif) pop = build_population(config, narrative, graph) return simulate_world(config, pop, graph) @@ -272,7 +273,7 @@ def test_subscription_per_customer(self) -> None: def test_customer_account_fk(self) -> None: config = _make_config(n_leads=50) narrative = _make_narrative() - graph = sample_hidden_graph(42) + graph = sample_hidden_graph(RNGRoot(42)) pop = build_population(config, narrative, graph) result = simulate_world(config, pop, graph) acct_ids = {a.account_id for a in pop.accounts} @@ -298,7 +299,7 @@ class TestEventIntegrity: def test_touch_lead_fk(self) -> None: config = _make_config(n_leads=50) narrative = _make_narrative() - graph = sample_hidden_graph(42) + graph = sample_hidden_graph(RNGRoot(42)) pop = build_population(config, narrative, graph) result = simulate_world(config, pop, graph) lead_ids = {row.lead_id for row in result.leads} @@ -308,7 +309,7 @@ def test_touch_lead_fk(self) -> None: def test_session_lead_fk(self) -> None: config = _make_config(n_leads=50) narrative = _make_narrative() - graph = sample_hidden_graph(42) + graph = sample_hidden_graph(RNGRoot(42)) pop = build_population(config, narrative, graph) result = simulate_world(config, pop, graph) lead_ids = {row.lead_id for row in result.leads} diff --git a/tests/simulation/test_population.py b/tests/simulation/test_population.py index a4d5d0c..9fe017a 100644 --- a/tests/simulation/test_population.py +++ b/tests/simulation/test_population.py @@ -8,6 +8,7 @@ from leadforge.core.exceptions import InvalidConfigError from leadforge.core.ids import ID_PREFIXES, make_id from leadforge.core.models import GenerationConfig +from leadforge.core.rng import RNGRoot from leadforge.narrative.spec import NarrativeSpec from leadforge.simulation.population import ( _N_REPS, @@ -37,7 +38,7 @@ def _make_result(seed: int = _SEED, motif: str | None = None) -> PopulationResul gen = Generator.from_recipe("b2b_saas_procurement_v1", seed=seed) narrative = gen.world_spec.narrative assert narrative is not None - graph = sample_hidden_graph(seed=seed, motif_family_name=motif) + graph = sample_hidden_graph(RNGRoot(seed), motif_family_name=motif) return build_population(config, narrative, graph) @@ -281,14 +282,14 @@ def test_fit_dominant_raises_account_fit_mean() -> None: narrative = gen.world_spec.narrative assert narrative is not None - g_fit = sample_hidden_graph(seed=seed, motif_family_name="fit_dominant") + g_fit = sample_hidden_graph(RNGRoot(seed), motif_family_name="fit_dominant") r_fit = build_population(config, narrative, g_fit) fit_means.append( sum(t["latent_account_fit"] for t in r_fit.latent_state.account_latents.values()) / config.n_accounts ) - g_fric = sample_hidden_graph(seed=seed, motif_family_name="buying_committee_friction") + g_fric = sample_hidden_graph(RNGRoot(seed), motif_family_name="buying_committee_friction") r_fric = build_population(config, narrative, g_fric) friction_means.append( sum(t["latent_account_fit"] for t in r_fric.latent_state.account_latents.values()) @@ -312,14 +313,14 @@ def test_buying_committee_friction_lowers_contact_authority() -> None: narrative = gen.world_spec.narrative assert narrative is not None - g_bc = sample_hidden_graph(seed=seed, motif_family_name="buying_committee_friction") + g_bc = sample_hidden_graph(RNGRoot(seed), motif_family_name="buying_committee_friction") r_bc = build_population(config, narrative, g_bc) bc_means.append( sum(t["latent_contact_authority"] for t in r_bc.latent_state.contact_latents.values()) / config.n_contacts ) - g_fd = sample_hidden_graph(seed=seed, motif_family_name="fit_dominant") + g_fd = sample_hidden_graph(RNGRoot(seed), motif_family_name="fit_dominant") r_fd = build_population(config, narrative, g_fd) fd_means.append( sum(t["latent_contact_authority"] for t in r_fd.latent_state.contact_latents.values()) @@ -371,7 +372,7 @@ def _base_narrative() -> NarrativeSpec: def _build_with_narrative(narrative: NarrativeSpec) -> PopulationResult: config = GenerationConfig(seed=0, n_accounts=10, n_contacts=20, n_leads=30) - graph = sample_hidden_graph(seed=0) + graph = sample_hidden_graph(RNGRoot(0)) return build_population(config, narrative, graph) @@ -425,7 +426,7 @@ def test_category_latent_correlations_shift_latents() -> None: gen = Generator.from_recipe("b2b_saas_procurement_v1", seed=42) narrative = gen.world_spec.narrative assert narrative is not None - graph = sample_hidden_graph(seed=42) + graph = sample_hidden_graph(RNGRoot(42)) # Build without correlations. baseline = build_population(config, narrative, graph) @@ -463,7 +464,7 @@ def test_category_latent_correlations_clamped() -> None: gen = Generator.from_recipe("b2b_saas_procurement_v1", seed=42) narrative = gen.world_spec.narrative assert narrative is not None - graph = sample_hidden_graph(seed=42) + graph = sample_hidden_graph(RNGRoot(42)) correlations = { "seniority": { @@ -488,7 +489,7 @@ def test_category_latent_correlations_deterministic() -> None: gen = Generator.from_recipe("b2b_saas_procurement_v1", seed=77) narrative = gen.world_spec.narrative assert narrative is not None - graph = sample_hidden_graph(seed=77) + graph = sample_hidden_graph(RNGRoot(77)) correlations = { "estimated_revenue_band": { @@ -516,7 +517,7 @@ def test_lead_source_boost_not_stacked_per_contact() -> None: gen = Generator.from_recipe("b2b_saas_procurement_v1", seed=42) narrative = gen.world_spec.narrative assert narrative is not None - graph = sample_hidden_graph(seed=42) + graph = sample_hidden_graph(RNGRoot(42)) pop = build_population(config, narrative, graph) # Find a contact with multiple leads of the same source. diff --git a/tests/structure/test_sampler.py b/tests/structure/test_sampler.py index ffc32da..70ffdba 100644 --- a/tests/structure/test_sampler.py +++ b/tests/structure/test_sampler.py @@ -2,6 +2,7 @@ import pytest +from leadforge.core.rng import RNGRoot from leadforge.structure.graph import WorldGraph from leadforge.structure.motifs import MOTIF_FAMILY_NAMES from leadforge.structure.node_types import NodeType @@ -13,12 +14,12 @@ def test_returns_world_graph() -> None: - g = sample_hidden_graph(seed=0) + g = sample_hidden_graph(RNGRoot(0)) assert isinstance(g, WorldGraph) def test_sampled_graph_has_outcome_node() -> None: - g = sample_hidden_graph(seed=0) + g = sample_hidden_graph(RNGRoot(0)) outcome_nodes = [n for n in g.graph.nodes if g.node_type(n) == NodeType.OUTCOME] assert len(outcome_nodes) >= 1 @@ -26,7 +27,7 @@ def test_sampled_graph_has_outcome_node() -> None: def test_sampled_graph_is_dag() -> None: import networkx as nx - g = sample_hidden_graph(seed=0) + g = sample_hidden_graph(RNGRoot(0)) assert nx.is_directed_acyclic_graph(g.graph) @@ -36,8 +37,8 @@ def test_sampled_graph_is_dag() -> None: def test_same_seed_same_graph() -> None: - g1 = sample_hidden_graph(seed=42) - g2 = sample_hidden_graph(seed=42) + g1 = sample_hidden_graph(RNGRoot(42)) + g2 = sample_hidden_graph(RNGRoot(42)) assert g1.motif_family == g2.motif_family assert sorted(g1.graph.nodes) == sorted(g2.graph.nodes) assert sorted(g1.graph.edges) == sorted(g2.graph.edges) @@ -48,7 +49,7 @@ def test_same_seed_same_graph() -> None: def test_different_seeds_can_differ() -> None: - graphs = [sample_hidden_graph(seed=s) for s in range(20)] + graphs = [sample_hidden_graph(RNGRoot(s)) for s in range(20)] families = {g.motif_family for g in graphs} # With 5 families and 20 seeds, we expect more than one family. assert len(families) > 1 @@ -61,23 +62,25 @@ def test_different_seeds_can_differ() -> None: @pytest.mark.parametrize("name", MOTIF_FAMILY_NAMES) def test_pinned_motif_family(name: str) -> None: - g = sample_hidden_graph(seed=7, motif_family_name=name) + g = sample_hidden_graph(RNGRoot(7), motif_family_name=name) assert g.motif_family == name def test_unknown_motif_family_raises() -> None: with pytest.raises(KeyError, match="bad_family"): - sample_hidden_graph(seed=0, motif_family_name="bad_family") + sample_hidden_graph(RNGRoot(0), motif_family_name="bad_family") -def test_bool_seed_raises() -> None: - with pytest.raises(ValueError, match="non-negative int"): - sample_hidden_graph(seed=True) # type: ignore[arg-type] +def test_deprecated_int_seed_warns() -> None: + with pytest.warns(DeprecationWarning, match="RNGRoot"): + g = sample_hidden_graph(0) # type: ignore[arg-type] + assert isinstance(g, WorldGraph) -def test_negative_seed_raises() -> None: - with pytest.raises(ValueError, match="non-negative int"): - sample_hidden_graph(seed=-1) +def test_deprecated_seed_kwarg_warns() -> None: + with pytest.warns(DeprecationWarning, match="RNGRoot"): + g = sample_hidden_graph(seed=0) # type: ignore[arg-type] + assert isinstance(g, WorldGraph) # --------------------------------------------------------------------------- @@ -88,7 +91,7 @@ def test_negative_seed_raises() -> None: @pytest.mark.parametrize("seed", range(30)) def test_all_sampled_graphs_are_valid(seed: int) -> None: """Property test: no seed should produce an invalid graph.""" - g = sample_hidden_graph(seed=seed) + g = sample_hidden_graph(RNGRoot(seed)) # If we got here without GraphValidationError, the graph is valid. assert g.graph.number_of_nodes() >= 2 assert g.graph.number_of_edges() >= 1 @@ -97,7 +100,7 @@ def test_all_sampled_graphs_are_valid(seed: int) -> None: @pytest.mark.parametrize("name", MOTIF_FAMILY_NAMES) def test_pinned_family_graphs_are_valid_across_seeds(name: str) -> None: for seed in range(10): - g = sample_hidden_graph(seed=seed, motif_family_name=name) + g = sample_hidden_graph(RNGRoot(seed), motif_family_name=name) assert g.graph.number_of_nodes() >= 2 @@ -109,12 +112,12 @@ def test_pinned_family_graphs_are_valid_across_seeds(name: str) -> None: def test_to_json_is_parseable() -> None: import json - g = sample_hidden_graph(seed=1) + g = sample_hidden_graph(RNGRoot(1)) data = json.loads(g.to_json()) assert "nodes" in data assert "edges" in data def test_to_graphml_contains_graph_tag() -> None: - g = sample_hidden_graph(seed=1) + g = sample_hidden_graph(RNGRoot(1)) assert " Date: Fri, 1 May 2026 09:27:23 +0300 Subject: [PATCH 2/2] refactor: drop backward-compat shim, clean up signature MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address self-review feedback: - Remove unnecessary deprecation warnings — sample_hidden_graph is internal, no external consumers need backward compat - Make rng_root a required RNGRoot parameter (no union type) - Add strict TypeError for non-RNGRoot input - Create RNGRoot once in generator.py rather than inline - Replace deprecation tests with proper type-error tests Co-Authored-By: Claude Opus 4.6 --- .agent-plan.md | 6 +++--- leadforge/api/generator.py | 3 ++- leadforge/structure/sampler.py | 32 ++++++-------------------------- scripts/spike_category_signal.py | 3 ++- tests/structure/test_sampler.py | 19 +++++++++++-------- 5 files changed, 24 insertions(+), 39 deletions(-) diff --git a/.agent-plan.md b/.agent-plan.md index 751966c..1c039b8 100644 --- a/.agent-plan.md +++ b/.agent-plan.md @@ -134,11 +134,11 @@ Documentation + CI: ### Sampler RNG refactor (PR #35, closes #9) -- [x] `leadforge/structure/sampler.py` — `sample_hidden_graph()` accepts `RNGRoot` instead of bare `seed: int`; deprecated `seed` parameter with backward-compat warnings -- [x] `leadforge/api/generator.py` — passes `RNGRoot(config.seed)` to `sample_hidden_graph()` +- [x] `leadforge/structure/sampler.py` — `sample_hidden_graph()` accepts `RNGRoot` instead of bare `seed: int`; strict type check (no backward-compat shim — internal function only) +- [x] `leadforge/api/generator.py` — creates `RNGRoot(config.seed)` once, passes to `sample_hidden_graph()` - [x] `scripts/spike_category_signal.py` — same - [x] All test files updated: `test_sampler.py`, `test_population.py`, `test_engine.py`, `test_render.py`, `test_snapshot_windowed.py` -- [x] 2 new deprecation warning tests in `test_sampler.py` +- [x] Input validation tests: `test_int_seed_raises_type_error`, `test_none_raises_type_error` - [x] All 740 tests pass; lint + format clean --- diff --git a/leadforge/api/generator.py b/leadforge/api/generator.py index 8e78e44..9ef87a4 100644 --- a/leadforge/api/generator.py +++ b/leadforge/api/generator.py @@ -165,7 +165,8 @@ def generate( "Generator.from_recipe() to resolve the narrative." ) - world_graph = sample_hidden_graph(RNGRoot(config.seed)) + rng_root = RNGRoot(config.seed) + world_graph = sample_hidden_graph(rng_root) # Load category-latent correlations from difficulty profile if available. from leadforge.api.recipes import Recipe diff --git a/leadforge/structure/sampler.py b/leadforge/structure/sampler.py index a2c7a92..d964b9a 100644 --- a/leadforge/structure/sampler.py +++ b/leadforge/structure/sampler.py @@ -8,8 +8,6 @@ from __future__ import annotations -import warnings - import numpy as np from leadforge.core.rng import RNGRoot @@ -26,10 +24,8 @@ def sample_hidden_graph( - rng_root: RNGRoot | int | None = None, + rng_root: RNGRoot, motif_family_name: str | None = None, - *, - seed: int | None = None, ) -> WorldGraph: """Draw a validated hidden world graph. @@ -45,38 +41,22 @@ def sample_hidden_graph( (must be one of :data:`~leadforge.structure.motifs.MOTIF_FAMILY_NAMES`). If ``None``, a family is chosen uniformly at random from the five v1 families. - seed: **Deprecated.** Pass an ``RNGRoot`` as the first argument - instead. When *seed* is given and *rng_root* is not, an - ``RNGRoot(seed)`` is constructed automatically. Returns: A validated :class:`~leadforge.structure.graph.WorldGraph`. Raises: + TypeError: If *rng_root* is not an :class:`RNGRoot` instance. KeyError: If *motif_family_name* is not a known motif family name. RuntimeError: If :data:`_MAX_ATTEMPTS` rewiring attempts all produce graphs that fail structural validation (should not happen in practice with well-formed motifs). """ - # ---- backward-compat: accept bare int seed ---- - if isinstance(rng_root, int): - warnings.warn( - "Passing an int seed as the first argument to sample_hidden_graph() " - "is deprecated. Pass an RNGRoot instance instead.", - DeprecationWarning, - stacklevel=2, - ) - rng_root = RNGRoot(rng_root) - elif seed is not None: - warnings.warn( - "The 'seed' keyword argument to sample_hidden_graph() is deprecated. " - "Pass an RNGRoot instance as the first argument instead.", - DeprecationWarning, - stacklevel=2, + if not isinstance(rng_root, RNGRoot): + raise TypeError( + f"sample_hidden_graph() requires an RNGRoot instance as the first " + f"argument, got {type(rng_root).__name__!r}" ) - rng_root = RNGRoot(seed) - elif rng_root is None: - raise TypeError("sample_hidden_graph() requires an RNGRoot instance as the first argument") np_seed = rng_root.child("hidden_graph").getrandbits(64) rng = np.random.default_rng(np_seed) diff --git a/scripts/spike_category_signal.py b/scripts/spike_category_signal.py index 07c3d43..58b677d 100644 --- a/scripts/spike_category_signal.py +++ b/scripts/spike_category_signal.py @@ -163,7 +163,8 @@ def run_pipeline(label: str, gen: Generator, scale: float | None = None) -> None if narrative is None: raise RuntimeError("No narrative loaded") - world_graph = sample_hidden_graph(RNGRoot(config.seed)) + rng_root = RNGRoot(config.seed) + world_graph = sample_hidden_graph(rng_root) print(f" Motif family: {world_graph.motif_family}") pop = build_population(config, narrative, world_graph) diff --git a/tests/structure/test_sampler.py b/tests/structure/test_sampler.py index 70ffdba..098c106 100644 --- a/tests/structure/test_sampler.py +++ b/tests/structure/test_sampler.py @@ -71,16 +71,19 @@ def test_unknown_motif_family_raises() -> None: sample_hidden_graph(RNGRoot(0), motif_family_name="bad_family") -def test_deprecated_int_seed_warns() -> None: - with pytest.warns(DeprecationWarning, match="RNGRoot"): - g = sample_hidden_graph(0) # type: ignore[arg-type] - assert isinstance(g, WorldGraph) +# --------------------------------------------------------------------------- +# Input validation +# --------------------------------------------------------------------------- -def test_deprecated_seed_kwarg_warns() -> None: - with pytest.warns(DeprecationWarning, match="RNGRoot"): - g = sample_hidden_graph(seed=0) # type: ignore[arg-type] - assert isinstance(g, WorldGraph) +def test_int_seed_raises_type_error() -> None: + with pytest.raises(TypeError, match="RNGRoot"): + sample_hidden_graph(42) # type: ignore[arg-type] + + +def test_none_raises_type_error() -> None: + with pytest.raises(TypeError, match="RNGRoot"): + sample_hidden_graph(None) # type: ignore[arg-type] # ---------------------------------------------------------------------------