Skip to content

[JAX] Expert Parallelism: JAX primitives + VJPs#3036

Merged
phu0ngng merged 24 commits into
NVIDIA:mainfrom
phu0ngng:phuong/ep-3-jax
Jun 27, 2026
Merged

[JAX] Expert Parallelism: JAX primitives + VJPs#3036
phu0ngng merged 24 commits into
NVIDIA:mainfrom
phu0ngng:phuong/ep-3-jax

Conversation

@phu0ngng

@phu0ngng phu0ngng commented May 22, 2026

Copy link
Copy Markdown
Collaborator

Summary

Third PR in the TE Expert Parallelism (EP) series, built on top of #3034. Lands the JAX bindings: an XLA FFI layer over the nvte_ep_* C API, a Python wrapper with custom_vjp for autograd, mesh-aware sharding rules, a multi-process test suite, and an end-to-end MoE example. NCCL ncclEpDispatch/ncclEpCombine are exposed as XLA primitives and work with CUDA-graph capture.

Implementation

Public Python API (transformer_engine/jax/ep.py)

from transformer_engine.jax.ep import (
    EpHandle,        # opaque (id, handle_mem) pair from ep_prepare
    ep_bootstrap,    # one-shot per-process: init NCCL comm + nvte_ep_initialize
    ep_dispatch,     # custom_vjp-wrapped dispatch 
    ep_combine,      # custom_vjp-wrapped combine

ep_dispatch / ep_combine are jax.custom_vjp functions: forward is the FFI primitive, backward calls the matching nvte_ep_*_bwd FFI primitive directly (no ep_prepare in the bwd — routing state is already cached in handle.mem). Note that ep_dispatch also calls ep_prepare in the forward path, which all-gathers and preprocesses routing maps.

XLA FFI bindings (transformer_engine/jax/csrc/extensions/ep.cpp)

Five XLA_FFI_DEFINE_HANDLER_SYMBOL entries — EpPrepareHandler, EpDispatchHandler, EpCombineHandler, EpDispatchBwdHandler, EpCombineBwdHandler — each calling the corresponding nvte_ep_* C entry point. All marked FFI_CudaGraph_Traits so they capture cleanly. handle_id is a static FFI attribute baked at jit trace time.

Primitives + Python layer (transformer_engine/jax/cpp_extensions/ep.py, +951 lines)

Standard TE primitive plumbing: abstract_eval (shape/dtype inference), lowering, impl, outer_primitive registration, and partitioning rules so the EP collective is treated as a single sharded op by XLA (no spurious resharding around it).

Sharding (transformer_engine/jax/sharding.py, +12 lines)

Adds the EP mesh axis to the global mesh resource set so downstream sharding rules can reference it.

Build wiring (build_tools/jax.py, +41 lines)

Threads NCCL EP linkage through the JAX transformer_engine_jax extension. No new top-level build flags; rides on the parent PR's NVTE_BUILD_WITH_NCCL_EP.

Tests & example

  • tests/jax/test_multi_process_ep.py (+690 lines): 13 tests covering bootstrap, ep_prepare shape/handle contracts, primitive-level dispatch/combine identity (uniform + skewed routing), custom_vjp fwd+bwd correctness, and HLO inspection (must not insert XLA collectives outside the EP FFI).
  • tests/jax/multi_process_launch_ep.sh: 4-rank launcher; sets XLA_FLAGS to keep XLA command-buffer capture off for the EP FFI sequence (NCCL EP graph-destroy interaction).
  • examples/jax/ep/ep_moe.py (+394 lines) + run_test_ep.sh: end-to-end MoE with EP, dp=ep=2 mesh, includes a ref-comparison --check that verifies fwd+bwd vs a single-process reference.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps

greptile-apps Bot commented May 22, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR lands the JAX layer for Expert Parallelism: five XLA FFI handlers (EpPrepare, EpDispatch, EpCombine, and their backward variants) wrapping the nvte_ep_* C API, jax.custom_vjp-wrapped Python entry points with mesh-aware SPMD sharding rules, a new nccl_ep_enabled() build helper, a 13-test multi-process suite, and an end-to-end MoE example.

  • Custom VJP wrappers (ep_dispatch, ep_combine) correctly thread handle_mem through residuals and re-pin cotangent sharding in the backward via with_sharding_constraint, avoiding XLA-transpose shard-dropping.
  • XLA FFI handlers in ep.cpp perform on-stream int32→int64 topk_idx upcasting into a pre-allocated scratch buffer and cleanly manage NCCL comm lifetime via EpInstanceState shared-pointer anchors tied to executable lifetimes.
  • nccl_ep_enabled() centralizes the Hopper+ arch gate that was previously duplicated between setup.py and build_tools/jax.py, fixing the previously flagged inconsistency between the two build paths.

Confidence Score: 4/5

Safe to merge with the caveat that direct callers of the public ep_combine_fwd primitive outside the ep_combine custom_vjp wrapper will silently receive partially-uninitialized output in SPMD when out_partition_spec is omitted.

Several previously-flagged issues (the assert on NCCL UID, the libnccl.so.2 hardcode, the stale-communicator re-init path) are still present but were already under discussion. A new finding: EpCombinePrimitive.partition's else branch, added to stop the crash the previous review found, substitutes silent data corruption by passing the global token count as the per-shard output size, causing the C++ combine op to fill only the local token slots and leave the rest of the buffer uninitialised before XLA replicates the result.

transformer_engine/jax/cpp_extensions/ep.py (EpCombinePrimitive.partition else-branch) and transformer_engine/jax/ep.py (assert-based NCCL UID check, hardcoded libnccl.so.2)

Important Files Changed

Filename Overview
transformer_engine/jax/ep.py New file: public ep_bootstrap, ep_dispatch (custom_vjp), ep_combine (custom_vjp). Validation guards and sharding helpers look correct; assert-based NCCL UID check and hardcoded libnccl.so.2 are still present (flagged in earlier reviews).
transformer_engine/jax/cpp_extensions/ep.py New file: five JAX primitives with abstract-eval, lowering, impl, batcher, partition, and shardy rules. EpCombinePrimitive.partition's None-path silently produces wrong SPMD output (P1). Remaining primitives' sharding logic is consistent with global-vs-per-shard conventions.
transformer_engine/jax/csrc/extensions/ep.cpp New file: five XLA FFI handlers wrapped around nvte_ep_* C API. Workspace sizing for int32 to int64 upcast is byte-exact. ncclCommInitRank holds g_ep_mu during construction, but the intended workflow serializes bootstrap before JIT compilation.
build_tools/utils.py New nccl_ep_enabled() helper centralizes arch-gate logic previously duplicated in setup.py and jax.py. native arch treated as Hopper+ (P2 false-positive); otherwise logic matches the original setup.py behavior.
tests/jax/test_multi_process_ep.py Comprehensive 13-test multi-process test suite covering bootstrap, prepare, dispatch/combine identity and custom_vjp fwd+bwd, and HLO collective guard. Tests use a shared class-level bootstrap to avoid repeated NCCL init.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant User as Python (ep.py)
    participant Prim as JAX Primitives
    participant FFI as XLA FFI (ep.cpp)
    participant TE as TE Common
    participant NCCL as NCCL EP

    User->>User: ep_bootstrap() allgather UID, set_ep_bootstrap_params
    FFI->>NCCL: ncclCommInitRank
    FFI->>TE: nvte_ep_initialize

    Note over User,NCCL: Forward pass
    User->>Prim: ep_dispatch(cfg, topk_idx, tokens, weights, cap)
    Prim->>FFI: EpPrepareHandler
    FFI->>TE: nvte_ep_prepare
    Prim->>FFI: EpDispatchHandler
    FFI->>NCCL: nvte_ep_dispatch
    Prim-->>User: recv_tokens, recv_weights, handle_mem, token_counts

    User->>User: expert MLP
    User->>Prim: ep_combine(cfg, handle_mem, token_counts, expert_out, T)
    Prim->>FFI: EpCombineHandler
    FFI->>NCCL: nvte_ep_combine
    Prim-->>User: result tokens

    Note over User,NCCL: Backward pass
    User->>Prim: _combine_bwd
    FFI->>NCCL: nvte_ep_combine_bwd
    User->>Prim: _dispatch_bwd
    FFI->>NCCL: nvte_ep_dispatch_bwd
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant User as Python (ep.py)
    participant Prim as JAX Primitives
    participant FFI as XLA FFI (ep.cpp)
    participant TE as TE Common
    participant NCCL as NCCL EP

    User->>User: ep_bootstrap() allgather UID, set_ep_bootstrap_params
    FFI->>NCCL: ncclCommInitRank
    FFI->>TE: nvte_ep_initialize

    Note over User,NCCL: Forward pass
    User->>Prim: ep_dispatch(cfg, topk_idx, tokens, weights, cap)
    Prim->>FFI: EpPrepareHandler
    FFI->>TE: nvte_ep_prepare
    Prim->>FFI: EpDispatchHandler
    FFI->>NCCL: nvte_ep_dispatch
    Prim-->>User: recv_tokens, recv_weights, handle_mem, token_counts

    User->>User: expert MLP
    User->>Prim: ep_combine(cfg, handle_mem, token_counts, expert_out, T)
    Prim->>FFI: EpCombineHandler
    FFI->>NCCL: nvte_ep_combine
    Prim-->>User: result tokens

    Note over User,NCCL: Backward pass
    User->>Prim: _combine_bwd
    FFI->>NCCL: nvte_ep_combine_bwd
    User->>Prim: _dispatch_bwd
    FFI->>NCCL: nvte_ep_dispatch_bwd
Loading

Reviews (24): Last reviewed commit: "Merge branch 'main' into phuong/ep-3-jax" | Re-trigger Greptile

Comment thread build_tools/jax.py Outdated
Comment thread build_tools/jax.py Outdated
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread transformer_engine/jax/csrc/extensions/ep.cpp Outdated
Error_Type EpPrepareFFI(cudaStream_t stream, Buffer_Type topk_idx, Result_Type token_counts,
Result_Type handle_mem, Result_Type workspace, EpPrepareConfig config) {
auto topk_dims = topk_idx.dimensions();
NVTE_CHECK(topk_dims.size() >= 2,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we return FFI InvalidArgument instead of a NVTE_CHECK for these inputs?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably a good idea. I suggest we make another follow-up MR to do so for all the FFIs.

@phu0ngng phu0ngng requested a review from tdophung May 22, 2026 15:51
@phu0ngng

Copy link
Copy Markdown
Collaborator Author

I would appreciate your help to review this PR @tdophung @jberchtold-nvidia!
Please focus on the changes in the JAX side, as the TE/Common ones will be discussed in #3034

Comment thread examples/jax/ep/ep_moe.py Outdated
Comment thread tests/jax/multi_process_launch_ep.sh Outdated
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread examples/jax/ep/ep_moe.py
Comment thread transformer_engine/jax/ep.py Outdated
Comment thread transformer_engine/jax/ep.py Outdated
Comment thread transformer_engine/jax/ep.py Outdated

@jberchtold-nvidia jberchtold-nvidia left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending CI

Comment thread transformer_engine/jax/ep.py
Comment thread transformer_engine/jax/ep.py
Comment thread transformer_engine/jax/ep.py Outdated
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread transformer_engine/jax/csrc/extensions/ep.cpp
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread transformer_engine/jax/csrc/extensions.h Outdated
jberchtold-nvidia pushed a commit to jberchtold-nvidia/TransformerEngine that referenced this pull request Jun 5, 2026
PR NVIDIA#3034 commit 9b225cb added a required NVTEEpGroupConfig.max_token_dtype
field. The C++ backend (ep_backend.cpp:349) enforces
    typeToSize(tok_dtype) <= typeToSize(max_token_dtype)
at every dispatch, and the field is also used at group create to size the
NCCL EP staging buffers (ep_backend.cpp:221-222).

PR NVIDIA#3036's JAX bootstrap (SetEpBootstrapParams / ep_bootstrap) was written
before this field existed and never set it, so any JAX EP group landed with
the zero-initialized default (kByte = 1 byte). Any bf16/fp16 dispatch from
JAX then failed immediately with:
    tokens dtype (6) wider than group max_token_dtype (0)

This commit threads max_token_dtype end-to-end:

  - transformer_engine/jax/csrc/extensions.h
    update SetEpBootstrapParams declaration to match the new arity.

  - transformer_engine/jax/csrc/extensions/ep.cpp
    add max_token_dtype to EpBootstrapParams and SetEpBootstrapParams;
    forward it into NVTEEpGroupConfig in the EpResources ctor.

  - transformer_engine/jax/csrc/extensions/pybind.cpp
    add the matching pybind11::arg("max_token_dtype") = 0.

  - transformer_engine/jax/ep.py
    add max_token_dtype kwarg to ep_bootstrap, convert numpy dtype to
    NVTEDType int, forward to the C++ setter.

Carried on the te-ep-fixes branch until PR NVIDIA#3036 exposes the field upstream.
See PR NVIDIA#3034 (commit 9b225cb, ep.h:43) for the field definition.
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
Comment thread transformer_engine/jax/csrc/extensions/ep.cpp
@phu0ngng phu0ngng force-pushed the phuong/ep-3-jax branch 2 times, most recently from 06f8a13 to c34771d Compare June 10, 2026 15:24
tdophung added a commit to tdophung/TransformerEngine that referenced this pull request Jun 10, 2026
PR NVIDIA#3034 commit 9b225cb added a required NVTEEpGroupConfig.max_token_dtype
field. The C++ backend (ep_backend.cpp:349) enforces
    typeToSize(tok_dtype) <= typeToSize(max_token_dtype)
at every dispatch, and the field is also used at group create to size the
NCCL EP staging buffers (ep_backend.cpp:221-222).

PR NVIDIA#3036's JAX bootstrap (SetEpBootstrapParams / ep_bootstrap) was written
before this field existed and never set it, so any JAX EP group landed with
the zero-initialized default (kByte = 1 byte). Any bf16/fp16 dispatch from
JAX then failed immediately with:
    tokens dtype (6) wider than group max_token_dtype (0)

This commit threads max_token_dtype end-to-end:

  - transformer_engine/jax/csrc/extensions.h
    update SetEpBootstrapParams declaration to match the new arity.

  - transformer_engine/jax/csrc/extensions/ep.cpp
    add max_token_dtype to EpBootstrapParams and SetEpBootstrapParams;
    forward it into NVTEEpGroupConfig in the EpResources ctor.

  - transformer_engine/jax/csrc/extensions/pybind.cpp
    add the matching pybind11::arg("max_token_dtype") = 0.

  - transformer_engine/jax/ep.py
    add max_token_dtype kwarg to ep_bootstrap, convert numpy dtype to
    NVTEDType int, forward to the C++ setter.

Carried on the te-ep-fixes branch until PR NVIDIA#3036 exposes the field upstream.
See PR NVIDIA#3034 (commit 9b225cb, ep.h:43) for the field definition.
tdophung added a commit to tdophung/TransformerEngine that referenced this pull request Jun 10, 2026
Reset 33 local commits onto phuong/ep-3-jax @ c34771d (her latest with
EpConfig + EpLayerConfig API, NCCL bumped to 808d2433) and re-applied
the three deltas uniquely ours:

  * transformer_engine/jax/moe.py: replaces upstream's multi-backend
    MoE block with our TE-EP-only single-custom-vjp rewrite. Adapted
    to her new API surface: tex.EpLayerConfig replaces tex.ep_make_handle
    (no more EpHandle pool/cache); 5 EP callsites rewired (cfg passed
    in place of handle, ep_prepare arg order swapped, top_k= dropped
    from ep_dispatch_bwd since it's now in cfg.
  * tests/jax/test_te_ep_moe.py: TE-EP MoE test (kept), with
    ep_bootstrap kwargs ep_size= and allow_handle_mem_reloc= dropped
    (no longer supported; ep_size is derived from mesh axes and the
    handle_mem reloc gating is gone).
  * tests/jax/run_te_ep_moe.sh: multi-process launcher (kept).

Pre-sync state preserved at branch
teddy/te_ep_integration.backup-pre-phuong-sync.
EOF
)
@phu0ngng

Copy link
Copy Markdown
Collaborator Author

/te-ci JAX L1

phu0ngng and others added 12 commits June 24, 2026 00:22
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…; define NVTE_WITH_NCCL_EP

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ract, drop dead helpers

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…e example) jax distributed suites

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ARY_PATH for libnccl_ep.so

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ck via nccl_ep_enabled()

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng

Copy link
Copy Markdown
Collaborator Author

/te-ci JAX L1

Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
@phu0ngng phu0ngng added the 2.7.0 label Jun 24, 2026
…ition methods

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng

Copy link
Copy Markdown
Collaborator Author

/te-ci JAX L1

@jberchtold-nvidia jberchtold-nvidia left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending CI

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng

Copy link
Copy Markdown
Collaborator Author

/te-ci JAX L1

@phu0ngng phu0ngng merged commit 42f37ff into NVIDIA:main Jun 27, 2026
8 of 13 checks passed
@phu0ngng phu0ngng deleted the phuong/ep-3-jax branch June 27, 2026 08:57
Comment on lines +563 to +599
mesh,
arg_infos,
result_infos,
):
del result_infos
eo_spec = arg_infos[1].sharding.spec
if not _ep_spec_ok(eo_spec, trailing_count=2):
raise NotImplementedError(
"EpCombine: expert_out must be sharded as PartitionSpec(ep_resource,"
" None, None) (or ((dp, ep), None, None) when dp/fsdp is set)"
f" over [num_procs, recv_pr, H]; got spec={eo_spec}."
)
if out_partition_spec is not None:
per_shard_leading = _leading_per_shard(out_leading_shape, out_partition_spec[0], mesh)
out_sharding = NamedSharding(mesh, PartitionSpec(*out_partition_spec))
else:
per_shard_leading = out_leading_shape
out_sharding = NamedSharding(mesh, PartitionSpec())
arg_shardings = tuple(a.sharding for a in arg_infos)

def sharded_impl(handle_mem, expert_out):
return EpCombinePrimitive.impl(
handle_mem,
expert_out,
top_k,
dispatch_output_per_expert_alignment,
per_shard_leading,
out_partition_spec,
)

return mesh, sharded_impl, out_sharding, arg_shardings

@staticmethod
def shardy_sharding_rule(*args):
# Signature: (*static_args, mesh, value_types, result_types). Static args:
# (top_k, dispatch_alignment, out_leading_shape, out_partition_spec).
result_types = args[-1]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 out_partition_spec=None in SPMD produces silent data corruption instead of a crash

The else branch (reached when ep_combine_fwd is called with the default out_partition_spec=None in a multi-device SPMD context) sets per_shard_leading = out_leading_shape (the global token count, e.g. (8,)) instead of the per-shard count (e.g. (2,) with ep=4), and assigns out_sharding = PartitionSpec() (fully replicated). When sharded_impl runs, it allocates an (8, H) output buffer on each device but nvte_ep_combine writes only the local 2-token result, leaving 6 rows uninitialized, and XLA then treats that partial buffer as the authoritative replicated global result, propagating garbage values silently.

The previous review flagged the crash that existed at this code site before the else branch was added and suggested raise ValueError("out_partition_spec must be specified in SPMD mode"). That guard was not applied, so the failure mode changed from an early loud error to a silent correctness failure. The ep_combine custom_vjp path is safe because _combine_fwd always calls _default_out_partition_spec(), but the public ep_combine_fwd primitive (exported in __all__) is still broken for direct SPMD callers.

KshitijLakhani pushed a commit that referenced this pull request Jun 29, 2026
* Expert Parallelism: JAX primitives + VJPs

---------

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants