Skip to content
This repository has been archived by the owner on Sep 12, 2023. It is now read-only.

Commit

Permalink
Format code
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Jun 21, 2023
1 parent 3e0ebc3 commit e772d4c
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 54 deletions.
56 changes: 46 additions & 10 deletions src/Numerics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,28 @@ using LinearAlgebra
using UUIDs: uuid4

# TODO test array container typevar on output
for op in [:+, :-, :*, :/, :\, :^, :÷, :fld, :cld, :mod, :%, :fldmod, :fld1, :mod1, :fldmod1, ://, :gcd, :lcm, :gcdx, :widemul]
for op in [
:+,
:-,
:*,
:/,
:\,
:^,
:÷,
:fld,
:cld,
:mod,
:%,
:fldmod,
:fld1,
:mod1,
:fldmod1,
://,
:gcd,
:lcm,
:gcdx,
:widemul,
]
@eval Base.$op(a::Tensor{A,0}, b::Tensor{B,0}) where {A,B} = broadcast($op, a, b)
end

Expand All @@ -12,7 +33,7 @@ end
Perform tensor contraction operation.
"""
function contract(a::Tensor, b::Tensor; dims=((labels(a), labels(b))))
function contract(a::Tensor, b::Tensor; dims = ((labels(a), labels(b))))
ia = labels(a)
ib = labels(b)
i = (dims, ia, ib)
Expand All @@ -25,7 +46,7 @@ function contract(a::Tensor, b::Tensor; dims=(∩(labels(a), labels(b))))
return Tensor(data, ic)
end

function contract(a::Tensor; dims=nonunique(labels(a)))
function contract(a::Tensor; dims = nonunique(labels(a)))
ia = labels(a)
i = (dims, ia)

Expand All @@ -39,7 +60,8 @@ end

contract(a::Union{T,AbstractArray{T,0}}, b::Tensor{T}) where {T} = contract(Tensor(a), b)
contract(a::Tensor{T}, b::Union{T,AbstractArray{T,0}}) where {T} = contract(a, Tensor(b))
contract(a::AbstractArray{<:Any,0}, b::AbstractArray{<:Any,0}) = contract(Tensor(a), Tensor(b)) |> only
contract(a::AbstractArray{<:Any,0}, b::AbstractArray{<:Any,0}) =
contract(Tensor(a), Tensor(b)) |> only
contract(a::Number, b::Number) = contract(fill(a), fill(b))

"""
Expand All @@ -51,7 +73,7 @@ Base.:*(a::Tensor, b::Tensor) = contract(a, b)
Base.:*(a::Tensor, b) = contract(a, b)
Base.:*(a, b::Tensor) = contract(a, b)

LinearAlgebra.svd(t::Tensor; left_inds=(), kwargs...) = svd(t, left_inds; kwargs...)
LinearAlgebra.svd(t::Tensor; left_inds = (), kwargs...) = svd(t, left_inds; kwargs...)

function LinearAlgebra.svd(t::Tensor, left_inds; kwargs...)
if isempty(left_inds)
Expand All @@ -69,7 +91,11 @@ function LinearAlgebra.svd(t::Tensor, left_inds; kwargs...)

# permute array
tensor = permutedims(t, (left_inds..., right_inds...))
data = reshape(parent(tensor), prod(i -> size(t, i), left_inds), prod(i -> size(t, i), right_inds))
data = reshape(
parent(tensor),
prod(i -> size(t, i), left_inds),
prod(i -> size(t, i), right_inds),
)

# compute SVD
U, s, V = svd(data; kwargs...)
Expand All @@ -89,19 +115,29 @@ function LinearAlgebra.svd(t::Tensor, left_inds; kwargs...)
return U, s, Vt
end

LinearAlgebra.qr(t::Tensor; left_inds=(), kwargs...) = qr(t, left_inds; kwargs...)
LinearAlgebra.qr(t::Tensor; left_inds = (), kwargs...) = qr(t, left_inds; kwargs...)

function LinearAlgebra.qr(t::Tensor, left_inds; virtualind::Symbol=Symbol(uuid4()), kwargs...)
function LinearAlgebra.qr(
t::Tensor,
left_inds;
virtualind::Symbol = Symbol(uuid4()),
kwargs...,
)
# TODO better error exception and checks
isempty(left_inds) && throw(ErrorException("no left-indices in QR factorization"))
left_inds labels(t) || throw(ErrorException("all left-indices must be in $(labels(t))"))
left_inds labels(t) ||
throw(ErrorException("all left-indices must be in $(labels(t))"))

right_inds = setdiff(labels(t), left_inds)
isempty(right_inds) && throw(ErrorException("no right-indices in QR factorization"))

# permute array
tensor = permutedims(t, (left_inds..., right_inds...))
data = reshape(parent(tensor), prod(i -> size(t, i), left_inds), prod(i -> size(t, i), right_inds))
data = reshape(
parent(tensor),
prod(i -> size(t, i), left_inds),
prod(i -> size(t, i), right_inds),
)

# compute QR
Q, R = qr(data; kwargs...)
Expand Down
53 changes: 37 additions & 16 deletions src/Tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,17 @@ struct Tensor{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
labels::NTuple{N,Symbol}
meta::Dict{Symbol,Any}

function Tensor{T,N,A}(data::A, labels::NTuple{N,Symbol}; meta...) where {T,N,A<:AbstractArray{T,N}}
function Tensor{T,N,A}(
data::A,
labels::NTuple{N,Symbol};
meta...,
) where {T,N,A<:AbstractArray{T,N}}
meta = Dict{Symbol,Any}(meta...)
haskey(meta, :tags) || (meta[:tags] = Set{String}())
all(i -> allequal(Iterators.map(dim -> size(data, dim), findall(==(i), labels))), nonunique(collect(labels))) ||
throw(DimensionMismatch("nonuniform size of repeated indices"))
all(
i -> allequal(Iterators.map(dim -> size(data, dim), findall(==(i), labels))),
nonunique(collect(labels)),
) || throw(DimensionMismatch("nonuniform size of repeated indices"))

new{T,N,A}(data, labels, meta)
end
Expand All @@ -43,7 +49,13 @@ function Base.similar(t::Tensor{_,N}, ::Type{T}; kwargs...) where {_,T,N}
end
end
# TODO fix this
function Base.similar(t::Tensor, ::Type{T}, dims::Int64...; labels=labels(t), meta...) where {T}
function Base.similar(
t::Tensor,
::Type{T},
dims::Int64...;
labels = labels(t),
meta...,
) where {T}
data = similar(parent(t), T, dims)

# copy metadata
Expand Down Expand Up @@ -124,21 +136,26 @@ Base.axes(t::Tensor, d) = axes(parent(t), dim(t, d))
Base.strides(t::Tensor) = strides(parent(t))
Base.stride(t::Tensor, i::Symbol) = stride(parent(t), dim(t, i))

Base.unsafe_convert(::Type{Ptr{T}}, t::Tensor{T}) where {T} = Base.unsafe_convert(Ptr{T}, parent(t))
Base.unsafe_convert(::Type{Ptr{T}}, t::Tensor{T}) where {T} =
Base.unsafe_convert(Ptr{T}, parent(t))

Base.elsize(T::Type{<:Tensor}) = elsize(parenttype(T))

# Broadcasting
Base.BroadcastStyle(::Type{T}) where {T<:Tensor} = ArrayStyle{T}()

function Base.similar(bc::Broadcasted{ArrayStyle{Tensor{T,N,A}}}, ::Type{ElType}) where {T,N,A,ElType}
function Base.similar(
bc::Broadcasted{ArrayStyle{Tensor{T,N,A}}},
::Type{ElType},
) where {T,N,A,ElType}
# NOTE already checked if dimension mismatch
# TODO throw on label mismatch?
tensor = first(arg for arg in bc.args if arg isa Tensor{T,N,A})
similar(tensor, ElType)
end

Base.selectdim(t::Tensor, d::Integer, i) = Tensor(selectdim(parent(t), d, i), labels(t); t.meta...)
Base.selectdim(t::Tensor, d::Integer, i) =
Tensor(selectdim(parent(t), d, i), labels(t); t.meta...)
function Base.selectdim(t::Tensor, d::Integer, i::Integer)
data = selectdim(parent(t), d, i)
indices = [label for (i, label) in enumerate(labels(t)) if i != d]
Expand All @@ -147,16 +164,21 @@ end

Base.selectdim(t::Tensor, d::Symbol, i) = selectdim(t, dim(t, d), i)

Base.permutedims(t::Tensor, perm) = Tensor(permutedims(parent(t), perm), getindex.((labels(t),), perm); t.meta...)
Base.permutedims!(dest::Tensor, src::Tensor, perm) = permutedims!(parent(dest), parent(src), perm)
Base.permutedims(t::Tensor, perm) =
Tensor(permutedims(parent(t), perm), getindex.((labels(t),), perm); t.meta...)
Base.permutedims!(dest::Tensor, src::Tensor, perm) =
permutedims!(parent(dest), parent(src), perm)

function Base.permutedims(t::Tensor{T,N}, perm::NTuple{N,Symbol}) where {T,N}
perm = map(i -> findfirst(==(i), labels(t)), perm)
permutedims(t, perm)
end

Base.view(t::Tensor, i...) =
Tensor(view(parent(t), i...), [label for (label, j) in zip(labels(t), i) if !(j isa Integer)]; t.meta...)
Base.view(t::Tensor, i...) = Tensor(
view(parent(t), i...),
[label for (label, j) in zip(labels(t), i) if !(j isa Integer)];
t.meta...,
)

function Base.view(t::Tensor, inds::Pair{Symbol,<:Any}...)
indices = map(labels(t)) do ind
Expand All @@ -165,7 +187,8 @@ function Base.view(t::Tensor, inds::Pair{Symbol,<:Any}...)
end

let data = view(parent(t), indices...),
labels = [label for (index, label) in zip(indices, labels(t)) if !(index isa Integer)]
labels =
[label for (index, label) in zip(indices, labels(t)) if !(index isa Integer)]

Tensor(data, labels; t.meta...)
end
Expand All @@ -174,8 +197,6 @@ end
Base.adjoint(t::Tensor) = Tensor(conj(parent(t)), labels(t); t.meta...)

# NOTE: Maybe use transpose for lazy transposition ?
Base.transpose(t::Tensor{T,1,A}) where {T,A<:AbstractArray{T,1}} =
permutedims(t, (1,))
Base.transpose(t::Tensor{T,1,A}) where {T,A<:AbstractArray{T,1}} = permutedims(t, (1,))

Base.transpose(t::Tensor{T,2,A}) where {T,A<:AbstractArray{T,2}} =
permutedims(t, (2, 1))
Base.transpose(t::Tensor{T,2,A}) where {T,A<:AbstractArray{T,2}} = permutedims(t, (2, 1))
8 changes: 6 additions & 2 deletions src/Tensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@ end

@static if !isdefined(Base, :get_extension)
function __init__()
@require ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" include("../ext/TensorsChainRulesCoreExt.jl")
@require FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" include("../ext/TensorsFiniteDifferencesExt.jl")
@require ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" include(
"../ext/TensorsChainRulesCoreExt.jl",
)
@require FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" include(
"../ext/TensorsFiniteDifferencesExt.jl",
)
end
end

Expand Down
5 changes: 3 additions & 2 deletions test/Metadata_test.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
@testset "Metadata" begin
@testset "tags" begin
tensor = Tensor(zeros(2, 2, 2), (:i, :j, :k), tags=Set{String}(["TAG_A", "TAG_B"]))
tensor =
Tensor(zeros(2, 2, 2), (:i, :j, :k), tags = Set{String}(["TAG_A", "TAG_B"]))

@test issetequal(tags(tensor), ["TAG_A", "TAG_B"])

Expand All @@ -12,4 +13,4 @@

@test untag!(tensor, "TAG_UNEXISTANT") == tags(tensor)
end
end
end
14 changes: 7 additions & 7 deletions test/Numerics_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
@testset "axis sum" begin
A = Tensor(rand(2, 3, 4), (:i, :j, :k))

C = contract(A, dims=(:i,))
C = contract(A, dims = (:i,))
C_ein = ein"ijk -> jk"(A)
@test labels(C) == (:j, :k)
@test size(C) == size(C_ein) == (3, 4)
Expand All @@ -62,7 +62,7 @@
@testset "diagonal" begin
A = Tensor(rand(2, 3, 2), (:i, :j, :i))

C = contract(A, dims=())
C = contract(A, dims = ())
C_ein = ein"iji -> ij"(A)
@test labels(C) == (:i, :j)
@test size(C) == size(C_ein) == (2, 3)
Expand All @@ -72,7 +72,7 @@
@testset "trace" begin
A = Tensor(rand(2, 3, 2), (:i, :j, :i))

C = contract(A, dims=(:i,))
C = contract(A, dims = (:i,))
C_ein = ein"iji -> j"(A)
@test labels(C) == (:j,)
@test size(C) == size(C_ein) == (3,)
Expand Down Expand Up @@ -132,14 +132,14 @@
B = Tensor(rand(4, 5, 3), (:k, :l, :j))

# Contraction of all common indices
C = contract(A, B, dims=(:j, :k))
C = contract(A, B, dims = (:j, :k))
C_ein = ein"ijk, klj -> il"(A, B)
@test labels(C) == (:i, :l)
@test size(C) == (2, 5) == size(C_ein)
@test C C_ein

# Contraction of not all common indices
C = contract(A, B, dims=(:j,))
C = contract(A, B, dims = (:j,))
C_ein = ein"ijk, klj -> ikl"(A, B)
@test labels(C) == (:i, :k, :l)
@test size(C) == (2, 4, 5) == size(C_ein)
Expand All @@ -149,7 +149,7 @@
A = Tensor(rand(Complex{Float64}, 2, 3, 4), (:i, :j, :k))
B = Tensor(rand(Complex{Float64}, 4, 5, 3), (:k, :l, :j))

C = contract(A, B, dims=(:j, :k))
C = contract(A, B, dims = (:j, :k))
C_ein = ein"ijk, klj -> il"(A, B)
@test labels(C) == (:i, :l)
@test size(C) == (2, 5) == size(C_ein)
Expand All @@ -168,7 +168,7 @@
# Throw exception if left_inds ∉ labels(tensor)
@test_throws ErrorException qr(tensor, (:l,))
# throw exception if no right-inds
@test_throws ErrorException qr(tensor, (:i,:j,:k))
@test_throws ErrorException qr(tensor, (:i, :j, :k))
end

@testset "labels" begin
Expand Down
20 changes: 12 additions & 8 deletions test/Tensor_test.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
@testset "Tensor" begin
@testset "Constructors" begin
@testset "Number" begin
tensor = Tensor(1.0, tags=Set(["TEST"]))
tensor = Tensor(1.0, tags = Set(["TEST"]))
@test labels(tensor) == ()
@test parent(tensor) == fill(1.0)
@test hastag(tensor, "TEST")
Expand Down Expand Up @@ -53,7 +53,11 @@
@test parent(replace(tensor, :a => :u, :b => :v, :c => :w)) === parent(tensor)

# :alias in meta
tensor = Tensor(zeros(2, 2, 2), (:i, :j, :k); alias=Dict(:left => :i, :right => :j, :up => :k))
tensor = Tensor(
zeros(2, 2, 2),
(:i, :j, :k);
alias = Dict(:left => :i, :right => :j, :up => :k),
)

replaced_tensor = replace(tensor, :i => :u, :j => :v, :k => :w)
@test labels(replaced_tensor) == (:u, :v, :w)
Expand Down Expand Up @@ -134,7 +138,7 @@
@test first(tensor) == first(data)
@test last(tensor) == last(data)
@test tensor[1, :, 2] == data[1, :, 2]
@test tensor[i=1, k=2] == data[1, :, 2]
@test tensor[i = 1, k = 2] == data[1, :, 2]

tensor[1] = 0
@test tensor[1] == data[1]
Expand All @@ -153,7 +157,7 @@
@testset "adjoint" begin
@testset "Vector" begin
data = rand(Complex{Float64}, 2)
tensor = Tensor(data, (:i,); test="TEST")
tensor = Tensor(data, (:i,); test = "TEST")

@test adjoint(tensor) |> labels == labels(tensor)
@test adjoint(tensor) |> ndims == 1
Expand All @@ -166,7 +170,7 @@
using LinearAlgebra: tr

data = rand(Complex{Float64}, 2, 2)
tensor = Tensor(data, (:i, :j); test="TEST")
tensor = Tensor(data, (:i, :j); test = "TEST")

@test adjoint(tensor) |> labels == labels(tensor)
@test adjoint(tensor) |> ndims == 2
Expand All @@ -179,7 +183,7 @@
@testset "transpose" begin
@testset "Vector" begin
data = rand(Complex{Float64}, 2)
tensor = Tensor(data, (:i,); test="TEST")
tensor = Tensor(data, (:i,); test = "TEST")

@test transpose(tensor) |> labels == labels(tensor)
@test transpose(tensor) |> ndims == 1
Expand All @@ -192,7 +196,7 @@
using LinearAlgebra: tr

data = rand(Complex{Float64}, 2, 2)
tensor = Tensor(data, (:i, :j); test="TEST")
tensor = Tensor(data, (:i, :j); test = "TEST")

@test transpose(tensor) |> labels == (:j, :i)
@test transpose(tensor) |> ndims == 2
Expand All @@ -201,4 +205,4 @@
@test isapprox(only(transpose(tensor) * tensor), tr(transpose(data) * data))
end
end
end
end
Loading

0 comments on commit e772d4c

Please sign in to comment.