Add attention rollout interpretability method (Abnar & Zuidema 2020)#1158
Add attention rollout interpretability method (Abnar & Zuidema 2020)#1158fbonc wants to merge 7 commits into
Conversation
…_rollout.rst added
jhnwu3
left a comment
There was a problem hiding this comment.
Thanks for this — the implementation, docs, and test suite are in good shape. The rationale in the description is clear and the correctness invariants in the tests (row-stochastic check, identity-attention test) are exactly what we want to see. A few minor things (header comments, trailing newline in attention_rollout.py, the unreachable ValueError branch in _fuse_heads) are non-blocking; feel free to clean them up or leave them.
Two things before merge:
1. One doc spot still lists only Chefer
docs/why_pyhealth.rst line 178 currently reads:
Attention-based: Chefer relevance propagation for transformers
Since rollout is the canonical gradient-free attention baseline that Chefer-style methods are measured against, please add it here — e.g.:
Attention-based: Chefer relevance propagation and attention rollout for transformers
Everything else checks out: the __init__ export, the toctree entry, the API rst page, and the six MIMIC comparison scripts are all wired up correctly.
2. A runnable benchmark example
Every existing script that calls evaluate_attribution depends on credentialed full MIMIC-IV plus pre-trained checkpoints and hardcoded /shared/eng/... paths — which is exactly why you noted you couldn't verify end-to-end faithfulness. Could you add a self-contained example under examples/interpretability/ that uses the MIMIC-IV demo dataset (the ~100-patient open-access subset on PhysioNet, freely downloadable with no credentialing: https://physionet.org/content/mimic-iv-demo/)? It should train a small Transformer from scratch and run evaluate_attribution comparing AttentionRollout, CheferRelevance, and RandomBaseline on Comprehensiveness/Sufficiency.
Here's a complete starting point you can drop in as examples/interpretability/attention_rollout_benchmark_mimic4_demo.py:
"""Benchmark AttentionRollout on the MIMIC-IV demo dataset (no credentialing).
Download once from PhysioNet (open access, no credentialing required):
https://physionet.org/content/mimic-iv-demo/
Extract and point --ehr_root at the resulting directory.
Run:
python examples/interpretability/attention_rollout_benchmark_mimic4_demo.py \
--ehr_root /path/to/mimic-iv-clinical-database-demo
"""
import argparse
import torch
from pyhealth.datasets import MIMIC4Dataset, get_dataloader, split_by_patient
from pyhealth.interpret.methods import AttentionRollout, CheferRelevance, RandomBaseline
from pyhealth.metrics.interpretability import evaluate_attribution
from pyhealth.models import Transformer
from pyhealth.tasks import MortalityPredictionMIMIC4
from pyhealth.trainer import Trainer
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--ehr_root",
required=True,
help="Path to the extracted mimic-iv-clinical-database-demo directory",
)
parser.add_argument("--device", default="cpu")
args = parser.parse_args()
torch.manual_seed(0)
# 1. Load the open-access demo dataset (no credentialing required).
base_dataset = MIMIC4Dataset(
ehr_root=args.ehr_root,
ehr_tables=[
"patients",
"admissions",
"diagnoses_icd",
"procedures_icd",
"prescriptions",
],
)
# 2. Apply the plain-Transformer mortality task; train fresh (no checkpoints).
sample_dataset = base_dataset.set_task(MortalityPredictionMIMIC4())
print(f"Loaded {len(sample_dataset)} samples")
# 3. Patient-level split and loaders.
train_ds, val_ds, test_ds = split_by_patient(
sample_dataset, [0.7, 0.1, 0.2], seed=42
)
train_loader = get_dataloader(train_ds, batch_size=16, shuffle=True)
val_loader = get_dataloader(val_ds, batch_size=16, shuffle=False)
test_loader = get_dataloader(test_ds, batch_size=16, shuffle=False)
# 4. Small Transformer — already exposes the attention-readout methods
# both AttentionRollout and CheferRelevance rely on.
model = Transformer(
dataset=sample_dataset,
embedding_dim=64,
heads=2,
num_layers=2,
dropout=0.1,
)
# 5. Train briefly on the demo data.
trainer = Trainer(model=model, device=args.device, metrics=["roc_auc"])
trainer.train(
train_dataloader=train_loader,
val_dataloader=val_loader,
epochs=5,
monitor="roc_auc",
monitor_criterion="max",
)
model.eval()
# 6. Compare attention interpreters against the random floor.
methods = {
"random": RandomBaseline(model),
"chefer": CheferRelevance(model),
"rollout": AttentionRollout(model),
}
print(f"\n{'method':<10}{'comprehensiveness':>20}{'sufficiency':>16}")
print("-" * 46)
for name, method in methods.items():
scores = evaluate_attribution(
model,
test_loader,
method,
metrics=["comprehensiveness", "sufficiency"],
percentages=[25, 50, 99],
)
print(
f"{name:<10}"
f"{scores['comprehensiveness']:>20.4f}"
f"{scores['sufficiency']:>16.4f}"
)
if __name__ == "__main__":
main()Both rollout and chefer should land above random on comprehensiveness — that's the sanity check that confirms the metric path works end-to-end for AttentionRollout.
Happy to review once those two are in.
Generated by Claude Code
jhnwu3
left a comment
There was a problem hiding this comment.
Check comments. But you're almost done.
Contributor: Felipe Amaral Bonchristiano (felipea5@illinois.edu)
Contribution Type: New interpretability method
Description:
Adds vanilla attention rollout (Abnar & Zuidema, "Quantifying Attention Flow
in Transformers," 2020, arXiv:2005.00928) as a new interpretability module,
AttentionRollout. Rollout is the canonical forward-only, gradient-free,class-agnostic attention-flow baseline: it accounts for residual connections
(
 = 0.5·(A + I)), fuses heads by mean, and composes per-layer attention bymatrix product to produce per-token relevance. It complements the existing
CheferRelevance(gradient-weighted, class-specific) by providing the standardbaseline that gradient-based attention methods are measured against, which the
interpretability suite currently lacks.
The implementation reuses the existing attention-readout methods already on
PyHealth's attention models (
set_attention_hooks,get_attention_layers,get_relevance_tensor), so it requires no model-side changes. It is asingle new method file plus its export, tests, docs, and example registrations.
Files to Review:
pyhealth/interpret/methods/attention_rollout.py: core implementation (AttentionRollout)pyhealth/interpret/methods/__init__.py: exportsAttentionRollouttests/core/test_attention_rollout.py: synthetic-data unit tests (see Testing below)docs/api/interpret/pyhealth.interpret.methods.attention_rollout.rst: API documentationdocs/api/interpret.rst: added to the Attribution Methods toctreeexamples/interpretability/{mp,los,dka}_{transformer,stageattn}_mimic4_interpret.py:AttentionRolloutregistered in the method comparison dicts alongsideCheferRelevanceQuick note:
The actual bounty on the doc lists "Rollout Attention" and links arXiv:2012.09838,
which is Chefer et al., Transformer Interpretability Beyond Attention
Visualization (CVPR 2021), a gradient/LRP relevance method, not the rollout
paper. (The existing
CheferRelevanceimplements the related Chefer et al. ICCV2021 method, arXiv:2103.15679.) I read the bounty's intent from its name and from
the actual gap in the suite, as there was no gradient-free, class-agnostic baseline,
and implemented canonical rollout (Abnar & Zuidema 2020) rather than more
Chefer-style work. If the literal citation was intended, happy to redirect.
Key design decisions:
0.5·(A + I); alternative fusions and residual schemes are deferred to optionalkwargs. Again, this module's value is fidelity to the baseline, not improving on it.
isinstance(CheferInterpretable).The three readout methods are general attention readout, not Chefer-specific;
__init__checkshasattrand raisesTypeErrornaming the missing methods.This keeps the PR to one new file with zero edits to the shared interface.
target_class_idxaccepted but ignored, documented as a no-op, so rollout isdrop-in swappable with class-specific interpreters in existing pipelines.
_map_to_input_shapesduplicated fromCheferRelevance(rather than factoredto a shared util) so attributions match the raw-input granularity the
comprehensiveness/sufficiency metrics expect, while keeping this PR free of edits
to
chefer.py.Proposed follow-up: extract a general
AttentionInterpretableinterface and a shared shape-mapping helper that both
AttentionRolloutandCheferRelevancedepend on, removing the duck-typing and the duplicated_map_to_input_shapes. Kept separate to avoid bundling a refactor of shared codeinto a feature PR.
Testing: Unit tests use small synthetic data (
create_sample_dataset, tinyconfig, seeded) and run in well under a second with no network or credentials.
Beyond shape and dict-key checks, they assert the two correctness invariants:
(1) per-token relevance sums to 1 before input-shape expansion (the product of
row-stochastic matrices is row-stochastic), and (2) identity attention at every
layer yields an identity rollout. Construction-time errors (incompatible model,
unsupported
head_fusion) are covered.Note on verification: I am not yet MIMIC-credentialed, so end-to-end correctness
is established via the synthetic unit tests above;
AttentionRolloutis registeredin the MIMIC-IV comparison scripts for parity with the other methods but I have not
run those end-to-end myself.