Skip to content

[Common] Support scaled & clamped swiglu, srelu for BF16 #3132

Open
zhongbozhu wants to merge 7 commits into
NVIDIA:mainfrom
zhongbozhu:add_support_fused_swiglu
Open

[Common] Support scaled & clamped swiglu, srelu for BF16 #3132
zhongbozhu wants to merge 7 commits into
NVIDIA:mainfrom
zhongbozhu:add_support_fused_swiglu

Conversation

@zhongbozhu

Copy link
Copy Markdown
Collaborator

Description

Support Mega-C++ with Cublas BF16 Grouped GEMM backend: #3099

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: zhongboz <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 16, 2026
@zhongbozhu zhongbozhu marked this pull request as ready for review June 16, 2026 07:32
@zhongbozhu zhongbozhu requested a review from ptrendx as a code owner June 16, 2026 07:32
@greptile-apps

greptile-apps Bot commented Jun 16, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds six new CUDA kernels implementing scaled and clamped SwiGLU and SReLU activations for use with the Mega-C++ BF16 grouped GEMM backend, along with a comprehensive C++ test suite.

  • New kernels (scaled_activation.cu): forward and backward passes for SwiGLU, ClampedSwiGLU, and SReLU, each with a "no scale-grad" variant (flat element-wise grid) and a "with scale-grad" variant (one block per row for the per-row reduction). All math is promoted to float32 internally and cast once at store.
  • New API (activation.h): six public functions (nvte_scaled_swiglu, nvte_scaled_dswiglu, nvte_scaled_clamped_swiglu, nvte_scaled_clamped_dswiglu, nvte_scaled_srelu, nvte_scaled_dsrelu) covering forward and backward paths.
  • New test (test_scaled_activation.cu): parameterised over activation type, data dtype, scale dtype, tensor shape, GLU interleave factor, and whether to compute the scale gradient; validates forward output, grad_input, and grad_scales against a CPU reference.

Confidence Score: 5/5

Safe to merge. The kernel math, reduction, and index arithmetic are all correct; the CPU reference in the test is consistent with the CUDA implementation.

The forward and backward passes for all three activation types (SwiGLU, ClampedSwiGLU, SReLU) are mathematically correct and verified against math.h primitives. The per-row block reduction in the with-scale-grad kernels is implemented properly with correct shared-memory synchronization. Vectorization dispatch, GLU interleave index computation, and input validation checks are all sound. No logic errors were found.

No files require special attention. All five changed files are straightforward additions.

Important Files Changed

Filename Overview
transformer_engine/common/activation/scaled_activation.cu New file implementing six CUDA kernels for scaled activations. Forward/backward math for SwiGLU, ClampedSwiGLU, and SReLU is correct and consistent with math.h. Vectorization dispatch, per-row block-reduction, and GLU interleave index arithmetic are all sound.
transformer_engine/common/include/transformer_engine/activation.h Adds six well-documented public C API declarations for scaled/clamped SwiGLU and SReLU forward/backward. Documentation and signatures are consistent with the implementation.
tests/cpp/operator/test_scaled_activation.cu New test file with a parametrised suite covering forward and backward for all three activation types. CPU reference matches kernel math.
transformer_engine/common/CMakeLists.txt Adds scaled_activation.cu to both the standard and fast-math build source lists; correct and complete.
tests/cpp/operator/CMakeLists.txt Adds test_scaled_activation.cu to the test_operator executable; trivially correct.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["nvte_scaled_swiglu / nvte_scaled_clamped_swiglu / nvte_scaled_srelu"] --> B["launch_scaled_gated_forward / launch_scaled_srelu_forward"]
    B --> C{Type switches GradT / InputT / ScaleT / OutputT}
    C --> D{row_vector_alignment use_vector?}
    D -- "yes, nvec>1" --> E1["scaled_*_forward_kernel<nvec>"]
    D -- "no, scalar" --> E2["scaled_*_forward_kernel<1>"]
    E1 & E2 --> F["output = activation(input) * act_scales[row]"]

    G["nvte_scaled_dswiglu / nvte_scaled_clamped_dswiglu / nvte_scaled_dsrelu"] --> H["launch_scaled_gated_backward / launch_scaled_srelu_backward"]
    H --> I{grad_act_scales null?}
    I -- "yes" --> J{use_vector?}
    J -- "yes" --> K1["scaled_*_backward_kernel<nvec> flat grid"]
    J -- "no" --> K2["scaled_*_backward_kernel<1> flat grid"]
    I -- "no" --> L{use_vector?}
    L -- "yes" --> M1["scaled_*_backward_with_scale_grad_kernel<nvec> 1 block per row"]
    L -- "no" --> M2["scaled_*_backward_with_scale_grad_kernel<1> 1 block per row"]
    K1 & K2 --> N["grad_input = scale * dActivation(input) * grad_output"]
    M1 & M2 --> O["grad_input + grad_act_scales[row] = sum_j(grad_j * unscaled_j)"]
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    A["nvte_scaled_swiglu / nvte_scaled_clamped_swiglu / nvte_scaled_srelu"] --> B["launch_scaled_gated_forward / launch_scaled_srelu_forward"]
    B --> C{Type switches GradT / InputT / ScaleT / OutputT}
    C --> D{row_vector_alignment use_vector?}
    D -- "yes, nvec>1" --> E1["scaled_*_forward_kernel<nvec>"]
    D -- "no, scalar" --> E2["scaled_*_forward_kernel<1>"]
    E1 & E2 --> F["output = activation(input) * act_scales[row]"]

    G["nvte_scaled_dswiglu / nvte_scaled_clamped_dswiglu / nvte_scaled_dsrelu"] --> H["launch_scaled_gated_backward / launch_scaled_srelu_backward"]
    H --> I{grad_act_scales null?}
    I -- "yes" --> J{use_vector?}
    J -- "yes" --> K1["scaled_*_backward_kernel<nvec> flat grid"]
    J -- "no" --> K2["scaled_*_backward_kernel<1> flat grid"]
    I -- "no" --> L{use_vector?}
    L -- "yes" --> M1["scaled_*_backward_with_scale_grad_kernel<nvec> 1 block per row"]
    L -- "no" --> M2["scaled_*_backward_with_scale_grad_kernel<1> 1 block per row"]
    K1 & K2 --> N["grad_input = scale * dActivation(input) * grad_output"]
    M1 & M2 --> O["grad_input + grad_act_scales[row] = sum_j(grad_j * unscaled_j)"]
Loading

Reviews (2): Last reviewed commit: "Update tests/cpp/operator/test_scaled_ac..." | Re-trigger Greptile

Comment on lines +170 to +171
}
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Dead gated_unscaled call

gated_unscaled computes unscaled on line 170, but gated_grads unconditionally writes *unscaled on line 171, overwriting it. The first call is dead code — every gated_grads case sets *unscaled before returning, so the result of gated_unscaled is never observed. This should simply be removed.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

either out-dated or not precise comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right that "dead code" was imprecise. The more accurate characterization: gated_unscaled is a redundant computation — its return value is assigned to unscaled, but gated_grads immediately overwrites *unscaled in all three switch cases before returning, so the first assignment is discarded. The function itself is reachable and executes, but its result is never observed.

The substance still holds: the call to gated_unscaled on line 170 can be removed without any behavioral change, since gated_grads on line 171 already computes and sets unscaled as a side-output needed for the scale gradient. Keeping both creates a mild readability issue — a reader might assume the two calls serve different purposes.

A cleaner form would be:

Suggested change
}
}
gated_grads(activation, act_in, linear_in, &dact, &dlinear, &unscaled);

Tip: You can customize Greptile's behavior for this repo with .greptile/rules.md and .greptile/config.json.

@zhongbozhu

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

@vthumbe1503 vthumbe1503 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Left minor comments

Comment on lines +121 to +133
__device__ __forceinline__ float warp_reduce_sum(float value) {
#pragma unroll
for (int offset = THREADS_PER_WARP / 2; offset > 0; offset >>= 1) {
value += __shfl_down_sync(0xffffffff, value, offset);
}
return value;
}

__device__ __forceinline__ float block_reduce_sum(float value, float *smem) {
const int lane = threadIdx.x % THREADS_PER_WARP;
const int warp = threadIdx.x / THREADS_PER_WARP;

value = warp_reduce_sum(value);

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I beleive we can reuse this from utils.cuh

inline __device__ T reduce(T data, const Op &op) {
// only lane 0 holds the result!

Comment on lines +722 to +781
void nvte_scaled_swiglu(const NVTETensor input, const NVTETensor act_scales, NVTETensor output,
int64_t glu_interleave_size, cudaStream_t stream) {
NVTE_API_CALL(nvte_scaled_swiglu);
using namespace transformer_engine;
Empty empty = {};
(void)empty;
ClampedSwiGLUParam param = {};
launch_scaled_gated_forward<ScaledActivation::kSwiGLU>(
input, act_scales, output, glu_interleave_size, param, stream, "nvte_scaled_swiglu");
}

void nvte_scaled_dswiglu(const NVTETensor grad, const NVTETensor input, const NVTETensor act_scales,
NVTETensor grad_input, NVTETensor grad_act_scales,
int64_t glu_interleave_size, cudaStream_t stream) {
NVTE_API_CALL(nvte_scaled_dswiglu);
using namespace transformer_engine;
ClampedSwiGLUParam param = {};
launch_scaled_gated_backward<ScaledActivation::kSwiGLU>(grad, input, act_scales, grad_input,
grad_act_scales, glu_interleave_size,
param, stream, "nvte_scaled_dswiglu");
}

void nvte_scaled_clamped_swiglu(const NVTETensor input, const NVTETensor act_scales,
NVTETensor output, float limit, float alpha,
float glu_linear_offset, int64_t glu_interleave_size,
cudaStream_t stream) {
NVTE_API_CALL(nvte_scaled_clamped_swiglu);
using namespace transformer_engine;
ClampedSwiGLUParam param = {limit, alpha, glu_linear_offset};
launch_scaled_gated_forward<ScaledActivation::kClampedSwiGLU>(
input, act_scales, output, glu_interleave_size, param, stream, "nvte_scaled_clamped_swiglu");
}

void nvte_scaled_clamped_dswiglu(const NVTETensor grad, const NVTETensor input,
const NVTETensor act_scales, NVTETensor grad_input,
NVTETensor grad_act_scales, float limit, float alpha,
float glu_linear_offset, int64_t glu_interleave_size,
cudaStream_t stream) {
NVTE_API_CALL(nvte_scaled_clamped_dswiglu);
using namespace transformer_engine;
ClampedSwiGLUParam param = {limit, alpha, glu_linear_offset};
launch_scaled_gated_backward<ScaledActivation::kClampedSwiGLU>(
grad, input, act_scales, grad_input, grad_act_scales, glu_interleave_size, param, stream,
"nvte_scaled_clamped_dswiglu");
}

void nvte_scaled_srelu(const NVTETensor input, const NVTETensor act_scales, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_scaled_srelu);
using namespace transformer_engine;
launch_scaled_srelu_forward(input, act_scales, output, stream, "nvte_scaled_srelu");
}

void nvte_scaled_dsrelu(const NVTETensor grad, const NVTETensor input, const NVTETensor act_scales,
NVTETensor grad_input, NVTETensor grad_act_scales, cudaStream_t stream) {
NVTE_API_CALL(nvte_scaled_dsrelu);
using namespace transformer_engine;
launch_scaled_srelu_backward(grad, input, act_scales, grad_input, grad_act_scales, stream,
"nvte_scaled_dsrelu");
}

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be good to move these NVTE API definitions into new files scaled_swiglu.cu and scaled_srelu.cu, following the footsteps of other activation definitions.

Comment thread tests/cpp/operator/test_scaled_activation.cu Outdated
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants