Migrate norms and softmax kernels to NVRTC#3156
Conversation
Move the fused-softmax and LayerNorm/RMSNorm kernels from build-time template
instantiation to runtime NVRTC compilation, with full coverage of the existing
kernel set so the NVRTC path is the default.
Fused softmax:
- RTC compile/launch path for scaled / scaled-masked / scaled-upper-triangular /
scaled-aligned-causal softmax, keyed by dtype, shape and mask/causal mode.
- NVTE_BUILD_LEGACY_STATIC_FUSED_SOFTMAX (default OFF) restores the static
template dispatch.
Normalization (LayerNorm + RMSNorm, forward + backward):
- Replace the static REGISTER_NORM_LAUNCHER template fanout with an NVRTC
registry that compiles the selected (norm type, direction, dtypes, hidden size,
CTA config) kernel on first use and caches it.
- NVTE_BUILD_LEGACY_STATIC_NORM (default OFF) restores the static launchers.
- NVRTC-safe kernel sources: kernel sources/headers avoid common.h under
__CUDACC_RTC__; add the dtype aliases and a minimal std::is_same/conditional_t
in the RTC build, and replace a zero-length padding array (a GNU extension nvcc
accepts but NVRTC rejects) with a no-padding union specialization.
KernelManager (util/rtc.{h,cpp}) gains occupancy / function-attribute /
cooperative-launch helpers needed by the norm launchers.
Validated on sm_89 (RTX 6000 Ada): full normalization operator suite 192/192,
softmax + NVRTC unit tests pass; libtransformer_engine.so shrinks ~72 MB -> ~65 MB.
On sm_100a the NVRTC norm forward kernel builds where the static instantiation
crashed the compiler.
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
783156c to
4e8d10a
Compare
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
4e8d10a to
5235723
Compare
|
/te-ci pytorch |
for more information, see https://pre-commit.ci
Greptile SummaryThis PR migrates norm (LayerNorm, RMSNorm) and fused-softmax kernels from ahead-of-time (AOT) nvcc compilation to JIT compilation via NVRTC, reducing the shared library size by 36% and sequential build time by ~42% for the affected translation units. A static AOT fallback is preserved under
Confidence Score: 4/5The change is safe to merge. The new NVRTC path is guarded by The core correctness concern was whether the host-side workspace/smem computations in
Important Files Changed
Sequence Diagram%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
participant Host as Host (CPU)
participant Reg as TeNormalizationRegistry
participant Mgr as rtc::KernelManager
participant NVRTC as NVRTC Driver
participant GPU
Note over Host,Reg: Library init (static initializers in ln_fwd_cuda_kernel.cu etc.)
Host->>Reg: register_ln_fwd_tuned(wt, it, ot, ct, hidden, …)
Reg->>Reg: Store closure (label, kernel_expr, source, smem_bytes, …)
Note over Host,GPU: First kernel call (configure_params phase)
Host->>Reg: "getKernel(key)(launch_params, configure_params=true)"
Reg->>Mgr: is_compiled(label)? [shared_lock]
alt Not compiled
Mgr-->>Reg: false
Reg->>Mgr: compile(label, kernel_expr, source, extra_headers) [unique_lock]
Mgr->>NVRTC: nvrtcCreateProgram + nvrtcCompileProgram
NVRTC-->>Mgr: PTX / cubin
Mgr->>Mgr: cache Kernel object
end
Mgr->>Mgr: occupancy_max_active_blocks_per_sm [shared_lock]
Mgr-->>Reg: ctas_per_sm
Reg->>Host: set ctas_per_col, workspace_bytes, barrier_bytes
Note over Host,GPU: Second kernel call (launch phase)
Host->>Reg: "getKernel(key)(launch_params, configure_params=false)"
alt "smem_bytes >= 48 KB"
Reg->>Mgr: set_function_attribute(MAX_DYNAMIC_SHARED) [shared_lock]
end
alt "ctas_per_row == 1"
Reg->>Mgr: launch(label, grid, block, smem, stream, params) [shared_lock]
Mgr->>GPU: cuLaunchKernel
else "ctas_per_row > 1"
Reg->>Mgr: launch_cooperative(label, grid, block, smem, stream, params) [shared_lock]
Mgr->>GPU: cuLaunchCooperativeKernel
end
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
participant Host as Host (CPU)
participant Reg as TeNormalizationRegistry
participant Mgr as rtc::KernelManager
participant NVRTC as NVRTC Driver
participant GPU
Note over Host,Reg: Library init (static initializers in ln_fwd_cuda_kernel.cu etc.)
Host->>Reg: register_ln_fwd_tuned(wt, it, ot, ct, hidden, …)
Reg->>Reg: Store closure (label, kernel_expr, source, smem_bytes, …)
Note over Host,GPU: First kernel call (configure_params phase)
Host->>Reg: "getKernel(key)(launch_params, configure_params=true)"
Reg->>Mgr: is_compiled(label)? [shared_lock]
alt Not compiled
Mgr-->>Reg: false
Reg->>Mgr: compile(label, kernel_expr, source, extra_headers) [unique_lock]
Mgr->>NVRTC: nvrtcCreateProgram + nvrtcCompileProgram
NVRTC-->>Mgr: PTX / cubin
Mgr->>Mgr: cache Kernel object
end
Mgr->>Mgr: occupancy_max_active_blocks_per_sm [shared_lock]
Mgr-->>Reg: ctas_per_sm
Reg->>Host: set ctas_per_col, workspace_bytes, barrier_bytes
Note over Host,GPU: Second kernel call (launch phase)
Host->>Reg: "getKernel(key)(launch_params, configure_params=false)"
alt "smem_bytes >= 48 KB"
Reg->>Mgr: set_function_attribute(MAX_DYNAMIC_SHARED) [shared_lock]
end
alt "ctas_per_row == 1"
Reg->>Mgr: launch(label, grid, block, smem, stream, params) [shared_lock]
Mgr->>GPU: cuLaunchKernel
else "ctas_per_row > 1"
Reg->>Mgr: launch_cooperative(label, grid, block, smem, stream, params) [shared_lock]
Mgr->>GPU: cuLaunchCooperativeKernel
end
Reviews (4): Last reviewed commit: "import cleanup" | Re-trigger Greptile |
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
700a1eb to
b909589
Compare
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
for more information, see https://pre-commit.ci
|
/te-ci pytorch |
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
Description
Enables JIT compilation through NVRTC for Norm and Softmax kernels.
Reduces TE binary size by 36%, sequential build time by 5% (measured in cpu_user total time, hard to measure real impact due to parallelization, machine specs)
This is the first chunk of work related to #3054 .
The softmax kernels were chosen as they seemed like one of the simplest to migrate, for my understanding of the system.
The norm kernels include
normalization/layernorm/ln_fwd_cuda_kernel.cu, which is one of the heaviest kernel compilations in the build.It is still possible to enable nvcc static compilation through
NVTE_BUILD_LEGACY_STATIC_FUSED_SOFTMAXandNVTE_BUILD_LEGACY_STATIC_NORM, which then allow forNVTE_DISABLE_NVRTC=1to be used during runtime.Build time results:
Measured on RTX 6000 Ada, CUDA 12.8, single arch sm_89, 32-core host. AOT = -DNVTE_BUILD_LEGACY_STATIC_{FUSED_SOFTMAX,NORM}=ON (old behavior); NVRTC = default.
Per TU build time
scaled_masked_softmax.cuscaled_upper_triang_masked_softmax.cuscaled_aligned_causal_masked_softmax.culn_fwd_cuda_kernel.culn_bwd_semi_cuda_kernel.curmsnorm_fwd_cuda_kernel.curmsnorm_bwd_semi_cuda_kernel.cuBinary size
libtransformer_engine.soTotal build time
JIT compilation cost
layernorm_fwdlayernorm_bwdrmsnorm_fwdrmsnorm_bwdscaled_masked_softmax_fwdscaled_masked_softmax_bwdscaled_upper_triang_softmax_fwdscaled_upper_triang_softmax_bwdscaled_aligned_causal_softmax_fwdFixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
rtc_dispatch.cppto allow NVRTC to work with the registry used by norms. This is the largest chunk of new code.Checklist: