Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support boxdot with n neighboring indices #22

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 60 additions & 40 deletions src/TensorCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ function tensor!(dest::AbstractArray, A::AbstractArray, B::AbstractArray)
return dest
end

export boxdot, ⊡, boxdot!
export boxdot, ⊡, ⊡₂, boxdot!

"""
boxdot(A,B) = A ⊡ B # \\boxdot
Expand Down Expand Up @@ -177,40 +177,55 @@ Float64
```
See also `boxdot!(Y,A,B)`, which is to `⊡` as `mul!` is to `*`.
"""
function boxdot(A::AbstractArray, B::AbstractArray)
Amat = _squash_left(A)
Bmat = _squash_right(B)
function boxdot(A::AbstractArray, B::AbstractArray, nth::Val)
_check_boxdot_axes(A, B, nth)
Amat = _squash_left(A, nth)
Bmat = _squash_right(B, nth)

axA, axB = axes(Amat,2), axes(Bmat,1)
axA == axB || _throw_dmm(axA, axB)

return _boxdot_reshape(Amat * Bmat, A, B)
return _boxdot_reshape(Amat * Bmat, A, B, nth)
end

boxdot(A::AbstractArray, B::AbstractArray) = boxdot(A, B, Val(1))
boxdot2(A::AbstractArray, B::AbstractArray) = boxdot(A, B, Val(2))

const ⊡ = boxdot
const ⊡₂ = boxdot2

@noinline _throw_dmm(axA, axB) = throw(DimensionMismatch("neighbouring axes of `A` and `B` must match, got $axA and $axB"))
@noinline _throw_boxdot_nth(n) = throw(ArgumentError("boxdot order should be ≥ 1, got $n"))

function _check_boxdot_axes(A::AbstractArray{<:Any,N}, B::AbstractArray{<:Any,M}, ::Val{K}) where {N,M,K}
K::Int
(K >= 1) || _throw_boxdot_nth(K)
for i in 1:K
axA, axB = axes(A)[N-K+i], axes(B)[i]
axA == axB || _throw_dmm(axA, axB)
end
end

_squash_left(A::AbstractArray) = reshape(A, :,size(A,ndims(A)))
_squash_left(A::AbstractMatrix) = A
_squash_left(A::AbstractArray, ::Val{N}) where {N} = reshape(A, prod(size(A)[1:end-N]),:)
_squash_left(A::AbstractMatrix, ::Val{1}) = A

_squash_right(B::AbstractArray) = reshape(B, size(B,1),:)
_squash_right(B::AbstractVecOrMat) = B
_squash_right(B::AbstractArray, ::Val{N}) where {N} = reshape(B, :,prod(size(B)[1+N:end]))
_squash_right(B::AbstractVecOrMat, ::Val{1}) = B

function _boxdot_reshape(AB::AbstractArray, A::AbstractArray{T,N}, B::AbstractArray{S,M}) where {T,N,S,M}
ax = ntuple(i -> i<N ? axes(A, i) : axes(B, i-N+2), Val(N+M-2))
function _boxdot_reshape(AB::AbstractArray, A::AbstractArray{T,N}, B::AbstractArray{S,M}, ::Val{K}) where {T,N,S,M,K}
N == M == K+1 && return AB # These can skip final reshape
ax = ntuple(i -> i≤N-K ? axes(A, i) : axes(B, i-N+2K), Val(N+M-2K))
reshape(AB, ax) # some cases don't come here, so this doesn't really support OffsetArrays
end

# These can skip final reshape:
_boxdot_reshape(AB::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat) = AB
_boxdot_reshape(AB::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat, ::Val{1}) = AB

# These produce scalar output:
function boxdot(A::AbstractVector, B::AbstractVector)
axA, axB = axes(A,1), axes(B,1)
axA == axB || _throw_dmm(axA, axB)
function boxdot(A::AbstractArray{<:Any,N}, B::AbstractArray{<:Any,N}, ::Val{N}) where {N}
_check_boxdot_axes(A, B, Val(N))
if eltype(A) <: Number
return transpose(A)*B
return transpose(vec(A))*vec(B)
else
return sum(a*b for (a,b) in zip(A,B))
end
Expand All @@ -224,30 +239,30 @@ boxdot(a::Number, b::Number) = a*b
using LinearAlgebra: AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec

# Adjont and Transpose, vectors or almost (returning a scalar)
boxdot(A::AdjointAbsVec, B::AbstractVector) = A * B
boxdot(A::TransposeAbsVec, B::AbstractVector) = A * B
boxdot(A::AdjointAbsVec, B::AbstractVector, ::Val{1}) = A * B
boxdot(A::TransposeAbsVec, B::AbstractVector, ::Val{1}) = A * B

boxdot(A::AbstractVector, B::AdjointAbsVec) = A ⊡ vec(B)
boxdot(A::AbstractVector, B::TransposeAbsVec) = A ⊡ vec(B)
boxdot(A::AbstractVector, B::AdjointAbsVec, ::Val{1}) = A ⊡ vec(B)
boxdot(A::AbstractVector, B::TransposeAbsVec, ::Val{1}) = A ⊡ vec(B)

boxdot(A::AdjointAbsVec, B::AdjointAbsVec) = adjoint(adjoint(B) ⊡ adjoint(A))
boxdot(A::AdjointAbsVec, B::TransposeAbsVec) = vec(A) ⊡ vec(B)
boxdot(A::TransposeAbsVec, B::AdjointAbsVec) = vec(A) ⊡ vec(B)
boxdot(A::TransposeAbsVec, B::TransposeAbsVec) = transpose(transpose(B) ⊡ transpose(A))
boxdot(A::AdjointAbsVec, B::AdjointAbsVec, ::Val{1}) = adjoint(adjoint(B) ⊡ adjoint(A))
boxdot(A::AdjointAbsVec, B::TransposeAbsVec, ::Val{1}) = vec(A) ⊡ vec(B)
boxdot(A::TransposeAbsVec, B::AdjointAbsVec, ::Val{1}) = vec(A) ⊡ vec(B)
boxdot(A::TransposeAbsVec, B::TransposeAbsVec, ::Val{1}) = transpose(transpose(B) ⊡ transpose(A))

# ... with a matrix (returning another such)
boxdot(A::AdjointAbsVec, B::AbstractMatrix) = A * B
boxdot(A::TransposeAbsVec, B::AbstractMatrix) = A * B
boxdot(A::AdjointAbsVec, B::AbstractMatrix, ::Val{1}) = A * B
boxdot(A::TransposeAbsVec, B::AbstractMatrix, ::Val{1}) = A * B

boxdot(A::AbstractMatrix, B::AdjointAbsVec) = (B' ⊡ A')'
boxdot(A::AbstractMatrix, B::TransposeAbsVec) = transpose(transpose(B) ⊡ transpose(A))
boxdot(A::AbstractMatrix, B::AdjointAbsVec, ::Val{1}) = (B' ⊡ A')'
boxdot(A::AbstractMatrix, B::TransposeAbsVec, ::Val{1}) = transpose(transpose(B) ⊡ transpose(A))

# ... and with higher-dim (returning a plain array)
boxdot(A::AdjointAbsVec, B::AbstractArray) = vec(A) ⊡ B
boxdot(A::TransposeAbsVec, B::AbstractArray) = vec(A) ⊡ B
boxdot(A::AdjointAbsVec, B::AbstractArray, ::Val{1}) = vec(A) ⊡ B
boxdot(A::TransposeAbsVec, B::AbstractArray, ::Val{1}) = vec(A) ⊡ B

boxdot(A::AbstractArray, B::AdjointAbsVec) = A ⊡ vec(B)
boxdot(A::AbstractArray, B::TransposeAbsVec) = A ⊡ vec(B)
boxdot(A::AbstractArray, B::AdjointAbsVec, ::Val{1}) = A ⊡ vec(B)
boxdot(A::AbstractArray, B::TransposeAbsVec, ::Val{1}) = A ⊡ vec(B)


"""
Expand All @@ -260,25 +275,30 @@ function boxdot! end

if VERSION < v"1.3" # Then 5-arg mul! isn't defined

function boxdot!(Y::AbstractArray, A::AbstractArray, B::AbstractArray)
szY = prod(size(A)[1:end-1]), prod(size(B)[2:end])
mul!(reshape(Y, szY), _squash_left(A), _squash_right(B))
function boxdot!(Y::AbstractArray, A::AbstractArray, B::AbstractArray, ::Val{N}) where {N}
_check_boxdot_axes(A, B, Val(N))
szY = prod(size(A)[1:end-N]), prod(size(B)[1+N:end])
mul!(reshape(Y, szY), _squash_left(A, Val(N)), _squash_right(B, Val(N)))
Y
end

boxdot!(Y::AbstractArray, A::AbstractArray, B::AdjOrTransAbsVec) = boxdot!(Y, A, vec(B))
boxdot!(Y::AbstractArray, A::AbstractArray, B::AbstractArray) = boxdot!(Y, A, B, Val(1))
boxdot!(Y::AbstractArray, A::AbstractArray, B::AdjOrTransAbsVec) = boxdot!(Y, A, vec(B), Val(1))

else

function boxdot!(Y::AbstractArray, A::AbstractArray, B::AbstractArray, α::Number=true, β::Number=false)
szY = prod(size(A)[1:end-1]), prod(size(B)[2:end])
mul!(reshape(Y, szY), _squash_left(A), _squash_right(B), α, β)
function boxdot!(Y::AbstractArray, A::AbstractArray, B::AbstractArray, ::Val{N}, α::Number=true, β::Number=false) where {N}
_check_boxdot_axes(A, B, Val(N))
szY = prod(size(A)[1:end-N]), prod(size(B)[1+N:end])
mul!(reshape(Y, szY), _squash_left(A, Val(N)), _squash_right(B, Val(N)), α, β)
Y
end

boxdot!(Y::AbstractArray, A::AbstractArray, B::AbstractArray, α::Number=true, β::Number=false) = boxdot!(Y, A, B, Val(1), α, β)

# For boxdot!, only where mul! behaves differently:
boxdot!(Y::AbstractArray, A::AbstractArray, B::AdjOrTransAbsVec,
α::Number=true, β::Number=false) = boxdot!(Y, A, vec(B), α, β)
α::Number=true, β::Number=false) = boxdot!(Y, A, vec(B), Val(1), α, β)

end

Expand Down
53 changes: 53 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,59 @@ end
@test boxdot!(similar(c,1,2), c', A) == c' * A

@test boxdot!(similar(c,1), c', d) == [dot(c, d)]

@testset "higher-order boxdot" begin
@test A ⊡₂ A isa Complex
@test boxdot(E3, E3, Val(3)) isa Complex
@test boxdot(F4, F4, Val(4)) isa Complex
@test A ⊡₂ A == sum(A .* A)
@test boxdot(E3, E3, Val(3)) == sum(E3 .* E3)
@test boxdot(F4, F4, Val(4)) == sum(F4 .* F4)

@test size(A ⊡₂ E3) == (2,)
@test A ⊡₂ E3 == vec(reshape(A, 1,:) * reshape(E3, :,2))
@test A ⊡₂ E3lazy == A ⊡₂ E3
@test E3 ⊡₂ A' == vec((A ⊡₂ E3adjoint)')
@test E3 ⊡₂ transpose(A) == A ⊡₂ conj(E3adjoint)

@test size(A ⊡₂ F4) == (2,2)
@test A ⊡₂ F4 == reshape(reshape(A, 1,:) * reshape(F4, :,4), 2,2)
@test A ⊡₂ F4lazy == A ⊡₂ F4
@test F4lazy ⊡₂ A == F4 ⊡₂ A

@test size(F4 ⊡₂ E3) == (2,2,2)
@test F4 ⊡₂ E3 == reshape(reshape(F4, 4,:) * reshape(E3, :,2), 2,2,2)
@test F4 ⊡₂ E3 == F4lazy ⊡₂ E3lazy

# In-place
@test boxdot!(similar(c), A, E3, Val(2)) == A ⊡₂ E3
if VERSION >= v"1.3"
@test boxdot!(similar(c), A, E3, Val(2), 100) == A ⊡₂ E3 * 100
@test boxdot!(copy(c), B, E3, Val(2), 100, -5) == B ⊡₂ E3 * 100 .- 5 .* c
end

@test boxdot!(similar(c,1), A, A, Val(2)) == [A ⊡₂ A]
@test boxdot!(similar(c,2,2), A, F4, Val(2)) == A ⊡₂ F4
@test boxdot!(similar(c,2,2,2), F4, E3, Val(2)) == F4 ⊡₂ E3

# Errors
@test_throws DimensionMismatch ones(2,2) ⊡₂ ones(3,2)
@test_throws DimensionMismatch ones(2,2) ⊡₂ ones(2,3)
@test_throws DimensionMismatch ones(2,2,2) ⊡₂ ones(2,3,2)
@test_throws BoundsError ones(2,2) ⊡₂ ones(2)
@test_throws BoundsError ones(2) ⊡₂ ones(2,2)
@test_throws ArgumentError boxdot(ones(2), ones(2), Val(-1))
@test_throws TypeError boxdot(ones(2), ones(2), Val(UInt(1)))

@test_throws DimensionMismatch boxdot!(similar(c,1), ones(2,2), ones(3,2), Val(2))
@test_throws DimensionMismatch boxdot!(similar(c,1), ones(2,2), ones(2,3), Val(2))
@test_throws DimensionMismatch boxdot!(similar(c,2,2), ones(2,2,2), ones(2,3,2), Val(2))
@test_throws BoundsError boxdot!(similar(c,1), ones(2,2), ones(2), Val(2))
@test_throws BoundsError boxdot!(similar(c,1), ones(2), ones(2,2), Val(2))
@test_throws DimensionMismatch boxdot!(similar(c,2,3), ones(2,2,3), ones(2,3,2), Val(2))
@test_throws ArgumentError boxdot!(similar(c,1), ones(2), ones(2), Val(-1))
@test_throws TypeError boxdot!(similar(c,1), ones(2), ones(2), Val(UInt(1)))
end
end

@testset "_adjoint" begin
Expand Down
Loading