Skip to content

[Common, PyTorch] Improve mHC to match DeepSeek's implementation#2978

Open
kainzhong wants to merge 7 commits into
NVIDIA:mainfrom
kainzhong:feat/mhc_optimization1
Open

[Common, PyTorch] Improve mHC to match DeepSeek's implementation#2978
kainzhong wants to merge 7 commits into
NVIDIA:mainfrom
kainzhong:feat/mhc_optimization1

Conversation

@kainzhong

Copy link
Copy Markdown
Collaborator

Description

Some enhancement for mHC to better align with DeepSeek's tilelang implementation: https://github.com/deepseek-ai/TileKernels/tree/main/tile_kernels/mhc

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add mhc_generate_mix_and_aggregate API that does projection, scale, sinkhorn and aggregate together
  • Allow mhc_fused_projection to accept arguments with mixed dtype: x.dtype=bf16, phi.dtype=fp32, which matches DeepSeek's implementation
  • mhc_fused_projection now outputs fp32 regardless of the input dtype, matching DeepSeek's implementation
  • Add fuse_grad_x_acc optimization (default to False), which will reuse the same grad_x buffer to accumulate the initial mHC input x's gradient for mhc_fused_expand_combine, mhc_fused_aggregate and mhc_fused_projection
  • Support norm_weight for mhc_fused_projection, which would be equivalent to apply RMSNorm in the unfused manner with elementwise_affine=True, which would be the learnable per-element affine parameters for RMSNorm
  • Refactor some kernel code to avoid duplication. I just realized if you make a triton kernel constexpr, it can be used as a macro in if branches since triton will not compile if it knows in compile time that some branch will not be taken
  • Fix the bug that grid will exceed CUDA's limitation when M is too large and the autotune candidate is BLOCK_SIZE_M=1. Such invalid configs will be pruned now.
  • Improve projection op by using TMA if on Hopper+

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@kainzhong kainzhong force-pushed the feat/mhc_optimization1 branch from af685d7 to eccba0c Compare May 12, 2026 00:57
@kainzhong kainzhong marked this pull request as ready for review May 13, 2026 17:26
@kainzhong kainzhong requested a review from ksivaman as a code owner May 13, 2026 17:26
@greptile-apps

greptile-apps Bot commented May 13, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR extends the mHC (manifold Hyper-Connection) implementation to better align with DeepSeek's TileKernels reference. It adds a high-level mhc_generate_mix_and_aggregate API, supports mixed-dtype inputs and norm_weight for RMSNorm in projection, introduces a fuse_grad_x_acc optimization for shared gradient accumulation, adds TMA support on Hopper+, and refactors backward kernels to use STEP_SIZE_C loops and workspace-based deterministic reductions.

  • New _mhc_projection_bwd_fused_dphi kernel separates phi/norm_weight gradients from the dx kernel; _mhc_projection_fwd_fused gains TMA and split-K paths, and bias/norm_weight handling in expand-combine is unified via HAS_BIAS/HAS_NORM_WEIGHT constexpr guards.
  • All backward ops now have pruner-based grid-dim guards and NVTE_DISABLE_TRITON_AUTOTUNING moved inside the pruner; fused_grad_x_acc_buffer lets expand-combine, aggregate, and projection kernels share one fp32 gradient buffer, with only projection returning the accumulated result.

Confidence Score: 4/5

Safe to merge; no correctness regressions found in the autograd plumbing, kernel math, or determinism paths.

The backward return arities, grid-dim pruning, and NVTE_DISABLE_TRITON_AUTOTUNING placement are all addressed. The fused_grad_x_acc_buffer accumulation order is logically sound and covered by test_mhc_fuse_grad_acc. The TMA-alignment check slightly misidentifies which phi tensor to evaluate when norm_weight is present, but PyTorch's allocator always produces sufficiently aligned tensors so it will not fire in practice.

transformer_engine/pytorch/triton/mhc.py — TMA alignment guard and the forward save/dispatch ordering around norm_weight multiplication.

Important Files Changed

Filename Overview
transformer_engine/pytorch/triton/mhc.py Large refactor adding norm_weight, fused_grad_x_acc_buffer, TMA path, and deterministic workspace reductions; backward return arities now match forward signatures. TMA-alignment check references pre-multiplication phi when norm_weight is provided.
transformer_engine/common/triton/mhc.py New _mhc_projection_bwd_fused_dphi and _mhc_projection_bwd_fused_dx kernels; forward kernel gains TMA/split-K; aggregate/expand-combine backward unified into STEP_SIZE_C loop; all pruners moved inside callbacks for correct single-config NVTE_DISABLE_TRITON_AUTOTUNING handling. Dead stride_norm_weight constexpr present in forward kernel.
tests/pytorch/test_mhc.py Extended with mixed-dtype projection tests, norm_weight coverage, use_split_k parameter, test_mhc_rmsnorm, test_mhc_fuse_grad_acc, and tolerances adjusted per ENFORCE_DETERMINISTIC.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant User
    participant genMixAgg as mhc_generate_mix_and_aggregate
    participant proj as mhc_fused_projection
    participant scale as mhc_fused_scale
    participant sink as mhc_fused_sinkhorn
    participant agg as mhc_fused_aggregate
    participant expC as mhc_fused_expand_combine
    participant buf as fused_grad_x_acc_buffer (fp32)

    User->>genMixAgg: x(bf16/fp32), phi, alpha, beta, norm_weight, fused_grad_x_acc_buffer
    genMixAgg->>proj: x.view(M,nC), phi [+norm_weight absorbed]
    proj-->>genMixAgg: H(fp32, M×32), ms(fp32)
    genMixAgg->>scale: H, alpha, beta, ms
    scale-->>genMixAgg: H_pre, H_post, H_res
    genMixAgg->>sink: H_res.view(s,b,n,n)
    sink-->>genMixAgg: H_res (doubly stochastic)
    genMixAgg->>agg: x(s,b,C,n), H_pre
    agg-->>genMixAgg: out(s,b,C)
    genMixAgg-->>User: out, H_post, H_res

    Note over User: User runs sub-layer (Attn/FFN) on out
    User->>expC: f(sub-layer output), H_post, x, H_res

    rect rgb(240,240,255)
        Note over expC,buf: BACKWARD (reverse order)
        expC-->>buf: WRITE grad_x contribution (fp32)
        agg-->>buf: READ + ADD grad_x contribution
        proj-->>buf: READ + ADD, cast to x.dtype, return
    end
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant User
    participant genMixAgg as mhc_generate_mix_and_aggregate
    participant proj as mhc_fused_projection
    participant scale as mhc_fused_scale
    participant sink as mhc_fused_sinkhorn
    participant agg as mhc_fused_aggregate
    participant expC as mhc_fused_expand_combine
    participant buf as fused_grad_x_acc_buffer (fp32)

    User->>genMixAgg: x(bf16/fp32), phi, alpha, beta, norm_weight, fused_grad_x_acc_buffer
    genMixAgg->>proj: x.view(M,nC), phi [+norm_weight absorbed]
    proj-->>genMixAgg: H(fp32, M×32), ms(fp32)
    genMixAgg->>scale: H, alpha, beta, ms
    scale-->>genMixAgg: H_pre, H_post, H_res
    genMixAgg->>sink: H_res.view(s,b,n,n)
    sink-->>genMixAgg: H_res (doubly stochastic)
    genMixAgg->>agg: x(s,b,C,n), H_pre
    agg-->>genMixAgg: out(s,b,C)
    genMixAgg-->>User: out, H_post, H_res

    Note over User: User runs sub-layer (Attn/FFN) on out
    User->>expC: f(sub-layer output), H_post, x, H_res

    rect rgb(240,240,255)
        Note over expC,buf: BACKWARD (reverse order)
        expC-->>buf: WRITE grad_x contribution (fp32)
        agg-->>buf: READ + ADD grad_x contribution
        proj-->>buf: READ + ADD, cast to x.dtype, return
    end
Loading

Reviews (10): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/triton/mhc.py Outdated
Comment thread transformer_engine/pytorch/triton/mhc.py Outdated
Comment thread transformer_engine/pytorch/triton/mhc.py Outdated
Comment thread transformer_engine/common/triton/mhc.py
Comment thread transformer_engine/common/triton/mhc.py Outdated
Comment thread transformer_engine/common/triton/mhc.py
@timmoon10

Copy link
Copy Markdown
Member

/te-ci

timmoon10
timmoon10 previously approved these changes Jun 29, 2026

@timmoon10 timmoon10 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, I don't see anything particularly suspicious and the API changes are backward-compatible.

fused_grad_x_acc_buffer : Optional[torch.Tensor]
A pre-allocated buffer for inplace gradient accumulation to avoid PyTorch autograd overhead.
If not None, triton kernels will accumulate the gradient of x into this same buffer to avoid copying the gradient by PyTorch.
This optimization requires the operation order to be mhc_fused_projection -> mhc_fused_aggregate -> mhc_fused_expand_combine.

@timmoon10 timmoon10 Jun 29, 2026

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flagging that the ordering requirement is fragile and unintuitive. That said, fused_grad_x_acc_buffer is an optional advanced optimization and the requirement is well documented.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can ask user to pass a zero buffer instead of an uninitialized one so in this case the order will not matter (the reason why it has to be this specific order is because in backward mhc_fused_expand_combine will overwrite the value into the buffer instead of a read+write accumulate, so it has to be the last operation in forward and the first operation in backward), but an uninitialized buffer should avoid the cost to reset the memory to zero.
Which approach is better?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we also need to have this specific order so that mhc_fused_projection can cast the buffer to BF16 and attach it to x.grad. I don't have a strong opinion since all options I see involve doing something non-standard with PyTorch autograd. Probably best to keep it as it is right now, and generalize in the future if there's a need.

Comment thread tests/pytorch/test_mhc.py

@pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc)
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"])
@pytest.mark.parametrize("recompute", [False, True], ids=["no_recompute", "recompute"])

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we removing this test case? It doesn't seem that we have touched recompute in mhc_fused_sinkhorn.

@kainzhong kainzhong Jun 30, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah it's a mistake. Previously this PR includes a gluon kernel which always recomputes but later I decided to make that a separate PR. I must have forget to revert this line of change. Fixed now.

Comment thread tests/pytorch/test_mhc.py
Comment on lines +504 to +507
# If upcasting from bf16 to fp32 takes place inside the triton kernel, triton will ignore "ieee" precision and use tf32 anyway
# See https://github.com/triton-lang/triton/issues/10176 for detail.
# Therefore, we need to use tf32x3 instead which at least has better accuracy than tf32 just to make the tests pass. In production
# precision should be tf32 so it's not affected.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we advertise a feature ("ieee" precision) and it does not work for whatever reason, we should make sure that this caveat is visible in the documentation. tf32x3 should be quite good in emulating full FP32 precision, but still we should have that listed to avoid any confusion from the users.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed comments so use_tf32 will make it clear that if activation is bf16 and weight is fp32 then it will use tf32x3 now.

"""
assert n == 4, "Only n=4 is supported in this implementation"
check_deterministic("mhc_fused_scale")
out = mHCScaleFusedOp.apply(H, alpha, beta, ms, n)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_mhc_scale_bwd_fused uses atomic_add, but this function now advertises itself as deterministic. Is the atomic add not a problem for that?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I deleted this by mistake. Previously this PR made this kernel deterministic but later I thought it should be a separate PR but I forgot to revert this line of change. Fixed now.



def _support_tma():
return torch.cuda.get_device_capability()[0] >= 9

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should use the device actually used (so device of the input) rather than the first device.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now _support_tma will decide based on input's device instead.

```
layer_input, H_post, H_res = mhc_generate_mix_and_aggregate(x, phi, alpha, beta)
layer_output = layer(layer_input) # Attn / FFN layer
x = mhc_fused_expand_combine(layer_input, bias, H_post, x, H_res)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the combine also take the layer_output?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah typo here. It should be layer_output instead of layer_input. Fixed now.

@kainzhong kainzhong force-pushed the feat/mhc_optimization1 branch 3 times, most recently from f969407 to 1240a6f Compare June 30, 2026 20:48
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
@kainzhong kainzhong force-pushed the feat/mhc_optimization1 branch from ee83f1e to 6a9bc2d Compare June 30, 2026 22:16
@kainzhong

Copy link
Copy Markdown
Collaborator Author

/te-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants