Skip to content

Unroll _transform_tuple to fix Enzyme.autodiff on tuples of length ≥ 33#170

Merged
tpapp merged 2 commits into
tpapp:masterfrom
jlperla:enzyme-3104-unroll-transform-tuple
Jun 23, 2026
Merged

Unroll _transform_tuple to fix Enzyme.autodiff on tuples of length ≥ 33#170
tpapp merged 2 commits into
tpapp:masterfrom
jlperla:enzyme-3104-unroll-transform-tuple

Conversation

@jlperla

@jlperla jlperla commented May 15, 2026

Copy link
Copy Markdown
Contributor

Replace the Base.tail-recursive _transform_tuple with a @generated straight-line unroll — same outputs bit-for-bit, but the typed IR no longer contains a self-invoke, which is what Enzyme.autodiff (Forward and Reverse) trips on at tuple length ≥ 33 with AssertionError("conv == 37") (EnzymeAD/Enzyme.jl#3104).

The recursive Base.tail fold in _transform_tuple makes Enzyme.autodiff
(Forward and Reverse) throw `AssertionError("conv == 37")` from
Enzyme/src/rules/jitrules.jl:2073 once the tuple has ≥ 33 entries
(EnzymeAD/Enzyme.jl#3104). Replace it with a @generated straight-line
unroll that produces the same outputs bit-for-bit while emitting no
self-invoke in the typed IR — which is what Enzyme trips on.

Verified against the full Pkg.test() suite (all Pass = Total) and a
35-entry SW07-Pfeifer-style NamedTuple prior (fwd + rev both succeed).
@scheidan

scheidan commented Jun 15, 2026

Copy link
Copy Markdown
Contributor

Any chance that this will be merged here? I understand that the real fix should be on Enzyme's side, but that may be much harder.

Thanks!

PS: The is my real world MWE that lead me finally to this PR; maybe it is useful for someone.

using Distributions
using Enzyme
using TransformVariables

N = 33
dists = ntuple(i -> LogNormal(0.0, 1.0), N)
dists = NamedTuple{ntuple(i -> Symbol("x", i), N)}(dists)

function prior_transform(priors)
    transforms = map(priors) do prior
        left, right = extrema(support(prior))
        left = isinf(left) ? -TransformVariables.∞ : left
        right = isinf(right) ? TransformVariables.∞ : right
        TransformVariables.as(Real, left, right)
    end
    TransformVariables.as(transforms)
end

trans = prior_transform(dists)
q = fill(-0.1, TransformVariables.dimension(trans))

foo(q) = sum(values(TransformVariables.transform(trans, q)))

Enzyme.gradient(Enzyme.Reverse, foo, q) # AssertionError: conv == 37

@tpapp

tpapp commented Jun 15, 2026

Copy link
Copy Markdown
Owner

@jlperla, thanks for this, @scheidan, thanks for the ping. I apologize for the delay in reviewing this.

It is not strictly equivalent as, AFAIK, built-ins do not necessarily unroll above a certain tuple length. But given that the intention of using a tuple is to get type-stable code, I don't see a problem with this here. Also, EnzymeAD/Enzyme.jl#3104 indicates that this is an issue on the Julia side, so fixing it on our end may be the best option for now.

@devmotion, this is fine with me, do you have any comments?

@tpapp

tpapp commented Jun 15, 2026

Copy link
Copy Markdown
Owner

(closing and reopening to make CI run)

@tpapp tpapp closed this Jun 15, 2026
@tpapp tpapp reopened this Jun 15, 2026
Comment thread src/aggregation.jl
Comment thread src/aggregation.jl Outdated
for i in 1:N]
ℓ_sum = foldl((a, b) -> :($a + $b), ℓs)
return quote
idx = index

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is a separate idx variable needed? Couldn't we just operate with index?

Comment thread src/aggregation.jl Outdated
Co-authored-by: Tamas K. Papp <tkpapp@gmail.com>

@tpapp tpapp left a comment

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@tpapp tpapp merged commit b36d918 into tpapp:master Jun 23, 2026
7 checks passed
@devmotion

Copy link
Copy Markdown
Collaborator

same outputs bit-for-bit

Just wanted to mention: This is wrong, I actually saw downstream test failures due to this PR. Previously, summation was performed using (basically) foldr whereas this PR uses foldl, which does of course generally not give exactly the same results.

@jlperla

jlperla commented Jun 24, 2026

Copy link
Copy Markdown
Contributor Author

I was waiting for CI to run. Did it?

@devmotion

Copy link
Copy Markdown
Collaborator

In this PR? Sure, it ran before it was merged. My point was merely that it broke my downstream CI due to not yielding "same outputs bit-for-bit", and I wanted to point out for any future reader of this PR that this was an incorrect claim in the initial comment above.

@tpapp

tpapp commented Jun 24, 2026

Copy link
Copy Markdown
Owner

same outputs bit-for-bit

I just want to clarify that this is not something this package ever promised. AFAIK very few packages in the Julia ecosystem have that kind of commitment. @devmotion, thanks for pointing this out though.

downstream test failures due to this PR

I am sorry to hear this, but if they were comparing exact output, those were the wrong kind of tests.

@devmotion

Copy link
Copy Markdown
Collaborator

No, they were not comparing to exact TransformVariables output. The test failure was caused by different likelihood values of MLE estimates, apparently the tiny difference in the transform was sufficient to cause slightly different optimization trajectories. Not a big problem, of course, but that made me realize the incorrect claim above.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants