Revision 5c7875626d22f737161a44a4fba6c0a00a62c698 authored by Stefan Karpinski on 03 July 2013, 03:39:57 UTC, committed by Stefan Karpinski on 03 July 2013, 03:39:57 UTC
2 parent s 3126fe2 + ec89a87
Raw File
blas.jl
typealias BlasFloat Union(Float64,Float32,Complex128,Complex64)
typealias BlasChar Char

if USE_LIB64
    typealias BlasInt Int64
    blas_int(x) = int64(x)
else
    typealias BlasInt Int32
    blas_int(x) = int32(x)
end

module BLAS

export copy!,
       scal!,
       scal,
       dot,
       nrm2,
       axpy!,
       syrk!,
       syrk,
       herk!,
       herk,
       gbmv!,
       gbmv,
       sbmv!,
       sbmv,
       gemm!,
       gemm,
       symm!,
       symm,
       symv!,
       symv

const libblas = Base.libblas_name

import Base.BlasFloat
import Base.BlasChar
import Base.BlasInt
import Base.blas_int

# SUBROUTINE DCOPY(N,DX,INCX,DY,INCY)
for (fname, elty) in ((:dcopy_,:Float64), (:scopy_,:Float32),
                      (:zcopy_,:Complex128), (:ccopy_,:Complex64))
    @eval begin
        function copy!(n::Integer, DX::Union(Ptr{$elty},Array{$elty}), incx::Integer, DY::Union(Ptr{$elty},Array{$elty}), incy::Integer)
            ccall(($(string(fname)),libblas), Void,
                  (Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}),
                  &n, DX, &incx, DY, &incy)
            DY
        end
    end
end

# SUBROUTINE DSCAL(N,DA,DX,INCX)
for (fname, elty) in ((:dscal_,:Float64),    (:sscal_,:Float32),
                      (:zscal_,:Complex128), (:cscal_,:Complex64))
    @eval begin
        function scal!(n::Integer, DA::$elty, DX::Union(Ptr{$elty},Array{$elty}), incx::Integer)
            ccall(($(string(fname)),libblas), Void,
                  (Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
                  &n, &DA, DX, &incx)
            DX
        end
        function scal(n::Integer, DA::$elty, DX_orig::Union(Ptr{$elty},Array{$elty}), incx::Integer)
            DX = copy(DX_orig)
            ccall(($(string(fname)),libblas), Void,
                  (Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
                  &n, &DA, DX, &incx)
            DX
        end
    end
end

# dot
for (fname, elty) in ((:ddot_,:Float64), (:sdot_,:Float32))
    @eval begin
#       DOUBLE PRECISION FUNCTION DDOT(N,DX,INCX,DY,INCY)
# *     .. Scalar Arguments ..
#       INTEGER INCX,INCY,N
# *     ..
# *     .. Array Arguments ..
#       DOUBLE PRECISION DX(*),DY(*)
        function dot(n::Integer, DX::Union(Ptr{$elty},Array{$elty}), incx::Integer, DY::Union(Ptr{$elty},Array{$elty}), incy::Integer)
            ccall(($(string(fname)),libblas), $elty,
                  (Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}),
                  &n, DX, &incx, DY, &incy)
        end
    end
end
for (fname, elty, relty) in ((:zdotc_,:Complex128,:Float64), (:cdotc_,:Complex64,:Float32))
    @eval begin
#       DOUBLE COMPLEX FUNCTION ZDOTC(N,ZX,INCX,ZY,INCY)
# *     .. Scalar Arguments ..
#       INTEGER INCX,INCY,N
# *     ..
# *     .. Array Arguments ..
#       DOUBLE COMPLEX ZX(*),ZY(*)
        function dot(n::Integer, DX::Union(Ptr{$elty},Array{$elty}), incx::Integer, DY::Union(Ptr{$elty},Array{$elty}), incy::Integer)
            convert($elty, ccall(($(string(fname)),libblas), ComplexPair{$relty},
                  (Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}),
                  &n, DX, &incx, DY, &incy))
        end
    end
end
function dot{T<:BlasFloat}(DX::Array{T}, DY::Array{T})
    n = length(DX)
    if n != length(DY) throw(BlasDimMisMatch) end
    return dot(n, DX, 1, DY, 1)
end

# DOUBLE PRECISION FUNCTION DNRM2(N,X,INCX)
for (fname, elty, ret_type) in ((:dnrm2_,:Float64,:Float64),
                                (:snrm2_,:Float32,:Float32),
                                (:dznrm2_,:Complex128,:Float64),
                                (:scnrm2_,:Complex64,:Float32))
    @eval begin
        function nrm2(n::Integer, X::Union(Ptr{$elty},Array{$elty}), incx::Integer)
            ccall(($(string(fname)),libblas), $ret_type,
                  (Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}),
                  &n, X, &incx)
        end
    end
end

nrm2(A::Array) = nrm2(length(A), A, 1)

# SUBROUTINE DAXPY(N,DA,DX,INCX,DY,INCY)
# DY <- DA*DX + DY
#*     .. Scalar Arguments ..
#      DOUBLE PRECISION DA
#      INTEGER INCX,INCY,N
#*     .. Array Arguments ..
#      DOUBLE PRECISION DX(*),DY(*)
for (fname, elty) in ((:daxpy_,:Float64), (:saxpy_,:Float32),
                      (:zaxpy_,:Complex128), (:caxpy_,:Complex64))
    @eval begin
        function axpy!(n::Integer, alpha::($elty),
                       dx::Union(Ptr{$elty},Array{$elty}), incx::Integer,
                       dy::Union(Ptr{$elty},Array{$elty}), incy::Integer)
            ccall(($(string(fname)),libblas), Void,
                  (Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}),
                  &n, &alpha, dx, &incx, dy, &incy)
            return dy
        end
    end
end

function axpy!{T,Ta<:Number}(alpha::Ta, x::Array{T}, y::Array{T})
    if length(x) != length(y)
        error("Inputs should be of the same length")
    end
    return axpy!(length(x), convert(T, alpha), x, 1, y, 1)
end

function axpy!{T,Ta<:Number,Ti<:Integer}(alpha::Ta, x::Array{T}, rx::Union(Range1{Ti},Range{Ti}), y::Array{T}, ry::Union(Range1{Ti},Range{Ti}))

    if length(rx) != length(ry)
        error("Ranges should be of the same length")
    end

    if min(rx) < 1 || max(rx) > length(x) || min(ry) < 1 || max(ry) > length(y)
        throw(BoundsError())
    end

    return axpy!(length(rx), convert(T, alpha), pointer(x)+(first(rx)-1)*sizeof(T), step(rx), pointer(y)+(first(ry)-1)*sizeof(T), step(ry))
end

for (fname, elty) in ((:dsyrk_,:Float64), (:ssyrk_,:Float32),
                      (:zsyrk_,:Complex128), (:csyrk_,:Complex64))
   @eval begin
       # SUBROUTINE DSYRK(UPLO,TRANS,N,K,ALPHA,A,LDA,BETA,C,LDC)
       # *     .. Scalar Arguments ..
       #       REAL ALPHA,BETA
       #       INTEGER K,LDA,LDC,N
       #       CHARACTER TRANS,UPLO
       # *     .. Array Arguments ..
       #       REAL A(LDA,*),C(LDC,*)
       function syrk!(uplo::BlasChar, trans::BlasChar, alpha::($elty), A::StridedVecOrMat{$elty},
                      beta::($elty), C::StridedMatrix{$elty})
           m, n = size(C)
           if m != n error("syrk!: matrix C must be square") end
           nn = size(A, trans == 'N' ? 1 : 2)
           if nn != n error("syrk!: dimension mismatch") end
           k  = size(A, trans == 'N' ? 2 : 1)
           ccall(($(string(fname)),libblas), Void,
                 (Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty},
                  Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
                 &uplo, &trans, &n, &k, &alpha, A, &stride(A,2), &beta, C, &stride(C,2))
           C
       end
       function syrk(uplo::BlasChar, trans::BlasChar, alpha::($elty), A::StridedVecOrMat{$elty})
           n = size(A, trans == 'N' ? 1 : 2)
           syrk!(uplo, trans, alpha, A, zero($elty), Array($elty, (n, n)))
       end
       syrk(uplo::BlasChar, trans::BlasChar, A::StridedVecOrMat{$elty}) = syrk(uplo, trans, one($elty), A)
   end
end

# SUBROUTINE CHERK(UPLO,TRANS,N,K,ALPHA,A,LDA,BETA,C,LDC)
# *     .. Scalar Arguments ..
#       REAL ALPHA,BETA
#       INTEGER K,LDA,LDC,N
#       CHARACTER TRANS,UPLO
# *     ..
# *     .. Array Arguments ..
#       COMPLEX A(LDA,*),C(LDC,*)
for (fname, elty) in ((:zherk_,:Complex128), (:cherk_,:Complex64))
   @eval begin
       function herk!(uplo::BlasChar, trans::BlasChar, alpha::($elty), A::StridedVecOrMat{$elty},
                      beta::($elty), C::StridedMatrix{$elty})
           m, n = size(C)
           if m != n error("syrk!: matrix C must be square") end
           nn = size(A, trans == 'N' ? 1 : 2)
           if nn != n error("syrk!: dimension mismatch") end
           k  = size(A, trans == 'N' ? 2 : 1)
           ccall(($(string(fname)),libblas), Void,
                 (Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty},
                  Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
                 &uplo, &trans, &n, &k, &alpha, A, &stride(A,2), &beta, C, &stride(C,2))
           C
       end
       function herk(uplo::BlasChar, trans::BlasChar, alpha::($elty), A::StridedVecOrMat{$elty})
           n = size(A, trans == 'N' ? 1 : 2)
           herk!(uplo, trans, alpha, A, zero($elty), Array($elty, (n,n)))
       end
       herk(uplo::BlasChar, trans::BlasChar, A::StridedVecOrMat{$elty}) = herk(uplo, trans, one($elty), A)
   end
end

# (GB) general banded matrix-vector multiplication
# SUBROUTINE DGBMV(TRANS,M,N,KL,KU,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
# *     .. Scalar Arguments ..
#       DOUBLE PRECISION ALPHA,BETA
#       INTEGER INCX,INCY,KL,KU,LDA,M,N
#       CHARACTER TRANS
# *     ..
# *     .. Array Arguments ..
#       DOUBLE PRECISION A(LDA,*),X(*),Y(*)
for (fname, elty) in ((:dgbmv_,:Float64), (:sgbmv_,:Float32),
                      (:zgbmv_,:Complex128), (:cgbmv_,:Complex64))
   @eval begin
       function gbmv!(trans::BlasChar, m::Integer, kl::Integer, ku::Integer,
                      alpha::($elty), A::StridedMatrix{$elty}, x::StridedVector{$elty},
                      beta::($elty), y::StridedVector{$elty})
           ccall(($(string(fname)),libblas), Void,
                 (Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
                  Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
                  Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
                 &trans, &m, &size(A,2), &kl, &ku, &alpha, A, &stride(A,2),
                 x, &stride(x,1), &beta, y, &stride(y,1))
           y
       end
       function gbmv(trans::BlasChar, m::Integer, kl::Integer, ku::Integer,
                     alpha::($elty), A::StridedMatrix{$elty}, x::StridedVector{$elty})
           n = stride(A,2)
           gbmv!(trans, m, kl, ku, alpha, A, x, zero($elty), Array($elty, n))
       end
       function gbmv(trans::BlasChar, m::Integer, kl::Integer, ku::Integer,
                     A::StridedMatrix{$elty}, x::StridedVector{$elty})
           gbmv(trans, m, kl, ku, one($elty), A, x)
       end
   end
end

# (SB) symmetric banded matrix-vector multiplication
#       SUBROUTINE DSBMV(UPLO,N,K,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
# *     .. Scalar Arguments ..
#       DOUBLE PRECISION ALPHA,BETA
#       INTEGER INCX,INCY,K,LDA,N
#       CHARACTER UPLO
# *     ..
# *     .. Array Arguments ..
#       DOUBLE PRECISION A(LDA,*),X(*),Y(*)
for (fname, elty) in ((:dsbmv_,:Float64), (:ssbmv_,:Float32),
                      (:zsbmv_,:Complex128), (:csbmv_,:Complex64))
   @eval begin
       function sbmv!(uplo::BlasChar, k::Integer,
                      alpha::($elty), A::StridedMatrix{$elty}, x::StridedVector{$elty}, 
                      beta::($elty), y::StridedVector{$elty})
           ccall(($(string(fname)),libblas), Void,
                 (Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt},
                 Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
                 &uplo, &size(A,2), &k, &alpha, A, &stride(A,2), x, &stride(x,1), &beta, y, &stride(y,1))
           y
       end
       function sbmv(uplo::BlasChar, k::Integer, alpha::($elty), A::StridedMatrix{$elty},
                     x::StridedVector{$elty})
           n = size(A,2)
           sbmv!(uplo, k, alpha, A, x, zero($elty), Array($elty, n))
       end
       function sbmv(uplo::BlasChar, k::Integer, A::StridedMatrix{$elty}, x::StridedVector{$elty})
           sbmv(uplo, k, one($elty), A, x)
       end
   end
end

# (GE) general matrix-matrix and matrix-vector multiplication
for (gemm, gemv, elty) in
    ((:dgemm_,:dgemv_,:Float64),
     (:sgemm_,:sgemv_,:Float32),
     (:zgemm_,:zgemv_,:Complex128),
     (:cgemm_,:cgemv_,:Complex64))
   @eval begin
       # SUBROUTINE DGEMM(TRANSA,TRANSB,M,N,K,ALPHA,A,LDA,B,LDB,BETA,C,LDC)
       # *     .. Scalar Arguments ..
       #       DOUBLE PRECISION ALPHA,BETA
       #       INTEGER K,LDA,LDB,LDC,M,N
       #       CHARACTER TRANSA,TRANSB
       # *     .. Array Arguments ..
       #       DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*)
       function gemm!(transA::BlasChar, transB::BlasChar,
                      alpha::($elty), A::StridedMatrix{$elty},
                      B::StridedMatrix{$elty},
                      beta::($elty), C::StridedMatrix{$elty})
#           if any([stride(A,1), stride(B,1), stride(C,1)] .!= 1)
#               error("gemm!: BLAS module requires contiguous matrix columns")
#           end  # should this be checked on every call?
           m = size(A, transA == 'N' ? 1 : 2)
           k = size(A, transA == 'N' ? 2 : 1)
           n = size(B, transB == 'N' ? 2 : 1)
           if m != size(C,1) || n != size(C,2)
               error("gemm!: mismatched dimensions")
           end
           ccall(($(string(gemm)),libblas), Void,
                 (Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
                  Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
                  Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
                 &transA, &transB, &m, &n, &k, &alpha, A, &stride(A,2),
                 B, &stride(B,2), &beta, C, &stride(C,2))
           C
       end
       function gemm(transA::BlasChar, transB::BlasChar,
                     alpha::($elty), A::StridedMatrix{$elty},
                     B::StridedMatrix{$elty})
           gemm!(transA, transB, alpha, A, B, zero($elty),
                 Array($elty, (size(A, transA == 'N' ? 1 : 2),
                               size(B, transB == 'N' ? 2 : 1))))
       end
       function gemm(transA::BlasChar, transB::BlasChar,
                     A::StridedMatrix{$elty}, B::StridedMatrix{$elty})
           gemm(transA, transB, one($elty), A, B)
       end
       #SUBROUTINE DGEMV(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
       #*     .. Scalar Arguments ..
       #      DOUBLE PRECISION ALPHA,BETA
       #      INTEGER INCX,INCY,LDA,M,N
       #      CHARACTER TRANS
       #*     .. Array Arguments ..
       #      DOUBLE PRECISION A(LDA,*),X(*),Y(*)
       function gemv!(trans::BlasChar,
                      alpha::($elty), A::StridedMatrix{$elty},
                      X::StridedVector{$elty},
                      beta::($elty), Y::StridedVector{$elty})
           ccall(($(string(gemv)),libblas), Void,
                 (Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt},
                  Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt},
                  Ptr{$elty}, Ptr{BlasInt},
                  Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
                 &trans, &size(A,1), &size(A,2), &alpha, A, &stride(A,2),
                 X, &stride(X,1), &beta, Y, &stride(Y,1))
           Y
       end
       function gemv(trans::BlasChar,
                     alpha::($elty), A::StridedMatrix{$elty},
                     X::StridedVector{$elty})
           gemv!(trans, alpha, A, X, zero($elty),
                 Array($elty, size(A, (trans == 'N' ? 1 : 2))))
       end
       function gemv(trans::BlasChar, A::StridedMatrix{$elty}, X::StridedVector{$elty})
           gemv!(trans, one($elty), A, X, zero($elty),
                 Array($elty, size(A, (trans == 'N' ? 1 : 2))))
       end
   end
end

# (SY) symmetric matrix-matrix and matrix-vector multiplication
for (mfname, vfname, elty) in
    ((:dsymm_,:dsymv_,:Float64),
     (:ssymm_,:ssymv_,:Float32),
     (:zsymm_,:zsymv_,:Complex128),
     (:csymm_,:csymv_,:Complex64))
   @eval begin
       #     SUBROUTINE DSYMM(SIDE,UPLO,M,N,ALPHA,A,LDA,B,LDB,BETA,C,LDC)
       #     .. Scalar Arguments ..
       #     DOUBLE PRECISION ALPHA,BETA
       #     INTEGER LDA,LDB,LDC,M,N
       #     CHARACTER SIDE,UPLO
       #     .. Array Arguments ..
       #     DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*)
       function symm!(side::BlasChar, uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty},
                      B::StridedMatrix{$elty}, beta::($elty), C::StridedMatrix{$elty})
           m, n = size(C)
           k, j = size(A)
           if k != j error("symm!: matrix A is $k by $j but must be square") end
           if j != (side == 'L' ? m : n) || size(B,2) != n error("symm!: Dimension mismatch") end
           ccall(($(string(mfname)),libblas), Void,
                 (Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt},
                 Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
                 &side, &uplo, &m, &n, &alpha, A, &stride(A,2), B, &stride(B,2),
                 &beta, C, &stride(C,2))
           C
       end
       function symm(side::BlasChar, uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty},
                     B::StridedMatrix{$elty})
           symm!(side, uplo, alpha, A, B, zero($elty), similar(B))
       end
       function symm(side::BlasChar, uplo::BlasChar, A::StridedMatrix{$elty}, B::StridedMatrix{$elty})
           symm(side, uplo, one($elty), A, B)
       end
       #      SUBROUTINE DSYMV(UPLO,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
       #     .. Scalar Arguments ..
       #      DOUBLE PRECISION ALPHA,BETA
       #      INTEGER INCX,INCY,LDA,N
       #      CHARACTER UPLO
       #     .. Array Arguments ..
       #      DOUBLE PRECISION A(LDA,*),X(*),Y(*)
       function symv!(uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, x::StridedVector{$elty},
                      beta::($elty), y::StridedVector{$elty})
           m, n = size(A)
           if m != n error("symm!: matrix A is $m by $n but must be square") end
           if m != length(x) || m != length(y) error("symm!: dimension mismatch") end
           ccall(($(string(vfname)),libblas), Void,
                 (Ptr{Uint8}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt},
                 Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
                 &uplo, &n, &alpha, A, &stride(A,2), x, &stride(x,1), &beta, y, &stride(y,1))
           Y
       end
       function symv(uplo::BlasChar, alpha::($elty), A::StridedMatrix{$elty}, x::StridedVector{$elty})
           symv!(uplo, alpha, A, x, zero($elty), similar(x))
       end
       function symv(uplo::BlasChar, A::StridedMatrix{$elty}, x::StridedVector{$elty})
           symv(uplo, one($elty), A, x)
       end
   end
end

end # module

# Use BLAS copy for small arrays where it is faster than memcpy, and for strided copying

function copy!{T<:BlasFloat}(dest::Ptr{T}, src::Ptr{T}, n::Integer)
    if n < 200
        BLAS.copy!(n, src, 1, dest, 1)
    else
        ccall(:memcpy, Ptr{Void}, (Ptr{Void}, Ptr{Void}, Uint), dest, src, n*sizeof(T))
    end
    return dest
end

function copy!{T<:BlasFloat}(dest::Array{T}, src::Array{T})
    n = length(src)
    if n > length(dest); throw(BoundsError()); end
    copy!(pointer(dest), pointer(src), n)
    return dest
end

function copy!{T<:BlasFloat,Ti<:Integer}(dest::Array{T}, rdest::Union(Range1{Ti},Range{Ti}), 
                                         src::Array{T}, rsrc::Union(Range1{Ti},Range{Ti}))
    if min(rdest) < 1 || max(rdest) > length(dest) || min(rsrc) < 1 || max(rsrc) > length(src)
        throw(BoundsError())
    end
    if length(rdest) != length(rsrc)
        error("Ranges must be of the same length")
    end
    BLAS.copy!(length(rsrc), pointer(src)+(first(rsrc)-1)*sizeof(T), step(rsrc),
              pointer(dest)+(first(rdest)-1)*sizeof(T), step(rdest))
    return dest
end
back to top