[JAX] Add attention tutorials#3162
Conversation
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds two JAX attention tutorials — a single-GPU BSHD example (GQA + SWA + MLA-style heads) and a multi-GPU context-parallel THD example (Ring and AllGather strategies) — along with pytest entry points for both, pre-recorded output files, and updates to the integration hub page.
Confidence Score: 4/5Safe to merge for documentation purposes; the one actionable concern is in the test harness, not in the tutorial code itself. The tutorial code and RST docs are clean. The only functional issue is in test_attention.py: it top-level imports attention_context_parallel, which allocates ~16 GB of GPU tensors at import time. On a single-GPU CI node (where context-parallel tests are meant to be skipped), this can OOM before any skip guard fires, turning a clean skip into a job failure. test_dense.py in the same directory explicitly documents and avoids this exact pattern. docs/examples/jax/test_attention.py — the import strategy should be revisited to match the deferred-import pattern in test_dense.py. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[attention.py\nmodule-level init] --> B[FlaxNativeGQAAttention\nbaseline]
A --> C[TEDotProductAttention\nGQA + SWA]
A --> D[TEDotProductAttention\nMLA-style head dims]
E[attention_context_parallel.py\nmodule-level init] --> F[create_qkv_inputs\nbatch=2, seq=65536]
E --> G[create_packed_segment_ids_and_pos\n16 segments per seq]
E --> H[SequenceDescriptor\nTHD format]
F & G & H --> I[shard_for_context_parallel\nreorder + device_put]
I --> J[fused_attn\nCPStrategy.RING\nstripe_size=1]
I --> K[fused_attn\nCPStrategy.ALL_GATHER\nstripe_size=4096]
J & K --> L[inverse_reorder_causal_load_balancing\nrestore original token order]
M[test_attention.py] -->|top-level import| A
M -->|top-level import| E
M --> N[test_bshd_gqa_swa_runs]
M --> O[test_mla_variant_runs]
M --> P[test_multi_gpu_context_parallel_ring_case\nrequires_cp skip]
M --> Q[test_multi_gpu_context_parallel_allgather_case\nrequires_cp skip]
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
A[attention.py\nmodule-level init] --> B[FlaxNativeGQAAttention\nbaseline]
A --> C[TEDotProductAttention\nGQA + SWA]
A --> D[TEDotProductAttention\nMLA-style head dims]
E[attention_context_parallel.py\nmodule-level init] --> F[create_qkv_inputs\nbatch=2, seq=65536]
E --> G[create_packed_segment_ids_and_pos\n16 segments per seq]
E --> H[SequenceDescriptor\nTHD format]
F & G & H --> I[shard_for_context_parallel\nreorder + device_put]
I --> J[fused_attn\nCPStrategy.RING\nstripe_size=1]
I --> K[fused_attn\nCPStrategy.ALL_GATHER\nstripe_size=4096]
J & K --> L[inverse_reorder_causal_load_balancing\nrestore original token order]
M[test_attention.py] -->|top-level import| A
M -->|top-level import| E
M --> N[test_bshd_gqa_swa_runs]
M --> O[test_mla_variant_runs]
M --> P[test_multi_gpu_context_parallel_ring_case\nrequires_cp skip]
M --> Q[test_multi_gpu_context_parallel_allgather_case\nrequires_cp skip]
Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| def test_bshd_gqa_swa_runs(): | ||
| out = attention.te_model.apply( | ||
| attention.te_vars, |
There was a problem hiding this comment.
Top-level tutorial imports contradict the established deferral pattern
test_dense.py contains an explicit comment explaining why tutorial module imports must be deferred to each test body: dense.py runs te_vars = te_model.init(...) at module scope, which raises on unsupported hardware before pytest can apply skip marks. The same risk exists here for attention_context_parallel: its module-level code allocates four (2, 65536, 128, 128) bfloat16 tensors (~4 GB each) before any @requires_cp guard has a chance to fire. On a single-GPU CI node where context-parallel tests should be skipped, this import can trigger an OOM or a JAX device-memory error, making the job fail instead of report a clean skip. The attention import also runs model initialization at module scope, though it is less likely to OOM in practice.
| transpose_batch_sequence=False, | ||
| window_size=self.window_size, | ||
| )( | ||
| query, | ||
| key, | ||
| value, | ||
| sequence_descriptor=sequence_descriptor, | ||
| deterministic=deterministic, | ||
| ) | ||
|
|
||
|
|
||
| te_model = TEDotProductAttention(num_kv_heads=num_kv_heads, window_size=window_size) | ||
| te_vars = te_model.init( | ||
| jax.random.PRNGKey(2026), | ||
| qkv, | ||
| sequence_descriptor=sequence_descriptor, | ||
| deterministic=False, | ||
| ) | ||
| # ATTENTION_TE_MODEL_END | ||
|
|
||
|
|
||
| def run_forward_backward(model, variables, input_qkv, output_grad, seq_desc=None): |
There was a problem hiding this comment.
num_attention_heads is taken from the module-level global num_query_heads (128) rather than a class attribute. A reader copying TEDotProductAttention to a different file and changing the global, or instantiating it with a different query-head count, would get silent wrong behaviour because DotProductAttention would still receive the stale global value. Adding num_query_heads: int as a Flax dataclass field keeps the class self-contained and consistent with how num_kv_heads is already handled.
| transpose_batch_sequence=False, | |
| window_size=self.window_size, | |
| )( | |
| query, | |
| key, | |
| value, | |
| sequence_descriptor=sequence_descriptor, | |
| deterministic=deterministic, | |
| ) | |
| te_model = TEDotProductAttention(num_kv_heads=num_kv_heads, window_size=window_size) | |
| te_vars = te_model.init( | |
| jax.random.PRNGKey(2026), | |
| qkv, | |
| sequence_descriptor=sequence_descriptor, | |
| deterministic=False, | |
| ) | |
| # ATTENTION_TE_MODEL_END | |
| def run_forward_backward(model, variables, input_qkv, output_grad, seq_desc=None): | |
| class TEDotProductAttention(nn.Module): | |
| """Thin Flax wrapper around TE's DotProductAttention.""" | |
| num_query_heads: int | |
| num_kv_heads: int | |
| qk_head_dim: int = head_dim | |
| attn_mask_type: str = "causal" | |
| qkv_layout: str = "bshd_bshd_bshd" | |
| window_size: Optional[Tuple[int, int]] = None | |
| @nn.compact | |
| def __call__( | |
| self, | |
| qkv_tensors, | |
| sequence_descriptor: Optional[SequenceDescriptor] = None, | |
| *, | |
| deterministic: bool = False, | |
| ): | |
| query, key, value = qkv_tensors | |
| return DotProductAttention( | |
| head_dim=self.qk_head_dim, | |
| num_attention_heads=self.num_query_heads, | |
| num_gqa_groups=self.num_kv_heads, |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Description
Add attention tutorial for integrating TE into an existing framework
Type of change
Changes
Add a page for non-CP attn tutorial
Add a page for CP attn tutorial
Checklist: