[PyTorch] Add optional caller-provided output/grad-input buffers to GroupedLinear and fused grouped MLP#3161
[PyTorch] Add optional caller-provided output/grad-input buffers to GroupedLinear and fused grouped MLP#3161phu0ngng wants to merge 6 commits into
Conversation
| output: Optional[torch.Tensor] = None, | ||
| grad_input: Optional[torch.Tensor] = None, |
There was a problem hiding this comment.
I'm skeptical of this API because it looks general, but is actually only supported with grouped linear and grouped MLP. How about something like the following:
def forward(
self,
input: torch.Tensor,
*extra_inputs: torch.Tensor,
extra_kwargs: Optional[Sequence[Optional[dict[str, Any]]]] = None,
):
...
if extra_kwargs is None:
extra_kwargs = [None] * len(self)
# Figure out what extra_kwargs go to each module group, and fail if a non-None extra_kwargs goes to a non-basic op
for module_group in self._module_groups:
xs = module_group(*xs, basic_op_kwargs=module_group_basic_op_kwargs)There was a problem hiding this comment.
Thanks, Tim. Good call. The output and grad_input kwargs overpromised generality. Building on your suggestion, I went with a generic per op kwargs pass through, but keyed by module or index instead of a positional list:
seq(x, split_sizes, probs, split_sizes,
op_kwargs={fc1: {GRAD_INPUT_BUFFER_KEY: g}, fc2: {OUTPUT_BUFFER_KEY: o}})
Sequential resolves each key to its group and basic op slot and forwards it as basic_op_kwargs, with no new dispatch path, and raises if a key targets a non-fusible op. This keeps the container kwarg-agnostic, so the buffers are now just op specific kwargs, while avoiding None padding and positional coupling.
Let me know what you think about this new design.
| # when the value above was materialized into it; otherwise it is copied. | ||
| if output_buffer is not None: | ||
| output_buffer = validate_or_alloc_output(output_buffer, fc2_out_shape, dtype, device) | ||
| if output_buffer.data_ptr() != fc2_out.data_ptr(): |
There was a problem hiding this comment.
Under what case this condition can be false? It seems to be it is always true, I do not see we feed the user provided buffer to the kernel?
There was a problem hiding this comment.
You were right that it was always true for MXFP8, which uses CuTe GEMM.
I pushed the latest commit in which I reworked it so that it does not trigger extra copies: the caller buffer is passed straight into the CuTe GEMM via d_tensor, so the kernel writes output/dgrad in place for both nvfp4 and MXFP8. This needs a cuDNN FE change I made a PR for: NVIDIA/cudnn-frontend#338 (until it lands and the bundled FE is bumped, the in-place MXFP8 path can't run). So for testing, we will need to check out this cuDNN FE manually.
…ar module and fusible ops Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
…ping Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…uffers, eliminating the D2D copy + cleanup Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Description
Let callers have an option to supply preallocated buffers for the forward output and backward input gradient of
GroupedLinearand the fused grouped MLP, instead of the module/op allocating them internally. The caller can pass a symmetric-memory buffer that is fed directly into TE EP'scombine, eliminating the D2D copy from the token buffer into TE EP's internal staging buffer.Type of change
Changes
GroupedLinear.forward: optionalout/dgrad_outbuffers, wired through the legacy and GroupedTensor paths (incl. FP8). Validated via a sharedvalidate_or_alloc_outputhelper (shape/dtype/device/contiguous/no-grad).Sequential.forwardtakes anop_kwargsmapping (keyed by module or index) that routes per-op kwargs to the target op, used to pass output/grad-input buffers to grouped linear / grouped MLP.m_splits); any padded tail is left untouched. Buffers may be supplied independently.Dependency
The in-place MXFP8 path requires a cuDNN frontend change that exposes an optional output tensor on the grouped-GEMM quant wrapper: NVIDIA/cudnn-frontend#338.
Testing
test_grouped_linear_caller_output_buffers: covers out/dgrad_out/both across the legacy and GroupedTensor paths _ buffer aliasing, bit-exact match vs. internal allocation, untouched padded tail, andValueErroron shape mismatch.test_grouped_linear_caller_buffers(ops) andtest_grouped_mlp_caller_buffers: verify the fused op runs, the output aliases the buffer, and grads match internal allocation.Checklist: