[JAX] Expert Parallelism: JAX primitives + VJPs#3036
Conversation
Greptile SummaryThis PR lands the JAX layer for Expert Parallelism: five XLA FFI handlers (
Confidence Score: 4/5Safe to merge with the caveat that direct callers of the public ep_combine_fwd primitive outside the ep_combine custom_vjp wrapper will silently receive partially-uninitialized output in SPMD when out_partition_spec is omitted. Several previously-flagged issues (the assert on NCCL UID, the libnccl.so.2 hardcode, the stale-communicator re-init path) are still present but were already under discussion. A new finding: EpCombinePrimitive.partition's else branch, added to stop the crash the previous review found, substitutes silent data corruption by passing the global token count as the per-shard output size, causing the C++ combine op to fill only the local token slots and leave the rest of the buffer uninitialised before XLA replicates the result. transformer_engine/jax/cpp_extensions/ep.py (EpCombinePrimitive.partition else-branch) and transformer_engine/jax/ep.py (assert-based NCCL UID check, hardcoded libnccl.so.2) Important Files Changed
Sequence Diagram%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
participant User as Python (ep.py)
participant Prim as JAX Primitives
participant FFI as XLA FFI (ep.cpp)
participant TE as TE Common
participant NCCL as NCCL EP
User->>User: ep_bootstrap() allgather UID, set_ep_bootstrap_params
FFI->>NCCL: ncclCommInitRank
FFI->>TE: nvte_ep_initialize
Note over User,NCCL: Forward pass
User->>Prim: ep_dispatch(cfg, topk_idx, tokens, weights, cap)
Prim->>FFI: EpPrepareHandler
FFI->>TE: nvte_ep_prepare
Prim->>FFI: EpDispatchHandler
FFI->>NCCL: nvte_ep_dispatch
Prim-->>User: recv_tokens, recv_weights, handle_mem, token_counts
User->>User: expert MLP
User->>Prim: ep_combine(cfg, handle_mem, token_counts, expert_out, T)
Prim->>FFI: EpCombineHandler
FFI->>NCCL: nvte_ep_combine
Prim-->>User: result tokens
Note over User,NCCL: Backward pass
User->>Prim: _combine_bwd
FFI->>NCCL: nvte_ep_combine_bwd
User->>Prim: _dispatch_bwd
FFI->>NCCL: nvte_ep_dispatch_bwd
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
participant User as Python (ep.py)
participant Prim as JAX Primitives
participant FFI as XLA FFI (ep.cpp)
participant TE as TE Common
participant NCCL as NCCL EP
User->>User: ep_bootstrap() allgather UID, set_ep_bootstrap_params
FFI->>NCCL: ncclCommInitRank
FFI->>TE: nvte_ep_initialize
Note over User,NCCL: Forward pass
User->>Prim: ep_dispatch(cfg, topk_idx, tokens, weights, cap)
Prim->>FFI: EpPrepareHandler
FFI->>TE: nvte_ep_prepare
Prim->>FFI: EpDispatchHandler
FFI->>NCCL: nvte_ep_dispatch
Prim-->>User: recv_tokens, recv_weights, handle_mem, token_counts
User->>User: expert MLP
User->>Prim: ep_combine(cfg, handle_mem, token_counts, expert_out, T)
Prim->>FFI: EpCombineHandler
FFI->>NCCL: nvte_ep_combine
Prim-->>User: result tokens
Note over User,NCCL: Backward pass
User->>Prim: _combine_bwd
FFI->>NCCL: nvte_ep_combine_bwd
User->>Prim: _dispatch_bwd
FFI->>NCCL: nvte_ep_dispatch_bwd
Reviews (24): Last reviewed commit: "Merge branch 'main' into phuong/ep-3-jax" | Re-trigger Greptile |
| Error_Type EpPrepareFFI(cudaStream_t stream, Buffer_Type topk_idx, Result_Type token_counts, | ||
| Result_Type handle_mem, Result_Type workspace, EpPrepareConfig config) { | ||
| auto topk_dims = topk_idx.dimensions(); | ||
| NVTE_CHECK(topk_dims.size() >= 2, |
There was a problem hiding this comment.
nit: can we return FFI InvalidArgument instead of a NVTE_CHECK for these inputs?
There was a problem hiding this comment.
This is probably a good idea. I suggest we make another follow-up MR to do so for all the FFIs.
|
I would appreciate your help to review this PR @tdophung @jberchtold-nvidia! |
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
LGTM pending CI
PR NVIDIA#3034 commit 9b225cb added a required NVTEEpGroupConfig.max_token_dtype field. The C++ backend (ep_backend.cpp:349) enforces typeToSize(tok_dtype) <= typeToSize(max_token_dtype) at every dispatch, and the field is also used at group create to size the NCCL EP staging buffers (ep_backend.cpp:221-222). PR NVIDIA#3036's JAX bootstrap (SetEpBootstrapParams / ep_bootstrap) was written before this field existed and never set it, so any JAX EP group landed with the zero-initialized default (kByte = 1 byte). Any bf16/fp16 dispatch from JAX then failed immediately with: tokens dtype (6) wider than group max_token_dtype (0) This commit threads max_token_dtype end-to-end: - transformer_engine/jax/csrc/extensions.h update SetEpBootstrapParams declaration to match the new arity. - transformer_engine/jax/csrc/extensions/ep.cpp add max_token_dtype to EpBootstrapParams and SetEpBootstrapParams; forward it into NVTEEpGroupConfig in the EpResources ctor. - transformer_engine/jax/csrc/extensions/pybind.cpp add the matching pybind11::arg("max_token_dtype") = 0. - transformer_engine/jax/ep.py add max_token_dtype kwarg to ep_bootstrap, convert numpy dtype to NVTEDType int, forward to the C++ setter. Carried on the te-ep-fixes branch until PR NVIDIA#3036 exposes the field upstream. See PR NVIDIA#3034 (commit 9b225cb, ep.h:43) for the field definition.
06f8a13 to
c34771d
Compare
PR NVIDIA#3034 commit 9b225cb added a required NVTEEpGroupConfig.max_token_dtype field. The C++ backend (ep_backend.cpp:349) enforces typeToSize(tok_dtype) <= typeToSize(max_token_dtype) at every dispatch, and the field is also used at group create to size the NCCL EP staging buffers (ep_backend.cpp:221-222). PR NVIDIA#3036's JAX bootstrap (SetEpBootstrapParams / ep_bootstrap) was written before this field existed and never set it, so any JAX EP group landed with the zero-initialized default (kByte = 1 byte). Any bf16/fp16 dispatch from JAX then failed immediately with: tokens dtype (6) wider than group max_token_dtype (0) This commit threads max_token_dtype end-to-end: - transformer_engine/jax/csrc/extensions.h update SetEpBootstrapParams declaration to match the new arity. - transformer_engine/jax/csrc/extensions/ep.cpp add max_token_dtype to EpBootstrapParams and SetEpBootstrapParams; forward it into NVTEEpGroupConfig in the EpResources ctor. - transformer_engine/jax/csrc/extensions/pybind.cpp add the matching pybind11::arg("max_token_dtype") = 0. - transformer_engine/jax/ep.py add max_token_dtype kwarg to ep_bootstrap, convert numpy dtype to NVTEDType int, forward to the C++ setter. Carried on the te-ep-fixes branch until PR NVIDIA#3036 exposes the field upstream. See PR NVIDIA#3034 (commit 9b225cb, ep.h:43) for the field definition.
Reset 33 local commits onto phuong/ep-3-jax @ c34771d (her latest with EpConfig + EpLayerConfig API, NCCL bumped to 808d2433) and re-applied the three deltas uniquely ours: * transformer_engine/jax/moe.py: replaces upstream's multi-backend MoE block with our TE-EP-only single-custom-vjp rewrite. Adapted to her new API surface: tex.EpLayerConfig replaces tex.ep_make_handle (no more EpHandle pool/cache); 5 EP callsites rewired (cfg passed in place of handle, ep_prepare arg order swapped, top_k= dropped from ep_dispatch_bwd since it's now in cfg. * tests/jax/test_te_ep_moe.py: TE-EP MoE test (kept), with ep_bootstrap kwargs ep_size= and allow_handle_mem_reloc= dropped (no longer supported; ep_size is derived from mesh axes and the handle_mem reloc gating is gone). * tests/jax/run_te_ep_moe.sh: multi-process launcher (kept). Pre-sync state preserved at branch teddy/te_ep_integration.backup-pre-phuong-sync. EOF )
c34771d to
351b9df
Compare
|
/te-ci JAX L1 |
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…; define NVTE_WITH_NCCL_EP Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ract, drop dead helpers Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…e example) jax distributed suites Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ARY_PATH for libnccl_ep.so Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ck via nccl_ep_enabled() Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
4bb76dc to
9df769a
Compare
|
/te-ci JAX L1 |
…ition methods Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
|
/te-ci JAX L1 |
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
LGTM pending CI
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
|
/te-ci JAX L1 |
| mesh, | ||
| arg_infos, | ||
| result_infos, | ||
| ): | ||
| del result_infos | ||
| eo_spec = arg_infos[1].sharding.spec | ||
| if not _ep_spec_ok(eo_spec, trailing_count=2): | ||
| raise NotImplementedError( | ||
| "EpCombine: expert_out must be sharded as PartitionSpec(ep_resource," | ||
| " None, None) (or ((dp, ep), None, None) when dp/fsdp is set)" | ||
| f" over [num_procs, recv_pr, H]; got spec={eo_spec}." | ||
| ) | ||
| if out_partition_spec is not None: | ||
| per_shard_leading = _leading_per_shard(out_leading_shape, out_partition_spec[0], mesh) | ||
| out_sharding = NamedSharding(mesh, PartitionSpec(*out_partition_spec)) | ||
| else: | ||
| per_shard_leading = out_leading_shape | ||
| out_sharding = NamedSharding(mesh, PartitionSpec()) | ||
| arg_shardings = tuple(a.sharding for a in arg_infos) | ||
|
|
||
| def sharded_impl(handle_mem, expert_out): | ||
| return EpCombinePrimitive.impl( | ||
| handle_mem, | ||
| expert_out, | ||
| top_k, | ||
| dispatch_output_per_expert_alignment, | ||
| per_shard_leading, | ||
| out_partition_spec, | ||
| ) | ||
|
|
||
| return mesh, sharded_impl, out_sharding, arg_shardings | ||
|
|
||
| @staticmethod | ||
| def shardy_sharding_rule(*args): | ||
| # Signature: (*static_args, mesh, value_types, result_types). Static args: | ||
| # (top_k, dispatch_alignment, out_leading_shape, out_partition_spec). | ||
| result_types = args[-1] |
There was a problem hiding this comment.
out_partition_spec=None in SPMD produces silent data corruption instead of a crash
The else branch (reached when ep_combine_fwd is called with the default out_partition_spec=None in a multi-device SPMD context) sets per_shard_leading = out_leading_shape (the global token count, e.g. (8,)) instead of the per-shard count (e.g. (2,) with ep=4), and assigns out_sharding = PartitionSpec() (fully replicated). When sharded_impl runs, it allocates an (8, H) output buffer on each device but nvte_ep_combine writes only the local 2-token result, leaving 6 rows uninitialized, and XLA then treats that partial buffer as the authoritative replicated global result, propagating garbage values silently.
The previous review flagged the crash that existed at this code site before the else branch was added and suggested raise ValueError("out_partition_spec must be specified in SPMD mode"). That guard was not applied, so the failure mode changed from an early loud error to a silent correctness failure. The ep_combine custom_vjp path is safe because _combine_fwd always calls _default_out_partition_spec(), but the public ep_combine_fwd primitive (exported in __all__) is still broken for direct SPMD callers.
* Expert Parallelism: JAX primitives + VJPs --------- Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Summary
Third PR in the TE Expert Parallelism (EP) series, built on top of #3034. Lands the JAX bindings: an XLA FFI layer over the
nvte_ep_*C API, a Python wrapper withcustom_vjpfor autograd, mesh-aware sharding rules, a multi-process test suite, and an end-to-end MoE example. NCCLncclEpDispatch/ncclEpCombineare exposed as XLA primitives and work with CUDA-graph capture.Implementation
Public Python API (
transformer_engine/jax/ep.py)ep_dispatch/ep_combinearejax.custom_vjpfunctions: forward is the FFI primitive, backward calls the matchingnvte_ep_*_bwdFFI primitive directly (noep_preparein the bwd — routing state is already cached inhandle.mem). Note thatep_dispatchalso callsep_preparein the forward path, which all-gathers and preprocesses routing maps.XLA FFI bindings (
transformer_engine/jax/csrc/extensions/ep.cpp)Five
XLA_FFI_DEFINE_HANDLER_SYMBOLentries —EpPrepareHandler,EpDispatchHandler,EpCombineHandler,EpDispatchBwdHandler,EpCombineBwdHandler— each calling the correspondingnvte_ep_*C entry point. All markedFFI_CudaGraph_Traitsso they capture cleanly.handle_idis a static FFI attribute baked at jit trace time.Primitives + Python layer (
transformer_engine/jax/cpp_extensions/ep.py, +951 lines)Standard TE primitive plumbing:
abstract_eval(shape/dtype inference),lowering,impl,outer_primitiveregistration, and partitioning rules so the EP collective is treated as a single sharded op by XLA (no spurious resharding around it).Sharding (
transformer_engine/jax/sharding.py, +12 lines)Adds the EP mesh axis to the global mesh resource set so downstream sharding rules can reference it.
Build wiring (
build_tools/jax.py, +41 lines)Threads NCCL EP linkage through the JAX
transformer_engine_jaxextension. No new top-level build flags; rides on the parent PR'sNVTE_BUILD_WITH_NCCL_EP.Tests & example
tests/jax/test_multi_process_ep.py(+690 lines): 13 tests covering bootstrap,ep_prepareshape/handle contracts, primitive-level dispatch/combine identity (uniform + skewed routing),custom_vjpfwd+bwd correctness, and HLO inspection (must not insert XLA collectives outside the EP FFI).tests/jax/multi_process_launch_ep.sh: 4-rank launcher; setsXLA_FLAGSto keep XLA command-buffer capture off for the EP FFI sequence (NCCL EP graph-destroy interaction).examples/jax/ep/ep_moe.py(+394 lines) +run_test_ep.sh: end-to-end MoE with EP, dp=ep=2 mesh, includes a ref-comparison--checkthat verifies fwd+bwd vs a single-process reference.Type of change
Checklist: