Skip to content

Migrate norms and softmax kernels to NVRTC#3156

Open
CarlosGomes98 wants to merge 9 commits into
NVIDIA:mainfrom
CarlosGomes98:cgomes/nvrtc-phase0
Open

Migrate norms and softmax kernels to NVRTC#3156
CarlosGomes98 wants to merge 9 commits into
NVIDIA:mainfrom
CarlosGomes98:cgomes/nvrtc-phase0

Conversation

@CarlosGomes98

@CarlosGomes98 CarlosGomes98 commented Jun 30, 2026

Copy link
Copy Markdown
Contributor

Description

Enables JIT compilation through NVRTC for Norm and Softmax kernels.
Reduces TE binary size by 36%, sequential build time by 5% (measured in cpu_user total time, hard to measure real impact due to parallelization, machine specs)
This is the first chunk of work related to #3054 .

The softmax kernels were chosen as they seemed like one of the simplest to migrate, for my understanding of the system.
The norm kernels include normalization/layernorm/ln_fwd_cuda_kernel.cu, which is one of the heaviest kernel compilations in the build.

It is still possible to enable nvcc static compilation through NVTE_BUILD_LEGACY_STATIC_FUSED_SOFTMAX and NVTE_BUILD_LEGACY_STATIC_NORM, which then allow for NVTE_DISABLE_NVRTC=1 to be used during runtime.

Build time results:

Measured on RTX 6000 Ada, CUDA 12.8, single arch sm_89, 32-core host. AOT = -DNVTE_BUILD_LEGACY_STATIC_{FUSED_SOFTMAX,NORM}=ON (old behavior); NVRTC = default.

Per TU build time

translation unit compile AOT (s) compile NVRTC (s) Δ time (s) obj AOT (KB) obj NVRTC (KB) Δ size (KB)
scaled_masked_softmax.cu 10.80 2.64 −8.2 2348 274 −2074
scaled_upper_triang_masked_softmax.cu 11.75 2.51 −9.2 1891 238 −1653
scaled_aligned_causal_masked_softmax.cu 10.24 2.52 −7.7 2097 233 −1864
ln_fwd_cuda_kernel.cu 63.90 28.47 −35.4 9092 131 −8961
ln_bwd_semi_cuda_kernel.cu 43.80 28.45 −15.4 5295 121 −5174
rmsnorm_fwd_cuda_kernel.cu 38.28 28.39 −9.9 2499 114 −2385
rmsnorm_bwd_semi_cuda_kernel.cu 37.34 28.55 −8.8 2566 115 −2451
total 225.0 130.5 −94.5 (−42%) 26060 1498 −24562 (−94%)

Binary size

target AOT (MB) NVRTC (MB) Δ (MB)
libtransformer_engine.so 64.4 41.2 −23.2 (−36%)

Total build time

metric AOT NVRTC Δ
wall (s) 283.6 210.6 −72.9 (−26%)
cpu_user (s) 2584.6 2455.0 −129.6 (−5%)
max_rss (MB) 4037 4037 0

JIT compilation cost

kernel NVRTC cold (ms) static (ms)
layernorm_fwd 97.1 3.5
layernorm_bwd 133.1 2.7
rmsnorm_fwd 83.6 1.0
rmsnorm_bwd 106.4 1.1
scaled_masked_softmax_fwd 52.5 0.7
scaled_masked_softmax_bwd 38.8
scaled_upper_triang_softmax_fwd 48.3 0.4
scaled_upper_triang_softmax_bwd 42.3
scaled_aligned_causal_softmax_fwd 48.5 0.3

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Functionality to pass build options to the NVRTC compile manager
  • Softmax kernels through NVRTC
  • rtc_dispatch.cpp to allow NVRTC to work with the registry used by norms. This is the largest chunk of new code.
  • Norm kernels through NVRTC

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 30, 2026
CarlosGomes98 and others added 3 commits June 30, 2026 14:56
Move the fused-softmax and LayerNorm/RMSNorm kernels from build-time template
instantiation to runtime NVRTC compilation, with full coverage of the existing
kernel set so the NVRTC path is the default.

Fused softmax:
- RTC compile/launch path for scaled / scaled-masked / scaled-upper-triangular /
  scaled-aligned-causal softmax, keyed by dtype, shape and mask/causal mode.
- NVTE_BUILD_LEGACY_STATIC_FUSED_SOFTMAX (default OFF) restores the static
  template dispatch.

Normalization (LayerNorm + RMSNorm, forward + backward):
- Replace the static REGISTER_NORM_LAUNCHER template fanout with an NVRTC
  registry that compiles the selected (norm type, direction, dtypes, hidden size,
  CTA config) kernel on first use and caches it.
- NVTE_BUILD_LEGACY_STATIC_NORM (default OFF) restores the static launchers.
- NVRTC-safe kernel sources: kernel sources/headers avoid common.h under
  __CUDACC_RTC__; add the dtype aliases and a minimal std::is_same/conditional_t
  in the RTC build, and replace a zero-length padding array (a GNU extension nvcc
  accepts but NVRTC rejects) with a no-padding union specialization.

KernelManager (util/rtc.{h,cpp}) gains occupancy / function-attribute /
cooperative-launch helpers needed by the norm launchers.

Validated on sm_89 (RTX 6000 Ada): full normalization operator suite 192/192,
softmax + NVRTC unit tests pass; libtransformer_engine.so shrinks ~72 MB -> ~65 MB.
On sm_100a the NVRTC norm forward kernel builds where the static instantiation
crashed the compiler.

Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
@CarlosGomes98

Copy link
Copy Markdown
Contributor Author

/te-ci pytorch

@greptile-apps

greptile-apps Bot commented Jun 30, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR migrates norm (LayerNorm, RMSNorm) and fused-softmax kernels from ahead-of-time (AOT) nvcc compilation to JIT compilation via NVRTC, reducing the shared library size by 36% and sequential build time by ~42% for the affected translation units. A static AOT fallback is preserved under NVTE_BUILD_LEGACY_STATIC_FUSED_SOFTMAX / NVTE_BUILD_LEGACY_STATIC_NORM build flags, and the existing NVTE_DISABLE_NVRTC=1 runtime flag selects it.

  • rtc_dispatch.cpp / rtc_dispatch.h: New 750-line file that registers RTC-backed closures for all four norm kernel families (LN/RMSNorm fwd/bwd, tuned/general paths) using a lazy-compile pattern; workspace/smem sizes are recomputed in host code to match the static kernel formulas.
  • util/rtc.cpp / util/rtc.h: Extended KernelManager with extra_headers, extra_options, cooperative-launch, set_function_attribute, and occupancy-query APIs; upgraded internal locking from std::mutex to std::shared_mutex so concurrent reads no longer block each other; the compile-time TOCTOU is fixed by moving the duplicate-check inside the exclusive lock.
  • utils.cuh: Adds project-owned rtc_detail::is_same / conditional_t to replace <type_traits> for NVRTC compilation, avoiding undefined behavior from adding to namespace std.

Confidence Score: 4/5

The change is safe to merge. The new NVRTC path is guarded by rtc::is_enabled() with a clean static fallback, and all workspace/smem calculations in rtc_dispatch.cpp were verified to reproduce the original static formulas exactly. The shared-mutex upgrade is correctly implemented.

The core correctness concern was whether the host-side workspace/smem computations in rtc_dispatch.cpp match what the original template-instantiated kernels computed. Verification against the pre-PR static launch functions shows the new formulas are identical. The TOCTOU race and namespace std issues raised in earlier review rounds are addressed. The JIT compilation cost (~50–130 ms cold) is clearly documented and acceptable for a one-time-per-process cost.

rtc_dispatch.cpp carries the most risk — it is a dense 742-line file that hand-replicates the smem/workspace logic of six kernel families. Reviewers should pay special attention to the bwd_smem_bytes helper and the finalize_ctas calculation for the tuned backward paths.

Important Files Changed

Filename Overview
transformer_engine/common/normalization/rtc_dispatch.cpp New 742-line file that registers RTC-backed closures for all four norm kernel families. Workspace/smem calculations correctly replicate the static kernel formulas. The is_compiled + compile pattern is safe because compile now rechecks under an exclusive lock.
transformer_engine/common/util/rtc.cpp Upgrades KernelManager to shared_mutex, adds extra_options/extra_headers to compile, fixes TOCTOU by moving the duplicate-check inside the exclusive write lock, and adds cooperative launch, set_function_attribute, and occupancy query APIs.
transformer_engine/common/util/rtc.h Header updated to expose Header struct, extended compile signature, cooperative launch template, set_function_attribute, and occupancy query. lock_ correctly declared mutable for use in const methods with shared locks.
transformer_engine/common/fused_softmax/scaled_masked_softmax.cu Largest softmax file, now dual-mode (NVRTC / static AOT). The forward path explicitly casts scale to acc_t before the RTC launch. std::exp replaced with expf and std::numeric_limits::infinity() with neg_infinity<T>() for NVRTC compatibility.
transformer_engine/common/utils.cuh Adds rtc_detail::is_same, is_same_v, conditional, conditional_t in a project-owned namespace (addressing the prior namespace std UB concern), and updates all usages in abs_val / max_val.
transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu Adds NVRTC registration macros for all four RMSNorm backward variants. ADD_FLAG is now correctly forwarded to register_rmsnorm_bwd_tuned/general and guarded by static_assert(ADD_FLAG) on BackwardAdd paths.
transformer_engine/common/normalization/kernel_params.h New file extracting kernel param structs from common.h as a lean, host-independent header that NVRTC can include.
transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh Replaces #include <type_traits> with NVRTC-compatible rtc_detail::conditional_t. The dx_add_t union is split to avoid zero-length array rejection by NVRTC.
transformer_engine/common/normalization/rtc_dispatch.h Clean new header exposing the eight register_* functions and the StaticFallback type alias with ergonomic default-nullptr fallback parameters.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant Host as Host (CPU)
    participant Reg as TeNormalizationRegistry
    participant Mgr as rtc::KernelManager
    participant NVRTC as NVRTC Driver
    participant GPU

    Note over Host,Reg: Library init (static initializers in ln_fwd_cuda_kernel.cu etc.)
    Host->>Reg: register_ln_fwd_tuned(wt, it, ot, ct, hidden, …)
    Reg->>Reg: Store closure (label, kernel_expr, source, smem_bytes, …)

    Note over Host,GPU: First kernel call (configure_params phase)
    Host->>Reg: "getKernel(key)(launch_params, configure_params=true)"
    Reg->>Mgr: is_compiled(label)?  [shared_lock]
    alt Not compiled
        Mgr-->>Reg: false
        Reg->>Mgr: compile(label, kernel_expr, source, extra_headers) [unique_lock]
        Mgr->>NVRTC: nvrtcCreateProgram + nvrtcCompileProgram
        NVRTC-->>Mgr: PTX / cubin
        Mgr->>Mgr: cache Kernel object
    end
    Mgr->>Mgr: occupancy_max_active_blocks_per_sm [shared_lock]
    Mgr-->>Reg: ctas_per_sm
    Reg->>Host: set ctas_per_col, workspace_bytes, barrier_bytes

    Note over Host,GPU: Second kernel call (launch phase)
    Host->>Reg: "getKernel(key)(launch_params, configure_params=false)"
    alt "smem_bytes >= 48 KB"
        Reg->>Mgr: set_function_attribute(MAX_DYNAMIC_SHARED) [shared_lock]
    end
    alt "ctas_per_row == 1"
        Reg->>Mgr: launch(label, grid, block, smem, stream, params) [shared_lock]
        Mgr->>GPU: cuLaunchKernel
    else "ctas_per_row > 1"
        Reg->>Mgr: launch_cooperative(label, grid, block, smem, stream, params) [shared_lock]
        Mgr->>GPU: cuLaunchCooperativeKernel
    end
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant Host as Host (CPU)
    participant Reg as TeNormalizationRegistry
    participant Mgr as rtc::KernelManager
    participant NVRTC as NVRTC Driver
    participant GPU

    Note over Host,Reg: Library init (static initializers in ln_fwd_cuda_kernel.cu etc.)
    Host->>Reg: register_ln_fwd_tuned(wt, it, ot, ct, hidden, …)
    Reg->>Reg: Store closure (label, kernel_expr, source, smem_bytes, …)

    Note over Host,GPU: First kernel call (configure_params phase)
    Host->>Reg: "getKernel(key)(launch_params, configure_params=true)"
    Reg->>Mgr: is_compiled(label)?  [shared_lock]
    alt Not compiled
        Mgr-->>Reg: false
        Reg->>Mgr: compile(label, kernel_expr, source, extra_headers) [unique_lock]
        Mgr->>NVRTC: nvrtcCreateProgram + nvrtcCompileProgram
        NVRTC-->>Mgr: PTX / cubin
        Mgr->>Mgr: cache Kernel object
    end
    Mgr->>Mgr: occupancy_max_active_blocks_per_sm [shared_lock]
    Mgr-->>Reg: ctas_per_sm
    Reg->>Host: set ctas_per_col, workspace_bytes, barrier_bytes

    Note over Host,GPU: Second kernel call (launch phase)
    Host->>Reg: "getKernel(key)(launch_params, configure_params=false)"
    alt "smem_bytes >= 48 KB"
        Reg->>Mgr: set_function_attribute(MAX_DYNAMIC_SHARED) [shared_lock]
    end
    alt "ctas_per_row == 1"
        Reg->>Mgr: launch(label, grid, block, smem, stream, params) [shared_lock]
        Mgr->>GPU: cuLaunchKernel
    else "ctas_per_row > 1"
        Reg->>Mgr: launch_cooperative(label, grid, block, smem, stream, params) [shared_lock]
        Mgr->>GPU: cuLaunchCooperativeKernel
    end
Loading

Reviews (4): Last reviewed commit: "import cleanup" | Re-trigger Greptile

Comment thread transformer_engine/common/util/rtc.cpp Outdated
Comment thread transformer_engine/common/normalization/rtc_dispatch.cpp
Comment thread transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu Outdated
Comment thread transformer_engine/common/utils.cuh Outdated
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
@CarlosGomes98 CarlosGomes98 force-pushed the cgomes/nvrtc-phase0 branch from 700a1eb to b909589 Compare July 1, 2026 14:01
CarlosGomes98 and others added 2 commits July 1, 2026 16:24
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
@ptrendx

ptrendx commented Jul 1, 2026

Copy link
Copy Markdown
Member

/te-ci pytorch

Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants