[Common, PyTorch] Improve mHC to match DeepSeek's implementation#2978
[Common, PyTorch] Improve mHC to match DeepSeek's implementation#2978kainzhong wants to merge 7 commits into
Conversation
af685d7 to
eccba0c
Compare
|
/te-ci |
timmoon10
left a comment
There was a problem hiding this comment.
LGTM, I don't see anything particularly suspicious and the API changes are backward-compatible.
| fused_grad_x_acc_buffer : Optional[torch.Tensor] | ||
| A pre-allocated buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. | ||
| If not None, triton kernels will accumulate the gradient of x into this same buffer to avoid copying the gradient by PyTorch. | ||
| This optimization requires the operation order to be mhc_fused_projection -> mhc_fused_aggregate -> mhc_fused_expand_combine. |
There was a problem hiding this comment.
Flagging that the ordering requirement is fragile and unintuitive. That said, fused_grad_x_acc_buffer is an optional advanced optimization and the requirement is well documented.
There was a problem hiding this comment.
I can ask user to pass a zero buffer instead of an uninitialized one so in this case the order will not matter (the reason why it has to be this specific order is because in backward mhc_fused_expand_combine will overwrite the value into the buffer instead of a read+write accumulate, so it has to be the last operation in forward and the first operation in backward), but an uninitialized buffer should avoid the cost to reset the memory to zero.
Which approach is better?
There was a problem hiding this comment.
I think we also need to have this specific order so that mhc_fused_projection can cast the buffer to BF16 and attach it to x.grad. I don't have a strong opinion since all options I see involve doing something non-standard with PyTorch autograd. Probably best to keep it as it is right now, and generalize in the future if there's a need.
|
|
||
| @pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc) | ||
| @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"]) | ||
| @pytest.mark.parametrize("recompute", [False, True], ids=["no_recompute", "recompute"]) |
There was a problem hiding this comment.
Why are we removing this test case? It doesn't seem that we have touched recompute in mhc_fused_sinkhorn.
There was a problem hiding this comment.
Ah it's a mistake. Previously this PR includes a gluon kernel which always recomputes but later I decided to make that a separate PR. I must have forget to revert this line of change. Fixed now.
| # If upcasting from bf16 to fp32 takes place inside the triton kernel, triton will ignore "ieee" precision and use tf32 anyway | ||
| # See https://github.com/triton-lang/triton/issues/10176 for detail. | ||
| # Therefore, we need to use tf32x3 instead which at least has better accuracy than tf32 just to make the tests pass. In production | ||
| # precision should be tf32 so it's not affected. |
There was a problem hiding this comment.
If we advertise a feature ("ieee" precision) and it does not work for whatever reason, we should make sure that this caveat is visible in the documentation. tf32x3 should be quite good in emulating full FP32 precision, but still we should have that listed to avoid any confusion from the users.
There was a problem hiding this comment.
Changed comments so use_tf32 will make it clear that if activation is bf16 and weight is fp32 then it will use tf32x3 now.
| """ | ||
| assert n == 4, "Only n=4 is supported in this implementation" | ||
| check_deterministic("mhc_fused_scale") | ||
| out = mHCScaleFusedOp.apply(H, alpha, beta, ms, n) |
There was a problem hiding this comment.
_mhc_scale_bwd_fused uses atomic_add, but this function now advertises itself as deterministic. Is the atomic add not a problem for that?
There was a problem hiding this comment.
Ah I deleted this by mistake. Previously this PR made this kernel deterministic but later I thought it should be a separate PR but I forgot to revert this line of change. Fixed now.
|
|
||
|
|
||
| def _support_tma(): | ||
| return torch.cuda.get_device_capability()[0] >= 9 |
There was a problem hiding this comment.
This should use the device actually used (so device of the input) rather than the first device.
There was a problem hiding this comment.
Now _support_tma will decide based on input's device instead.
| ``` | ||
| layer_input, H_post, H_res = mhc_generate_mix_and_aggregate(x, phi, alpha, beta) | ||
| layer_output = layer(layer_input) # Attn / FFN layer | ||
| x = mhc_fused_expand_combine(layer_input, bias, H_post, x, H_res) |
There was a problem hiding this comment.
Shouldn't the combine also take the layer_output?
There was a problem hiding this comment.
Ah typo here. It should be layer_output instead of layer_input. Fixed now.
f969407 to
1240a6f
Compare
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
ee83f1e to
6a9bc2d
Compare
|
/te-ci |
Description
Some enhancement for mHC to better align with DeepSeek's tilelang implementation: https://github.com/deepseek-ai/TileKernels/tree/main/tile_kernels/mhc
Fixes # (issue)
Type of change
Changes
mhc_generate_mix_and_aggregateAPI that does projection, scale, sinkhorn and aggregate togethermhc_fused_projectionto accept arguments with mixed dtype: x.dtype=bf16, phi.dtype=fp32, which matches DeepSeek's implementationmhc_fused_projectionnow outputs fp32 regardless of the input dtype, matching DeepSeek's implementationfuse_grad_x_accoptimization (default to False), which will reuse the same grad_x buffer to accumulate the initial mHC input x's gradient formhc_fused_expand_combine,mhc_fused_aggregateandmhc_fused_projectionnorm_weightformhc_fused_projection, which would be equivalent to apply RMSNorm in the unfused manner withelementwise_affine=True, which would be the learnable per-element affine parameters for RMSNormChecklist: