Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions ext/ArrayDiffONNXExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ function _attr_int(node, name; default::Int)
return a === nothing ? default : Int(a.i)
end

function _attr_ints(node, name; default::Union{Nothing,Vector{Int}} = nothing)
a = _find_attr(node, name)
return a === nothing ? default : Int[Int(x) for x in a.ints]
end

function _attr_float(::Type{T}, node, name; default) where {T<:Real}
a = _find_attr(node, name)
return a === nothing ? T(default) : T(a.f)
Expand Down Expand Up @@ -183,6 +188,9 @@ function _convert_node(
x, sx = env[node.input[1]]
return (_bcall(:-, Any[zero(T), x], sx), sx)

elseif op == "Transpose"
return _convert_transpose(node, env)

elseif op == "MatMul"
return _convert_matmul(node, env)

Expand Down Expand Up @@ -222,6 +230,34 @@ function _binop_broadcast(op::Symbol, node, env)
return (_bcall(op, Any[a, b], s), s)
end

function _convert_transpose(node, env)
x, sx = env[node.input[1]]
# ONNX `perm` is 0-indexed and defaults to reversing all axes.
perm = _attr_ints(node, "perm")
if length(sx) == 0
return (x, sx)
elseif length(sx) == 1
# 1-D tensor: transpose is a no-op on shape and on data.
return (x, sx)
elseif length(sx) == 2
# Only the two 2-D permutations are valid.
if perm === nothing || perm == [1, 0]
if x isa AbstractMatrix{<:Real}
# Constant input: just permute the values at conversion time.
return (collect(permutedims(x)), (sx[2], sx[1]))
end
new_shape = (sx[2], sx[1])
return (_call(:transpose, Any[x], new_shape), new_shape)
elseif perm == [0, 1]
return (x, sx)
else
error("Transpose: unsupported perm $perm for 2-D input")
end
else
error("Transpose: tensors with ndim > 2 are not supported (got $sx)")
end
end

function _convert_matmul(node, env)
a, sa = env[node.input[1]]
b, sb = env[node.input[2]]
Expand Down
2 changes: 1 addition & 1 deletion src/Coloring/Coloring.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ IndexedSet(n::Integer) = IndexedSet(zeros(Int, n), trues(n), 0)

function Base.push!(v::IndexedSet, i::Integer)
if v.empty[i] # new index
v.nzidx[v.nnz += 1] = i
v.nzidx[v.nnz+=1] = i
v.empty[i] = false
end
return
Expand Down
17 changes: 17 additions & 0 deletions src/JuMP/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,23 @@ function Base.broadcasted(
return Base.broadcasted(^, x, y)
end

function _transpose(x::AbstractJuMPArray{T,N}) where {T,N}
V = JuMP.variable_ref_type(x)
if N == 1
return GenericArrayExpr{V,2}(:transpose, Any[x], (1, size(x, 1)), false)
end
@assert N == 2 "`transpose` only supports 1-D and 2-D arrays"
return GenericArrayExpr{V,2}(
:transpose,
Any[x],
(size(x, 2), size(x, 1)),
false,
)
end

LinearAlgebra.transpose(x::AbstractJuMPArray) = _transpose(x)
LinearAlgebra.adjoint(x::AbstractJuMPArray) = _transpose(x)

function Base.sum(x::AbstractJuMPArray; dims = Colon())
V = JuMP.variable_ref_type(x)
if dims === Colon()
Expand Down
1 change: 1 addition & 0 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ const DEFAULT_MULTIVARIATE_OPERATORS = [
:sum,
:row,
:sum_dims,
:transpose,
]

function _validate_register_assumptions(
Expand Down
55 changes: 55 additions & 0 deletions src/reverse_mode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,16 @@ function _forward_eval(
sum!,
tuple(),
)
elseif node.index == 18 # transpose
@assert N == 1 "`transpose` expects a single child"
arr_ix = children_arr[first(children_indices)]
_reshape_call(
f.forward_storage,
f.sizes,
(k, arr_ix),
_forward_transpose!,
tuple(),
)
elseif node.index <= length(operators.multivariate_operators) &&
haskey(
operators.chainrules_operators,
Expand Down Expand Up @@ -813,6 +823,40 @@ function _reverse_sum_dims!(rev_arr, rev_parent)
return
end

# Forward for `:transpose`. The child can be 1-D (treated as a column,
# producing a row matrix `(1, n)`) or 2-D `(m, n)` producing `(n, m)`.
# Hand-rolled loops because `permutedims!` allocates on the `ReshapedArray`
# views returned by `_view_matrix`.
function _forward_transpose!(out, x)
if ndims(x) == 1
for j in eachindex(x)
out[1, j] = x[j]
end
else
m, n = size(x)
for j in 1:n, i in 1:m
out[j, i] = x[i, j]
end
end
return
end

# Reverse for `:transpose`. `y = xᵀ`, so ∂L/∂x[i,j] = ∂L/∂y[j,i]; the inverse
# permutation lifts the parent adjoint back to the child's shape.
function _reverse_transpose!(rev_arr, rev_parent)
if ndims(rev_arr) == 1
for j in eachindex(rev_arr)
rev_arr[j] = rev_parent[1, j]
end
else
m, n = size(rev_arr)
for j in 1:n, i in 1:m
rev_arr[i, j] = rev_parent[j, i]
end
end
return
end

"""
_reverse_eval(f::_SubexpressionStorage)

Expand Down Expand Up @@ -1063,6 +1107,17 @@ function _reverse_eval(
tuple(),
)
continue
elseif op == :transpose
@assert length(children_indices) == 1 "`transpose` expects a single child"
arr_ix = children_arr[first(children_indices)]
_reshape_call(
f.reverse_storage,
f.sizes,
(arr_ix, k),
_reverse_transpose!,
tuple(),
)
continue
end
elseif node.type == NODE_CALL_MULTIVARIATE &&
haskey(f.chainrules_pullbacks, k)
Expand Down
9 changes: 9 additions & 0 deletions src/sizes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,15 @@ function infer_sizes(::typeof(_row_op), shapes...)
return (1, length(shapes))
end

# transpose: vector (n,) → row matrix (1, n); matrix (m, n) → (n, m).
function infer_sizes(::typeof(LinearAlgebra.transpose), shape)
if length(shape) == 1
return (1, shape[1])
end
@assert length(shape) == 2 "`transpose` only supports 1-D and 2-D inputs"
return (shape[2], shape[1])
end

# Map a built-in operator symbol to its Julia function so `infer_sizes` can
# dispatch on `typeof(fn)`. Returns `nothing` for `:sum_dims`, whose shape
# depends on the constant dims vector and is handled inline by `_infer_sizes`.
Expand Down
69 changes: 69 additions & 0 deletions test/JuMP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,75 @@ function test_sum_dims_along_cols()
return
end

function test_transpose_matrix()
rows, cols = 3, 2
model = Model()
@variable(model, W[1:rows, 1:cols], container = ArrayDiff.ArrayOfVariables)
expr = transpose(W)
@test expr isa ArrayDiff.MatrixExpr
@test expr.head == :transpose
@test size(expr) == (cols, rows)
@test expr.broadcasted == false
x = Float64.(collect(1:(rows*cols)))
W_val = reshape(x, rows, cols)
# f(W) = ‖Wᵀ‖_F = ‖W‖_F; gradient is W ./ ‖W‖_F in column-major order.
sizes, val, g = _eval(model, LinearAlgebra.norm(expr), x)
@test val ≈ LinearAlgebra.norm(W_val)
@test g ≈ x ./ LinearAlgebra.norm(W_val)
# Tape: norm (k=1, scalar) → transpose (k=2, (cols, rows)).
@test sizes.ndims[1] == 0
@test sizes.ndims[2] == 2
t_off = sizes.size_offset[2]
@test sizes.size[t_off+1] == cols
@test sizes.size[t_off+2] == rows
return
end

function test_transpose_matrix_inner_product()
# f(W) = sum(Wᵀ .* C) where C is a constant of shape (cols, rows).
# ∂f/∂W[i,j] = C[j,i].
rows, cols = 2, 3
model = Model()
@variable(model, W[1:rows, 1:cols], container = ArrayDiff.ArrayOfVariables)
C = reshape(Float64.(collect(1:(rows*cols))) .+ 0.5, cols, rows)
expr = sum(transpose(W) .* C)
x = Float64.(collect(1:(rows*cols)))
W_val = reshape(x, rows, cols)
_, val, g = _eval(model, expr, x)
@test val ≈ sum(transpose(W_val) .* C)
# Gradient column-major: g[(j-1)*rows + i] = C[j, i].
expected = vec([C[j, i] for i in 1:rows, j in 1:cols])
@test g ≈ expected
return
end

function test_transpose_vector()
n = 4
model = Model()
@variable(model, x[1:n], container = ArrayDiff.ArrayOfVariables)
expr = transpose(x)
@test expr isa ArrayDiff.MatrixExpr
@test expr.head == :transpose
@test size(expr) == (1, n)
xv = Float64.(collect(1:n))
# f(x) = ‖xᵀ‖ = ‖x‖; ∂f/∂x[i] = x[i] / ‖x‖.
_, val, g = _eval(model, LinearAlgebra.norm(expr), xv)
@test val ≈ LinearAlgebra.norm(xv)
@test g ≈ xv ./ LinearAlgebra.norm(xv)
return
end

function test_transpose_adjoint_alias()
rows, cols = 2, 3
model = Model()
@variable(model, W[1:rows, 1:cols], container = ArrayDiff.ArrayOfVariables)
expr = adjoint(W)
@test expr isa ArrayDiff.MatrixExpr
@test expr.head == :transpose
@test size(expr) == (cols, rows)
return
end

function test_broadcast_nonsquare_matrix()
model = Model()
@variable(model, W[1:2, 1:3], container = ArrayDiff.ArrayOfVariables)
Expand Down
113 changes: 113 additions & 0 deletions test/ONNXExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ function _attr_int(name::String, i::Integer)
return ONNX.AttributeProto(name = name, i = Int64(i), var"#type" = AT.INT)
end

function _attr_ints(name::String, ints::AbstractVector{<:Integer})
return ONNX.AttributeProto(
name = name,
ints = Int64[Int64(i) for i in ints],
var"#type" = AT.INTS,
)
end

function _attr_tensor(name::String, t::ONNX.TensorProto)
return ONNX.AttributeProto(name = name, t = t, var"#type" = AT.TENSOR)
end
Expand Down Expand Up @@ -782,6 +790,111 @@ function test_float32_sigmoid_and_gemm()
@test val ≈ fjulia(xv) rtol = 1.0f-5
end

# Transpose on a variable matrix input followed by matmul against itself:
# y = Wᵀ * v, where W is the (2,3) variable matrix and v is a constant.
# f(W) = ‖y‖² = ‖Wᵀ v‖² ⇒ ∂f/∂W = 2 v (Wᵀ v)ᵀ = 2 v vᵀ W.
function test_transpose_variable_matrix()
vars = [MOI.VariableIndex(i) for i in 1:6]
var_mat = collect(reshape(vars, 2, 3))
v = [0.4, -1.0]
init = _make_tensor("v", v)
n1 = _make_node(
"Transpose",
["x"],
["xT"];
attrs = [_attr_ints("perm", [1, 0])],
)
n2 = _make_node("MatMul", ["xT", "v"], ["y"])
proto = _build_model([n1, n2], ["x"], ["y"]; initializers = [init])
xv = [0.3, -0.7, 1.1, 2.0, 0.5, -1.5]
val, g = _eval_with_gradient(proto, vars, xv; input = var_mat)
fjulia(x) = sum((reshape(x, 2, 3)' * v) .^ 2)
@test val ≈ fjulia(xv)
@test g ≈ ForwardDiff.gradient(fjulia, xv)
end

# Transpose with no `perm` attribute (defaults to reverse-all): for 2-D this
# is the (1,0) swap, same as explicit perm=[1,0].
function test_transpose_default_perm()
vars = [MOI.VariableIndex(i) for i in 1:6]
var_mat = collect(reshape(vars, 2, 3))
v = [0.4, -1.0]
init = _make_tensor("v", v)
n1 = _make_node("Transpose", ["x"], ["xT"]) # no perm attr
n2 = _make_node("MatMul", ["xT", "v"], ["y"])
proto = _build_model([n1, n2], ["x"], ["y"]; initializers = [init])
xv = [0.3, -0.7, 1.1, 2.0, 0.5, -1.5]
val, g = _eval_with_gradient(proto, vars, xv; input = var_mat)
fjulia(x) = sum((reshape(x, 2, 3)' * v) .^ 2)
@test val ≈ fjulia(xv)
@test g ≈ ForwardDiff.gradient(fjulia, xv)
end

# Transpose with perm=[0,1] is a no-op; the result is `x` unchanged.
function test_transpose_identity_perm()
vars = [MOI.VariableIndex(i) for i in 1:6]
var_mat = collect(reshape(vars, 2, 3))
n1 = _make_node(
"Transpose",
["x"],
["y"];
attrs = [_attr_ints("perm", [0, 1])],
)
proto = _build_model([n1], ["x"], ["y"])
xv = [0.3, -0.7, 1.1, 2.0, 0.5, -1.5]
val, g = _eval_with_gradient(proto, vars, xv; input = var_mat)
fjulia(x) = sum(x .^ 2)
@test val ≈ fjulia(xv)
@test g ≈ ForwardDiff.gradient(fjulia, xv)
end

# Transpose of a constant initializer: folded at conversion time. Verify the
# graph still produces the correct value/gradient.
function test_transpose_constant_initializer()
vars = [MOI.VariableIndex(i) for i in 1:3]
W = [
0.4 -0.1 0.5;
1.2 -0.3 0.7
] # (2, 3)
init = _make_tensor("W", W)
n1 = _make_node(
"Transpose",
["W"],
["Wt"];
attrs = [_attr_ints("perm", [1, 0])],
)
n2 = _make_node("MatMul", ["Wt", "x"], ["y"]) # (3,2) * (2,) = (3,)
proto = _build_model([n1, n2], ["x"], ["y"]; initializers = [init])
xv = [0.6, -0.4]
# Vec × Mat path: ArrayDiff routes Vec × Mat through a transpose trick that
# requires the matrix to be constant — `Wt` here is constant after folding.
# But MatMul Mat × Vec needs Wt as the LHS and x as the RHS. Since `x` is
# the variable vector, `Wt * x` is (3,2) × (2,) — Mat × Vec, which is fine.
val, g = _eval_with_gradient(proto, vars[1:2], xv)
fjulia(x) = sum((W' * x) .^ 2)
@test val ≈ fjulia(xv)
@test g ≈ ForwardDiff.gradient(fjulia, xv)
end

# Higher-dim Transpose isn't supported.
function test_transpose_3d_errors()
init = _make_tensor("t", reshape(collect(1.0:6.0), 2, 3))
# Pretend the test input is 3-D by lying about it. Easier: try transposing
# a graph input declared as a 3-D Matrix-of-vars-shaped wrapping — but the
# wrapper itself rejects 3-D inputs. So instead exercise the path via the
# error inside `_convert_transpose` directly using a fake env entry.
ext = Base.get_extension(ArrayDiff, :ArrayDiffONNXExt)
fake_env = Dict{String,Tuple{Any,Tuple{Vararg{Int}}}}()
fake_env["x"] = (zeros(2, 2, 2), (2, 2, 2))
node = _make_node(
"Transpose",
["x"],
["y"];
attrs = [_attr_ints("perm", [2, 1, 0])],
)
@test_throws ErrorException ext._convert_transpose(node, fake_env)
end

# Multi-output graph: result is keyed by output name.
function test_multi_output()
vars = [MOI.VariableIndex(i) for i in 1:3]
Expand Down
Loading