[Pytorch][Bug] Requires Grad doesnt flow through Autograd boundaries for QuantizedTensor#3172
[Pytorch][Bug] Requires Grad doesnt flow through Autograd boundaries for QuantizedTensor#3172vthumbe1503 wants to merge 3 commits into
Conversation
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
… QuantizedTensor Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci pytorch |
Greptile SummaryThis PR fixes two related issues with
Confidence Score: 4/5Safe to merge; the changes are narrowly scoped, well-motivated, and consistent with the patterns already established by the other three quantized tensor types. The core fixes — removing the stale
Important Files Changed
Sequence Diagram%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
participant U as User code
participant QT as QuantizedTensor (base)
participant FB as Float8BlockwiseQTensor
participant AF as _FromFloat8BlockwiseFunc
participant AG as PyTorch Autograd
U->>QT: "x (requires_grad=True)"
U->>FB: quantizer(x) → x_q (has grad_fn via quantizer)
Note over QT: No _requires_grad cache<br/>requires_grad reads from C++ directly
U->>FB: x_q.requires_grad → True (native C++)
U->>FB: x_q.dequantize()
alt torch.is_grad_enabled()
FB->>AF: _FromFloat8BlockwiseFunc.apply(self, dtype)
AF->>AG: Record forward pass
AG-->>U: "dequantized tensor (requires_grad=True, grad_fn set)"
else no_grad context
FB->>AF: _FromFloat8BlockwiseFunc.forward(None, self, dtype)
AF-->>U: dequantized tensor (no grad_fn)
end
U->>AG: loss.backward()
AG->>AF: backward(ctx, grad) → (grad, None)
AG->>FB: propagate grad through x_q.grad_fn
AG-->>U: x.grad populated
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
participant U as User code
participant QT as QuantizedTensor (base)
participant FB as Float8BlockwiseQTensor
participant AF as _FromFloat8BlockwiseFunc
participant AG as PyTorch Autograd
U->>QT: "x (requires_grad=True)"
U->>FB: quantizer(x) → x_q (has grad_fn via quantizer)
Note over QT: No _requires_grad cache<br/>requires_grad reads from C++ directly
U->>FB: x_q.requires_grad → True (native C++)
U->>FB: x_q.dequantize()
alt torch.is_grad_enabled()
FB->>AF: _FromFloat8BlockwiseFunc.apply(self, dtype)
AF->>AG: Record forward pass
AG-->>U: "dequantized tensor (requires_grad=True, grad_fn set)"
else no_grad context
FB->>AF: _FromFloat8BlockwiseFunc.forward(None, self, dtype)
AF-->>U: dequantized tensor (no grad_fn)
end
U->>AG: loss.backward()
AG->>AF: backward(ctx, grad) → (grad, None)
AG->>FB: propagate grad through x_q.grad_fn
AG-->>U: x.grad populated
Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| @staticmethod | ||
| def backward( | ||
| _ctx: torch.autograd.function.FunctionCtx, # unused | ||
| grad: torch.Tensor, | ||
| ) -> Tuple[Optional[torch.Tensor], ...]: |
There was a problem hiding this comment.
The
_ctx annotation is Optional[FunctionCtx] in forward (to allow direct calls with None) but is FunctionCtx (non-optional) in backward. Backward is only ever invoked by PyTorch's autograd engine, which always passes a real FunctionCtx, so the annotation is factually correct there — but the inconsistency with forward can confuse readers and type-checkers. Aligning both to Optional makes the public contract clear.
| @staticmethod | |
| def backward( | |
| _ctx: torch.autograd.function.FunctionCtx, # unused | |
| grad: torch.Tensor, | |
| ) -> Tuple[Optional[torch.Tensor], ...]: | |
| @staticmethod | |
| def backward( | |
| _ctx: Optional[torch.autograd.function.FunctionCtx], # unused | |
| grad: torch.Tensor, | |
| ) -> Tuple[Optional[torch.Tensor], ...]: |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| def test_quantize_dequantize_autograd( | ||
| self, | ||
| *, | ||
| quantization: str, | ||
| shape: Iterable[int] = (128, 128), | ||
| dtype: torch.dtype = torch.bfloat16, | ||
| device: torch.device = "cuda", | ||
| ) -> None: | ||
| """Autograd must survive a quantize -> dequantize round trip.""" | ||
|
|
||
| quantizer = make_quantizer(quantization, device=device) | ||
| x = torch.randn(list(shape), dtype=dtype, device=device, requires_grad=True) | ||
| # Quantize with autograd enabled: a grad_fn is attached to the output. | ||
| x_q = quantizer(x) | ||
| assert isinstance(x_q, QuantizedTensor) | ||
| assert x_q.grad_fn is not None, "quantized tensor is missing its grad_fn" | ||
| # requires_grad must reflect the attached grad_fn, not a stale cache. | ||
| assert x_q.requires_grad, ( | ||
| "quantized tensor reports requires_grad=False despite having a " | ||
| "grad_fn (stale requires_grad cache)" | ||
| ) | ||
|
|
||
| # Dequantize and take a loss; the gradient must reach the input. | ||
| (x_q.dequantize().float() ** 2).sum().backward() | ||
| assert ( | ||
| x.grad is not None and x.grad.norm().item() > 0 | ||
| ), "Gradient did not flow back to the input through quantize -> dequantize" | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) | ||
| class TestMXFP8Tensor: |
There was a problem hiding this comment.
Test silently skips the key
fp8_blockwise fix on machines without block-scaling hardware
_quantization_list is populated conditionally on fp8_block_scaling_available. On a runner where that flag is False, test_quantize_dequantize_autograd will run with zero parameterization for the fp8_blockwise case — the very quantization path this PR actually changes. If CI machines don't universally support block-scaling, the regression guard for Float8BlockwiseQTensor.dequantize may not fire even if that code is broken later. Consider adding an explicit skip/xfail marker or a dedicated non-parametric test so the gap is visible rather than silent.
Description
Bugs
Alternative way of reducing CPU overheads for attribute access:
The main root cause for requires_grad attribute access of a QuantizedTensor being slow was our custom torch function implementation. The implementation doesnt do anything and uses the _disabled_torch_function_impl. But the problem was presence of python implementation resulted in us crossing C++ to python boundary twice.
When we do requires_grad access of custom torch tensor it goes through this function in pytorch C++ which checks the presence of torch function and actually goes back to python to execute it.
This can be avoided by assigning the sentinel value directly instead of defining a noop torch function
__torch_function__ = torch._C._disabled_torch_function_implTorchao does the same for their custom MXFP8Tensor here
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: