diff --git a/ext/ArrayDiffONNXExt.jl b/ext/ArrayDiffONNXExt.jl index 7c0b4c8..a5548bd 100644 --- a/ext/ArrayDiffONNXExt.jl +++ b/ext/ArrayDiffONNXExt.jl @@ -28,6 +28,11 @@ function _attr_int(node, name; default::Int) return a === nothing ? default : Int(a.i) end +function _attr_ints(node, name; default::Union{Nothing,Vector{Int}} = nothing) + a = _find_attr(node, name) + return a === nothing ? default : Int[Int(x) for x in a.ints] +end + function _attr_float(::Type{T}, node, name; default) where {T<:Real} a = _find_attr(node, name) return a === nothing ? T(default) : T(a.f) @@ -183,6 +188,9 @@ function _convert_node( x, sx = env[node.input[1]] return (_bcall(:-, Any[zero(T), x], sx), sx) + elseif op == "Transpose" + return _convert_transpose(node, env) + elseif op == "MatMul" return _convert_matmul(node, env) @@ -222,6 +230,34 @@ function _binop_broadcast(op::Symbol, node, env) return (_bcall(op, Any[a, b], s), s) end +function _convert_transpose(node, env) + x, sx = env[node.input[1]] + # ONNX `perm` is 0-indexed and defaults to reversing all axes. + perm = _attr_ints(node, "perm") + if length(sx) == 0 + return (x, sx) + elseif length(sx) == 1 + # 1-D tensor: transpose is a no-op on shape and on data. + return (x, sx) + elseif length(sx) == 2 + # Only the two 2-D permutations are valid. + if perm === nothing || perm == [1, 0] + if x isa AbstractMatrix{<:Real} + # Constant input: just permute the values at conversion time. + return (collect(permutedims(x)), (sx[2], sx[1])) + end + new_shape = (sx[2], sx[1]) + return (_call(:transpose, Any[x], new_shape), new_shape) + elseif perm == [0, 1] + return (x, sx) + else + error("Transpose: unsupported perm $perm for 2-D input") + end + else + error("Transpose: tensors with ndim > 2 are not supported (got $sx)") + end +end + function _convert_matmul(node, env) a, sa = env[node.input[1]] b, sb = env[node.input[2]] diff --git a/src/Coloring/Coloring.jl b/src/Coloring/Coloring.jl index c97a7f4..c3315dd 100644 --- a/src/Coloring/Coloring.jl +++ b/src/Coloring/Coloring.jl @@ -30,7 +30,7 @@ IndexedSet(n::Integer) = IndexedSet(zeros(Int, n), trues(n), 0) function Base.push!(v::IndexedSet, i::Integer) if v.empty[i] # new index - v.nzidx[v.nnz += 1] = i + v.nzidx[v.nnz+=1] = i v.empty[i] = false end return diff --git a/src/JuMP/operators.jl b/src/JuMP/operators.jl index 04b5a13..b0555bb 100644 --- a/src/JuMP/operators.jl +++ b/src/JuMP/operators.jl @@ -80,6 +80,23 @@ function Base.broadcasted( return Base.broadcasted(^, x, y) end +function _transpose(x::AbstractJuMPArray{T,N}) where {T,N} + V = JuMP.variable_ref_type(x) + if N == 1 + return GenericArrayExpr{V,2}(:transpose, Any[x], (1, size(x, 1)), false) + end + @assert N == 2 "`transpose` only supports 1-D and 2-D arrays" + return GenericArrayExpr{V,2}( + :transpose, + Any[x], + (size(x, 2), size(x, 1)), + false, + ) +end + +LinearAlgebra.transpose(x::AbstractJuMPArray) = _transpose(x) +LinearAlgebra.adjoint(x::AbstractJuMPArray) = _transpose(x) + function Base.sum(x::AbstractJuMPArray; dims = Colon()) V = JuMP.variable_ref_type(x) if dims === Colon() diff --git a/src/operators.jl b/src/operators.jl index b962e5d..e476499 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -19,6 +19,7 @@ const DEFAULT_MULTIVARIATE_OPERATORS = [ :sum, :row, :sum_dims, + :transpose, ] function _validate_register_assumptions( diff --git a/src/reverse_mode.jl b/src/reverse_mode.jl index 7dd620e..f83ef6a 100644 --- a/src/reverse_mode.jl +++ b/src/reverse_mode.jl @@ -439,6 +439,16 @@ function _forward_eval( sum!, tuple(), ) + elseif node.index == 18 # transpose + @assert N == 1 "`transpose` expects a single child" + arr_ix = children_arr[first(children_indices)] + _reshape_call( + f.forward_storage, + f.sizes, + (k, arr_ix), + _forward_transpose!, + tuple(), + ) elseif node.index <= length(operators.multivariate_operators) && haskey( operators.chainrules_operators, @@ -813,6 +823,40 @@ function _reverse_sum_dims!(rev_arr, rev_parent) return end +# Forward for `:transpose`. The child can be 1-D (treated as a column, +# producing a row matrix `(1, n)`) or 2-D `(m, n)` producing `(n, m)`. +# Hand-rolled loops because `permutedims!` allocates on the `ReshapedArray` +# views returned by `_view_matrix`. +function _forward_transpose!(out, x) + if ndims(x) == 1 + for j in eachindex(x) + out[1, j] = x[j] + end + else + m, n = size(x) + for j in 1:n, i in 1:m + out[j, i] = x[i, j] + end + end + return +end + +# Reverse for `:transpose`. `y = xᵀ`, so ∂L/∂x[i,j] = ∂L/∂y[j,i]; the inverse +# permutation lifts the parent adjoint back to the child's shape. +function _reverse_transpose!(rev_arr, rev_parent) + if ndims(rev_arr) == 1 + for j in eachindex(rev_arr) + rev_arr[j] = rev_parent[1, j] + end + else + m, n = size(rev_arr) + for j in 1:n, i in 1:m + rev_arr[i, j] = rev_parent[j, i] + end + end + return +end + """ _reverse_eval(f::_SubexpressionStorage) @@ -1063,6 +1107,17 @@ function _reverse_eval( tuple(), ) continue + elseif op == :transpose + @assert length(children_indices) == 1 "`transpose` expects a single child" + arr_ix = children_arr[first(children_indices)] + _reshape_call( + f.reverse_storage, + f.sizes, + (arr_ix, k), + _reverse_transpose!, + tuple(), + ) + continue end elseif node.type == NODE_CALL_MULTIVARIATE && haskey(f.chainrules_pullbacks, k) diff --git a/src/sizes.jl b/src/sizes.jl index d168638..dfa04ee 100644 --- a/src/sizes.jl +++ b/src/sizes.jl @@ -497,6 +497,15 @@ function infer_sizes(::typeof(_row_op), shapes...) return (1, length(shapes)) end +# transpose: vector (n,) → row matrix (1, n); matrix (m, n) → (n, m). +function infer_sizes(::typeof(LinearAlgebra.transpose), shape) + if length(shape) == 1 + return (1, shape[1]) + end + @assert length(shape) == 2 "`transpose` only supports 1-D and 2-D inputs" + return (shape[2], shape[1]) +end + # Map a built-in operator symbol to its Julia function so `infer_sizes` can # dispatch on `typeof(fn)`. Returns `nothing` for `:sum_dims`, whose shape # depends on the constant dims vector and is handled inline by `_infer_sizes`. diff --git a/test/JuMP.jl b/test/JuMP.jl index 8ab0620..a5d9991 100644 --- a/test/JuMP.jl +++ b/test/JuMP.jl @@ -555,6 +555,75 @@ function test_sum_dims_along_cols() return end +function test_transpose_matrix() + rows, cols = 3, 2 + model = Model() + @variable(model, W[1:rows, 1:cols], container = ArrayDiff.ArrayOfVariables) + expr = transpose(W) + @test expr isa ArrayDiff.MatrixExpr + @test expr.head == :transpose + @test size(expr) == (cols, rows) + @test expr.broadcasted == false + x = Float64.(collect(1:(rows*cols))) + W_val = reshape(x, rows, cols) + # f(W) = ‖Wᵀ‖_F = ‖W‖_F; gradient is W ./ ‖W‖_F in column-major order. + sizes, val, g = _eval(model, LinearAlgebra.norm(expr), x) + @test val ≈ LinearAlgebra.norm(W_val) + @test g ≈ x ./ LinearAlgebra.norm(W_val) + # Tape: norm (k=1, scalar) → transpose (k=2, (cols, rows)). + @test sizes.ndims[1] == 0 + @test sizes.ndims[2] == 2 + t_off = sizes.size_offset[2] + @test sizes.size[t_off+1] == cols + @test sizes.size[t_off+2] == rows + return +end + +function test_transpose_matrix_inner_product() + # f(W) = sum(Wᵀ .* C) where C is a constant of shape (cols, rows). + # ∂f/∂W[i,j] = C[j,i]. + rows, cols = 2, 3 + model = Model() + @variable(model, W[1:rows, 1:cols], container = ArrayDiff.ArrayOfVariables) + C = reshape(Float64.(collect(1:(rows*cols))) .+ 0.5, cols, rows) + expr = sum(transpose(W) .* C) + x = Float64.(collect(1:(rows*cols))) + W_val = reshape(x, rows, cols) + _, val, g = _eval(model, expr, x) + @test val ≈ sum(transpose(W_val) .* C) + # Gradient column-major: g[(j-1)*rows + i] = C[j, i]. + expected = vec([C[j, i] for i in 1:rows, j in 1:cols]) + @test g ≈ expected + return +end + +function test_transpose_vector() + n = 4 + model = Model() + @variable(model, x[1:n], container = ArrayDiff.ArrayOfVariables) + expr = transpose(x) + @test expr isa ArrayDiff.MatrixExpr + @test expr.head == :transpose + @test size(expr) == (1, n) + xv = Float64.(collect(1:n)) + # f(x) = ‖xᵀ‖ = ‖x‖; ∂f/∂x[i] = x[i] / ‖x‖. + _, val, g = _eval(model, LinearAlgebra.norm(expr), xv) + @test val ≈ LinearAlgebra.norm(xv) + @test g ≈ xv ./ LinearAlgebra.norm(xv) + return +end + +function test_transpose_adjoint_alias() + rows, cols = 2, 3 + model = Model() + @variable(model, W[1:rows, 1:cols], container = ArrayDiff.ArrayOfVariables) + expr = adjoint(W) + @test expr isa ArrayDiff.MatrixExpr + @test expr.head == :transpose + @test size(expr) == (cols, rows) + return +end + function test_broadcast_nonsquare_matrix() model = Model() @variable(model, W[1:2, 1:3], container = ArrayDiff.ArrayOfVariables) diff --git a/test/ONNXExt.jl b/test/ONNXExt.jl index ada591c..d046506 100644 --- a/test/ONNXExt.jl +++ b/test/ONNXExt.jl @@ -55,6 +55,14 @@ function _attr_int(name::String, i::Integer) return ONNX.AttributeProto(name = name, i = Int64(i), var"#type" = AT.INT) end +function _attr_ints(name::String, ints::AbstractVector{<:Integer}) + return ONNX.AttributeProto( + name = name, + ints = Int64[Int64(i) for i in ints], + var"#type" = AT.INTS, + ) +end + function _attr_tensor(name::String, t::ONNX.TensorProto) return ONNX.AttributeProto(name = name, t = t, var"#type" = AT.TENSOR) end @@ -782,6 +790,111 @@ function test_float32_sigmoid_and_gemm() @test val ≈ fjulia(xv) rtol = 1.0f-5 end +# Transpose on a variable matrix input followed by matmul against itself: +# y = Wᵀ * v, where W is the (2,3) variable matrix and v is a constant. +# f(W) = ‖y‖² = ‖Wᵀ v‖² ⇒ ∂f/∂W = 2 v (Wᵀ v)ᵀ = 2 v vᵀ W. +function test_transpose_variable_matrix() + vars = [MOI.VariableIndex(i) for i in 1:6] + var_mat = collect(reshape(vars, 2, 3)) + v = [0.4, -1.0] + init = _make_tensor("v", v) + n1 = _make_node( + "Transpose", + ["x"], + ["xT"]; + attrs = [_attr_ints("perm", [1, 0])], + ) + n2 = _make_node("MatMul", ["xT", "v"], ["y"]) + proto = _build_model([n1, n2], ["x"], ["y"]; initializers = [init]) + xv = [0.3, -0.7, 1.1, 2.0, 0.5, -1.5] + val, g = _eval_with_gradient(proto, vars, xv; input = var_mat) + fjulia(x) = sum((reshape(x, 2, 3)' * v) .^ 2) + @test val ≈ fjulia(xv) + @test g ≈ ForwardDiff.gradient(fjulia, xv) +end + +# Transpose with no `perm` attribute (defaults to reverse-all): for 2-D this +# is the (1,0) swap, same as explicit perm=[1,0]. +function test_transpose_default_perm() + vars = [MOI.VariableIndex(i) for i in 1:6] + var_mat = collect(reshape(vars, 2, 3)) + v = [0.4, -1.0] + init = _make_tensor("v", v) + n1 = _make_node("Transpose", ["x"], ["xT"]) # no perm attr + n2 = _make_node("MatMul", ["xT", "v"], ["y"]) + proto = _build_model([n1, n2], ["x"], ["y"]; initializers = [init]) + xv = [0.3, -0.7, 1.1, 2.0, 0.5, -1.5] + val, g = _eval_with_gradient(proto, vars, xv; input = var_mat) + fjulia(x) = sum((reshape(x, 2, 3)' * v) .^ 2) + @test val ≈ fjulia(xv) + @test g ≈ ForwardDiff.gradient(fjulia, xv) +end + +# Transpose with perm=[0,1] is a no-op; the result is `x` unchanged. +function test_transpose_identity_perm() + vars = [MOI.VariableIndex(i) for i in 1:6] + var_mat = collect(reshape(vars, 2, 3)) + n1 = _make_node( + "Transpose", + ["x"], + ["y"]; + attrs = [_attr_ints("perm", [0, 1])], + ) + proto = _build_model([n1], ["x"], ["y"]) + xv = [0.3, -0.7, 1.1, 2.0, 0.5, -1.5] + val, g = _eval_with_gradient(proto, vars, xv; input = var_mat) + fjulia(x) = sum(x .^ 2) + @test val ≈ fjulia(xv) + @test g ≈ ForwardDiff.gradient(fjulia, xv) +end + +# Transpose of a constant initializer: folded at conversion time. Verify the +# graph still produces the correct value/gradient. +function test_transpose_constant_initializer() + vars = [MOI.VariableIndex(i) for i in 1:3] + W = [ + 0.4 -0.1 0.5; + 1.2 -0.3 0.7 + ] # (2, 3) + init = _make_tensor("W", W) + n1 = _make_node( + "Transpose", + ["W"], + ["Wt"]; + attrs = [_attr_ints("perm", [1, 0])], + ) + n2 = _make_node("MatMul", ["Wt", "x"], ["y"]) # (3,2) * (2,) = (3,) + proto = _build_model([n1, n2], ["x"], ["y"]; initializers = [init]) + xv = [0.6, -0.4] + # Vec × Mat path: ArrayDiff routes Vec × Mat through a transpose trick that + # requires the matrix to be constant — `Wt` here is constant after folding. + # But MatMul Mat × Vec needs Wt as the LHS and x as the RHS. Since `x` is + # the variable vector, `Wt * x` is (3,2) × (2,) — Mat × Vec, which is fine. + val, g = _eval_with_gradient(proto, vars[1:2], xv) + fjulia(x) = sum((W' * x) .^ 2) + @test val ≈ fjulia(xv) + @test g ≈ ForwardDiff.gradient(fjulia, xv) +end + +# Higher-dim Transpose isn't supported. +function test_transpose_3d_errors() + init = _make_tensor("t", reshape(collect(1.0:6.0), 2, 3)) + # Pretend the test input is 3-D by lying about it. Easier: try transposing + # a graph input declared as a 3-D Matrix-of-vars-shaped wrapping — but the + # wrapper itself rejects 3-D inputs. So instead exercise the path via the + # error inside `_convert_transpose` directly using a fake env entry. + ext = Base.get_extension(ArrayDiff, :ArrayDiffONNXExt) + fake_env = Dict{String,Tuple{Any,Tuple{Vararg{Int}}}}() + fake_env["x"] = (zeros(2, 2, 2), (2, 2, 2)) + node = _make_node( + "Transpose", + ["x"], + ["y"]; + attrs = [_attr_ints("perm", [2, 1, 0])], + ) + @test_throws ErrorException ext._convert_transpose(node, fake_env) +end + # Multi-output graph: result is keyed by output name. function test_multi_output() vars = [MOI.VariableIndex(i) for i in 1:3]