Skip to content

[JAX] Add attention tutorials#3162

Open
KshitijLakhani wants to merge 2 commits into
NVIDIA:mainfrom
KshitijLakhani:klakhani/feat/jax-attention-tutorials
Open

[JAX] Add attention tutorials#3162
KshitijLakhani wants to merge 2 commits into
NVIDIA:mainfrom
KshitijLakhani:klakhani/feat/jax-attention-tutorials

Conversation

@KshitijLakhani

@KshitijLakhani KshitijLakhani commented Jul 1, 2026

Copy link
Copy Markdown
Collaborator

Description

Add attention tutorial for integrating TE into an existing framework

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Add a page for non-CP attn tutorial
Add a page for CP attn tutorial

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani self-assigned this Jul 1, 2026
@KshitijLakhani KshitijLakhani marked this pull request as ready for review July 1, 2026 22:56
@KshitijLakhani KshitijLakhani changed the title Add JAX attention tutorials [JAX] Add attention tutorials Jul 1, 2026
@greptile-apps

greptile-apps Bot commented Jul 1, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds two JAX attention tutorials — a single-GPU BSHD example (GQA + SWA + MLA-style heads) and a multi-GPU context-parallel THD example (Ring and AllGather strategies) — along with pytest entry points for both, pre-recorded output files, and updates to the integration hub page.

  • attention.py / attention.rst: demonstrates replacing a native JAX GQA+SWA implementation with DotProductAttention, and adds an MLA-style variant with different Q/K vs V head dimensions.
  • attention_context_parallel.py / attention_context_parallel.rst: shows striped causal load-balanced CP over a JAX mesh using fused_attn directly, with explicit sharding and Ring vs AllGather strategy comparison.
  • test_attention.py: top-level imports of both tutorial modules conflict with the deferred-import pattern established in test_dense.py; the context-parallel module allocates ~16 GB of tensors at import time, which can cause the test job to crash before skip marks take effect on low-memory or single-GPU nodes.

Confidence Score: 4/5

Safe to merge for documentation purposes; the one actionable concern is in the test harness, not in the tutorial code itself.

The tutorial code and RST docs are clean. The only functional issue is in test_attention.py: it top-level imports attention_context_parallel, which allocates ~16 GB of GPU tensors at import time. On a single-GPU CI node (where context-parallel tests are meant to be skipped), this can OOM before any skip guard fires, turning a clean skip into a job failure. test_dense.py in the same directory explicitly documents and avoids this exact pattern.

docs/examples/jax/test_attention.py — the import strategy should be revisited to match the deferred-import pattern in test_dense.py.

Important Files Changed

Filename Overview
docs/examples/jax/test_attention.py New pytest file; top-level imports of both tutorial modules contradict the deferred-import pattern documented in test_dense.py, risking OOM-on-import before skip marks can fire for context-parallel tests.
docs/examples/jax/attention.py New single-GPU BSHD tutorial; TEDotProductAttention silently uses module-level num_query_heads instead of a class attribute, but is otherwise well-structured.
docs/examples/jax/attention_context_parallel.py New context-parallel THD tutorial; mesh setup, striped reorder/inverse-reorder, and Ring/AllGather benchmarking are correctly structured. Module-level tensor allocation is large (~16 GB total) and may cause issues when imported unconditionally from tests.
docs/examples/jax/attention.rst RST documentation replacing the TODO placeholder; literalinclude markers are consistent with the Python source.
docs/examples/jax/attention_context_parallel.rst New RST documentation for CP tutorial; literalinclude markers align with the Python source and explanations are clear.
docs/examples/te_jax_integration.rst Hub page updated to mark both attention tutorials as Available and add a new CP row; toctree entry added correctly.
docs/examples/jax/attention.out Pre-recorded output from a GB200; timing values are hardware-specific and will differ on other GPUs, but the RST anchors the hardware context explicitly.
docs/examples/jax/attention_context_parallel.out Pre-recorded 4-GPU context-parallel output; timing values are hardware-specific but RST context is clear.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[attention.py\nmodule-level init] --> B[FlaxNativeGQAAttention\nbaseline]
    A --> C[TEDotProductAttention\nGQA + SWA]
    A --> D[TEDotProductAttention\nMLA-style head dims]

    E[attention_context_parallel.py\nmodule-level init] --> F[create_qkv_inputs\nbatch=2, seq=65536]
    E --> G[create_packed_segment_ids_and_pos\n16 segments per seq]
    E --> H[SequenceDescriptor\nTHD format]

    F & G & H --> I[shard_for_context_parallel\nreorder + device_put]
    I --> J[fused_attn\nCPStrategy.RING\nstripe_size=1]
    I --> K[fused_attn\nCPStrategy.ALL_GATHER\nstripe_size=4096]
    J & K --> L[inverse_reorder_causal_load_balancing\nrestore original token order]

    M[test_attention.py] -->|top-level import| A
    M -->|top-level import| E
    M --> N[test_bshd_gqa_swa_runs]
    M --> O[test_mla_variant_runs]
    M --> P[test_multi_gpu_context_parallel_ring_case\nrequires_cp skip]
    M --> Q[test_multi_gpu_context_parallel_allgather_case\nrequires_cp skip]
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    A[attention.py\nmodule-level init] --> B[FlaxNativeGQAAttention\nbaseline]
    A --> C[TEDotProductAttention\nGQA + SWA]
    A --> D[TEDotProductAttention\nMLA-style head dims]

    E[attention_context_parallel.py\nmodule-level init] --> F[create_qkv_inputs\nbatch=2, seq=65536]
    E --> G[create_packed_segment_ids_and_pos\n16 segments per seq]
    E --> H[SequenceDescriptor\nTHD format]

    F & G & H --> I[shard_for_context_parallel\nreorder + device_put]
    I --> J[fused_attn\nCPStrategy.RING\nstripe_size=1]
    I --> K[fused_attn\nCPStrategy.ALL_GATHER\nstripe_size=4096]
    J & K --> L[inverse_reorder_causal_load_balancing\nrestore original token order]

    M[test_attention.py] -->|top-level import| A
    M -->|top-level import| E
    M --> N[test_bshd_gqa_swa_runs]
    M --> O[test_mla_variant_runs]
    M --> P[test_multi_gpu_context_parallel_ring_case\nrequires_cp skip]
    M --> Q[test_multi_gpu_context_parallel_allgather_case\nrequires_cp skip]
Loading

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +15 to +17
def test_bshd_gqa_swa_runs():
out = attention.te_model.apply(
attention.te_vars,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Top-level tutorial imports contradict the established deferral pattern

test_dense.py contains an explicit comment explaining why tutorial module imports must be deferred to each test body: dense.py runs te_vars = te_model.init(...) at module scope, which raises on unsupported hardware before pytest can apply skip marks. The same risk exists here for attention_context_parallel: its module-level code allocates four (2, 65536, 128, 128) bfloat16 tensors (~4 GB each) before any @requires_cp guard has a chance to fire. On a single-GPU CI node where context-parallel tests should be skipped, this import can trigger an OOM or a JAX device-memory error, making the job fail instead of report a clean skip. The attention import also runs model initialization at module scope, though it is less likely to OOM in practice.

Comment on lines +151 to +172
transpose_batch_sequence=False,
window_size=self.window_size,
)(
query,
key,
value,
sequence_descriptor=sequence_descriptor,
deterministic=deterministic,
)


te_model = TEDotProductAttention(num_kv_heads=num_kv_heads, window_size=window_size)
te_vars = te_model.init(
jax.random.PRNGKey(2026),
qkv,
sequence_descriptor=sequence_descriptor,
deterministic=False,
)
# ATTENTION_TE_MODEL_END


def run_forward_backward(model, variables, input_qkv, output_grad, seq_desc=None):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 num_attention_heads is taken from the module-level global num_query_heads (128) rather than a class attribute. A reader copying TEDotProductAttention to a different file and changing the global, or instantiating it with a different query-head count, would get silent wrong behaviour because DotProductAttention would still receive the stale global value. Adding num_query_heads: int as a Flax dataclass field keeps the class self-contained and consistent with how num_kv_heads is already handled.

Suggested change
transpose_batch_sequence=False,
window_size=self.window_size,
)(
query,
key,
value,
sequence_descriptor=sequence_descriptor,
deterministic=deterministic,
)
te_model = TEDotProductAttention(num_kv_heads=num_kv_heads, window_size=window_size)
te_vars = te_model.init(
jax.random.PRNGKey(2026),
qkv,
sequence_descriptor=sequence_descriptor,
deterministic=False,
)
# ATTENTION_TE_MODEL_END
def run_forward_backward(model, variables, input_qkv, output_grad, seq_desc=None):
class TEDotProductAttention(nn.Module):
"""Thin Flax wrapper around TE's DotProductAttention."""
num_query_heads: int
num_kv_heads: int
qk_head_dim: int = head_dim
attn_mask_type: str = "causal"
qkv_layout: str = "bshd_bshd_bshd"
window_size: Optional[Tuple[int, int]] = None
@nn.compact
def __call__(
self,
qkv_tensors,
sequence_descriptor: Optional[SequenceDescriptor] = None,
*,
deterministic: bool = False,
):
query, key, value = qkv_tensors
return DotProductAttention(
head_dim=self.qk_head_dim,
num_attention_heads=self.num_query_heads,
num_gqa_groups=self.num_kv_heads,

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@KshitijLakhani KshitijLakhani added the documentation Improvements or additions to documentation label Jul 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

2.18 documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant