From df7503cc780e28477792b2169b80e952847009df Mon Sep 17 00:00:00 2001 From: Shay Palachy Date: Wed, 22 Apr 2026 16:58:27 +0300 Subject: [PATCH 1/5] =?UTF-8?q?feat:=20Milestone=204=20=E2=80=94=20world?= =?UTF-8?q?=20structure=20layer=20(node=20types,=20graph,=20motifs,=20rewi?= =?UTF-8?q?ring,=20sampler)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements the hidden-world variability mechanism (§11 of architecture spec): leadforge/structure/node_types.py - NodeType enum with 9 semantic categories (global_context, account_latent, contact_latent, lead_state, engagement_state, sales_process_state, observable_feature_source, outcome, post_conversion_state) - ROOT_ELIGIBLE, REQUIRES_PARENT, LEAF_ONLY constraint sets leadforge/structure/graph.py - NodeSpec / EdgeSpec dataclasses for graph construction - WorldGraph: wraps networkx.DiGraph; validates acyclicity, node-type legality (REQUIRES_PARENT, LEAF_ONLY), nondegeneracy, and outcome reachability - Exports: to_dict(), to_json(), to_graphml() leadforge/structure/motifs.py - MotifFamily frozen dataclass (canonical_nodes, canonical_edges, optional_node_ids) - All 5 v1 families: fit_dominant, intent_dominant, sales_execution_sensitive, demo_trial_mediated, buying_committee_friction - get_motif_family() lookup; ALL_MOTIF_FAMILIES / MOTIF_FAMILY_NAMES registry leadforge/structure/rewiring.py - rewire(motif, rng): optional-node dropping (p=0.4), edge-weight jitter (±0.15), optional latent-confounder injection (p=0.35) - All perturbations deterministic given seed; outcome nodes always retain a parent leadforge/structure/sampler.py - sample_hidden_graph(seed, motif_family_name=None): pins or randomly selects motif, applies rewiring, validates via WorldGraph, retries up to 20 times pyproject.toml: add networkx>=3.2 and numpy>=1.26 as core deps; mypy override for networkx. 132 new tests (node_types, graph, motifs, rewiring, sampler); 327 total passing. Co-Authored-By: Claude Sonnet 4.6 --- .agent-plan.md | 39 +++-- leadforge/structure/graph.py | 241 +++++++++++++++++++++++++++ leadforge/structure/motifs.py | 252 +++++++++++++++++++++++++++++ leadforge/structure/node_types.py | 58 +++++++ leadforge/structure/rewiring.py | 131 +++++++++++++++ leadforge/structure/sampler.py | 83 ++++++++++ pyproject.toml | 6 + tests/structure/__init__.py | 0 tests/structure/test_graph.py | 204 +++++++++++++++++++++++ tests/structure/test_motifs.py | 99 ++++++++++++ tests/structure/test_node_types.py | 38 +++++ tests/structure/test_rewiring.py | 113 +++++++++++++ tests/structure/test_sampler.py | 106 ++++++++++++ 13 files changed, 1357 insertions(+), 13 deletions(-) create mode 100644 leadforge/structure/graph.py create mode 100644 leadforge/structure/motifs.py create mode 100644 leadforge/structure/node_types.py create mode 100644 leadforge/structure/rewiring.py create mode 100644 leadforge/structure/sampler.py create mode 100644 tests/structure/__init__.py create mode 100644 tests/structure/test_graph.py create mode 100644 tests/structure/test_motifs.py create mode 100644 tests/structure/test_node_types.py create mode 100644 tests/structure/test_rewiring.py create mode 100644 tests/structure/test_sampler.py diff --git a/.agent-plan.md b/.agent-plan.md index 447c76f..0cc6073 100644 --- a/.agent-plan.md +++ b/.agent-plan.md @@ -6,35 +6,48 @@ ## Current System State -**v0.2.0 in progress — Milestone 3 complete (PR open).** All 9 relational table schemas defined as -typed row dataclasses with Parquet round-trip support. FK constraints, ID generation, feature -dictionary, and task manifest implemented. 192 tests passing. +**v0.3.0 in progress — Milestone 4 complete (PR open).** Hidden world graph fully implemented: +typed node system, DAG-validated WorldGraph, all 5 v1 motif families, stochastic rewiring, and +graph sampler. 327 tests passing. --- -## Active Task Breakdown — Milestone 4: World Structure (v0.3.0) +## Active Task Breakdown — Milestone 5: Population Generation (v0.3.0) -Goal: Implement the hidden world graph — DAG of latent nodes, motif families, and stochastic rewiring. +Goal: Generate the base commercial world population before dynamic events begin. -- [ ] **1. Node type system** (`structure/node_types.py`) -- [ ] **2. World graph** (`structure/graph.py`) — `networkx.DiGraph`, DAG validation -- [ ] **3. Motif families** (`structure/motifs.py`, `structure/templates.py`) — 5 v1 families -- [ ] **4. Stochastic rewiring** (`structure/rewiring.py`) — seeded perturbation -- [ ] **5. Sampler** (`structure/sampler.py`) — draw a world graph from a motif + config +- [ ] **1. Account generation** (`simulation/population.py`) — accounts with latent traits +- [ ] **2. Contact generation** — contacts conditional on account properties +- [ ] **3. Lead creation** — leads anchored to contacts/accounts +- [ ] **4. Latent-state initialisation** — sample core latent traits tied to graph/motif --- ## Context Pointers -- Milestone 4 scope: `docs/leadforge_implementation_plan.md` §7 "Milestone 4" -- Full milestone dependency graph: `docs/leadforge_implementation_plan.md` §6 +- Milestone 5 scope: `docs/leadforge_implementation_plan.md` §8 "Milestone 5" - Structure spec: `docs/leadforge_architecture_spec.md` §11 -- Motif families: `docs/leadforge_architecture_spec.md` §11.2 +- Latent variables: `docs/leadforge_architecture_spec.md` §9 --- ## Completed Phases +### Milestone 4 — World Structure Layer ✓ (v0.3.0 in PR) +- `leadforge/structure/node_types.py`: `NodeType` enum (9 categories); `ROOT_ELIGIBLE`, + `REQUIRES_PARENT`, `LEAF_ONLY` constraint sets +- `leadforge/structure/graph.py`: `WorldGraph` wrapping `networkx.DiGraph` with DAG validation, + node-type legality, nondegeneracy, and outcome-reachability checks; JSON + GraphML export +- `leadforge/structure/motifs.py`: `MotifFamily` frozen dataclass; all 5 v1 families + (fit_dominant, intent_dominant, sales_execution_sensitive, demo_trial_mediated, + buying_committee_friction); `get_motif_family()` lookup +- `leadforge/structure/rewiring.py`: `rewire()` — optional-node dropping, edge-weight jitter, + optional latent-confounder injection; fully deterministic given seed +- `leadforge/structure/sampler.py`: `sample_hidden_graph(seed, motif_family_name=None)` — + selects motif, applies rewiring, validates, retries up to 20 times +- `pyproject.toml`: added `networkx>=3.2` + `numpy>=1.26`; mypy override for networkx +- 132 new tests; total 327 passing + ### Milestone 3 — Schema Layer ✓ (v0.2.0 in PR) - `leadforge/core/ids.py`: `make_id(prefix, n)` + `ID_PREFIXES` registry - `leadforge/schema/entities.py`: typed row dataclasses for all 9 tables (accounts, contacts, diff --git a/leadforge/structure/graph.py b/leadforge/structure/graph.py new file mode 100644 index 0000000..f433783 --- /dev/null +++ b/leadforge/structure/graph.py @@ -0,0 +1,241 @@ +"""Hidden world graph representation and validation. + +:class:`WorldGraph` wraps a ``networkx.DiGraph`` and enforces structural +invariants — acyclicity, node-type legality, reachability, and +nondegeneracy — at construction time and on demand. +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from typing import Any + +import networkx as nx + +from leadforge.core.exceptions import LeadforgeError +from leadforge.structure.node_types import LEAF_ONLY, REQUIRES_PARENT, NodeType + + +class GraphValidationError(LeadforgeError): + """Raised when a hidden world graph violates a structural invariant.""" + + +@dataclass +class NodeSpec: + """Specification for a single hidden-graph node. + + Attributes: + node_id: Unique string identifier within the graph. + node_type: Semantic category of the node. + label: Human-readable name used in exports. + metadata: Arbitrary extra attributes (e.g. prior strength, proxy + accuracy). Serialised as JSON in GraphML export. + """ + + node_id: str + node_type: NodeType + label: str = "" + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class EdgeSpec: + """Specification for a directed edge between two hidden-graph nodes. + + Attributes: + source: ``node_id`` of the parent node. + target: ``node_id`` of the child node. + weight: Signed influence strength in the range [-1, 1]. Positive + values indicate facilitation; negative values indicate + inhibition. + metadata: Arbitrary extra attributes (e.g. mechanism type, lag). + """ + + source: str + target: str + weight: float = 1.0 + metadata: dict[str, Any] = field(default_factory=dict) + + +class WorldGraph: + """Validated directed acyclic graph representing one hidden world. + + The graph is built from :class:`NodeSpec` and :class:`EdgeSpec` + objects and validated immediately. All subsequent access is via the + underlying ``networkx.DiGraph`` exposed as :attr:`graph`. + + Args: + nodes: Node specifications. Node IDs must be unique. + edges: Edge specifications. Both endpoints must reference known + node IDs. + motif_family: Name of the motif family that seeded this graph. + + Raises: + GraphValidationError: If any structural invariant is violated. + """ + + def __init__( + self, + nodes: list[NodeSpec], + edges: list[EdgeSpec], + motif_family: str, + ) -> None: + self._motif_family = motif_family + self._graph: nx.DiGraph = nx.DiGraph() + + # Add nodes + seen_ids: set[str] = set() + for n in nodes: + if n.node_id in seen_ids: + raise GraphValidationError(f"Duplicate node_id: {n.node_id!r}") + seen_ids.add(n.node_id) + self._graph.add_node( + n.node_id, + node_type=n.node_type.value, + label=n.label, + **n.metadata, + ) + + # Add edges + for e in edges: + if e.source not in seen_ids: + raise GraphValidationError(f"Edge source {e.source!r} not in node set") + if e.target not in seen_ids: + raise GraphValidationError(f"Edge target {e.target!r} not in node set") + self._graph.add_edge( + e.source, + e.target, + weight=e.weight, + **e.metadata, + ) + + self._validate() + + # ------------------------------------------------------------------ + # Public interface + # ------------------------------------------------------------------ + + @property + def graph(self) -> nx.DiGraph: + """The underlying ``networkx.DiGraph`` (read-only intent).""" + return self._graph + + @property + def motif_family(self) -> str: + """Name of the motif family that produced this graph.""" + return self._motif_family + + def node_type(self, node_id: str) -> NodeType: + """Return the :class:`NodeType` of *node_id*.""" + return NodeType(self._graph.nodes[node_id]["node_type"]) + + def topological_order(self) -> list[str]: + """Return node IDs in topological order (roots first).""" + return list(nx.topological_sort(self._graph)) + + # ------------------------------------------------------------------ + # Export + # ------------------------------------------------------------------ + + def to_dict(self) -> dict[str, Any]: + """Return a JSON-serialisable dict representation.""" + nodes = [ + { + "node_id": n, + "node_type": self._graph.nodes[n]["node_type"], + "label": self._graph.nodes[n].get("label", ""), + } + for n in self.topological_order() + ] + edges = [ + { + "source": u, + "target": v, + "weight": data.get("weight", 1.0), + } + for u, v, data in self._graph.edges(data=True) + ] + return { + "motif_family": self._motif_family, + "nodes": nodes, + "edges": edges, + } + + def to_json(self) -> str: + """Return a JSON string representation.""" + return json.dumps(self.to_dict(), indent=2) + + def to_graphml(self) -> str: + """Return a GraphML string representation.""" + lines = nx.generate_graphml(self._graph) + return "\n".join(lines) + + # ------------------------------------------------------------------ + # Validation + # ------------------------------------------------------------------ + + def _validate(self) -> None: + """Run all structural invariant checks. + + Raises: + GraphValidationError: on first violation found. + """ + self._check_acyclic() + self._check_node_type_legality() + self._check_nondegeneracy() + self._check_outcome_reachable() + + def _check_acyclic(self) -> None: + if not nx.is_directed_acyclic_graph(self._graph): + cycle = nx.find_cycle(self._graph) + raise GraphValidationError(f"Graph contains a cycle: {cycle}") + + def _check_node_type_legality(self) -> None: + for node_id in self._graph.nodes: + nt = self.node_type(node_id) + in_degree = self._graph.in_degree(node_id) + out_degree = self._graph.out_degree(node_id) + + if nt in REQUIRES_PARENT and in_degree == 0: + raise GraphValidationError( + f"Node {node_id!r} (type={nt.value}) requires at least one parent but has none" + ) + if nt in LEAF_ONLY and out_degree > 0: + raise GraphValidationError( + f"Node {node_id!r} (type={nt.value}) must be a leaf but " + f"has {out_degree} child(ren)" + ) + + def _check_nondegeneracy(self) -> None: + """Reject fully isolated graphs and single-node graphs.""" + n = self._graph.number_of_nodes() + if n < 2: + raise GraphValidationError( + f"Graph has only {n} node(s); a meaningful hidden world requires at least 2 nodes" + ) + # Reject a graph where every node is isolated (no edges at all). + if self._graph.number_of_edges() == 0: + raise GraphValidationError( + "Graph has no edges; a meaningful hidden world requires at least one causal edge" + ) + + def _check_outcome_reachable(self) -> None: + """Every OUTCOME node must be reachable from at least one root.""" + outcome_nodes = [n for n in self._graph.nodes if self.node_type(n) == NodeType.OUTCOME] + if not outcome_nodes: + raise GraphValidationError( + "Graph has no OUTCOME node; every world must have at least " + "one conversion-outcome node" + ) + roots = [n for n in self._graph.nodes if self._graph.in_degree(n) == 0] + for outcome in outcome_nodes: + reachable = False + for root in roots: + if nx.has_path(self._graph, root, outcome): + reachable = True + break + if not reachable: + raise GraphValidationError( + f"OUTCOME node {outcome!r} is not reachable from any root node" + ) diff --git a/leadforge/structure/motifs.py b/leadforge/structure/motifs.py new file mode 100644 index 0000000..3f67988 --- /dev/null +++ b/leadforge/structure/motifs.py @@ -0,0 +1,252 @@ +"""Motif family definitions for the v1 hidden world graph. + +Each :class:`MotifFamily` describes the canonical node/edge skeleton for +one named hidden-world template. The five v1 families are defined at the +bottom of this module; they are consumed by :mod:`leadforge.structure.sampler` +to seed a concrete :class:`~leadforge.structure.graph.WorldGraph`. + +See §11.2 of the architecture spec for the semantics of each family. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +from leadforge.structure.graph import EdgeSpec, NodeSpec +from leadforge.structure.node_types import NodeType + + +@dataclass(frozen=True) +class MotifFamily: + """Canonical template for one named hidden-world motif. + + Attributes: + name: Machine-readable identifier used in manifests and exports. + description: One-sentence human description of the causal story. + canonical_nodes: Ordered list of :class:`NodeSpec` objects that + form the core graph skeleton. Node IDs must be unique within + the family. + canonical_edges: Directed edges between the canonical nodes. + optional_node_ids: IDs from *canonical_nodes* that may be dropped + during stochastic rewiring (see + :mod:`leadforge.structure.rewiring`). + """ + + name: str + description: str + canonical_nodes: tuple[NodeSpec, ...] + canonical_edges: tuple[EdgeSpec, ...] + optional_node_ids: frozenset[str] = field(default_factory=frozenset) + + +# --------------------------------------------------------------------------- +# v1 motif family 1 — fit-dominant +# --------------------------------------------------------------------------- + +FIT_DOMINANT: MotifFamily = MotifFamily( + name="fit_dominant", + description=( + "Conversion is primarily driven by account and contact fit; " + "engagement is partly downstream of fit rather than an independent driver." + ), + canonical_nodes=( + NodeSpec("global_ctx", NodeType.GLOBAL_CONTEXT, label="Global context"), + NodeSpec("acct_fit", NodeType.ACCOUNT_LATENT, label="Account fit"), + NodeSpec("acct_maturity", NodeType.ACCOUNT_LATENT, label="Process maturity"), + NodeSpec("contact_authority", NodeType.CONTACT_LATENT, label="Contact authority"), + NodeSpec("budget_readiness", NodeType.ACCOUNT_LATENT, label="Budget readiness"), + NodeSpec("lead_state", NodeType.LEAD_STATE, label="Lead state"), + NodeSpec("engagement", NodeType.ENGAGEMENT_STATE, label="Engagement signal"), + NodeSpec("conversion", NodeType.OUTCOME, label="Converted within 90 days"), + ), + canonical_edges=( + EdgeSpec("global_ctx", "acct_fit", weight=0.3), + EdgeSpec("global_ctx", "acct_maturity", weight=0.2), + EdgeSpec("acct_fit", "lead_state", weight=0.7), + EdgeSpec("acct_fit", "engagement", weight=0.5), + EdgeSpec("acct_maturity", "lead_state", weight=0.4), + EdgeSpec("budget_readiness", "lead_state", weight=0.5), + EdgeSpec("contact_authority", "lead_state", weight=0.3), + EdgeSpec("lead_state", "engagement", weight=0.4), + EdgeSpec("lead_state", "conversion", weight=0.6), + EdgeSpec("engagement", "conversion", weight=0.3), + ), + optional_node_ids=frozenset({"acct_maturity", "contact_authority"}), +) + + +# --------------------------------------------------------------------------- +# v1 motif family 2 — intent-dominant +# --------------------------------------------------------------------------- + +INTENT_DOMINANT: MotifFamily = MotifFamily( + name="intent_dominant", + description=( + "Behavioral engagement and urgency dominate conversion probability, " + "even among leads with mixed account/contact fit scores." + ), + canonical_nodes=( + NodeSpec("global_ctx", NodeType.GLOBAL_CONTEXT, label="Global context"), + NodeSpec("acct_fit", NodeType.ACCOUNT_LATENT, label="Account fit"), + NodeSpec("problem_awareness", NodeType.CONTACT_LATENT, label="Problem awareness"), + NodeSpec("urgency", NodeType.LEAD_STATE, label="Urgency / timing"), + NodeSpec("engagement", NodeType.ENGAGEMENT_STATE, label="Engagement signal"), + NodeSpec("intent_score", NodeType.OBSERVABLE_FEATURE_SOURCE, label="Intent signal proxy"), + NodeSpec("conversion", NodeType.OUTCOME, label="Converted within 90 days"), + ), + canonical_edges=( + EdgeSpec("global_ctx", "problem_awareness", weight=0.25), + EdgeSpec("acct_fit", "engagement", weight=0.2), + EdgeSpec("problem_awareness", "urgency", weight=0.6), + EdgeSpec("problem_awareness", "engagement", weight=0.6), + EdgeSpec("urgency", "intent_score", weight=0.7), + EdgeSpec("engagement", "intent_score", weight=0.6), + EdgeSpec("intent_score", "conversion", weight=0.8), + EdgeSpec("urgency", "conversion", weight=0.4), + ), + optional_node_ids=frozenset({"acct_fit"}), +) + + +# --------------------------------------------------------------------------- +# v1 motif family 3 — sales-execution-sensitive +# --------------------------------------------------------------------------- + +SALES_EXECUTION_SENSITIVE: MotifFamily = MotifFamily( + name="sales_execution_sensitive", + description=( + "Follow-up timing, rep quality, and sales process friction " + "materially affect conversion outcomes beyond lead characteristics." + ), + canonical_nodes=( + NodeSpec("global_ctx", NodeType.GLOBAL_CONTEXT, label="Global context"), + NodeSpec("acct_fit", NodeType.ACCOUNT_LATENT, label="Account fit"), + NodeSpec("contact_responsiveness", NodeType.CONTACT_LATENT, label="Contact responsiveness"), + NodeSpec("lead_state", NodeType.LEAD_STATE, label="Lead state"), + NodeSpec("sales_process", NodeType.SALES_PROCESS_STATE, label="Sales process quality"), + NodeSpec("rep_quality", NodeType.SALES_PROCESS_STATE, label="Rep execution quality"), + NodeSpec("sales_friction", NodeType.SALES_PROCESS_STATE, label="Process friction"), + NodeSpec("conversion", NodeType.OUTCOME, label="Converted within 90 days"), + ), + canonical_edges=( + EdgeSpec("global_ctx", "acct_fit", weight=0.3), + EdgeSpec("global_ctx", "rep_quality", weight=0.2), + EdgeSpec("acct_fit", "lead_state", weight=0.4), + EdgeSpec("contact_responsiveness", "lead_state", weight=0.35), + EdgeSpec("lead_state", "sales_process", weight=0.5), + EdgeSpec("rep_quality", "sales_process", weight=0.6), + EdgeSpec("rep_quality", "sales_friction", weight=-0.5), + EdgeSpec("sales_process", "conversion", weight=0.6), + EdgeSpec("sales_friction", "conversion", weight=-0.4), + ), + optional_node_ids=frozenset({"sales_friction", "contact_responsiveness"}), +) + + +# --------------------------------------------------------------------------- +# v1 motif family 4 — demo/trial-mediated +# --------------------------------------------------------------------------- + +DEMO_TRIAL_MEDIATED: MotifFamily = MotifFamily( + name="demo_trial_mediated", + description=( + "Product demonstration or trial progression acts as a major mediator " + "between initial engagement and conversion." + ), + canonical_nodes=( + NodeSpec("global_ctx", NodeType.GLOBAL_CONTEXT, label="Global context"), + NodeSpec("acct_fit", NodeType.ACCOUNT_LATENT, label="Account fit"), + NodeSpec("problem_awareness", NodeType.CONTACT_LATENT, label="Problem awareness"), + NodeSpec("engagement", NodeType.ENGAGEMENT_STATE, label="Top-of-funnel engagement"), + NodeSpec("demo_completion", NodeType.LEAD_STATE, label="Demo / trial completion"), + NodeSpec("trial_depth", NodeType.OBSERVABLE_FEATURE_SOURCE, label="Trial depth proxy"), + NodeSpec("sales_process", NodeType.SALES_PROCESS_STATE, label="Post-demo sales process"), + NodeSpec("conversion", NodeType.OUTCOME, label="Converted within 90 days"), + ), + canonical_edges=( + EdgeSpec("global_ctx", "acct_fit", weight=0.3), + EdgeSpec("acct_fit", "engagement", weight=0.4), + EdgeSpec("problem_awareness", "engagement", weight=0.5), + EdgeSpec("engagement", "demo_completion", weight=0.6), + EdgeSpec("acct_fit", "demo_completion", weight=0.4), + EdgeSpec("demo_completion", "trial_depth", weight=0.7), + EdgeSpec("demo_completion", "sales_process", weight=0.5), + EdgeSpec("trial_depth", "conversion", weight=0.6), + EdgeSpec("sales_process", "conversion", weight=0.4), + ), + # sales_process is the primary post-demo path to conversion; only the + # observational proxy (trial_depth) may be dropped. + optional_node_ids=frozenset({"trial_depth"}), +) + + +# --------------------------------------------------------------------------- +# v1 motif family 5 — buying-committee-friction +# --------------------------------------------------------------------------- + +BUYING_COMMITTEE_FRICTION: MotifFamily = MotifFamily( + name="buying_committee_friction", + description=( + "Multiple stakeholders and approval friction materially slow or block " + "progression; contact authority and consensus dynamics dominate." + ), + canonical_nodes=( + NodeSpec("global_ctx", NodeType.GLOBAL_CONTEXT, label="Global context"), + NodeSpec("acct_fit", NodeType.ACCOUNT_LATENT, label="Account fit"), + NodeSpec("contact_authority", NodeType.CONTACT_LATENT, label="Primary contact authority"), + NodeSpec("committee_alignment", NodeType.CONTACT_LATENT, label="Committee alignment"), + NodeSpec("lead_state", NodeType.LEAD_STATE, label="Lead state"), + NodeSpec("approval_friction", NodeType.SALES_PROCESS_STATE, label="Approval friction"), + NodeSpec("engagement", NodeType.ENGAGEMENT_STATE, label="Multi-stakeholder engagement"), + NodeSpec("conversion", NodeType.OUTCOME, label="Converted within 90 days"), + ), + canonical_edges=( + EdgeSpec("global_ctx", "acct_fit", weight=0.3), + EdgeSpec("global_ctx", "committee_alignment", weight=0.2), + EdgeSpec("acct_fit", "lead_state", weight=0.45), + EdgeSpec("contact_authority", "lead_state", weight=0.5), + EdgeSpec("committee_alignment", "approval_friction", weight=-0.6), + EdgeSpec("lead_state", "engagement", weight=0.4), + EdgeSpec("engagement", "approval_friction", weight=-0.3), + EdgeSpec("contact_authority", "approval_friction", weight=-0.4), + EdgeSpec("approval_friction", "conversion", weight=-0.5), + EdgeSpec("lead_state", "conversion", weight=0.4), + EdgeSpec("engagement", "conversion", weight=0.3), + ), + optional_node_ids=frozenset({"committee_alignment", "approval_friction"}), +) + + +# --------------------------------------------------------------------------- +# Registry +# --------------------------------------------------------------------------- + +ALL_MOTIF_FAMILIES: tuple[MotifFamily, ...] = ( + FIT_DOMINANT, + INTENT_DOMINANT, + SALES_EXECUTION_SENSITIVE, + DEMO_TRIAL_MEDIATED, + BUYING_COMMITTEE_FRICTION, +) + +MOTIF_FAMILY_NAMES: tuple[str, ...] = tuple(m.name for m in ALL_MOTIF_FAMILIES) + +_BY_NAME: dict[str, MotifFamily] = {m.name: m for m in ALL_MOTIF_FAMILIES} + + +def get_motif_family(name: str) -> MotifFamily: + """Look up a motif family by name. + + Args: + name: One of the values in :data:`MOTIF_FAMILY_NAMES`. + + Returns: + The corresponding :class:`MotifFamily`. + + Raises: + KeyError: If *name* is not a known motif family. + """ + try: + return _BY_NAME[name] + except KeyError: + raise KeyError(f"Unknown motif family {name!r}. Valid names: {sorted(_BY_NAME)}") from None diff --git a/leadforge/structure/node_types.py b/leadforge/structure/node_types.py new file mode 100644 index 0000000..7078711 --- /dev/null +++ b/leadforge/structure/node_types.py @@ -0,0 +1,58 @@ +"""Node type definitions for the hidden world graph. + +Each node in the hidden causal graph carries a :class:`NodeType` that +constrains how it participates in mechanisms, rewiring, and exports. +""" + +from __future__ import annotations + +from enum import Enum + + +class NodeType(str, Enum): + """Semantic category of a hidden-graph node. + + Values mirror the nine categories specified in §11.1 of the + architecture spec. Using ``str`` as a mixin makes serialisation + (JSON, GraphML) straightforward without extra conversion. + """ + + GLOBAL_CONTEXT = "global_context" + ACCOUNT_LATENT = "account_latent" + CONTACT_LATENT = "contact_latent" + LEAD_STATE = "lead_state" + ENGAGEMENT_STATE = "engagement_state" + SALES_PROCESS_STATE = "sales_process_state" + OBSERVABLE_FEATURE_SOURCE = "observable_feature_source" + OUTCOME = "outcome" + POST_CONVERSION_STATE = "post_conversion_state" + + +# Node types that may appear as graph roots (no required predecessors). +ROOT_ELIGIBLE: frozenset[NodeType] = frozenset( + { + NodeType.GLOBAL_CONTEXT, + NodeType.ACCOUNT_LATENT, + NodeType.CONTACT_LATENT, + } +) + +# Node types that must have at least one predecessor. +REQUIRES_PARENT: frozenset[NodeType] = frozenset( + { + NodeType.LEAD_STATE, + NodeType.ENGAGEMENT_STATE, + NodeType.SALES_PROCESS_STATE, + NodeType.OBSERVABLE_FEATURE_SOURCE, + NodeType.OUTCOME, + NodeType.POST_CONVERSION_STATE, + } +) + +# Node types that may not have children (leaf nodes only). +LEAF_ONLY: frozenset[NodeType] = frozenset( + { + NodeType.OUTCOME, + NodeType.POST_CONVERSION_STATE, + } +) diff --git a/leadforge/structure/rewiring.py b/leadforge/structure/rewiring.py new file mode 100644 index 0000000..3e52dc7 --- /dev/null +++ b/leadforge/structure/rewiring.py @@ -0,0 +1,131 @@ +"""Stochastic rewiring of motif-family graph skeletons. + +:func:`rewire` takes a :class:`~leadforge.structure.motifs.MotifFamily` +and a seeded :class:`~numpy.random.Generator` and returns perturbed lists +of :class:`~leadforge.structure.graph.NodeSpec` and +:class:`~leadforge.structure.graph.EdgeSpec` that still satisfy the graph +invariants (acyclicity, legality, nondegeneracy). + +Permitted variability (§11.3 of architecture spec): +- dropping optional mediator nodes (and their incident edges) +- perturbing edge weights within a bounded range +- adding an optional latent confounder node +- swapping one optional node for an alternate proxy + +Forbidden variability (hard constraints enforced here): +- chronologically impossible edges (validated downstream in WorldGraph) +- orphaned outcome nodes +- degenerate worlds with no edges +""" + +from __future__ import annotations + +import copy +from typing import TYPE_CHECKING + +from leadforge.structure.graph import EdgeSpec, NodeSpec +from leadforge.structure.node_types import NodeType + +if TYPE_CHECKING: + import numpy as np + + from leadforge.structure.motifs import MotifFamily + +# Maximum ± perturbation applied to each edge weight. +_WEIGHT_JITTER = 0.15 + +# Probability that each optional node is dropped (per rewiring call). +_DROP_PROB = 0.4 + +# Probability that an optional latent confounder is injected. +_CONFOUNDER_PROB = 0.35 + + +def rewire( + motif: MotifFamily, + rng: np.random.Generator, +) -> tuple[list[NodeSpec], list[EdgeSpec]]: + """Return perturbed node/edge lists derived from *motif*'s skeleton. + + The canonical skeleton from *motif* is copied and then stochastically + modified: + + 1. Each optional node is independently dropped with probability + :data:`_DROP_PROB`. Edges incident to a dropped node are removed. + 2. Edge weights are jittered by ±:data:`_WEIGHT_JITTER`, clamped to + [-1, 1]. + 3. With probability :data:`_CONFOUNDER_PROB`, a single additional + ``ACCOUNT_LATENT`` confounder node is injected with edges to the + first ``LEAD_STATE`` node found in the skeleton. + + Args: + motif: The motif family providing the canonical skeleton. + rng: A seeded ``numpy.random.Generator`` for reproducibility. + + Returns: + A ``(nodes, edges)`` tuple suitable for passing to + :class:`~leadforge.structure.graph.WorldGraph`. + """ + nodes: list[NodeSpec] = [copy.copy(n) for n in motif.canonical_nodes] + edges: list[EdgeSpec] = [copy.copy(e) for e in motif.canonical_edges] + + # Step 1 — drop optional nodes + dropped: set[str] = set() + for node in list(nodes): + if node.node_id in motif.optional_node_ids: + if rng.random() < _DROP_PROB: + dropped.add(node.node_id) + + if dropped: + nodes = [n for n in nodes if n.node_id not in dropped] + edges = [e for e in edges if e.source not in dropped and e.target not in dropped] + + # Step 2 — jitter edge weights + active_node_ids = {n.node_id for n in nodes} + perturbed_edges: list[EdgeSpec] = [] + for e in edges: + jitter = rng.uniform(-_WEIGHT_JITTER, _WEIGHT_JITTER) + new_weight = float(max(-1.0, min(1.0, e.weight + jitter))) + perturbed_edges.append( + EdgeSpec( + source=e.source, + target=e.target, + weight=new_weight, + metadata=dict(e.metadata), + ) + ) + edges = perturbed_edges + + # Step 3 — optional latent confounder injection + if rng.random() < _CONFOUNDER_PROB: + # Find first LEAD_STATE node to attach to. + lead_state_ids = [n.node_id for n in nodes if n.node_type == NodeType.LEAD_STATE] + if lead_state_ids: + conf_id = _unique_id("latent_confounder", active_node_ids) + conf_weight = float(rng.uniform(0.1, 0.5)) + nodes.append( + NodeSpec( + node_id=conf_id, + node_type=NodeType.ACCOUNT_LATENT, + label="Latent confounder", + ) + ) + edges.append( + EdgeSpec( + source=conf_id, + target=lead_state_ids[0], + weight=conf_weight, + ) + ) + + return nodes, edges + + +def _unique_id(base: str, existing: set[str]) -> str: + """Return *base* if not in *existing*, else *base_2*, *base_3*, …""" + if base not in existing: + return base + i = 2 + while f"{base}_{i}" in existing: + i += 1 + return f"{base}_{i}" diff --git a/leadforge/structure/sampler.py b/leadforge/structure/sampler.py new file mode 100644 index 0000000..414a1ec --- /dev/null +++ b/leadforge/structure/sampler.py @@ -0,0 +1,83 @@ +"""World graph sampler — draw a concrete hidden world from a motif + seed. + +:func:`sample_hidden_graph` is the single entry point consumed by the +simulation layer. It selects a motif family (deterministically from the +recipe, or randomly from the seed), applies stochastic rewiring, and +returns a validated :class:`~leadforge.structure.graph.WorldGraph`. +""" + +from __future__ import annotations + +import numpy as np + +from leadforge.structure.graph import WorldGraph +from leadforge.structure.motifs import ( + ALL_MOTIF_FAMILIES, + MotifFamily, + get_motif_family, +) +from leadforge.structure.rewiring import rewire + +# Maximum number of rewiring attempts before giving up. +_MAX_ATTEMPTS = 20 + + +def sample_hidden_graph( + seed: int, + motif_family_name: str | None = None, +) -> WorldGraph: + """Draw a validated hidden world graph. + + The function is fully deterministic given ``(seed, motif_family_name)``. + + Args: + seed: Integer seed for the NumPy random generator. All stochastic + choices (motif selection if *motif_family_name* is ``None``, + rewiring decisions, weight jitter) derive from this seed. + motif_family_name: If provided, pin the motif family by name + (must be one of :data:`~leadforge.structure.motifs.MOTIF_FAMILY_NAMES`). + If ``None``, a family is chosen uniformly at random from the + five v1 families. + + Returns: + A validated :class:`~leadforge.structure.graph.WorldGraph`. + + Raises: + KeyError: If *motif_family_name* is not a known motif family name. + RuntimeError: If :data:`_MAX_ATTEMPTS` rewiring attempts all + produce graphs that fail structural validation (should not + happen in practice with well-formed motifs). + """ + rng = np.random.default_rng(seed) + + motif = _select_motif(motif_family_name, rng) + + last_exc: Exception | None = None + for _attempt in range(_MAX_ATTEMPTS): + # Each attempt uses an independent sub-seed so that earlier + # failures do not corrupt the RNG state of later attempts. + attempt_seed = int(rng.integers(0, 2**31)) + attempt_rng = np.random.default_rng(attempt_seed) + nodes, edges = rewire(motif, attempt_rng) + try: + return WorldGraph(nodes=nodes, edges=edges, motif_family=motif.name) + except Exception as exc: # noqa: BLE001 + last_exc = exc + continue + + raise RuntimeError( + f"Failed to produce a valid WorldGraph from motif " + f"{motif.name!r} after {_MAX_ATTEMPTS} rewiring attempts. " + f"Last error: {last_exc}" + ) + + +def _select_motif( + name: str | None, + rng: np.random.Generator, +) -> MotifFamily: + """Return the requested motif family, or pick one at random.""" + if name is not None: + return get_motif_family(name) + idx = int(rng.integers(0, len(ALL_MOTIF_FAMILIES))) + return ALL_MOTIF_FAMILIES[idx] diff --git a/pyproject.toml b/pyproject.toml index c8f29a0..39db58a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,8 @@ dependencies = [ "pyyaml>=6.0", "pandas>=2.0", "pyarrow>=14.0", + "networkx>=3.2", + "numpy>=1.26", ] [project.optional-dependencies] @@ -71,5 +73,9 @@ no_implicit_optional = true module = ["pandas", "pandas.*", "pyarrow", "pyarrow.*"] ignore_missing_imports = true +[[tool.mypy.overrides]] +module = ["networkx", "networkx.*"] +ignore_missing_imports = true + [tool.pytest.ini_options] testpaths = ["tests"] diff --git a/tests/structure/__init__.py b/tests/structure/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/structure/test_graph.py b/tests/structure/test_graph.py new file mode 100644 index 0000000..a2c00bd --- /dev/null +++ b/tests/structure/test_graph.py @@ -0,0 +1,204 @@ +"""Tests for leadforge.structure.graph — WorldGraph validation and exports.""" + +import json + +import pytest + +from leadforge.structure.graph import ( + EdgeSpec, + GraphValidationError, + NodeSpec, + WorldGraph, +) +from leadforge.structure.node_types import NodeType + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _minimal_valid_graph() -> WorldGraph: + """Smallest possible valid WorldGraph (root → outcome).""" + nodes = [ + NodeSpec("root", NodeType.ACCOUNT_LATENT, label="Root"), + NodeSpec("lead", NodeType.LEAD_STATE, label="Lead state"), + NodeSpec("outcome", NodeType.OUTCOME, label="Outcome"), + ] + edges = [ + EdgeSpec("root", "lead", weight=0.8), + EdgeSpec("lead", "outcome", weight=0.7), + ] + return WorldGraph(nodes=nodes, edges=edges, motif_family="test") + + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + + +def test_valid_graph_constructs_without_error() -> None: + g = _minimal_valid_graph() + assert g.graph.number_of_nodes() == 3 + assert g.graph.number_of_edges() == 2 + + +def test_motif_family_stored() -> None: + g = _minimal_valid_graph() + assert g.motif_family == "test" + + +def test_node_type_accessor() -> None: + g = _minimal_valid_graph() + assert g.node_type("root") == NodeType.ACCOUNT_LATENT + assert g.node_type("outcome") == NodeType.OUTCOME + + +def test_topological_order_root_first() -> None: + g = _minimal_valid_graph() + order = g.topological_order() + assert order.index("root") < order.index("outcome") + + +# --------------------------------------------------------------------------- +# Validation — acyclicity +# --------------------------------------------------------------------------- + + +def test_cycle_raises_graph_validation_error() -> None: + nodes = [ + NodeSpec("a", NodeType.ACCOUNT_LATENT), + NodeSpec("b", NodeType.LEAD_STATE), + NodeSpec("out", NodeType.OUTCOME), + ] + edges = [ + EdgeSpec("a", "b"), + EdgeSpec("b", "a"), # creates cycle + EdgeSpec("b", "out"), + ] + with pytest.raises(GraphValidationError, match="cycle"): + WorldGraph(nodes=nodes, edges=edges, motif_family="test") + + +# --------------------------------------------------------------------------- +# Validation — node type legality +# --------------------------------------------------------------------------- + + +def test_outcome_with_child_raises() -> None: + nodes = [ + NodeSpec("root", NodeType.ACCOUNT_LATENT), + NodeSpec("lead", NodeType.LEAD_STATE), + NodeSpec("out", NodeType.OUTCOME), + NodeSpec("post", NodeType.LEAD_STATE), # outcome → post is forbidden + ] + edges = [ + EdgeSpec("root", "lead"), + EdgeSpec("lead", "out"), + EdgeSpec("out", "post"), + ] + with pytest.raises(GraphValidationError, match="leaf"): + WorldGraph(nodes=nodes, edges=edges, motif_family="test") + + +def test_lead_state_without_parent_raises() -> None: + nodes = [ + NodeSpec("lead", NodeType.LEAD_STATE), + NodeSpec("out", NodeType.OUTCOME), + ] + edges = [EdgeSpec("lead", "out")] + with pytest.raises(GraphValidationError, match="requires at least one parent"): + WorldGraph(nodes=nodes, edges=edges, motif_family="test") + + +# --------------------------------------------------------------------------- +# Validation — nondegeneracy +# --------------------------------------------------------------------------- + + +def test_no_edges_raises() -> None: + # Two root-eligible nodes with no edges hit the nondegeneracy check. + nodes = [ + NodeSpec("a", NodeType.ACCOUNT_LATENT), + NodeSpec("b", NodeType.ACCOUNT_LATENT), + ] + with pytest.raises(GraphValidationError, match="no edges"): + WorldGraph(nodes=nodes, edges=[], motif_family="test") + + +def test_single_node_raises() -> None: + nodes = [NodeSpec("a", NodeType.ACCOUNT_LATENT)] + with pytest.raises(GraphValidationError, match="only 1 node"): + WorldGraph(nodes=nodes, edges=[], motif_family="test") + + +# --------------------------------------------------------------------------- +# Validation — outcome reachability +# --------------------------------------------------------------------------- + + +def test_no_outcome_node_raises() -> None: + nodes = [ + NodeSpec("root", NodeType.ACCOUNT_LATENT), + NodeSpec("lead", NodeType.LEAD_STATE), + ] + edges = [EdgeSpec("root", "lead")] + with pytest.raises(GraphValidationError, match="no OUTCOME node"): + WorldGraph(nodes=nodes, edges=edges, motif_family="test") + + +def test_unreachable_outcome_raises() -> None: + nodes = [ + NodeSpec("root", NodeType.ACCOUNT_LATENT), + NodeSpec("lead", NodeType.LEAD_STATE), + NodeSpec("root2", NodeType.ACCOUNT_LATENT), + NodeSpec("out", NodeType.OUTCOME), + ] + edges = [ + EdgeSpec("root", "lead"), + EdgeSpec("root2", "out"), # out reachable from root2 only + # lead has no path to out + ] + # out IS reachable from root2 so this should pass + g = WorldGraph(nodes=nodes, edges=edges, motif_family="test") + assert g.graph.number_of_nodes() == 4 + + +# --------------------------------------------------------------------------- +# Duplicate node IDs +# --------------------------------------------------------------------------- + + +def test_duplicate_node_id_raises() -> None: + nodes = [ + NodeSpec("a", NodeType.ACCOUNT_LATENT), + NodeSpec("a", NodeType.LEAD_STATE), # duplicate + NodeSpec("out", NodeType.OUTCOME), + ] + edges = [EdgeSpec("a", "out")] + with pytest.raises(GraphValidationError, match="Duplicate"): + WorldGraph(nodes=nodes, edges=edges, motif_family="test") + + +# --------------------------------------------------------------------------- +# Exports +# --------------------------------------------------------------------------- + + +def test_to_dict_keys() -> None: + g = _minimal_valid_graph() + d = g.to_dict() + assert set(d.keys()) == {"motif_family", "nodes", "edges"} + + +def test_to_json_round_trips() -> None: + g = _minimal_valid_graph() + data = json.loads(g.to_json()) + assert data["motif_family"] == "test" + assert len(data["nodes"]) == 3 + assert len(data["edges"]) == 2 + + +def test_to_graphml_returns_string() -> None: + g = _minimal_valid_graph() + gml = g.to_graphml() + assert "graphml" in gml.lower() diff --git a/tests/structure/test_motifs.py b/tests/structure/test_motifs.py new file mode 100644 index 0000000..68b855a --- /dev/null +++ b/tests/structure/test_motifs.py @@ -0,0 +1,99 @@ +"""Tests for leadforge.structure.motifs — motif family definitions.""" + +import pytest + +from leadforge.structure.motifs import ( + ALL_MOTIF_FAMILIES, + MOTIF_FAMILY_NAMES, + MotifFamily, + get_motif_family, +) +from leadforge.structure.node_types import NodeType + +# --------------------------------------------------------------------------- +# Registry +# --------------------------------------------------------------------------- + + +def test_five_motif_families_defined() -> None: + assert len(ALL_MOTIF_FAMILIES) == 5 + + +def test_motif_family_names_match_registry() -> None: + assert set(MOTIF_FAMILY_NAMES) == {m.name for m in ALL_MOTIF_FAMILIES} + + +def test_all_five_expected_names_present() -> None: + expected = { + "fit_dominant", + "intent_dominant", + "sales_execution_sensitive", + "demo_trial_mediated", + "buying_committee_friction", + } + assert set(MOTIF_FAMILY_NAMES) == expected + + +def test_get_motif_family_returns_correct_instance() -> None: + for motif in ALL_MOTIF_FAMILIES: + assert get_motif_family(motif.name) is motif + + +def test_get_motif_family_unknown_raises() -> None: + with pytest.raises(KeyError, match="unknown_family"): + get_motif_family("unknown_family") + + +# --------------------------------------------------------------------------- +# Structural invariants per motif +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("motif", ALL_MOTIF_FAMILIES) +def test_motif_has_at_least_one_outcome_node(motif: MotifFamily) -> None: + outcomes = [n for n in motif.canonical_nodes if n.node_type == NodeType.OUTCOME] + assert len(outcomes) >= 1, f"{motif.name} has no OUTCOME node" + + +@pytest.mark.parametrize("motif", ALL_MOTIF_FAMILIES) +def test_motif_node_ids_unique(motif: MotifFamily) -> None: + ids = [n.node_id for n in motif.canonical_nodes] + assert len(ids) == len(set(ids)), f"{motif.name} has duplicate node IDs" + + +@pytest.mark.parametrize("motif", ALL_MOTIF_FAMILIES) +def test_motif_edge_endpoints_exist(motif: MotifFamily) -> None: + node_ids = {n.node_id for n in motif.canonical_nodes} + for e in motif.canonical_edges: + assert e.source in node_ids, f"{motif.name}: edge source {e.source!r} not in node set" + assert e.target in node_ids, f"{motif.name}: edge target {e.target!r} not in node set" + + +@pytest.mark.parametrize("motif", ALL_MOTIF_FAMILIES) +def test_motif_optional_nodes_exist(motif: MotifFamily) -> None: + node_ids = {n.node_id for n in motif.canonical_nodes} + for opt_id in motif.optional_node_ids: + assert opt_id in node_ids, ( + f"{motif.name}: optional node {opt_id!r} not in canonical node set" + ) + + +@pytest.mark.parametrize("motif", ALL_MOTIF_FAMILIES) +def test_motif_edge_weights_in_range(motif: MotifFamily) -> None: + for e in motif.canonical_edges: + assert -1.0 <= e.weight <= 1.0, ( + f"{motif.name}: edge {e.source}→{e.target} weight {e.weight} out of [-1, 1]" + ) + + +@pytest.mark.parametrize("motif", ALL_MOTIF_FAMILIES) +def test_motif_canonical_skeleton_builds_valid_graph(motif: MotifFamily) -> None: + """The canonical (non-rewired) skeleton must pass WorldGraph validation.""" + from leadforge.structure.graph import WorldGraph + + g = WorldGraph( + nodes=list(motif.canonical_nodes), + edges=list(motif.canonical_edges), + motif_family=motif.name, + ) + assert g.graph.number_of_nodes() == len(motif.canonical_nodes) diff --git a/tests/structure/test_node_types.py b/tests/structure/test_node_types.py new file mode 100644 index 0000000..b6894ad --- /dev/null +++ b/tests/structure/test_node_types.py @@ -0,0 +1,38 @@ +"""Tests for leadforge.structure.node_types.""" + +from leadforge.structure.node_types import ( + LEAF_ONLY, + REQUIRES_PARENT, + ROOT_ELIGIBLE, + NodeType, +) + + +def test_node_type_values_are_strings() -> None: + for nt in NodeType: + assert isinstance(nt.value, str) + + +def test_all_nine_node_types_defined() -> None: + assert len(NodeType) == 9 + + +def test_root_eligible_and_requires_parent_are_disjoint() -> None: + assert ROOT_ELIGIBLE.isdisjoint(REQUIRES_PARENT) + + +def test_leaf_only_is_subset_of_requires_parent() -> None: + assert LEAF_ONLY <= REQUIRES_PARENT + + +def test_outcome_is_leaf_only() -> None: + assert NodeType.OUTCOME in LEAF_ONLY + + +def test_global_context_is_root_eligible() -> None: + assert NodeType.GLOBAL_CONTEXT in ROOT_ELIGIBLE + + +def test_node_type_round_trips_via_value() -> None: + for nt in NodeType: + assert NodeType(nt.value) is nt diff --git a/tests/structure/test_rewiring.py b/tests/structure/test_rewiring.py new file mode 100644 index 0000000..ab133b5 --- /dev/null +++ b/tests/structure/test_rewiring.py @@ -0,0 +1,113 @@ +"""Tests for leadforge.structure.rewiring — stochastic rewiring rules.""" + +import numpy as np +import pytest + +from leadforge.structure.graph import WorldGraph +from leadforge.structure.motifs import ALL_MOTIF_FAMILIES, MotifFamily +from leadforge.structure.rewiring import rewire + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _rng(seed: int = 0) -> np.random.Generator: + return np.random.default_rng(seed) + + +# --------------------------------------------------------------------------- +# Output validity +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("motif", ALL_MOTIF_FAMILIES) +def test_rewired_graph_passes_validation(motif: MotifFamily) -> None: + """Every rewired graph must satisfy WorldGraph structural invariants.""" + for seed in range(20): + nodes, edges = rewire(motif, _rng(seed)) + # Should not raise. + WorldGraph(nodes=nodes, edges=edges, motif_family=motif.name) + + +@pytest.mark.parametrize("motif", ALL_MOTIF_FAMILIES) +def test_rewired_graph_has_at_least_two_nodes(motif: MotifFamily) -> None: + for seed in range(10): + nodes, _ = rewire(motif, _rng(seed)) + assert len(nodes) >= 2 + + +@pytest.mark.parametrize("motif", ALL_MOTIF_FAMILIES) +def test_rewired_graph_has_at_least_one_edge(motif: MotifFamily) -> None: + for seed in range(10): + _, edges = rewire(motif, _rng(seed)) + assert len(edges) >= 1 + + +# --------------------------------------------------------------------------- +# Edge weight bounds +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("motif", ALL_MOTIF_FAMILIES) +def test_rewired_edge_weights_in_range(motif: MotifFamily) -> None: + for seed in range(10): + _, edges = rewire(motif, _rng(seed)) + for e in edges: + assert -1.0 <= e.weight <= 1.0, ( + f"Weight {e.weight} out of range for {e.source}→{e.target}" + ) + + +# --------------------------------------------------------------------------- +# Determinism +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("motif", ALL_MOTIF_FAMILIES) +def test_rewire_is_deterministic(motif: MotifFamily) -> None: + nodes_a, edges_a = rewire(motif, _rng(42)) + nodes_b, edges_b = rewire(motif, _rng(42)) + assert [n.node_id for n in nodes_a] == [n.node_id for n in nodes_b] + assert [(e.source, e.target, round(e.weight, 10)) for e in edges_a] == [ + (e.source, e.target, round(e.weight, 10)) for e in edges_b + ] + + +# --------------------------------------------------------------------------- +# Variability across seeds +# --------------------------------------------------------------------------- + + +def test_different_seeds_produce_different_graphs() -> None: + """At least some seeds should yield structurally different graphs.""" + from leadforge.structure.motifs import FIT_DOMINANT + + structures: set[tuple[str, ...]] = set() + for seed in range(40): + nodes, _ = rewire(FIT_DOMINANT, _rng(seed)) + structures.add(tuple(sorted(n.node_id for n in nodes))) + # With _DROP_PROB=0.4 and two optional nodes we expect variation. + assert len(structures) > 1 + + +# --------------------------------------------------------------------------- +# Optional node dropping +# --------------------------------------------------------------------------- + + +def test_required_nodes_never_dropped() -> None: + """Non-optional nodes must always be present after rewiring.""" + from leadforge.structure.motifs import FIT_DOMINANT + + required = { + n.node_id + for n in FIT_DOMINANT.canonical_nodes + if n.node_id not in FIT_DOMINANT.optional_node_ids + } + for seed in range(30): + nodes, _ = rewire(FIT_DOMINANT, _rng(seed)) + present = {n.node_id for n in nodes} + assert required <= present, ( + f"Seed {seed}: required node(s) {required - present} were dropped" + ) diff --git a/tests/structure/test_sampler.py b/tests/structure/test_sampler.py new file mode 100644 index 0000000..3e16511 --- /dev/null +++ b/tests/structure/test_sampler.py @@ -0,0 +1,106 @@ +"""Tests for leadforge.structure.sampler — sample_hidden_graph.""" + +import pytest + +from leadforge.structure.graph import WorldGraph +from leadforge.structure.motifs import MOTIF_FAMILY_NAMES +from leadforge.structure.node_types import NodeType +from leadforge.structure.sampler import sample_hidden_graph + +# --------------------------------------------------------------------------- +# Basic contract +# --------------------------------------------------------------------------- + + +def test_returns_world_graph() -> None: + g = sample_hidden_graph(seed=0) + assert isinstance(g, WorldGraph) + + +def test_sampled_graph_has_outcome_node() -> None: + g = sample_hidden_graph(seed=0) + outcome_nodes = [n for n in g.graph.nodes if g.node_type(n) == NodeType.OUTCOME] + assert len(outcome_nodes) >= 1 + + +def test_sampled_graph_is_dag() -> None: + import networkx as nx + + g = sample_hidden_graph(seed=0) + assert nx.is_directed_acyclic_graph(g.graph) + + +# --------------------------------------------------------------------------- +# Determinism +# --------------------------------------------------------------------------- + + +def test_same_seed_same_graph() -> None: + g1 = sample_hidden_graph(seed=42) + g2 = sample_hidden_graph(seed=42) + assert g1.motif_family == g2.motif_family + assert sorted(g1.graph.nodes) == sorted(g2.graph.nodes) + assert sorted(g1.graph.edges) == sorted(g2.graph.edges) + + +def test_different_seeds_can_differ() -> None: + graphs = [sample_hidden_graph(seed=s) for s in range(20)] + families = {g.motif_family for g in graphs} + # With 5 families and 20 seeds, we expect more than one family. + assert len(families) > 1 + + +# --------------------------------------------------------------------------- +# Pinned motif family +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("name", MOTIF_FAMILY_NAMES) +def test_pinned_motif_family(name: str) -> None: + g = sample_hidden_graph(seed=7, motif_family_name=name) + assert g.motif_family == name + + +def test_unknown_motif_family_raises() -> None: + with pytest.raises(KeyError, match="bad_family"): + sample_hidden_graph(seed=0, motif_family_name="bad_family") + + +# --------------------------------------------------------------------------- +# Graph properties across many seeds +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("seed", range(30)) +def test_all_sampled_graphs_are_valid(seed: int) -> None: + """Property test: no seed should produce an invalid graph.""" + g = sample_hidden_graph(seed=seed) + # If we got here without GraphValidationError, the graph is valid. + assert g.graph.number_of_nodes() >= 2 + assert g.graph.number_of_edges() >= 1 + + +@pytest.mark.parametrize("name", MOTIF_FAMILY_NAMES) +def test_pinned_family_graphs_are_valid_across_seeds(name: str) -> None: + for seed in range(10): + g = sample_hidden_graph(seed=seed, motif_family_name=name) + assert g.graph.number_of_nodes() >= 2 + + +# --------------------------------------------------------------------------- +# Exports smoke tests +# --------------------------------------------------------------------------- + + +def test_to_json_is_parseable() -> None: + import json + + g = sample_hidden_graph(seed=1) + data = json.loads(g.to_json()) + assert "nodes" in data + assert "edges" in data + + +def test_to_graphml_contains_graph_tag() -> None: + g = sample_hidden_graph(seed=1) + assert " Date: Wed, 22 Apr 2026 19:53:26 +0300 Subject: [PATCH 2/5] fix: address Copilot review comments on PR #8 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CI: - node_types.py: replace str+Enum mixin with StrEnum (UP042, Python 3.11+) graph.py: - Guard reserved node attribute keys (node_type, label) in NodeSpec.metadata at WorldGraph construction; raise GraphValidationError on collision - Guard reserved edge attribute key (weight) in EdgeSpec.metadata; callers must use EdgeSpec.weight field directly - Add edge weight range validation ([-1, 1]) at construction time - Include node/edge metadata in to_dict() / to_json() output sampler.py: - Fix module docstring: "pinned by name" not "deterministically from the recipe" - Add seed validation: reject bool and negative ints (consistent with RNGRoot) - Narrow retry exception catch from broad Exception to GraphValidationError so programmer errors (TypeError, KeyError) surface immediately rewiring.py: - Remove unimplemented "swapping one optional node for an alternate proxy" bullet from the permitted-variability docstring tests: - test_graph.py: rename test_unreachable_outcome_raises → test_outcome_reachable_from_different_root_passes; add reserved-key and weight-range validation tests - test_sampler.py: add edge weight comparison to test_same_seed_same_graph; add test_bool_seed_raises and test_negative_seed_raises Resolves COPILOT-1 through COPILOT-10 (COPILOT-3 resolved as irrelevant). Co-Authored-By: Claude Sonnet 4.6 --- leadforge/structure/graph.py | 26 +++++++++++++++++-- leadforge/structure/node_types.py | 8 +++--- leadforge/structure/rewiring.py | 1 - leadforge/structure/sampler.py | 13 ++++++---- tests/structure/test_graph.py | 43 ++++++++++++++++++++++++++++--- tests/structure/test_sampler.py | 14 ++++++++++ 6 files changed, 89 insertions(+), 16 deletions(-) diff --git a/leadforge/structure/graph.py b/leadforge/structure/graph.py index f433783..49b4de7 100644 --- a/leadforge/structure/graph.py +++ b/leadforge/structure/graph.py @@ -84,12 +84,21 @@ def __init__( self._motif_family = motif_family self._graph: nx.DiGraph = nx.DiGraph() + # Reserved node attribute keys — metadata must not override these. + _reserved_node_keys = frozenset({"node_type", "label"}) + # Add nodes seen_ids: set[str] = set() for n in nodes: if n.node_id in seen_ids: raise GraphValidationError(f"Duplicate node_id: {n.node_id!r}") seen_ids.add(n.node_id) + reserved_clash = _reserved_node_keys & n.metadata.keys() + if reserved_clash: + raise GraphValidationError( + f"Node {n.node_id!r} metadata contains reserved key(s): " + f"{sorted(reserved_clash)}" + ) self._graph.add_node( n.node_id, node_type=n.node_type.value, @@ -103,6 +112,15 @@ def __init__( raise GraphValidationError(f"Edge source {e.source!r} not in node set") if e.target not in seen_ids: raise GraphValidationError(f"Edge target {e.target!r} not in node set") + if "weight" in e.metadata: + raise GraphValidationError( + f"Edge {e.source!r}→{e.target!r} metadata contains reserved key 'weight'; " + f"use the EdgeSpec.weight field instead" + ) + if not (-1.0 <= e.weight <= 1.0): + raise GraphValidationError( + f"Edge {e.source!r}→{e.target!r} weight {e.weight} is outside [-1, 1]" + ) self._graph.add_edge( e.source, e.target, @@ -140,19 +158,23 @@ def topological_order(self) -> list[str]: def to_dict(self) -> dict[str, Any]: """Return a JSON-serialisable dict representation.""" + _reserved_node = {"node_type", "label"} nodes = [ { "node_id": n, - "node_type": self._graph.nodes[n]["node_type"], - "label": self._graph.nodes[n].get("label", ""), + "node_type": attrs["node_type"], + "label": attrs.get("label", ""), + "metadata": {k: v for k, v in attrs.items() if k not in _reserved_node}, } for n in self.topological_order() + for attrs in (self._graph.nodes[n],) ] edges = [ { "source": u, "target": v, "weight": data.get("weight", 1.0), + "metadata": {k: v for k, v in data.items() if k != "weight"}, } for u, v, data in self._graph.edges(data=True) ] diff --git a/leadforge/structure/node_types.py b/leadforge/structure/node_types.py index 7078711..c5b218d 100644 --- a/leadforge/structure/node_types.py +++ b/leadforge/structure/node_types.py @@ -6,15 +6,15 @@ from __future__ import annotations -from enum import Enum +from enum import StrEnum -class NodeType(str, Enum): +class NodeType(StrEnum): """Semantic category of a hidden-graph node. Values mirror the nine categories specified in §11.1 of the - architecture spec. Using ``str`` as a mixin makes serialisation - (JSON, GraphML) straightforward without extra conversion. + architecture spec. ``StrEnum`` makes serialisation (JSON, GraphML) + straightforward without extra conversion. """ GLOBAL_CONTEXT = "global_context" diff --git a/leadforge/structure/rewiring.py b/leadforge/structure/rewiring.py index 3e52dc7..cf8b75d 100644 --- a/leadforge/structure/rewiring.py +++ b/leadforge/structure/rewiring.py @@ -10,7 +10,6 @@ - dropping optional mediator nodes (and their incident edges) - perturbing edge weights within a bounded range - adding an optional latent confounder node -- swapping one optional node for an alternate proxy Forbidden variability (hard constraints enforced here): - chronologically impossible edges (validated downstream in WorldGraph) diff --git a/leadforge/structure/sampler.py b/leadforge/structure/sampler.py index 414a1ec..210c13b 100644 --- a/leadforge/structure/sampler.py +++ b/leadforge/structure/sampler.py @@ -1,16 +1,16 @@ """World graph sampler — draw a concrete hidden world from a motif + seed. :func:`sample_hidden_graph` is the single entry point consumed by the -simulation layer. It selects a motif family (deterministically from the -recipe, or randomly from the seed), applies stochastic rewiring, and -returns a validated :class:`~leadforge.structure.graph.WorldGraph`. +simulation layer. It selects a motif family (pinned by name or chosen +at random from the seed), applies stochastic rewiring, and returns a +validated :class:`~leadforge.structure.graph.WorldGraph`. """ from __future__ import annotations import numpy as np -from leadforge.structure.graph import WorldGraph +from leadforge.structure.graph import GraphValidationError, WorldGraph from leadforge.structure.motifs import ( ALL_MOTIF_FAMILIES, MotifFamily, @@ -43,11 +43,14 @@ def sample_hidden_graph( A validated :class:`~leadforge.structure.graph.WorldGraph`. Raises: + ValueError: If *seed* is a ``bool`` or a negative integer. KeyError: If *motif_family_name* is not a known motif family name. RuntimeError: If :data:`_MAX_ATTEMPTS` rewiring attempts all produce graphs that fail structural validation (should not happen in practice with well-formed motifs). """ + if isinstance(seed, bool) or not isinstance(seed, int) or seed < 0: + raise ValueError(f"seed must be a non-negative int, got {seed!r}") rng = np.random.default_rng(seed) motif = _select_motif(motif_family_name, rng) @@ -61,7 +64,7 @@ def sample_hidden_graph( nodes, edges = rewire(motif, attempt_rng) try: return WorldGraph(nodes=nodes, edges=edges, motif_family=motif.name) - except Exception as exc: # noqa: BLE001 + except GraphValidationError as exc: last_exc = exc continue diff --git a/tests/structure/test_graph.py b/tests/structure/test_graph.py index a2c00bd..6752b05 100644 --- a/tests/structure/test_graph.py +++ b/tests/structure/test_graph.py @@ -146,7 +146,8 @@ def test_no_outcome_node_raises() -> None: WorldGraph(nodes=nodes, edges=edges, motif_family="test") -def test_unreachable_outcome_raises() -> None: +def test_outcome_reachable_from_different_root_passes() -> None: + # 'out' is reachable from 'root2', even though 'lead' has no path to it. nodes = [ NodeSpec("root", NodeType.ACCOUNT_LATENT), NodeSpec("lead", NodeType.LEAD_STATE), @@ -155,10 +156,8 @@ def test_unreachable_outcome_raises() -> None: ] edges = [ EdgeSpec("root", "lead"), - EdgeSpec("root2", "out"), # out reachable from root2 only - # lead has no path to out + EdgeSpec("root2", "out"), ] - # out IS reachable from root2 so this should pass g = WorldGraph(nodes=nodes, edges=edges, motif_family="test") assert g.graph.number_of_nodes() == 4 @@ -168,6 +167,42 @@ def test_unreachable_outcome_raises() -> None: # --------------------------------------------------------------------------- +def test_reserved_node_metadata_key_raises() -> None: + nodes = [ + NodeSpec("root", NodeType.ACCOUNT_LATENT, metadata={"node_type": "bad"}), + NodeSpec("lead", NodeType.LEAD_STATE), + NodeSpec("out", NodeType.OUTCOME), + ] + edges = [EdgeSpec("root", "lead"), EdgeSpec("lead", "out")] + with pytest.raises(GraphValidationError, match="reserved key"): + WorldGraph(nodes=nodes, edges=edges, motif_family="test") + + +def test_reserved_edge_weight_key_raises() -> None: + nodes = [ + NodeSpec("root", NodeType.ACCOUNT_LATENT), + NodeSpec("lead", NodeType.LEAD_STATE), + NodeSpec("out", NodeType.OUTCOME), + ] + edges = [ + EdgeSpec("root", "lead", metadata={"weight": 0.5}), + EdgeSpec("lead", "out"), + ] + with pytest.raises(GraphValidationError, match="reserved key 'weight'"): + WorldGraph(nodes=nodes, edges=edges, motif_family="test") + + +def test_edge_weight_out_of_range_raises() -> None: + nodes = [ + NodeSpec("root", NodeType.ACCOUNT_LATENT), + NodeSpec("lead", NodeType.LEAD_STATE), + NodeSpec("out", NodeType.OUTCOME), + ] + edges = [EdgeSpec("root", "lead", weight=1.5), EdgeSpec("lead", "out")] + with pytest.raises(GraphValidationError, match="outside \\[-1, 1\\]"): + WorldGraph(nodes=nodes, edges=edges, motif_family="test") + + def test_duplicate_node_id_raises() -> None: nodes = [ NodeSpec("a", NodeType.ACCOUNT_LATENT), diff --git a/tests/structure/test_sampler.py b/tests/structure/test_sampler.py index 3e16511..ffc32da 100644 --- a/tests/structure/test_sampler.py +++ b/tests/structure/test_sampler.py @@ -41,6 +41,10 @@ def test_same_seed_same_graph() -> None: assert g1.motif_family == g2.motif_family assert sorted(g1.graph.nodes) == sorted(g2.graph.nodes) assert sorted(g1.graph.edges) == sorted(g2.graph.edges) + # Edge weights must also be identical — catches regressions in weight jitter. + weights1 = {(u, v): d["weight"] for u, v, d in g1.graph.edges(data=True)} + weights2 = {(u, v): d["weight"] for u, v, d in g2.graph.edges(data=True)} + assert weights1 == weights2 def test_different_seeds_can_differ() -> None: @@ -66,6 +70,16 @@ def test_unknown_motif_family_raises() -> None: sample_hidden_graph(seed=0, motif_family_name="bad_family") +def test_bool_seed_raises() -> None: + with pytest.raises(ValueError, match="non-negative int"): + sample_hidden_graph(seed=True) # type: ignore[arg-type] + + +def test_negative_seed_raises() -> None: + with pytest.raises(ValueError, match="non-negative int"): + sample_hidden_graph(seed=-1) + + # --------------------------------------------------------------------------- # Graph properties across many seeds # --------------------------------------------------------------------------- From 3cff501830b3ddd27a351211ded61a231f5abebc Mon Sep 17 00:00:00 2001 From: Shay Palachy Date: Wed, 22 Apr 2026 21:42:58 +0300 Subject: [PATCH 3/5] fix: address Copilot round-2 review comments on PR #8 graph.py: - Fix NodeSpec.metadata docstring: clarify that non-primitive values are JSON-encoded in GraphML export rather than "Serialised as JSON" (which implied it was already done) - Add _make_graphml_safe() helper: JSON-encodes non-primitive attribute values (dict/list) into a '_json' string attribute so that networkx.generate_graphml() never raises TypeError - Rewrite to_graphml() to build an exportable copy of the graph through _make_graphml_safe() before calling generate_graphml rewiring.py: - Replace copy.copy() with explicit NodeSpec/EdgeSpec construction using deepcopy(metadata); shallow copy aliased the canonical motif's metadata dicts so post-rewiring mutations would have corrupted the frozen motif Six E501 complaints from Copilot are false-positives: CI Lint & format already passes (SUCCESS) on this commit; those threads resolved as stale. Co-Authored-By: Claude Sonnet 4.6 --- leadforge/structure/graph.py | 39 +++++++++++++++++++++++++++++---- leadforge/structure/rewiring.py | 12 +++++++--- 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/leadforge/structure/graph.py b/leadforge/structure/graph.py index 49b4de7..0f9a067 100644 --- a/leadforge/structure/graph.py +++ b/leadforge/structure/graph.py @@ -30,7 +30,9 @@ class NodeSpec: node_type: Semantic category of the node. label: Human-readable name used in exports. metadata: Arbitrary extra attributes (e.g. prior strength, proxy - accuracy). Serialised as JSON in GraphML export. + accuracy). Stored as raw node attributes; primitive values + are emitted directly in GraphML, non-primitive values are + JSON-encoded under a ``_json`` suffix key. """ node_id: str @@ -58,6 +60,24 @@ class EdgeSpec: metadata: dict[str, Any] = field(default_factory=dict) +def _make_graphml_safe(attrs: dict[str, Any]) -> dict[str, Any]: + """Return a copy of *attrs* where non-primitive values are JSON-encoded. + + GraphML only supports string, int, float, and bool attribute values. + Any dict or list value is serialised to a JSON string stored under a + key with a ``_json`` suffix so that ``networkx.generate_graphml`` + does not raise ``TypeError``. + """ + _primitive = (str, int, float, bool) + result: dict[str, Any] = {} + for k, v in attrs.items(): + if isinstance(v, _primitive): # noqa: UP038 + result[k] = v + else: + result[f"{k}_json"] = json.dumps(v) + return result + + class WorldGraph: """Validated directed acyclic graph representing one hidden world. @@ -189,9 +209,20 @@ def to_json(self) -> str: return json.dumps(self.to_dict(), indent=2) def to_graphml(self) -> str: - """Return a GraphML string representation.""" - lines = nx.generate_graphml(self._graph) - return "\n".join(lines) + """Return a GraphML string representation. + + Non-primitive node/edge attribute values (dicts, lists, etc.) are + JSON-encoded into a string attribute with a ``_json`` suffix so that + NetworkX's GraphML writer does not raise ``TypeError``. + """ + exportable = nx.DiGraph() + for node_id, attrs in self._graph.nodes(data=True): + safe = _make_graphml_safe(attrs) + exportable.add_node(node_id, **safe) + for u, v, attrs in self._graph.edges(data=True): + safe = _make_graphml_safe(attrs) + exportable.add_edge(u, v, **safe) + return "\n".join(nx.generate_graphml(exportable)) # ------------------------------------------------------------------ # Validation diff --git a/leadforge/structure/rewiring.py b/leadforge/structure/rewiring.py index cf8b75d..c17979e 100644 --- a/leadforge/structure/rewiring.py +++ b/leadforge/structure/rewiring.py @@ -19,7 +19,7 @@ from __future__ import annotations -import copy +from copy import deepcopy from typing import TYPE_CHECKING from leadforge.structure.graph import EdgeSpec, NodeSpec @@ -65,8 +65,14 @@ def rewire( A ``(nodes, edges)`` tuple suitable for passing to :class:`~leadforge.structure.graph.WorldGraph`. """ - nodes: list[NodeSpec] = [copy.copy(n) for n in motif.canonical_nodes] - edges: list[EdgeSpec] = [copy.copy(e) for e in motif.canonical_edges] + # deepcopy metadata dicts so mutations never alias the canonical motif specs. + nodes: list[NodeSpec] = [ + NodeSpec(n.node_id, n.node_type, n.label, deepcopy(n.metadata)) + for n in motif.canonical_nodes + ] + edges: list[EdgeSpec] = [ + EdgeSpec(e.source, e.target, e.weight, deepcopy(e.metadata)) for e in motif.canonical_edges + ] # Step 1 — drop optional nodes dropped: set[str] = set() From a252ced70f1e7312e052b48b75b3c4c0a44b008a Mon Sep 17 00:00:00 2001 From: Shay Palachy Date: Wed, 22 Apr 2026 23:45:56 +0300 Subject: [PATCH 4/5] fix: address Copilot round-3 review on PR #8 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit graph.py: - Freeze NodeSpec and EdgeSpec dataclasses so field reassignment on canonical motif specs raises FrozenInstanceError; metadata dicts are still mutable by Python's type system but rewiring already deepcopies them, so motif integrity is preserved end-to-end - Fix _make_graphml_safe docstring: clarify that ALL non-primitive values (None, tuples, enums, dicts, lists, …) are JSON-encoded, not just dict/list as the previous text implied COPILOT-1 (RNGRoot integration in sampler): deferred to issue #9 — the right substream name and API shape depend on how Generator calls the sampler, which is defined in Milestone 5. COPILOT-4 (E501 on EdgeSpec comprehension): resolved as false-positive — CI Lint & format passes SUCCESS on this commit. FAIL-1 (startup_failure on pr-agent-context-refresh): expected bot-triggered approval-gate behaviour; no code change required. Co-Authored-By: Claude Sonnet 4.6 --- leadforge/structure/graph.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/leadforge/structure/graph.py b/leadforge/structure/graph.py index 0f9a067..581be4b 100644 --- a/leadforge/structure/graph.py +++ b/leadforge/structure/graph.py @@ -21,7 +21,7 @@ class GraphValidationError(LeadforgeError): """Raised when a hidden world graph violates a structural invariant.""" -@dataclass +@dataclass(frozen=True) class NodeSpec: """Specification for a single hidden-graph node. @@ -41,7 +41,7 @@ class NodeSpec: metadata: dict[str, Any] = field(default_factory=dict) -@dataclass +@dataclass(frozen=True) class EdgeSpec: """Specification for a directed edge between two hidden-graph nodes. @@ -64,9 +64,10 @@ def _make_graphml_safe(attrs: dict[str, Any]) -> dict[str, Any]: """Return a copy of *attrs* where non-primitive values are JSON-encoded. GraphML only supports string, int, float, and bool attribute values. - Any dict or list value is serialised to a JSON string stored under a - key with a ``_json`` suffix so that ``networkx.generate_graphml`` - does not raise ``TypeError``. + Any value that is not one of those primitives (including ``None``, + tuples, enums, dicts, lists, etc.) is serialised to a JSON string + stored under a key with a ``_json`` suffix so that + ``networkx.generate_graphml`` does not raise ``TypeError``. """ _primitive = (str, int, float, bool) result: dict[str, Any] = {} From b5623d2779387987a430cf1a5c8471cdf9365878 Mon Sep 17 00:00:00 2001 From: Shay Palachy Date: Thu, 23 Apr 2026 10:17:21 +0300 Subject: [PATCH 5/5] fix: address Copilot round-4 review on PR #8 - NodeSpec/EdgeSpec: wrap metadata in MappingProxyType via __post_init__ so canonical motif skeletons are truly immutable (COPILOT-2a/2b) - _make_graphml_safe: add str() fallback for non-JSON-serialisable values and use collision-safe suffix-key generation (COPILOT-3) - sample_hidden_graph: derive NumPy seed from RNGRoot(seed).child("hidden_graph") to align with repo RNG convention (COPILOT-1) - rewiring: replace deepcopy(metadata) with dict() since MappingProxyType is already immutable and deepcopy cannot pickle it Co-Authored-By: Claude Sonnet 4.6 --- leadforge/structure/graph.py | 36 +++++++++++++++++++++++++-------- leadforge/structure/rewiring.py | 11 +++++----- leadforge/structure/sampler.py | 12 +++++++---- 3 files changed, 41 insertions(+), 18 deletions(-) diff --git a/leadforge/structure/graph.py b/leadforge/structure/graph.py index 581be4b..cdcd5bf 100644 --- a/leadforge/structure/graph.py +++ b/leadforge/structure/graph.py @@ -8,7 +8,9 @@ from __future__ import annotations import json +from collections.abc import Mapping from dataclasses import dataclass, field +from types import MappingProxyType from typing import Any import networkx as nx @@ -32,13 +34,17 @@ class NodeSpec: metadata: Arbitrary extra attributes (e.g. prior strength, proxy accuracy). Stored as raw node attributes; primitive values are emitted directly in GraphML, non-primitive values are - JSON-encoded under a ``_json`` suffix key. + serialised under a ``_json`` suffix key. The mapping is + immutable after construction to protect canonical motif specs. """ node_id: str node_type: NodeType label: str = "" - metadata: dict[str, Any] = field(default_factory=dict) + metadata: Mapping[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + object.__setattr__(self, "metadata", MappingProxyType(dict(self.metadata))) @dataclass(frozen=True) @@ -52,30 +58,44 @@ class EdgeSpec: values indicate facilitation; negative values indicate inhibition. metadata: Arbitrary extra attributes (e.g. mechanism type, lag). + The mapping is immutable after construction to protect + canonical motif specs. """ source: str target: str weight: float = 1.0 - metadata: dict[str, Any] = field(default_factory=dict) + metadata: Mapping[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + object.__setattr__(self, "metadata", MappingProxyType(dict(self.metadata))) def _make_graphml_safe(attrs: dict[str, Any]) -> dict[str, Any]: - """Return a copy of *attrs* where non-primitive values are JSON-encoded. + """Return a copy of *attrs* where non-primitive values are serialised to strings. GraphML only supports string, int, float, and bool attribute values. Any value that is not one of those primitives (including ``None``, - tuples, enums, dicts, lists, etc.) is serialised to a JSON string - stored under a key with a ``_json`` suffix so that - ``networkx.generate_graphml`` does not raise ``TypeError``. + tuples, enums, dicts, lists, etc.) is first attempted as a JSON + string; values that are not JSON-serialisable fall back to ``str()``. + The encoded value is stored under a ``_json`` suffix key. The suffix + key is made unique (by appending further ``_json`` segments) to avoid + collisions with other keys already present in *attrs*. """ _primitive = (str, int, float, bool) + _all_input_keys = set(attrs.keys()) result: dict[str, Any] = {} for k, v in attrs.items(): if isinstance(v, _primitive): # noqa: UP038 result[k] = v else: - result[f"{k}_json"] = json.dumps(v) + suffix_key = f"{k}_json" + while suffix_key in result or suffix_key in _all_input_keys: + suffix_key = f"{suffix_key}_json" + try: + result[suffix_key] = json.dumps(v) + except (TypeError, ValueError): + result[suffix_key] = str(v) return result diff --git a/leadforge/structure/rewiring.py b/leadforge/structure/rewiring.py index c17979e..be29bc3 100644 --- a/leadforge/structure/rewiring.py +++ b/leadforge/structure/rewiring.py @@ -19,7 +19,6 @@ from __future__ import annotations -from copy import deepcopy from typing import TYPE_CHECKING from leadforge.structure.graph import EdgeSpec, NodeSpec @@ -65,13 +64,13 @@ def rewire( A ``(nodes, edges)`` tuple suitable for passing to :class:`~leadforge.structure.graph.WorldGraph`. """ - # deepcopy metadata dicts so mutations never alias the canonical motif specs. + # metadata is already immutable (MappingProxyType); a plain dict() copy is + # sufficient — NodeSpec/EdgeSpec will re-wrap it in a new proxy. nodes: list[NodeSpec] = [ - NodeSpec(n.node_id, n.node_type, n.label, deepcopy(n.metadata)) - for n in motif.canonical_nodes + NodeSpec(n.node_id, n.node_type, n.label, dict(n.metadata)) for n in motif.canonical_nodes ] edges: list[EdgeSpec] = [ - EdgeSpec(e.source, e.target, e.weight, deepcopy(e.metadata)) for e in motif.canonical_edges + EdgeSpec(e.source, e.target, e.weight, dict(e.metadata)) for e in motif.canonical_edges ] # Step 1 — drop optional nodes @@ -96,7 +95,7 @@ def rewire( source=e.source, target=e.target, weight=new_weight, - metadata=dict(e.metadata), + metadata=dict(e.metadata), # MappingProxyType → mutable dict; NodeSpec re-wraps ) ) edges = perturbed_edges diff --git a/leadforge/structure/sampler.py b/leadforge/structure/sampler.py index 210c13b..3aee3c5 100644 --- a/leadforge/structure/sampler.py +++ b/leadforge/structure/sampler.py @@ -10,6 +10,7 @@ import numpy as np +from leadforge.core.rng import RNGRoot from leadforge.structure.graph import GraphValidationError, WorldGraph from leadforge.structure.motifs import ( ALL_MOTIF_FAMILIES, @@ -31,9 +32,11 @@ def sample_hidden_graph( The function is fully deterministic given ``(seed, motif_family_name)``. Args: - seed: Integer seed for the NumPy random generator. All stochastic - choices (motif selection if *motif_family_name* is ``None``, - rewiring decisions, weight jitter) derive from this seed. + seed: Integer seed passed to :class:`~leadforge.core.rng.RNGRoot`. + All stochastic choices (motif selection if *motif_family_name* + is ``None``, rewiring decisions, weight jitter) derive from a + named child stream of this root so the sampler integrates with + the repo's RNG convention. motif_family_name: If provided, pin the motif family by name (must be one of :data:`~leadforge.structure.motifs.MOTIF_FAMILY_NAMES`). If ``None``, a family is chosen uniformly at random from the @@ -51,7 +54,8 @@ def sample_hidden_graph( """ if isinstance(seed, bool) or not isinstance(seed, int) or seed < 0: raise ValueError(f"seed must be a non-negative int, got {seed!r}") - rng = np.random.default_rng(seed) + np_seed = RNGRoot(seed).child("hidden_graph").getrandbits(64) + rng = np.random.default_rng(np_seed) motif = _select_motif(motif_family_name, rng)