Skip to content

Add attention rollout interpretability method (Abnar & Zuidema 2020)#1158

Open
fbonc wants to merge 7 commits into
sunlabuiuc:masterfrom
fbonc:attention-rollout
Open

Add attention rollout interpretability method (Abnar & Zuidema 2020)#1158
fbonc wants to merge 7 commits into
sunlabuiuc:masterfrom
fbonc:attention-rollout

Conversation

@fbonc

@fbonc fbonc commented Jun 8, 2026

Copy link
Copy Markdown

Contributor: Felipe Amaral Bonchristiano (felipea5@illinois.edu)

Contribution Type: New interpretability method

Description:
Adds vanilla attention rollout (Abnar & Zuidema, "Quantifying Attention Flow
in Transformers," 2020, arXiv:2005.00928) as a new interpretability module,
AttentionRollout. Rollout is the canonical forward-only, gradient-free,
class-agnostic attention-flow baseline: it accounts for residual connections
(Â = 0.5·(A + I)), fuses heads by mean, and composes per-layer attention by
matrix product to produce per-token relevance. It complements the existing
CheferRelevance (gradient-weighted, class-specific) by providing the standard
baseline that gradient-based attention methods are measured against, which the
interpretability suite currently lacks.

The implementation reuses the existing attention-readout methods already on
PyHealth's attention models (set_attention_hooks, get_attention_layers,
get_relevance_tensor), so it requires no model-side changes. It is a
single new method file plus its export, tests, docs, and example registrations.

Files to Review:

  • pyhealth/interpret/methods/attention_rollout.py: core implementation (AttentionRollout)
  • pyhealth/interpret/methods/__init__.py: exports AttentionRollout
  • tests/core/test_attention_rollout.py: synthetic-data unit tests (see Testing below)
  • docs/api/interpret/pyhealth.interpret.methods.attention_rollout.rst: API documentation
  • docs/api/interpret.rst: added to the Attribution Methods toctree
  • examples/interpretability/{mp,los,dka}_{transformer,stageattn}_mimic4_interpret.py: AttentionRollout registered in the method comparison dicts alongside CheferRelevance

Quick note:
The actual bounty on the doc lists "Rollout Attention" and links arXiv:2012.09838,
which is Chefer et al., Transformer Interpretability Beyond Attention
Visualization
(CVPR 2021), a gradient/LRP relevance method, not the rollout
paper. (The existing CheferRelevance implements the related Chefer et al. ICCV
2021 method, arXiv:2103.15679.) I read the bounty's intent from its name and from
the actual gap in the suite, as there was no gradient-free, class-agnostic baseline,
and implemented canonical rollout (Abnar & Zuidema 2020) rather than more
Chefer-style work. If the literal citation was intended, happy to redirect.

Key design decisions:

  • Canonical rollout, not an enhanced variant. Default is mean head fusion +
    0.5·(A + I); alternative fusions and residual schemes are deferred to optional
    kwargs. Again, this module's value is fidelity to the baseline, not improving on it.
  • Model compatibility via duck-typing, not isinstance(CheferInterpretable).
    The three readout methods are general attention readout, not Chefer-specific;
    __init__ checks hasattr and raises TypeError naming the missing methods.
    This keeps the PR to one new file with zero edits to the shared interface.
  • target_class_idx accepted but ignored, documented as a no-op, so rollout is
    drop-in swappable with class-specific interpreters in existing pipelines.
  • _map_to_input_shapes duplicated from CheferRelevance (rather than factored
    to a shared util) so attributions match the raw-input granularity the
    comprehensiveness/sufficiency metrics expect, while keeping this PR free of edits
    to chefer.py.

Proposed follow-up: extract a general AttentionInterpretable
interface and a shared shape-mapping helper that both AttentionRollout and
CheferRelevance depend on, removing the duck-typing and the duplicated
_map_to_input_shapes. Kept separate to avoid bundling a refactor of shared code
into a feature PR.

Testing: Unit tests use small synthetic data (create_sample_dataset, tiny
config, seeded) and run in well under a second with no network or credentials.
Beyond shape and dict-key checks, they assert the two correctness invariants:
(1) per-token relevance sums to 1 before input-shape expansion (the product of
row-stochastic matrices is row-stochastic), and (2) identity attention at every
layer yields an identity rollout. Construction-time errors (incompatible model,
unsupported head_fusion) are covered.

Note on verification: I am not yet MIMIC-credentialed, so end-to-end correctness
is established via the synthetic unit tests above; AttentionRollout is registered
in the MIMIC-IV comparison scripts for parity with the other methods but I have not
run those end-to-end myself.

@jhnwu3 jhnwu3 requested a review from Copilot June 21, 2026 22:46

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot was unable to review this pull request because the user who requested the review has reached their quota limit.

@jhnwu3 jhnwu3 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this — the implementation, docs, and test suite are in good shape. The rationale in the description is clear and the correctness invariants in the tests (row-stochastic check, identity-attention test) are exactly what we want to see. A few minor things (header comments, trailing newline in attention_rollout.py, the unreachable ValueError branch in _fuse_heads) are non-blocking; feel free to clean them up or leave them.

Two things before merge:


1. One doc spot still lists only Chefer

docs/why_pyhealth.rst line 178 currently reads:

Attention-based: Chefer relevance propagation for transformers

Since rollout is the canonical gradient-free attention baseline that Chefer-style methods are measured against, please add it here — e.g.:

Attention-based: Chefer relevance propagation and attention rollout for transformers

Everything else checks out: the __init__ export, the toctree entry, the API rst page, and the six MIMIC comparison scripts are all wired up correctly.


2. A runnable benchmark example

Every existing script that calls evaluate_attribution depends on credentialed full MIMIC-IV plus pre-trained checkpoints and hardcoded /shared/eng/... paths — which is exactly why you noted you couldn't verify end-to-end faithfulness. Could you add a self-contained example under examples/interpretability/ that uses the MIMIC-IV demo dataset (the ~100-patient open-access subset on PhysioNet, freely downloadable with no credentialing: https://physionet.org/content/mimic-iv-demo/)? It should train a small Transformer from scratch and run evaluate_attribution comparing AttentionRollout, CheferRelevance, and RandomBaseline on Comprehensiveness/Sufficiency.

Here's a complete starting point you can drop in as examples/interpretability/attention_rollout_benchmark_mimic4_demo.py:

"""Benchmark AttentionRollout on the MIMIC-IV demo dataset (no credentialing).

Download once from PhysioNet (open access, no credentialing required):
  https://physionet.org/content/mimic-iv-demo/
Extract and point --ehr_root at the resulting directory.

Run:
    python examples/interpretability/attention_rollout_benchmark_mimic4_demo.py \
        --ehr_root /path/to/mimic-iv-clinical-database-demo
"""
import argparse

import torch

from pyhealth.datasets import MIMIC4Dataset, get_dataloader, split_by_patient
from pyhealth.interpret.methods import AttentionRollout, CheferRelevance, RandomBaseline
from pyhealth.metrics.interpretability import evaluate_attribution
from pyhealth.models import Transformer
from pyhealth.tasks import MortalityPredictionMIMIC4
from pyhealth.trainer import Trainer


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--ehr_root",
        required=True,
        help="Path to the extracted mimic-iv-clinical-database-demo directory",
    )
    parser.add_argument("--device", default="cpu")
    args = parser.parse_args()

    torch.manual_seed(0)

    # 1. Load the open-access demo dataset (no credentialing required).
    base_dataset = MIMIC4Dataset(
        ehr_root=args.ehr_root,
        ehr_tables=[
            "patients",
            "admissions",
            "diagnoses_icd",
            "procedures_icd",
            "prescriptions",
        ],
    )

    # 2. Apply the plain-Transformer mortality task; train fresh (no checkpoints).
    sample_dataset = base_dataset.set_task(MortalityPredictionMIMIC4())
    print(f"Loaded {len(sample_dataset)} samples")

    # 3. Patient-level split and loaders.
    train_ds, val_ds, test_ds = split_by_patient(
        sample_dataset, [0.7, 0.1, 0.2], seed=42
    )
    train_loader = get_dataloader(train_ds, batch_size=16, shuffle=True)
    val_loader = get_dataloader(val_ds, batch_size=16, shuffle=False)
    test_loader = get_dataloader(test_ds, batch_size=16, shuffle=False)

    # 4. Small Transformer — already exposes the attention-readout methods
    #    both AttentionRollout and CheferRelevance rely on.
    model = Transformer(
        dataset=sample_dataset,
        embedding_dim=64,
        heads=2,
        num_layers=2,
        dropout=0.1,
    )

    # 5. Train briefly on the demo data.
    trainer = Trainer(model=model, device=args.device, metrics=["roc_auc"])
    trainer.train(
        train_dataloader=train_loader,
        val_dataloader=val_loader,
        epochs=5,
        monitor="roc_auc",
        monitor_criterion="max",
    )
    model.eval()

    # 6. Compare attention interpreters against the random floor.
    methods = {
        "random": RandomBaseline(model),
        "chefer": CheferRelevance(model),
        "rollout": AttentionRollout(model),
    }

    print(f"\n{'method':<10}{'comprehensiveness':>20}{'sufficiency':>16}")
    print("-" * 46)
    for name, method in methods.items():
        scores = evaluate_attribution(
            model,
            test_loader,
            method,
            metrics=["comprehensiveness", "sufficiency"],
            percentages=[25, 50, 99],
        )
        print(
            f"{name:<10}"
            f"{scores['comprehensiveness']:>20.4f}"
            f"{scores['sufficiency']:>16.4f}"
        )


if __name__ == "__main__":
    main()

Both rollout and chefer should land above random on comprehensiveness — that's the sanity check that confirms the metric path works end-to-end for AttentionRollout.

Happy to review once those two are in.


Generated by Claude Code

@jhnwu3 jhnwu3 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check comments. But you're almost done.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants