Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 51 additions & 70 deletions tests/pytorch/test_torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
from transformer_engine.pytorch.quantized_tensor import QuantizedTensor, _STORAGE_REGISTRY
from transformer_engine.pytorch.dynamo import TensorProto, to_tensor_proto
from transformer_engine.pytorch.dynamo.traceable_utils import make_empty_traceable, _contiguous_stride
from transformer_engine.pytorch import (
is_fp8_available,
is_mxfp8_available,
Expand Down Expand Up @@ -499,6 +499,7 @@ def _nvfp4(with_rht=True):
rowwise=True,
columnwise=True,
with_rht=with_rht,
with_post_rht_amax=with_rht,
)


Expand Down Expand Up @@ -640,7 +641,7 @@ def fn(inp):


# ---------------------------------------------------------------------------
# torch.compile-traceable allocation primitives + TensorProto
# torch.compile-traceable allocation primitives
# ---------------------------------------------------------------------------


Expand All @@ -665,9 +666,7 @@ def fn(inp):
def _build_from_primitives(quantizer, shape, dtype, device="cpu"):
"""Assemble a quantized tensor straight from the quantizer primitives:
``alloc_tensors`` (buffers) + ``create_metadata`` (ctx) + the storage's
``__tensor_unflatten__`` -- i.e. exactly what ``TensorProto.create_tensor``
does, but without going through :class:`TensorProto`.
"""
``__tensor_unflatten__``."""
names = tuple(quantizer._describe_buffers(shape)) # pylint: disable=protected-access
ctx = quantizer.create_metadata(shape, dtype=dtype)
buffers = quantizer.alloc_tensors(shape, device=device)
Expand Down Expand Up @@ -713,7 +712,7 @@ def _skip_if_dequantize_unsupported(q):
@pytest.mark.parametrize("factory, shape", _PROTO_QUANTIZERS)
def test_primitives_unflatten_compiles(factory, shape):
"""create_metadata + alloc_tensors + __tensor_unflatten__ compose and trace
under ``fullgraph=True`` (CPU), without TensorProto."""
under ``fullgraph=True`` (CPU)."""
q = factory()
names = tuple(q._describe_buffers(shape)) # pylint: disable=protected-access

Expand Down Expand Up @@ -779,78 +778,64 @@ def test_storage_flatten_unflatten_roundtrip(factory, shape):
torch.testing.assert_close(rebuilt.dequantize(), expected, atol=0, rtol=0, equal_nan=True)


# ----- TensorProto -----
# ----- make_empty_traceable -----


@pytest.mark.parametrize("factory, shape", _PROTO_QUANTIZERS)
def test_tensor_proto_matches_primitives(factory, shape):
"""TensorProto is a thin wrapper: its ``create_metadata`` /
``create_inner_tensors`` / ``create_tensor`` match building everything
directly from the quantizer primitives."""
def test_make_empty_traceable_matches_primitives(factory, shape):
"""make_empty_traceable is equivalent to building from quantizer primitives
directly (alloc_tensors + create_metadata + __tensor_unflatten__)."""
q = factory()
proto = TensorProto(shape=shape, dtype=torch.bfloat16, quantizer=q, device=torch.device("cpu"))
assert proto.is_quantized

# Metadata matches the quantizer's.
assert proto.create_metadata() == q.create_metadata(shape, dtype=torch.bfloat16)
# Build via make_empty_traceable.
tensor = make_empty_traceable(q, shape, dtype=torch.bfloat16, device="cpu")

# inner_names follows the storage's canonical __tensor_flatten__ order (the
# order the real op flattens its outputs to), while create_inner_tensors
# matches the _describe_buffers geometry (a name->shape/dtype mapping).
bufs = q._describe_buffers(shape) # pylint: disable=protected-access
# Build directly from primitives.
direct = _build_from_primitives(q, shape, torch.bfloat16)
names = tuple(direct.__tensor_flatten__()[0])
assert set(names) == set(bufs)
assert proto.inner_names() == names
inner = proto.create_inner_tensors()
assert len(inner) == len(names)
for name, buf in zip(names, inner):
exp_shape, exp_dtype = bufs[name]
assert tuple(buf.shape) == tuple(exp_shape)
assert buf.dtype == exp_dtype

# The assembled tensor matches one built directly from the primitives.
assert _signature(proto.create_tensor(), names) == _signature(direct, names)
# Both should produce the same buffer layout.
names = tuple(q._describe_buffers(shape)) # pylint: disable=protected-access
assert _signature(tensor, names) == _signature(direct, names)


@pytest.mark.parametrize("factory, shape", _PROTO_QUANTIZERS)
def test_tensor_proto_create_tensor_eager(factory, shape):
"""``create_tensor`` (no fake) yields a real quantized tensor."""
def test_make_empty_traceable_eager(factory, shape):
"""make_empty_traceable (no fake) yields a real quantized tensor."""
q = factory()
proto = TensorProto(shape=shape, dtype=torch.bfloat16, quantizer=q, device=torch.device("cpu"))
out = proto.create_tensor()
out = make_empty_traceable(q, shape, dtype=torch.bfloat16, device="cpu")
assert isinstance(out, QuantizedTensor)
assert tuple(out.shape) == tuple(shape)
assert out.dtype == torch.bfloat16
for name in proto.inner_names():
names = tuple(q._describe_buffers(shape)) # pylint: disable=protected-access
for name in names:
assert not isinstance(getattr(out, name), FakeTensor)


@pytest.mark.parametrize("factory, shape", _PROTO_QUANTIZERS)
def test_tensor_proto_create_tensor_fake(factory, shape):
"""``create_tensor`` under ``FakeTensorMode`` yields a fake-backed quantized
def test_make_empty_traceable_fake(factory, shape):
"""make_empty_traceable under FakeTensorMode yields a fake-backed quantized
tensor with the right shape/dtype and fake inner buffers."""
q = factory()
proto = TensorProto(shape=shape, dtype=torch.bfloat16, quantizer=q, device=torch.device("cpu"))
with FakeTensorMode():
out = proto.create_tensor()
out = make_empty_traceable(q, shape, dtype=torch.bfloat16, device="cpu")
assert isinstance(out, QuantizedTensor)
assert tuple(out.shape) == tuple(shape)
assert out.dtype == torch.bfloat16
for name in proto.inner_names():
names = tuple(q._describe_buffers(shape)) # pylint: disable=protected-access
for name in names:
assert isinstance(getattr(out, name), FakeTensor)


@pytest.mark.parametrize("factory, shape", _PROTO_QUANTIZERS)
def test_tensor_proto_create_tensor_compiles(factory, shape):
"""``TensorProto.create_tensor`` traces under ``fullgraph=True`` (CPU)."""
def test_make_empty_traceable_compiles(factory, shape):
"""make_empty_traceable traces under torch.compile(fullgraph=True) (CPU)."""
q = factory()
names = tuple(q._describe_buffers(shape)) # pylint: disable=protected-access

def fn(x):
proto = TensorProto(shape=tuple(x.shape), dtype=x.dtype, quantizer=q, device=x.device)
t = proto.create_tensor()
t = make_empty_traceable(q, tuple(x.shape), dtype=x.dtype, device=x.device)
acc = x.new_zeros(())
for name in proto.inner_names():
for name in names:
acc = acc + getattr(t, name).float().sum()
return acc

Expand All @@ -860,36 +845,32 @@ def fn(x):
assert out.shape == ()


def test_to_tensor_proto_plain():
"""``to_tensor_proto`` describes a plain tensor."""
t = torch.empty(2, 3, dtype=torch.float32)
proto = to_tensor_proto(t)
assert not proto.is_quantized
assert proto.shape == (2, 3)
assert proto.dtype == torch.float32
assert proto.inner_names() == ("data",)
def test_make_empty_traceable_plain_tensor():
"""For non-quantized tensors, make_empty_traceable produces a plain tensor."""
t = make_empty_traceable(None, (2, 3), dtype=torch.float32, device="cpu")
assert not isinstance(t, QuantizedTensor)
assert t.shape == (2, 3)
assert t.dtype == torch.float32


@pytest.mark.parametrize("factory, shape", _PROTO_QUANTIZERS)
def test_to_tensor_proto_quantized(factory, shape):
"""``to_tensor_proto`` round-trips a quantized tensor back into a proto."""
def test_make_empty_traceable_roundtrip(factory, shape):
"""A tensor built via make_empty_traceable can be flattened and unflattened."""
q = factory()
tensor = TensorProto(
shape=shape, dtype=torch.bfloat16, quantizer=q, device=torch.device("cpu")
).create_tensor()

proto = to_tensor_proto(tensor)
assert proto.is_quantized
assert proto.shape == tuple(shape)
assert proto.dtype == torch.bfloat16
# Same buffer layout as the original tensor.
assert proto.inner_names() == tuple(
q._describe_buffers(shape)
) # pylint: disable=protected-access
# Rebuilding from the derived proto matches the original tensor's structure.
assert _signature(proto.create_tensor(), proto.inner_names()) == _signature(
tensor, proto.inner_names()
tensor = make_empty_traceable(q, shape, dtype=torch.bfloat16, device="cpu")
names = tuple(q._describe_buffers(shape)) # pylint: disable=protected-access

# Flatten.
flat_names, flat_ctx = tensor.__tensor_flatten__()
assert set(flat_names) == set(names)
inner = {name: getattr(tensor, name) for name in flat_names}

# Unflatten.
rebuilt = type(tensor).__tensor_unflatten__(
inner, flat_ctx, tuple(tensor.shape), tensor.stride()
)
assert isinstance(rebuilt, QuantizedTensor)
assert _signature(rebuilt, flat_names) == _signature(tensor, names)


# ---------------------------------------------------------------------------
Expand Down
5 changes: 2 additions & 3 deletions transformer_engine/pytorch/dynamo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
"""torch.compile glue for Transformer Engine."""

from .quantizer_opaque import register_value_opaque_quantizer, is_value_opaque_quantizer
from .tensor_proto import TensorProto, to_tensor_proto
from .traceable_utils import make_empty_traceable
from .custom_op import register_custom_op

__all__ = [
"register_value_opaque_quantizer",
"is_value_opaque_quantizer",
"TensorProto",
"to_tensor_proto",
"make_empty_traceable",
"register_custom_op",
]
Loading