Skip to content

[Pytorch][Bug] Requires Grad doesnt flow through Autograd boundaries for QuantizedTensor#3172

Open
vthumbe1503 wants to merge 3 commits into
NVIDIA:mainfrom
vthumbe1503:requires_grad_bug
Open

[Pytorch][Bug] Requires Grad doesnt flow through Autograd boundaries for QuantizedTensor#3172
vthumbe1503 wants to merge 3 commits into
NVIDIA:mainfrom
vthumbe1503:requires_grad_bug

Conversation

@vthumbe1503

@vthumbe1503 vthumbe1503 commented Jul 2, 2026

Copy link
Copy Markdown
Collaborator

Description

Bugs

  • Requires Grad caching solution was introduced in this PR that aimed at reducing CPU overheads in TE. Turns out when we pass QuantizedTensor across autograd boundaries, its corresponding requires_grad doesnt flow through. This is because requires_grad is set through grad_fn directly in C++ which bypasses our python cache.
  • While debugging this bug, I also uncovered our Float8Blockwise Dequantize is not autograd aware similar to other QuantizedTensor. This PR fixes that as well

Alternative way of reducing CPU overheads for attribute access:

  • The main root cause for requires_grad attribute access of a QuantizedTensor being slow was our custom torch function implementation. The implementation doesnt do anything and uses the _disabled_torch_function_impl. But the problem was presence of python implementation resulted in us crossing C++ to python boundary twice.

  • When we do requires_grad access of custom torch tensor it goes through this function in pytorch C++ which checks the presence of torch function and actually goes back to python to execute it.

  • This can be avoided by assigning the sentinel value directly instead of defining a noop torch function
    __torch_function__ = torch._C._disabled_torch_function_impl
    Torchao does the same for their custom MXFP8Tensor here

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
… QuantizedTensor

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 requested a review from ksivaman as a code owner July 2, 2026 20:57
@vthumbe1503

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

@greptile-apps

greptile-apps Bot commented Jul 2, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR fixes two related issues with QuantizedTensor: a stale requires_grad cache that broke gradient flow across autograd boundaries, and a redundant Python-layer __torch_function__ wrapper that incurred unnecessary C++↔Python overhead. It also fills a gap for Float8BlockwiseQTensor, which was the only quantized tensor type whose dequantize lacked an autograd-function wrapper.

  • Removes the _requires_grad cache and its property/setter overrides from QuantizedTensor, letting PyTorch's native C++ attribute serve as the single source of truth.
  • Assigns __torch_function__ = torch._C._disabled_torch_function_impl as a class-level sentinel (matching what torchao's MXFP8Tensor does) to skip the Python dispatch boundary on attribute access.
  • Introduces _FromFloat8BlockwiseFunc(torch.autograd.Function) in the storage module and wires Float8BlockwiseQTensor.dequantize to use .apply() when grad is enabled, completing the pattern already used by Float8Tensor, MXFP8Tensor, and NVFP4Tensor.

Confidence Score: 4/5

Safe to merge; the changes are narrowly scoped, well-motivated, and consistent with the patterns already established by the other three quantized tensor types.

The core fixes — removing the stale requires_grad cache and wiring Float8BlockwiseQTensor.dequantize through an autograd function — are both correct and match the existing patterns in Float8Tensor, MXFP8Tensor, and NVFP4Tensor. The __torch_function__ sentinel assignment is the canonical approach documented in PyTorch and used by torchao. The only non-trivial concern is that the new regression test is parametrized against _quantization_list, which is populated conditionally on hardware availability; on runners without block-scaling support, the fp8_blockwise case — the central path changed by this PR — would be silently absent from the test run.

tests/pytorch/test_quantized_tensor.py — the new autograd test may not exercise the fp8_blockwise path on all CI runners due to hardware-gated parametrization.

Important Files Changed

Filename Overview
transformer_engine/pytorch/quantized_tensor.py Removes the stale _requires_grad cache and the noop classmethod __torch_function__; replaces the latter with a direct sentinel assignment — both changes are correct and targeted.
transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py Adds _FromFloat8BlockwiseFunc autograd function (logic moved from old dequantize); minor _ctx annotation inconsistency between forward (Optional) and backward (non-Optional), no runtime impact.
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py Wires dequantize to call _FromFloat8BlockwiseFunc.apply when grad is enabled, matching the pattern in Float8Tensor/MXFP8Tensor/NVFP4Tensor exactly.
tests/pytorch/test_quantized_tensor.py Extracts make_quantizer helper (good refactor) and adds a new test_quantize_dequantize_autograd covering the fixed bug; test will only run for quantization types whose hardware is available on the runner.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant U as User code
    participant QT as QuantizedTensor (base)
    participant FB as Float8BlockwiseQTensor
    participant AF as _FromFloat8BlockwiseFunc
    participant AG as PyTorch Autograd

    U->>QT: "x (requires_grad=True)"
    U->>FB: quantizer(x) → x_q (has grad_fn via quantizer)
    Note over QT: No _requires_grad cache<br/>requires_grad reads from C++ directly
    U->>FB: x_q.requires_grad → True (native C++)
    U->>FB: x_q.dequantize()
    alt torch.is_grad_enabled()
        FB->>AF: _FromFloat8BlockwiseFunc.apply(self, dtype)
        AF->>AG: Record forward pass
        AG-->>U: "dequantized tensor (requires_grad=True, grad_fn set)"
    else no_grad context
        FB->>AF: _FromFloat8BlockwiseFunc.forward(None, self, dtype)
        AF-->>U: dequantized tensor (no grad_fn)
    end
    U->>AG: loss.backward()
    AG->>AF: backward(ctx, grad) → (grad, None)
    AG->>FB: propagate grad through x_q.grad_fn
    AG-->>U: x.grad populated
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant U as User code
    participant QT as QuantizedTensor (base)
    participant FB as Float8BlockwiseQTensor
    participant AF as _FromFloat8BlockwiseFunc
    participant AG as PyTorch Autograd

    U->>QT: "x (requires_grad=True)"
    U->>FB: quantizer(x) → x_q (has grad_fn via quantizer)
    Note over QT: No _requires_grad cache<br/>requires_grad reads from C++ directly
    U->>FB: x_q.requires_grad → True (native C++)
    U->>FB: x_q.dequantize()
    alt torch.is_grad_enabled()
        FB->>AF: _FromFloat8BlockwiseFunc.apply(self, dtype)
        AF->>AG: Record forward pass
        AG-->>U: "dequantized tensor (requires_grad=True, grad_fn set)"
    else no_grad context
        FB->>AF: _FromFloat8BlockwiseFunc.forward(None, self, dtype)
        AF-->>U: dequantized tensor (no grad_fn)
    end
    U->>AG: loss.backward()
    AG->>AF: backward(ctx, grad) → (grad, None)
    AG->>FB: propagate grad through x_q.grad_fn
    AG-->>U: x.grad populated
Loading

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +102 to +106
@staticmethod
def backward(
_ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The _ctx annotation is Optional[FunctionCtx] in forward (to allow direct calls with None) but is FunctionCtx (non-optional) in backward. Backward is only ever invoked by PyTorch's autograd engine, which always passes a real FunctionCtx, so the annotation is factually correct there — but the inconsistency with forward can confuse readers and type-checkers. Aligning both to Optional makes the public contract clear.

Suggested change
@staticmethod
def backward(
_ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
@staticmethod
def backward(
_ctx: Optional[torch.autograd.function.FunctionCtx], # unused
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +821 to 851
def test_quantize_dequantize_autograd(
self,
*,
quantization: str,
shape: Iterable[int] = (128, 128),
dtype: torch.dtype = torch.bfloat16,
device: torch.device = "cuda",
) -> None:
"""Autograd must survive a quantize -> dequantize round trip."""

quantizer = make_quantizer(quantization, device=device)
x = torch.randn(list(shape), dtype=dtype, device=device, requires_grad=True)
# Quantize with autograd enabled: a grad_fn is attached to the output.
x_q = quantizer(x)
assert isinstance(x_q, QuantizedTensor)
assert x_q.grad_fn is not None, "quantized tensor is missing its grad_fn"
# requires_grad must reflect the attached grad_fn, not a stale cache.
assert x_q.requires_grad, (
"quantized tensor reports requires_grad=False despite having a "
"grad_fn (stale requires_grad cache)"
)

# Dequantize and take a loss; the gradient must reach the input.
(x_q.dequantize().float() ** 2).sum().backward()
assert (
x.grad is not None and x.grad.norm().item() > 0
), "Gradient did not flow back to the input through quantize -> dequantize"


@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8)
class TestMXFP8Tensor:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Test silently skips the key fp8_blockwise fix on machines without block-scaling hardware

_quantization_list is populated conditionally on fp8_block_scaling_available. On a runner where that flag is False, test_quantize_dequantize_autograd will run with zero parameterization for the fp8_blockwise case — the very quantization path this PR actually changes. If CI machines don't universally support block-scaling, the regression guard for Float8BlockwiseQTensor.dequantize may not fire even if that code is broken later. Consider adding an explicit skip/xfail marker or a dedicated non-parametric test so the gap is visible rather than silent.

@pggPL pggPL self-requested a review July 2, 2026 21:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant