https://github.com/JuliaLang/julia
Raw File
Tip revision: 1e7db5c9e16f989878a742a717679ad990e6f8e8 authored by Julia Windows Test Machine on 18 January 2014, 02:28:22 UTC
Fix a few typos on windows
Tip revision: 1e7db5c
tridiag.jl
#### Specialized matrix types ####

import Base.conj, Base.transpose, Base.ctranspose, Base.convert

## Hermitian tridiagonal matrices
type SymTridiagonal{T} <: AbstractMatrix{T}
    dv::Vector{T}                        # diagonal
    ev::Vector{T}                        # subdiagonal
    function SymTridiagonal(dv::Vector{T}, ev::Vector{T})
        length(ev)==length(dv)-1 || throw(DimensionMismatch(""))
        new(dv,ev)
    end
end

SymTridiagonal{T}(dv::Vector{T}, ev::Vector{T}) = SymTridiagonal{T}(copy(dv), copy(ev))

function SymTridiagonal{Td,Te}(dv::Vector{Td}, ev::Vector{Te})
    T = promote(Td,Te)
    SymTridiagonal(convert(Vector{T}, dv), convert(Vector{T}, ev))
end

SymTridiagonal(A::AbstractMatrix) = diag(A,1)==diag(A,-1)?SymTridiagonal(diag(A), diag(A,1)):throw(DimensionMismatch("matrix is not symmetric; cannot convert to SymTridiagonal"))
full{T}(M::SymTridiagonal{T}) = convert(Matrix{T}, M)
convert{T}(::Type{Matrix{T}}, M::SymTridiagonal{T})=diagm(M.dv)+diagm(M.ev,-1)+conj(diagm(M.ev,1))

function show(io::IO, S::SymTridiagonal)
    println(io, summary(S), ":")
    print(io, "diag: ")
    print_matrix(io, (S.dv)')
    print(io, "\n sub: ")
    print_matrix(io, (S.ev)')
end

size(m::SymTridiagonal) = (length(m.dv), length(m.dv))
size(m::SymTridiagonal, d::Integer) = d<1 ? error("dimension out of range") : (d<=2 ? length(m.dv) : 1)

#Elementary operations
copy(S::SymTridiagonal) = SymTridiagonal(copy(S.dv), copy(S.ev))
round(M::SymTridiagonal) = SymTridiagonal(round(M.dv), round(M.ev))
iround(M::SymTridiagonal) = SymTridiagonal(iround(M.dv), iround(M.ev))

conj(M::SymTridiagonal) = SymTridiagonal(conj(M.dv), conj(M.ev))
transpose(M::SymTridiagonal) = M #Identity operation
ctranspose(M::SymTridiagonal) = conj(M)

function diag{T}(M::SymTridiagonal{T}, n::Integer=0)
    absn = abs(n)
    if absn==0
        return M.dv
    elseif absn==1
        return M.ev
    elseif absn<size(M,1)
        return zeros(T,size(M,1)-absn)
    else
        throw(BoundsError())
    end
end

+(A::SymTridiagonal, B::SymTridiagonal) = SymTridiagonal(A.dv+B.dv, A.ev+B.ev)
-(A::SymTridiagonal, B::SymTridiagonal) = SymTridiagonal(A.dv-B.dv, A.ev-B.ev)
*(A::SymTridiagonal, B::Number) = SymTridiagonal(A.dv*B, A.ev*B)
*(B::Number, A::SymTridiagonal) = A*B
/(A::SymTridiagonal, B::Number) = SymTridiagonal(A.dv/B, A.ev/B)
==(A::SymTridiagonal, B::SymTridiagonal) = (A.dv==B.dv) && (A.ev==B.ev)

## Solver
function \{T<:BlasFloat}(M::SymTridiagonal{T}, rhs::StridedVecOrMat{T})
    if stride(rhs, 1) == 1
        return LAPACK.gtsv!(copy(M.ev), copy(M.dv), copy(M.ev), copy(rhs))
    end
    solve(Tridiagonal(M), rhs)  # use the Julia "fallback"
end

#Wrap LAPACK DSTE{GR,BZ} to compute eigenvalues
eig{T<:BlasFloat}(m::SymTridiagonal{T}) = LAPACK.stegr!('V', copy(m.dv), copy(m.ev))
eigvals{T<:BlasFloat}(m::SymTridiagonal{T}, il::Int, iu::Int) = LAPACK.stegr!('N', 'I', copy(m.dv), copy(m.ev), 0.0, 0.0, il, iu)[1]
eigvals{T<:BlasFloat}(m::SymTridiagonal{T}, vl::Real, vu::Real) = LAPACK.stegr!('N', 'V', copy(m.dv), copy(m.ev), vl, vu, 0, 0)[1]
eigvals{T<:BlasFloat}(m::SymTridiagonal{T}) = LAPACK.stev!('N', copy(m.dv), copy(m.ev))[1]

#Computes largest and smallest eigenvalue
eigmax(m::SymTridiagonal) = eigvals(m, size(m, 1), size(m, 1))[1]
eigmin(m::SymTridiagonal) = eigvals(m, 1, 1)[1]

#Compute selected eigenvectors only corresponding to particular eigenvalues
eigvecs(m::SymTridiagonal) = eig(m)[2]
eigvecs{T<:BlasFloat,Eigenvalue<:Real}(m::SymTridiagonal{T}, eigvals::Vector{Eigenvalue}) = LAPACK.stein!(m.dv, m.ev, eigvals)

###################
# Generic methods #
###################

#Needed for inv_usmani()
type ZeroOffsetVector
    data::Vector
end
getindex (a::ZeroOffsetVector, i) = a.data[i+1] 
setindex!(a::ZeroOffsetVector, x, i) = a.data[i+1]=x

#Implements the inverse using the recurrence relation between principal minors
# a, b, c are assumed to be the subdiagonal, diagonal, and superdiagonal of
# a tridiagonal matrix.
#Ref:
#    R. Usmani, "Inversion of a tridiagonal Jacobi matrix",
#    Linear Algebra and its Applications 212-213 (1994), pp.413-414
#    doi:10.1016/0024-3795(94)90414-6
function inv_usmani{T}(a::Vector{T}, b::Vector{T}, c::Vector{T})
    n = length(b)
    θ = ZeroOffsetVector(zeros(T, n+1)) #principal minors of A
    θ[0] = 1
    n>=1 && (θ[1] = b[1])
    for i=2:n
        θ[i] = b[i]*θ[i-1]-a[i-1]*c[i-1]*θ[i-2]
    end
    φ = zeros(T, n+1)
    φ[n+1] = 1
    n>=1 && (φ[n] = b[n])
    for i=n-1:-1:1
        φ[i] = b[i]*φ[i+1]-a[i]*c[i]*φ[i+2]
    end
    α = Array(T, n, n)
    for i=1:n, j=1:n
        sign = (i+j)%2==0 ? (+) : (-)
        if i<j
            α[i,j]=(sign)(prod(c[i:j-1]))*θ[i-1]*φ[j+1]/θ[n]
        elseif i==j
            α[i,i]=                       θ[i-1]*φ[i+1]/θ[n]
        else #i>j
            α[i,j]=(sign)(prod(a[j:i-1]))*θ[j-1]*φ[i+1]/θ[n]
        end
    end
    α 
end

#Implements the determinant using principal minors
#Inputs and reference are as above for inv_usmani()
function det_usmani{T}(a::Vector{T}, b::Vector{T}, c::Vector{T})
    n = length(b)
    θa = one(T)
    n==0 && return θa
    θb = b[1]
    for i=2:n
        θb, θa = b[i]*θb-a[i-1]*c[i-1]*θa, θb
    end
    return θb
end

inv(A::SymTridiagonal) = inv_usmani(A.ev, A.dv, A.ev)
det(A::SymTridiagonal) = det_usmani(A.ev, A.dv, A.ev)

## Tridiagonal matrices ##
type Tridiagonal{T} <: AbstractMatrix{T}
    dl::Vector{T}    # sub-diagonal
    d::Vector{T}     # diagonal
    du::Vector{T}    # sup-diagonal
    dutmp::Vector{T} # scratch space for vector RHS solver, sup-diagonal
    rhstmp::Vector{T}# scratch space, rhs

    function Tridiagonal(N::Integer)
        dutmp = Array(T, N-1)
        rhstmp = Array(T, N)
        new(dutmp, rhstmp, similar(dutmp), similar(dutmp), similar(rhstmp))
    end

    function Tridiagonal(dl::Vector{T}, d::Vector{T}, du::Vector{T})
        N = length(d)
        if (length(dl) != N-1 || length(du) != N-1)
            error(string("Cannot make Tridiagonal from incompatible lengths of subdiagonal, diagonal and superdiagonal: (", length(dl), ", ", length(d), ", ", length(du),")"))
        end
        new(copy(dl), copy(d), copy(du), Array(T,N-1), Array(T,N))
    end
end

Tridiagonal{T}(dl::Vector{T}, d::Vector{T}, du::Vector{T}) = Tridiagonal{T}(dl, d, du)

function Tridiagonal{Tl, Td, Tu}(dl::Vector{Tl}, d::Vector{Td}, du::Vector{Tu})
    R = promote(Tl, Td, Tu)
    Tridiagonal(convert(Vector{R}, dl), convert(Vector{R}, d), convert(Vector{R}, du))
end

size(M::Tridiagonal) = (length(M.d), length(M.d))
size(M::Tridiagonal, d::Integer) = d<1 ? error("dimension out of range") : (d<=2 ? length(M.d) : 1)

function show(io::IO, M::Tridiagonal)
    println(io, summary(M), ":")
    print(io, " sub: ")
    print_matrix(io, (M.dl)')
    print(io, "\ndiag: ")
    print_matrix(io, (M.d)')
    print(io, "\n sup: ")
    print_matrix(io, (M.du)')
end
full{T}(M::Tridiagonal{T}) = convert(Matrix{T}, M)
function convert{T}(::Type{Matrix{T}}, M::Tridiagonal{T})
    A = zeros(T, size(M))
    for i = 1:length(M.d)
        A[i,i] = M.d[i]
    end
    for i = 1:length(M.d)-1
        A[i+1,i] = M.dl[i]
        A[i,i+1] = M.du[i]
    end
    A
end
function similar(M::Tridiagonal, T, dims::Dims)
    if length(dims) != 2 || dims[1] != dims[2]
        throw(DimensionMismatch("Tridiagonal matrices must be square"))
    end
    Tridiagonal{T}(dims[1])
end

# Operations on Tridiagonal matrices
copy!(dest::Tridiagonal, src::Tridiagonal) = Tridiagonal(copy!(dest.dl, src.dl), copy!(dest.d, src.d), copy!(dest.du, src.du))
# copy(A::Tridiagonal) = Tridiagonal(copy(A.dl), copy(A.d), copy(A.du))
round(M::Tridiagonal) = Tridiagonal(round(M.dl), round(M.d), round(M.du))
iround(M::Tridiagonal) = Tridiagonal(iround(M.dl), iround(M.d), iround(M.du))

conj(M::Tridiagonal) = Tridiagonal(conj(M.dl), conj(M.d), conj(M.du))
transpose(M::Tridiagonal) = Tridiagonal(M.du, M.d, M.dl)
ctranspose(M::Tridiagonal) = conj(transpose(M))

function diag{T}(M::Tridiagonal{T}, n::Integer=0)
    if n==0
        return M.d 
    elseif n==-1
        return M.dl
    elseif n==1
        return M.du
    elseif -size(M,1)n<size(M,1)
        return zeros(T,size(M,1)-abs(n))
    else 
        throw(BoundsError())
    end
end

###################
# Generic methods #
###################

+(A::Tridiagonal, B::Tridiagonal) = Tridiagonal(A.dl+B.dl, A.d+B.d, A.du+B.du)
-(A::Tridiagonal, B::Tridiagonal) = Tridiagonal(A.dl-B.dl, A.d-B.d, A.du+B.du)
*(A::Tridiagonal, B::Number) = Tridiagonal(A.dl*B, A.d*B, A.du*B)
*(B::Number, A::SymTridiagonal) = A*B
/(A::Tridiagonal, B::Number) = Tridiagonal(A.dl/B, A.d/B, A.du/B)

==(A::Tridiagonal, B::Tridiagonal) = (A.dl==B.dl) && (A.d==B.d) && (A.du==B.du)
==(A::Tridiagonal, B::SymTridiagonal) = (A.dl==A.du==B.ev) && (A.d==B.dv)
==(A::SymTridiagonal, B::SymTridiagonal) = B==A

inv(A::Tridiagonal) = inv_usmani(A.dl, A.d, A.du)
det(A::Tridiagonal) = det_usmani(A.dl, A.d, A.du)

# Elementary operations that mix Tridiagonal and SymTridiagonal matrices
convert(::Type{Tridiagonal}, A::SymTridiagonal) = Tridiagonal(A.ev, A.dv, A.ev)
+(A::Tridiagonal, B::SymTridiagonal) = Tridiagonal(A.dl+B.ev, A.d+B.dv, A.du+B.ev)
+(A::SymTridiagonal, B::Tridiagonal) = Tridiagonal(A.ev+B.dl, A.dv+B.d, A.ev+B.du)
-(A::Tridiagonal, B::SymTridiagonal) = Tridiagonal(A.dl-B.ev, A.d-B.dv, A.du-B.ev)
-(A::SymTridiagonal, B::Tridiagonal) = Tridiagonal(A.ev-B.dl, A.dv-B.d, A.ev-B.du)

convert{T}(::Type{Tridiagonal{T}}, M::SymTridiagonal{T}) = Tridiagonal(M)
convert{T}(::Type{SymTridiagonal{T}}, M::Tridiagonal) = M.dl==M.du ? (SymTridiagonal(M.dl, M.d)) :
    error("Tridiagonal is not symmetric, cannot convert to SymTridiagonal")

## Solvers

#### Tridiagonal matrix routines ####
function \{T<:BlasFloat}(M::Tridiagonal{T}, rhs::StridedVecOrMat{T})
    if stride(rhs, 1) == 1
        return LAPACK.gtsv!(copy(M.dl), copy(M.d), copy(M.du), copy(rhs))
    end
    solve(M, rhs)  # use the Julia "fallback"
end

# This is definitely not going to work
#eig(M::Tridiagonal) = LAPACK.stev!('V', copy(M))

# Allocation-free variants
# Note that solve is non-aliasing, so you can use the same array for
# input and output
function solve!{T<:BlasFloat}(x::AbstractArray{T}, xrng::Ranges{Int}, M::Tridiagonal{T}, rhs::AbstractArray{T}, rhsrng::Ranges{Int})
    d = M.d
    N = length(d)
    if length(xrng) != N || length(rhsrng) != N
        throw(DimensionMismatch(""))
    end
    dl = M.dl
    du = M.du
    dutmp = M.dutmp
    rhstmp = M.rhstmp
    xstart = first(xrng)
    xstride = step(xrng)
    rhsstart = first(rhsrng)
    rhsstride = step(rhsrng)
    # Forward sweep
    denom = d[1]
    dulast = du[1] / denom
    dutmp[1] = dulast
    rhslast = rhs[rhsstart] / denom
    rhstmp[1] = rhslast
    irhs = rhsstart+rhsstride
    for i in 2:N-1
        dltmp = dl[i-1]
        denom = d[i] - dltmp*dulast
        dulast = du[i] / denom
        dutmp[i] = dulast
        rhslast = (rhs[irhs] - dltmp*rhslast)/denom
        rhstmp[i] = rhslast
        irhs += rhsstride
    end
    dltmp = dl[N-1]
    denom = d[N] - dltmp*dulast
    xlast = (rhs[irhs] - dltmp*rhslast)/denom
    # Backward sweep
    ix = xstart + (N-2)*xstride
    x[ix+xstride] = xlast
    for i in N-1:-1:1
        xlast = rhstmp[i] - dutmp[i]*xlast
        x[ix] = xlast
        ix -= xstride
    end
    nothing
end

function solve!(x::StridedVector, M::Tridiagonal, rhs::StridedVector)
    solve!(x, 1:length(x), M, rhs, 1:length(rhs))
    x
end
solve{TM<:BlasFloat,TB<:BlasFloat}(M::Tridiagonal{TM}, B::StridedVecOrMat{TB}) = solve!(zeros(typeof(one(TM)/one(TB)), size(B)), M, B)
solve(M::Tridiagonal, B::StridedVecOrMat) = solve(float(M), float(B))
function solve!(X::StridedMatrix, M::Tridiagonal, B::StridedMatrix)
    size(B, 1) == size(M, 1) || throw(DimensionMismatch(""))
    size(X) == size(B) || throw(DimensionMismatch(""))
    m, n = size(B)
    for j = 1:n
        r = Range1((j-1)*m+1,m)
        solve!(X, r, M, B, r)
    end
    X
end

# User-friendly solver
\(M::Tridiagonal, rhs::StridedVecOrMat) = solve(M, rhs)

# Tridiagonal multiplication
function mult(x::AbstractArray, xrng::Ranges{Int}, M::Tridiagonal, v::AbstractArray, vrng::Ranges{Int})
    dl = M.dl
    d = M.d
    du = M.du
    N = length(d)
    xi = first(xrng)
    xstride = step(xrng)
    vi = first(vrng)
    vstride = step(vrng)
    x[xi] = d[1]*v[vi] + du[1]*v[vi+vstride]
    xi += xstride
    for i = 2:N-1
        x[xi] = dl[i-1]*v[vi] + d[i]*v[vi+vstride] + du[i]*v[vi+2*vstride]
        xi += xstride
        vi += vstride
    end
    x[xi] = dl[N-1]*v[vi] + d[N]*v[vi+vstride]
    x
end

mult(x::StridedVector, M::Tridiagonal, v::StridedVector) = mult(x, 1:length(x), M, v, 1:length(v))

function mult(X::StridedMatrix, M::Tridiagonal, B::StridedMatrix)
    size(B, 1) == size(M, 1) || throw(DimensionMismatch(""))
    size(X) == size(B) || throw(DimensionMismatch(""))
    m, n = size(B)
    for j = 1:n
        r = Range1((j-1)*m+1,m)
        mult(X, r, M, B, r)
    end
    X
end

mult(X::StridedMatrix, M1::Tridiagonal, M2::Tridiagonal) = mult(X, M1, full(M2))

*(M::Tridiagonal, B::Union(StridedVector,StridedMatrix)) = mult(similar(B), M, B)
*(A::Tridiagonal, B::Tridiagonal) = A*full(B)

#### Factorizations for Tridiagonal ####
type LDLTTridiagonal{T<:BlasFloat,S<:BlasFloat} <: Factorization{T}
    D::Vector{S}
    E::Vector{T}
    function LDLTTridiagonal(D::Vector{S}, E::Vector{T})
        typeof(real(E[1])) == eltype(D) ? new(D, E) : error("element types do not match")
        new(D, E)
    end
end

LDLTTridiagonal{S<:BlasFloat,T<:BlasFloat}(D::Vector{S}, E::Vector{T}) = LDLTTridiagonal{T,S}(D, E)

ldltd!{T<:BlasFloat}(A::SymTridiagonal{T}) = LDLTTridiagonal(LAPACK.pttrf!(real(A.dv),A.ev)...)
ldltd!{T<:Integer}(A::SymTridiagonal{T}) = ldltd!(SymTridiagonal(float(A.dv),float(A.ev)))
ldltd(A::SymTridiagonal) = ldltd!(copy(A))
factorize!(A::SymTridiagonal) = ldltd(A)

A_ldiv_B!{T<:BlasReal}(C::LDLTTridiagonal{T}, B::StridedVecOrMat{T}) = LAPACK.pttrs!(C.D, C.E, B)
A_ldiv_B!{T<:BlasComplex}(C::LDLTTridiagonal{T}, B::StridedVecOrMat{T}) = LAPACK.pttrs!('L', C.D, C.E, B)
A_ldiv_B!(C::LDLTTridiagonal, B::StridedVecOrMat) = A_ldiv_B!(C, float(B))

type LUTridiagonal{T} <: Factorization{T}
    dl::Vector{T}
    d::Vector{T}
    du::Vector{T}
    du2::Vector{T}
    ipiv::Vector{BlasInt}
    # function LUTridiagonal(dl::Vector{T}, d::Vector{T}, du::Vector{T},
    #                        du2::Vector{T}, ipiv::Vector{BlasInt})
    #     n = length(d)
    #     if length(dl) != n - 1 || length(du) != n - 1 || length(ipiv) != n || length(du2) != n-2
    #         throw(DimensionMismatch("LUTridiagonal")
    #     end
    #     new(dl, d, du, du2, ipiv)
    # end
end
lufact!{T<:BlasFloat}(A::Tridiagonal{T}) = LUTridiagonal{T}(LAPACK.gttrf!(A.dl,A.d,A.du)...)
lufact!{T<:Union(Rational,Integer)}(A::Tridiagonal{T}) = lufact!(float(A))
lufact(A::Tridiagonal) = lufact!(copy(A))
factorize!(A::Tridiagonal) = lufact!(A)
#show(io, lu::LUTridiagonal) = print(io, "LU decomposition of ", summary(lu.lu))

function det{T}(lu::LUTridiagonal{T})
    n = length(lu.d)
    prod(lu.d) * (bool(sum(lu.ipiv .!= 1:n) % 2) ? -one(T) : one(T))
end

det{T<:BlasFloat}(A::Tridiagonal{T}) = det(lufact(A))

A_ldiv_B!{T<:BlasFloat}(lu::LUTridiagonal{T}, B::StridedVecOrMat{T}) =
    LAPACK.gttrs!('N', lu.dl, lu.d, lu.du, lu.du2, lu.ipiv, B)
A_ldiv_B!(lu::LUTridiagonal, B::StridedVecOrMat) = A_ldiv_B!(lu, float(B))
back to top