diff --git a/src/datajoint/diagram.py b/src/datajoint/diagram.py index aacf4ed61..9b6c659f3 100644 --- a/src/datajoint/diagram.py +++ b/src/datajoint/diagram.py @@ -365,6 +365,17 @@ def cascade(cls, table_expr, part_integrity="enforce"): # Propagate downstream result._propagate_restrictions(node, mode="cascade", part_integrity=part_integrity) + # part_integrity="cascade" may pull in nodes that aren't descendants of + # the seed (e.g. the master of a seed Part, plus the master's other + # Parts). Expand nodes_to_show to include any restricted node and the + # descendants of any newly-restricted ancestor. See #1429. + restricted_nodes = set(result._cascade_restrictions) + expanded = set(result.nodes_to_show) | restricted_nodes + for n in restricted_nodes - result.nodes_to_show: + expanded.update(nx.descendants(result, n)) + result.nodes_to_show = expanded & set(result.nodes()) + result._expanded_nodes = set(result.nodes_to_show) + # Trim graph to cascade subgraph: only restricted tables # (seed + descendants) plus alias nodes connecting them. keep = set(result._cascade_restrictions) @@ -443,7 +454,6 @@ def _propagate_restrictions(self, start_node, mode, part_integrity="enforce"): propagation rules at each edge. Only processes descendants of start_node to avoid duplicate propagation when chaining. """ - from .table import FreeTable sorted_nodes = topo_sort(self) # Only propagate through descendants of start_node @@ -453,6 +463,18 @@ def _propagate_restrictions(self, start_node, mode, part_integrity="enforce"): restrictions = self._cascade_restrictions if mode == "cascade" else self._restrict_conditions + # Seed-is-Part case: when the seed itself is a Part and part_integrity="cascade", + # the main loop's part_integrity block (which fires inside `out_edges`) + # cannot trigger from the seed because a leaf Part has no out-edges. + # Trigger the upward propagation explicitly for the seed. See #1429. + if part_integrity == "cascade" and mode == "cascade": + seed_master = extract_master(start_node) + if seed_master and seed_master in self.nodes() and seed_master not in visited_masters: + visited_masters.add(seed_master) + if self._propagate_part_to_master(start_node, seed_master, mode, restrictions): + allowed_nodes.add(seed_master) + allowed_nodes.update(nx.descendants(self, seed_master)) + # Multiple passes to handle part_integrity="cascade" upward propagation. # When a part table triggers its master to join the cascade, the master's # other descendants need processing in a subsequent pass. The loop @@ -512,29 +534,19 @@ def _propagate_restrictions(self, start_node, mode, part_integrity="enforce"): any_new = True # part_integrity="cascade": propagate up from part to master + # via the actual FK graph path, applying upward propagation + # rules at each edge. Handles Part-of-Part chains and + # renamed FKs (via .proj()), unlike the prior implementation + # which assumed shared PK attribute names. See #1429. if part_integrity == "cascade" and mode == "cascade": master_name = extract_master(target) - if ( - master_name - and master_name in self.nodes() - and master_name not in restrictions - and master_name not in visited_masters - ): + if master_name and master_name in self.nodes() and master_name not in visited_masters: visited_masters.add(master_name) - child_ft = self._restricted_table(target) - master_ft = FreeTable(self._connection, master_name) - from .condition import make_condition - - master_restr = make_condition( - master_ft, - (master_ft.proj() & child_ft.proj()).to_arrays(), - master_ft.restriction_attributes, - ) - restrictions[master_name] = [master_restr] - self._restriction_attrs[master_name] = set() - allowed_nodes.add(master_name) - allowed_nodes.update(nx.descendants(self, master_name)) - any_new = True + propagated = self._propagate_part_to_master(target, master_name, mode, restrictions) + if propagated: + allowed_nodes.add(master_name) + allowed_nodes.update(nx.descendants(self, master_name)) + any_new = True def _apply_propagation_rule( self, @@ -590,6 +602,178 @@ def _apply_propagation_rule( self._restriction_attrs.setdefault(child_node, set()).update(child_attrs) + def _apply_propagation_rule_upward(self, child_ft, child_attrs, parent_node, attr_map, aliased, mode, restrictions): + """ + Apply the symmetric (upward) propagation rule to a parent←child edge. + + Inverts `_apply_propagation_rule`: derives a restriction on the parent + from a restriction on the child, following the FK chain in reverse. + Used by part_integrity="cascade" to propagate a Part's restriction up + to its Master, transparently handling renamed FKs (via .proj()) and + Part-of-Part chains. See #1429. + + Edge metadata convention (matches `_apply_propagation_rule`): + - `attr_map`: dict mapping child column → parent (referenced) column. + - `aliased`: True iff any column was renamed across the FK. + + Rules (symmetric to the forward rules in `_apply_propagation_rule`): + + 1. Non-aliased AND child restriction attrs ⊆ parent PK: + Copy child restriction directly (attrs are shared by name). + 2. Aliased FK (attr_map renames columns): + ``child.proj(**{parent: child for child, parent in attr_map.items()})`` + — reverses the renaming so the result has parent's column names. + 3. Non-aliased AND child restriction attrs ⊄ parent PK: + ``child.proj()`` — project child to parent's PK columns. + """ + parent_pk = self.nodes[parent_node].get("primary_key", set()) + + if not aliased and child_attrs and child_attrs <= parent_pk: + # Backward Rule 1: copy child restriction directly + child_restr = restrictions.get( + child_ft.full_table_name, + [] if mode == "cascade" else AndList(), + ) + if mode == "cascade": + restrictions.setdefault(parent_node, []).extend(child_restr) + else: + restrictions.setdefault(parent_node, AndList()).extend(child_restr) + parent_attrs = set(child_attrs) + elif aliased: + # Backward Rule 2: reverse rename + parent_item = child_ft.proj(**{pk: fk for fk, pk in attr_map.items()}) + if mode == "cascade": + restrictions.setdefault(parent_node, []).append(parent_item) + else: + restrictions.setdefault(parent_node, AndList()).append(parent_item) + parent_attrs = set(attr_map.values()) # parent's PK column names + else: + # Backward Rule 3: project child to parent PK + parent_item = child_ft.proj() + if mode == "cascade": + restrictions.setdefault(parent_node, []).append(parent_item) + else: + restrictions.setdefault(parent_node, AndList()).append(parent_item) + parent_attrs = set(attr_map.values()) + + self._restriction_attrs.setdefault(parent_node, set()).update(parent_attrs) + + def _propagate_part_to_master(self, part_node, master_name, mode, restrictions): + """ + Walk the FK graph from `part_node` up to `master_name`, applying + `_apply_propagation_rule_upward` at each real edge along the path. + + Returns True if any propagation occurred. Handles Part-of-Part chains + by walking the full path (intermediate Parts get restricted too) and + renamed FKs via the upward rules. + + Alias nodes (integer-named graph nodes inserted for aliased edges) + are transparent — both half-edges carry the same `attr_map` props, + so we read props from one and skip the alias node when walking. + + After the walk, the master's restriction is **materialized** to a + literal value tuple via ``to_arrays()``. Without materialization, a + subsequent forward cascade from the master back down to its parts + would produce a self-referential subquery (MySQL error 1093, since + the master's restriction depends on the same Part being deleted). + Materializing converts the restriction into a static value set, so + the forward cascade generates ``WHERE ... IN (literal-list)`` rather + than ``WHERE ... IN (SELECT ... FROM )``. + + Limitations + ----------- + - **Single FK path**: ``nx.shortest_path`` returns *one* path from + ``master_name`` to ``part_node``. If a Part is reachable from its + Master through multiple distinct FK chains (e.g. references two + different intermediate Parts), restrictions through the + non-shortest paths are not applied. This pattern is unusual; if a + schema hits it, the user is responsible for restricting the + additional paths explicitly via ``part_integrity="ignore"`` plus + manual ``delete()`` calls. + - **Memory cost of materialization**: ``master_ft.proj().to_arrays()`` + pulls the matching master primary keys into Python memory. Cost is + bounded by the count of *distinct* master rows referenced by the + matching parts — typically small for surgical cascades, but can + grow with bulk cascades on tables with many master rows. Cascade + *preview* (``Diagram.cascade(...).counts()``) pays the same cost. + """ + try: + path = nx.shortest_path(self, master_name, part_node) + except (nx.NetworkXNoPath, nx.NodeNotFound): + return False + + # Strip alias nodes; what remains is the sequence of real tables. + real_path = [n for n in path if not (isinstance(n, str) and n.isdigit())] + if len(real_path) < 2 or real_path[-1] != part_node or real_path[0] != master_name: + return False + + # Walk real_path in reverse (child → parent direction). For each + # adjacent (parent, child) pair, look up the edge props — direct + # edge if non-aliased, via alias node if aliased. + any_propagated = False + for i in range(len(real_path) - 1, 0, -1): + child = real_path[i] + parent = real_path[i - 1] + edge_props = self._find_real_edge_props(parent, child) + if edge_props is None: + return any_propagated # Path broken (shouldn't happen if shortest_path succeeded) + + attr_map = edge_props.get("attr_map", {}) + aliased = edge_props.get("aliased", False) + child_ft = self._restricted_table(child) + child_attrs = self._restriction_attrs.get(child, set()) + + self._apply_propagation_rule_upward( + child_ft, + child_attrs, + parent, + attr_map, + aliased, + mode, + restrictions, + ) + any_propagated = True + + # Materialize the master's restriction so subsequent forward cascade + # doesn't produce self-referential subqueries. Replace the master's + # accumulated query restrictions with a literal value tuple. + if any_propagated and master_name in restrictions: + from .condition import make_condition + from .table import FreeTable + + master_ft = self._restricted_table(master_name) + master_pk_values = master_ft.proj().to_arrays() + if mode == "cascade": + bare_master = FreeTable(self._connection, master_name) + if len(master_pk_values) > 0: + materialized = make_condition( + bare_master, + master_pk_values, + bare_master.restriction_attributes, + ) + restrictions[master_name] = [materialized] + else: + # No matching master rows — false restriction so master is + # included with zero matches in counts/iter. + restrictions[master_name] = [False] + self._restriction_attrs.setdefault(master_name, set()) + + return any_propagated + + def _find_real_edge_props(self, parent, child): + """ + Return edge props for parent → child, transparently traversing the + integer-named alias node that the graph inserts for aliased FKs. + Returns None if no such edge or alias-mediated edge exists. + """ + if self.has_edge(parent, child): + return self.edges[parent, child] + for _, mid, _ in self.out_edges(parent, data=True): + if isinstance(mid, str) and mid.isdigit() and self.has_edge(mid, child): + # Both half-edges carry the same attr_map / aliased props + return self.edges[parent, mid] + return None + def counts(self): """ Return affected row counts per table without modifying data. diff --git a/tests/integration/test_cascade_delete.py b/tests/integration/test_cascade_delete.py index 3bc3dc73b..607669124 100644 --- a/tests/integration/test_cascade_delete.py +++ b/tests/integration/test_cascade_delete.py @@ -292,3 +292,190 @@ class Child(dj.Manual): connection_by_backend.query(f"DROP DATABASE IF EXISTS {qi(name)}") except Exception: pass + + +# ========================================================================= +# Issue #1429: cascade with part_integrity="cascade" must traverse the FK +# chain through intermediate Parts (and renamed FKs), not assume that the +# Part shares PK attribute names with its Master. +# ========================================================================= + + +def test_cascade_part_of_part_no_master_reference(schema_by_backend): + """ + Case 2 from #1429: PartB references PartA directly (no -> Master). + Restricting PartB with part_integrity="cascade" must restrict both + PartA and Master (PartA via the direct FK, Master via the master-part + FK chained through PartA). + """ + + @schema_by_backend + class Master(dj.Manual): + definition = """ + master_id : int32 + """ + + class PartA(dj.Part): + definition = """ + -> master + part_a_id : int32 + """ + + class PartB(dj.Part): + definition = """ + -> Master.PartA + part_b_id : int32 + """ + + Master.insert([(1,), (2,)]) + Master.PartA.insert([(1, 10), (1, 11), (2, 20)]) + Master.PartB.insert([(1, 10, 100), (1, 10, 101), (1, 11, 110), (2, 20, 200)]) + + # Cascade preview: deleting one PartB row must propagate up to PartA and Master. + counts = dj.Diagram.cascade( + Master.PartB & {"master_id": 1, "part_a_id": 10, "part_b_id": 100}, + part_integrity="cascade", + ).counts() + + # Master row (1,) is the originating Part's master — must appear with count 1 + assert counts.get(Master.full_table_name, 0) == 1, ( + f"Master restricted by 1 row; got {counts.get(Master.full_table_name)}. " + "Indicates the Part→Master upward propagation did not reach the Master " + "through the intermediate PartA." + ) + # Master cascades back down to ALL of master_id=1's Parts + assert counts.get(Master.PartA.full_table_name, 0) == 2 # rows 10, 11 + assert counts.get(Master.PartB.full_table_name, 0) == 3 # rows under master_id=1 + + +def test_cascade_part_of_part_renamed_fk(schema_by_backend): + """ + Case 1 from #1429: PartB references PartA via a renamed FK (`.proj()`). + PartB has no attribute named `master_id` (renamed to `src_master`). The + upward propagation must use the FK metadata, not assume shared attribute + names. + """ + + @schema_by_backend + class Master(dj.Manual): + definition = """ + master_id : int32 + """ + + class PartA(dj.Part): + definition = """ + -> master + part_a_id : int32 + """ + + class PartB(dj.Part): + definition = """ + -> Master.PartA.proj(src_master='master_id', src_part='part_a_id') + part_b_id : int32 + """ + + Master.insert([(1,), (2,)]) + Master.PartA.insert([(1, 10), (2, 20)]) + Master.PartB.insert([(1, 10, 100), (2, 20, 200)]) + + # PartB has columns: src_master, src_part, part_b_id — NOT master_id. + counts = dj.Diagram.cascade( + Master.PartB & {"src_master": 1, "src_part": 10, "part_b_id": 100}, + part_integrity="cascade", + ).counts() + + assert counts.get(Master.full_table_name, 0) == 1, ( + f"Master restricted by 1 row; got {counts.get(Master.full_table_name)}. " + "Renamed FK was not reversed when propagating up to Master." + ) + assert counts.get(Master.PartA.full_table_name, 0) == 1 + assert counts.get(Master.PartB.full_table_name, 0) == 1 + + +def test_cascade_three_level_part_chain(schema_by_backend): + """ + Three-hop chain (#1429 follow-up review): PartC → PartB → PartA → Master. + Verify intermediate Parts (PartA, PartB) are restricted at every hop, not + just the first, and the master cascades back down to all siblings. + """ + + @schema_by_backend + class Master(dj.Manual): + definition = """ + master_id : int32 + """ + + class PartA(dj.Part): + definition = """ + -> master + part_a_id : int32 + """ + + class PartB(dj.Part): + definition = """ + -> Master.PartA + part_b_id : int32 + """ + + class PartC(dj.Part): + definition = """ + -> Master.PartB + part_c_id : int32 + """ + + Master.insert([(1,), (2,)]) + Master.PartA.insert([(1, 10), (1, 11), (2, 20)]) + Master.PartB.insert([(1, 10, 100), (1, 11, 110), (2, 20, 200)]) + Master.PartC.insert([(1, 10, 100, 1000), (1, 11, 110, 1100), (2, 20, 200, 2000)]) + + counts = dj.Diagram.cascade( + Master.PartC & {"master_id": 1, "part_a_id": 10, "part_b_id": 100, "part_c_id": 1000}, + part_integrity="cascade", + ).counts() + + # Master pulled in via the 3-hop upward walk + assert counts.get(Master.full_table_name, 0) == 1, ( + "Master restriction lost across 3-hop chain — the per-edge upward walk " "did not reach Master through PartA + PartB." + ) + # Master forward-cascades back down to all rows under master_id=1 + assert counts.get(Master.PartA.full_table_name, 0) == 2 # both PartA rows under master 1 + assert counts.get(Master.PartB.full_table_name, 0) == 2 # both PartB rows under master 1 + assert counts.get(Master.PartC.full_table_name, 0) == 2 # both PartC rows under master 1 + + +def test_cascade_part_of_part_actual_delete(schema_by_backend): + """ + End-to-end: actually run delete() with part_integrity="cascade" through + a Part-of-Part chain. Verifies the upward propagation produces SQL that + executes (no MySQL 1093 self-reference; correct row removal). + """ + + @schema_by_backend + class Master(dj.Manual): + definition = """ + master_id : int32 + """ + + class PartA(dj.Part): + definition = """ + -> master + part_a_id : int32 + """ + + class PartB(dj.Part): + definition = """ + -> Master.PartA + part_b_id : int32 + """ + + Master.insert([(1,), (2,)]) + Master.PartA.insert([(1, 10), (2, 20)]) + Master.PartB.insert([(1, 10, 100), (2, 20, 200)]) + + (Master.PartB & {"master_id": 1}).delete(part_integrity="cascade") + + # master_id=1 chain is entirely gone; master_id=2 chain intact. + assert len(Master()) == 1 + assert Master().fetch1("master_id") == 2 + assert len(Master.PartA()) == 1 + assert len(Master.PartB()) == 1