diff --git a/README.md b/README.md index fa365c5..442ca11 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,9 @@ [![codecov](https://codecov.io/gh/JuliaSparse/MKLSparse.jl/graph/badge.svg?token=j3KoKBEIt1)](https://codecov.io/gh/JuliaSparse/MKLSparse.jl) -`MKLSparse.jl` is a Julia package to seamlessly use the sparse functionality in MKL to speed up operations on sparse arrays in Julia. -In order to use `MKLSparse.jl` you do not need to install Intel's MKL library nor build Julia with MKL. `MKLSparse.jl` will automatically download and use the MKL library for you when installed. +*MKLSparse.jl* is a Julia package to seamlessly use the [sparse BLAS routines from Intel's Math Kernel Library (MKL)](https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2024-2/blas-and-sparse-blas-routines.html) +to speed up operations on sparse arrays in Julia. +In order to use *MKLSparse.jl* you do not need to install Intel's MKL library nor build Julia with MKL. *MKLSparse.jl* will automatically download and use the MKL library for you when installed. ### Matrix multiplications diff --git a/src/generic.jl b/src/generic.jl index 0b55408..47134fc 100644 --- a/src/generic.jl +++ b/src/generic.jl @@ -56,6 +56,203 @@ function mm!(transA::Char, alpha::T, A::AbstractSparseMatrix{T}, descr::matrix_d return C end +# C := op(A) * B, where C is sparse +function spmm(transA::Char, A::AbstractSparseMatrix{T}, B::AbstractSparseMatrix{T}) where T + check_trans(transA) + check_mat_op_sizes(nothing, A, transA, B, 'N') + Cout = Ref{sparse_matrix_t}() + hA = MKLSparseMatrix(A) + hB = MKLSparseMatrix(B) + res = mkl_call(Val{:mkl_sparse_spmmI}(), typeof(A), + transA, hA, hB, Cout) + destroy(hA) + destroy(hB) + check_status(res) + return MKLSparseMatrix(Cout[]) +end + +# C := op(A) * B, where C is dense +function spmmd!(transa::Char, A::AbstractSparseMatrix{T}, B::AbstractSparseMatrix{T}, + C::StridedMatrix{T}; + dense_layout::sparse_layout_t = SPARSE_LAYOUT_COLUMN_MAJOR +) where T + check_trans(transa) + check_mat_op_sizes(C, A, transa, B, 'N') + ldC = stride(C, 2) + hA = MKLSparseMatrix(A) + hB = MKLSparseMatrix(B) + res = mkl_call(Val{:mkl_sparse_T_spmmdI}(), typeof(A), + transa, hA, hB, dense_layout, C, ldC) + destroy(hA) + destroy(hB) + check_status(res) + return C +end + +# C := opA(A) * opB(B), where C is sparse +function sp2m(transA::Char, A::AbstractSparseMatrix{T}, descrA::matrix_descr, + transB::Char, B::AbstractSparseMatrix{T}, descrB::matrix_descr) where T + check_trans(transA) + check_trans(transB) + check_mat_op_sizes(nothing, A, transA, B, transB) + Cout = Ref{sparse_matrix_t}() + hA = MKLSparseMatrix(A) + hB = MKLSparseMatrix(B) + res = mkl_call(Val{:mkl_sparse_sp2mI}(), typeof(A), + transA, descrA, hA, transB, descrB, hB, + SPARSE_STAGE_FULL_MULT, Cout) + destroy(hA) + destroy(hB) + check_status(res) + # NOTE: we are guessing what is the storage format of C + return MKLSparseMatrix{typeof(A)}(Cout[]) +end + +# C := opA(A) * opB(B), where C is sparse, in-place version +# C should have the correct size and sparsity pattern +function sp2m!(transA::Char, A::AbstractSparseMatrix{T}, descrA::matrix_descr, + transB::Char, B::AbstractSparseMatrix{T}, descrB::matrix_descr, + C::SparseMatrixCSC{T}; + check_nzpattern::Bool = true +) where T + check_trans(transA) + check_trans(transB) + check_mat_op_sizes(C, A, transA, B, transB) + hA = MKLSparseMatrix(A) + hB = MKLSparseMatrix(B) + if check_nzpattern + # pre-multiply A * B to get the number of nonzeros per column in the result + CptnOut = Ref{sparse_matrix_t}() + res = mkl_call(Val{:mkl_sparse_sp2mI}(), typeof(A), + transA, descrA, hA, transB, descrB, hB, + SPARSE_STAGE_NNZ_COUNT, CptnOut) + check_status(res) + hCptn = MKLSparseMatrix{typeof(A)}(CptnOut[]) + try + # check if C has the same per-column nonzeros as the result + _C = extract_data(hCptn) + _Cnnz = _C.major_starts[end] - 1 + nnz(C) == _Cnnz || error(lazy"Number of nonzeros in the destination matrix ($(nnz(C))) does not match the result ($(_Cnnz))") + C.colptr == _C.major_starts || error(lazy"Nonzeros structure of the destination matrix does not match the result") + catch e + # destroy handles to A and B if the pattern check fails, + # otherwise reuse them at the actual multiplication + destroy(hA) + destroy(hB) + rethrow(e) + finally + destroy(hCptn) + end + # FIXME rowval not checked + end + # FIXME the optimal way would be to create the MKLSparse handle to C reusing its arrays + # and do SPARSE_STAGE_FINALIZE_MULT to directly write to the C.nzval + # but that causes segfaults when the handle is destroyed + # (also the partial mkl_sparse_copy(C) workaround to reuse the nz structure segfaults) + #hC = MKLSparseMatrix(C) + #hC_ref = Ref(hC) + #res = mkl_call(Val{:mkl_sparse_sp2mI}(), typeof(A), + # transA, descrA, hA, transB, descrB, hB, + # SPARSE_STAGE_FINALIZE_MULT, hC_ref) + #@assert hC_ref[] == hC + # so instead we do the full multiplication and copy the result into C nzvals + hCopy_ref = Ref{sparse_matrix_t}() + res = mkl_call(Val{:mkl_sparse_sp2mI}(), typeof(A), + transA, descrA, hA, transB, descrB, hB, + SPARSE_STAGE_FULL_MULT, hCopy_ref) + destroy(hA) + destroy(hB) + check_status(res) + if hCopy_ref[] != C_NULL + hCopy = MKLSparseMatrix{typeof(A)}(hCopy_ref[]) + copy!(C, hCopy; check_nzpattern) + destroy(hCopy) + end + return C +end + +# C := alpha * opA(A) * opB(B) + beta * C, where C is dense +function sp2md!(transA::Char, alpha::T, A::AbstractSparseMatrix{T}, descrA::matrix_descr, + transB::Char, B::AbstractSparseMatrix{T}, descrB::matrix_descr, + beta::T, C::StridedMatrix{T}; + dense_layout::sparse_layout_t = SPARSE_LAYOUT_COLUMN_MAJOR +) where T + check_trans(transA) + check_trans(transB) + check_mat_op_sizes(C, A, transA, B, transB) + ldC = stride(C, 2) + hA = MKLSparseMatrix(A) + hB = MKLSparseMatrix(B) + res = mkl_call(Val{:mkl_sparse_T_sp2mdI}(), typeof(A), + transA, descrA, hA, transB, descrB, hB, + alpha, beta, + C, dense_layout, ldC) + if res != SPARSE_STATUS_SUCCESS + @show transA descrA transB descrB + end + destroy(hA) + destroy(hB) + check_status(res) + return C +end + +# C := A * op(A), or +# C := op(A) * A, where C is sparse +# note: only the upper triangular part of C is computed +function syrk(transA::Char, A::AbstractSparseMatrix{T}) where T + check_trans(transA) + Cout = Ref{sparse_matrix_t}() + hA = MKLSparseMatrix(A) + res = mkl_call(Val{:mkl_sparse_syrkI}(), typeof(A), + transA, hA, Cout) + destroy(hA) + check_status(res) + return MKLSparseMatrix(Cout[]) +end + +# C := A * op(A), or +# C := op(A) * A, where C is dense +# note: only the upper triangular part of C is computed +function syrkd!(transA::Char, alpha::T, A::AbstractSparseMatrix{T}, beta::T, + C::StridedMatrix{T}; + dense_layout::sparse_layout_t = SPARSE_LAYOUT_COLUMN_MAJOR +) where T + check_trans(transA) + check_mat_op_sizes(C, A, transA, A, transA == 'N' ? 'T' : 'N'; dense_layout) + ldC = stride(C, 2) + hA = MKLSparseMatrix(A) + res = mkl_call(Val{:mkl_sparse_T_syrkdI}(), typeof(A), + transA, hA, alpha, beta, C, dense_layout, ldC) + destroy(hA) + check_status(res) + return C +end + +# C := alpha * op(A) * B * A + beta * C, or +# C := alpha * A * B * op(A) + beta * C, where C is dense +# note: only the upper triangular part of C is computed +function syprd!(transA::Char, alpha::T, A::AbstractSparseMatrix{T}, + B::StridedMatrix{T}, beta::T, C::StridedMatrix{T}; + dense_layout_B::sparse_layout_t = SPARSE_LAYOUT_COLUMN_MAJOR, + dense_layout_C::sparse_layout_t = SPARSE_LAYOUT_COLUMN_MAJOR +) where T + check_trans(transA) + # FIXME dense_layout_B not used + check_mat_op_sizes(C, A, transA, B, 'N'; + check_result_columns = false, dense_layout = dense_layout_C) + check_mat_op_sizes(C, B, 'N', A, transA == 'N' ? 'T' : 'N'; + check_result_rows = false, dense_layout = dense_layout_C) + ldB = stride(B, 2) + ldC = stride(C, 2) + hA = MKLSparseMatrix(A) + res = mkl_call(Val{:mkl_sparse_T_syprdI}(), typeof(A), + transA, hA, B, dense_layout_B, ldB, + alpha, beta, C, dense_layout_C, ldC) + destroy(hA) + check_status(res) + return C +end + # find y: op(A) * y = alpha * x function trsv!(transA::Char, alpha::T, A::AbstractSparseMatrix{T}, descr::matrix_descr, x::StridedVector{T}, y::StridedVector{T} diff --git a/src/interface.jl b/src/interface.jl index b11f16e..70174bf 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -1,16 +1,20 @@ import Base: \, * import LinearAlgebra: mul!, ldiv! -MKLSparseMat{T} = Union{SparseArrays.AbstractSparseMatrixCSC{T}, SparseMatrixCSR{T}, SparseMatrixCOO{T}} +SparseMat{T} = Union{SparseArrays.AbstractSparseMatrixCSC{T}, SparseMatrixCSR{T}, SparseMatrixCOO{T}} -SimpleOrAdjMat{T, M} = Union{M, Adjoint{T, <:M}, Transpose{T, <:M}} +AdjOrTranspMat{T, M} = Union{Adjoint{T, <:M}, Transpose{T,<:M}} -SpecialMat{T, M} = Union{LowerTriangular{T,<:M}, UpperTriangular{T,<:M}, - UnitLowerTriangular{T,<:M}, UnitUpperTriangular{T,<:M}, +SimpleOrAdjMat{T, M} = Union{M, Adjoint{T, <:M}, Transpose{T,<:M}} + +SpecialMat{T, M} = Union{LinearAlgebra.AbstractTriangular{T,<:M}, Symmetric{T,<:M}, Hermitian{T,<:M}} -SimpleOrSpecialMat{T, M} = Union{M, SpecialMat{T, <:M}} -SimpleOrSpecialOrAdjMat{T, M} = Union{SimpleOrAdjMat{T, <:SimpleOrSpecialMat{T, <:M}}, - SimpleOrSpecialMat{T, <:SimpleOrAdjMat{T, <:M}}} +SimpleOrSpecialMat{T, M} = Union{M, SpecialMat{T,<:M}} +SimpleOrSpecialOrAdjMat{T, M} = Union{M, + SpecialMat{T,<:M}, + AdjOrTranspMat{T,<:M}, + AdjOrTranspMat{T, <:SpecialMat{T,<:M}}, + SpecialMat{T,<:AdjOrTranspMat{T,<:M}}} # unwraps matrix A from Adjoint/Transpose transform unwrap_trans(A::AbstractMatrix) = A @@ -34,10 +38,10 @@ describe_and_unwrap(A::Symmetric{<:Any, T}) where T <: Union{Adjoint, Transpose} describe_and_unwrap(A::Hermitian{<:Any, T}) where T <: Union{Adjoint, Transpose} = (T <: Adjoint || (eltype(A) <: Real) ? 'N' : 'T', matrix_descr('H', A.uplo, 'N'), unwrap_trans(A)) -# 5-arg mul!() +# mul!(vec, sparse, vec, a, b) function mul!(y::StridedVector{T}, A::SimpleOrSpecialOrAdjMat{T, S}, x::StridedVector{T}, alpha::Number, beta::Number -) where {T <: BlasFloat, S <: MKLSparseMat{T}} +) where {T <: BlasFloat, S <: SparseMat{T}} transA, descrA, unwrapA = describe_and_unwrap(A) # fix the strange behaviour of multipling adjoint vectors by triangular matrices # looks like wrong the triangle is being used @@ -47,13 +51,36 @@ function mul!(y::StridedVector{T}, A::SimpleOrSpecialOrAdjMat{T, S}, mv!(transA, T(alpha), unwrapA, descrA, x, T(beta), y) end +# mul!(dense, sparse, dense, a, b) function mul!(C::StridedMatrix{T}, A::SimpleOrSpecialOrAdjMat{T, S}, B::StridedMatrix{T}, alpha::Number, beta::Number -) where {T <: BlasFloat, S <: MKLSparseMat{T}} +) where {T <: BlasFloat, S <: SparseMat{T}} transA, descrA, unwrapA = describe_and_unwrap(A) mm!(transA, T(alpha), unwrapA, descrA, B, T(beta), C) end +# mul!(dense, sparse, sparse, a, b) +function mul!(C::StridedMatrix{T}, A::SimpleOrSpecialOrAdjMat{T, S}, + B::SimpleOrSpecialOrAdjMat{T, S}, alpha::Number, beta::Number +) where {T <: BlasFloat, S <: SparseMat{T}} + transA, descrA, unwrapA = describe_and_unwrap(A) + transB, descrB, unwrapB = describe_and_unwrap(B) + # FIXME only general matrices are supported by sp2m in MKL SparseBLAS + # should the elements of the special matrices be fixed? + if descrA.type == SPARSE_MATRIX_TYPE_SYMMETRIC + @assert issymmetric(unwrapA) "A must be symmetric" + end + if descrB.type == SPARSE_MATRIX_TYPE_SYMMETRIC + @assert issymmetric(unwrapB) "B must be symmetric" + end + descrA = matrix_descr(descrA, type = SPARSE_MATRIX_TYPE_GENERAL, diag = SPARSE_DIAG_NON_UNIT, mode = SPARSE_FILL_MODE_FULL) + descrB = matrix_descr(descrB, type = SPARSE_MATRIX_TYPE_GENERAL, diag = SPARSE_DIAG_NON_UNIT, mode = SPARSE_FILL_MODE_FULL) + sp2md!(transA, T(alpha), unwrapA, descrA, + transB, unwrapB, descrB, + T(beta), C) +end + +# mul!(dense, dense, sparse, a, b) # ColMajorRes = ColMajorMtx*SparseMatrixCSC is implemented via # RowMajorRes = SparseMatrixCSR*RowMajorMtx Sparse MKL BLAS calls # Switching the B layout from CSC to CSR is required, because MKLSparse @@ -72,56 +99,99 @@ end # 3-arg mul!() calls 5-arg mul!() mul!(y::StridedVector{T}, A::SimpleOrSpecialOrAdjMat{T, S}, - x::StridedVector{T}) where {T <: BlasFloat, S <: MKLSparseMat{T}} = + x::StridedVector{T}) where {T <: BlasFloat, S <: SparseMat{T}} = mul!(y, A, x, one(T), zero(T)) mul!(C::StridedMatrix{T}, A::SimpleOrSpecialOrAdjMat{T, S}, - B::StridedMatrix{T}) where {T <: BlasFloat, S <: MKLSparseMat{T}} = + B::StridedMatrix{T}) where {T <: BlasFloat, S <: SparseMat{T}} = mul!(C, A, B, one(T), zero(T)) mul!(C::StridedMatrix{T}, A::StridedMatrix{T}, - B::SimpleOrSpecialOrAdjMat{T, S}) where {T <: BlasFloat, S <: MKLSparseMat{T}} = + B::SimpleOrSpecialOrAdjMat{T, S}) where {T <: BlasFloat, S <: SparseMat{T}} = + mul!(C, A, B, one(T), zero(T)) + +# mul!(dense, sparse, sparse) calls sp2md!() +mul!(C::StridedMatrix{T}, A::SimpleOrSpecialOrAdjMat{T, S}, + B::SimpleOrSpecialOrAdjMat{T, S} +) where {T <: BlasFloat, S <: SparseMat{T}} = mul!(C, A, B, one(T), zero(T)) +# mul!(sparse, sparse, sparse) +mul!(C::SparseMatrixCSC{T}, A::SimpleOrSpecialOrAdjMat{T, S}, + B::SimpleOrSpecialOrAdjMat{T, S} +) where {T <: BlasFloat, S <: SparseMat{T}} = + unsafe_mul!(C, A, B; check_nzpattern = true) + +# unsafe_mul!() allows disabling the check for the result's non-zero pattern +function unsafe_mul!(C::SparseMatrixCSC{T}, A::SimpleOrSpecialOrAdjMat{T, S}, + B::SimpleOrSpecialOrAdjMat{T, S}; + check_nzpattern::Bool = true +) where {T <: BlasFloat, S <: SparseMat{T}} + transA, descrA, unwrapA = describe_and_unwrap(A) + transB, descrB, unwrapB = describe_and_unwrap(B) + # FIXME only general matrices are supported by sp2m in MKL SparseBLAS + # should the elements of the special matrices be fixed? + descrA = matrix_descr(descrA, type = SPARSE_MATRIX_TYPE_GENERAL) + descrB = matrix_descr(descrB, type = SPARSE_MATRIX_TYPE_GENERAL) + sp2m!(transA, unwrapA, descrA, + transB, unwrapB, descrB, + parent(C); check_nzpattern) +end + # define 4-arg ldiv!(C, A, B, a) (C := alpha*inv(A)*B) that is not present in standard LinearAlgrebra # redefine 3-arg ldiv!(C, A, B) using 4-arg ldiv!(C, A, B, 1) function ldiv!(y::StridedVector{T}, A::SimpleOrSpecialOrAdjMat{T, S}, - x::StridedVector{T}, alpha::Number = one(T)) where {T <: BlasFloat, S <: MKLSparseMat{T}} + x::StridedVector{T}, alpha::Number = one(T)) where {T <: BlasFloat, S <: SparseMat{T}} transA, descrA, unwrapA = describe_and_unwrap(A) trsv!(transA, alpha, unwrapA, descrA, x, y) end function LinearAlgebra.ldiv!(C::StridedMatrix{T}, A::SimpleOrSpecialOrAdjMat{T, S}, - B::StridedMatrix{T}, alpha::Number = one(T)) where {T <: BlasFloat, S <: MKLSparseMat{T}} + B::StridedMatrix{T}, alpha::Number = one(T)) where {T <: BlasFloat, S <: SparseMat{T}} transA, descrA, unwrapA = describe_and_unwrap(A) trsm!(transA, alpha, unwrapA, descrA, B, C) end +# sparse := sparse * sparse +function (*)(A::SimpleOrSpecialOrAdjMat{T, S}, + B::SimpleOrSpecialOrAdjMat{T, S} +) where {T <: BlasFloat, S <: SparseMat{T}} + transA, descrA, unwrapA = describe_and_unwrap(A) + transB, descrB, unwrapB = describe_and_unwrap(B) + # FIXME only general matrices are supported by sp2m in MKL SparseBLAS + # should the elements of the special matrices be fixed? + descrA = matrix_descr(descrA, type = SPARSE_MATRIX_TYPE_GENERAL) + descrB = matrix_descr(descrB, type = SPARSE_MATRIX_TYPE_GENERAL) + res = sp2m(transA, unwrapA, descrA, + transB, unwrapB, descrB) + return convert(S, res) +end + if VERSION < v"1.10" # stdlib v1.9 does not provide these methods -(*)(A::SimpleOrSpecialOrAdjMat{T, S}, x::StridedVector{T}) where {T <: BlasFloat, S <: MKLSparseMat{T}} = +(*)(A::SimpleOrSpecialOrAdjMat{T, S}, x::StridedVector{T}) where {T <: BlasFloat, S <: SparseMat{T}} = mul!(Vector{T}(undef, size(A, 1)), A, x) -(*)(A::SimpleOrSpecialOrAdjMat{T, S}, B::StridedMatrix{T}) where {T <: BlasFloat, S <: MKLSparseMat{T}} = +(*)(A::SimpleOrSpecialOrAdjMat{T, S}, B::StridedMatrix{T}) where {T <: BlasFloat, S <: SparseMat{T}} = mul!(Matrix{T}(undef, size(A, 1), size(B, 2)), A, B) # xᵀ * B = (Bᵀ * x)ᵀ -(*)(x::Transpose{T, <:StridedVector{T}}, B::SimpleOrSpecialMat{T, S}) where {T <: BlasFloat, S <: MKLSparseMat{T}} = +(*)(x::Transpose{T, <:StridedVector{T}}, B::SimpleOrSpecialMat{T, S}) where {T <: BlasFloat, S <: SparseMat{T}} = transpose(mul!(similar(x, size(B, 2)), transpose(B), parent(x))) # xᴴ * B = (Bᴴ * x)ᴴ -(*)(x::Adjoint{T, <:StridedVector{T}}, B::SimpleOrSpecialMat{T, S}) where {T <: BlasFloat, S <: MKLSparseMat{T}} = +(*)(x::Adjoint{T, <:StridedVector{T}}, B::SimpleOrSpecialMat{T, S}) where {T <: BlasFloat, S <: SparseMat{T}} = adjoint(mul!(similar(x, size(B, 2)), adjoint(B), parent(x))) end # if VERSION < v"1.10" -(*)(A::StridedMatrix{T}, B::SimpleOrSpecialOrAdjMat{T, S}) where {T <: BlasFloat, S <: MKLSparseMat{T}} = +(*)(A::StridedMatrix{T}, B::SimpleOrSpecialOrAdjMat{T, S}) where {T <: BlasFloat, S <: SparseMat{T}} = mul!(Matrix{T}(undef, size(A, 1), size(B, 2)), A, B) # stdlib does not provide these methods for complex types # xᴴ * Bᵀ = (Bᵀᴴ * x)ᴴ function (*)(x::Adjoint{T, <:StridedVector{T}}, B::Transpose{T, <:SimpleOrSpecialMat{T, S}} -) where {T <: Union{ComplexF32, ComplexF64}, S <: MKLSparseMat{T}} +) where {T <: Union{ComplexF32, ComplexF64}, S <: SparseMat{T}} transB, descrB, unwrapB = describe_and_unwrap(parent(B)) y = similar(x, size(B, 2)) adjoint(mv!('C', one(T), lazypermutedims(unwrapB), lazypermutedims(descrB), parent(x), @@ -130,20 +200,20 @@ end # xᵀ * Bᴴ = (Bᵀᴴ * x)ᵀ function (*)(x::Transpose{T, <:StridedVector{T}}, B::Adjoint{T, <:SimpleOrSpecialMat{T, S}} -) where {T <: Union{ComplexF32, ComplexF64}, S <: MKLSparseMat{T}} +) where {T <: Union{ComplexF32, ComplexF64}, S <: SparseMat{T}} transB, descrB, unwrapB = describe_and_unwrap(parent(B)) y = similar(x, size(B, 2)) transpose(mv!('C', one(T), lazypermutedims(unwrapB), lazypermutedims(descrB), parent(x), zero(T), y)) end -function (\)(A::SimpleOrSpecialOrAdjMat{T, S}, x::StridedVector{T}) where {T <: BlasFloat, S <: MKLSparseMat{T}} +function (\)(A::SimpleOrSpecialOrAdjMat{T, S}, x::StridedVector{T}) where {T <: BlasFloat, S <: SparseMat{T}} n = length(x) y = Vector{T}(undef, n) return ldiv!(y, A, x) end -function (\)(A::SimpleOrSpecialOrAdjMat{T, S}, B::StridedMatrix{T}) where {T <: BlasFloat, S <: MKLSparseMat{T}} +function (\)(A::SimpleOrSpecialOrAdjMat{T, S}, B::StridedMatrix{T}) where {T <: BlasFloat, S <: SparseMat{T}} m, n = size(B) C = Matrix{T}(undef, m, n) return ldiv!(C, A, B) diff --git a/src/mklsparsematrix.jl b/src/mklsparsematrix.jl index a03f5b7..3cd0659 100644 --- a/src/mklsparsematrix.jl +++ b/src/mklsparsematrix.jl @@ -201,9 +201,27 @@ Base.convert(::Type{SparseMatrixCSC}, A::MKLSparseMatrix{SparseMatrixCSC{Tv, Ti} function Base.convert(::Type{S}, A::MKLSparseMatrix{S}) where {S <: SparseMatrixCSR} _A = extract_data(A) # not converting the col indices depending on index_base - @show length(_A.nzval) return S(_A.size..., copy(_A.major_starts), copy(_A.minor_val), copy(_A.nzval)) end Base.convert(::Type{SparseMatrixCSR}, A::MKLSparseMatrix{SparseMatrixCSR{Tv, Ti}}) where {Tv, Ti} = convert(SparseMatrixCSR{Tv, Ti}, A) + +# copy the non-zero values from the MKL Sparse matrix A into the sparse matrix B +# A and B should have the same non-zero pattern +function Base.copy!(B::S, A::MKLSparseMatrix{S}; + check_nzpattern::Bool = true) where {S <: SparseMatrixCSC} + _A = extract_data(A) + Ti = eltype(B.rowval) + length(_A.nzval) == nnz(B) || error(lazy"Number of nonzeros in the source ($(length(_A.nzval))) does not match the destination matrix ($(nnz(B)))") + size(B) == _A.size || throw(DimensionMismatch(lazy"Size of the source $(_A.size) does not match the destination $(size(B))")) + if check_nzpattern + B.colptr == _A.major_starts || error("Source and destination colptr do not match") + rowval_match = _A.index_base == SPARSE_INDEX_BASE_ZERO ? + all((a, b) -> a + one(Ti) == b, zip(_A.minor_val, B.rowval)) : # convert to 1-based + _A.minor_val == B.rowval + rowval_match || error("Source and destination rowval do not match") + end + (pointer(B.nzval) != pointer(_A.nzval)) && copy!(B.nzval, _A.nzval) + return B +end diff --git a/src/types.jl b/src/types.jl index 465ee7d..5e7a962 100644 --- a/src/types.jl +++ b/src/types.jl @@ -10,7 +10,7 @@ elseif Tv == ComplexF64 'z' else - throw(ArgumentError("Unsupported sparse value type $Tv")) + throw(ArgumentError(lazy"Unsupported sparse value type $Tv")) end end @@ -20,12 +20,12 @@ end elseif Ti == Int64 "_64" else - throw(ArgumentError("Unsupported sparse index type $Ti")) + throw(ArgumentError(lazy"Unsupported sparse index type $Ti")) end end mkl_storagetype_specifier(::Type{S}) where S <: AbstractSparseMatrix = - throw(ArgumentError("Unsupported sparse matrix storage type $S")) + throw(ArgumentError(lazy"Unsupported sparse matrix storage type $S")) mkl_storagetype_specifier(::Type{<:SparseMatrixCSC}) = "csc" @@ -52,6 +52,13 @@ matrix_descr(A::SparseMatrixCSC) = matrix_descr('G', 'F', 'N') matrix_descr(A::Transpose) = matrix_descr(A.parent) matrix_descr(A::Adjoint) = matrix_descr(A.parent) +# modify the specific fields of the descriptor +matrix_descr(descr::matrix_descr; + type::sparse_matrix_type_t = descr.type, + mode::sparse_fill_mode_t = descr.mode, + diag::sparse_diag_type_t = descr.diag) = + matrix_descr(type, mode, diag) + @inline function Base.convert(::Type{sparse_operation_t}, trans::Char) if trans == 'N' SPARSE_OPERATION_NON_TRANSPOSE @@ -60,7 +67,7 @@ matrix_descr(A::Adjoint) = matrix_descr(A.parent) elseif trans == 'C' SPARSE_OPERATION_CONJUGATE_TRANSPOSE else - throw(ArgumentError("Unknown operation $trans")) + throw(ArgumentError(lazy"Unknown operation $trans")) end end @@ -82,7 +89,7 @@ end elseif mattype == 'D' SPARSE_MATRIX_TYPE_DIAGONAL else - throw(ArgumentError("Unknown matrix type $mattype")) + throw(ArgumentError(lazy"Unknown matrix type $mattype")) end end @@ -94,7 +101,7 @@ end elseif mattype == "BD" return SPARSE_MATRIX_TYPE_BLOCK_DIAGONAL else - throw(ArgumentError("Unknown matrix type $mattype")) + throw(ArgumentError(lazy"Unknown matrix type $mattype")) end end @@ -104,7 +111,7 @@ end elseif index == '1' return SPARSE_INDEX_BASE_ONE else - throw(ArgumentError("Unknown index base $index")) + throw(ArgumentError(lazy"Unknown index base $index")) end end @@ -116,7 +123,7 @@ end elseif uplo =='F' SPARSE_FILL_MODE_FULL else - throw(ArgumentError("Unknown fill mode $uplo")) + throw(ArgumentError(lazy"Unknown fill mode $uplo")) end end @@ -126,7 +133,7 @@ end elseif diag == 'N' SPARSE_DIAG_NON_UNIT else - throw(ArgumentError("Unknown diag type $diag")) + throw(ArgumentError(lazy"Unknown diag type $diag")) end end @@ -136,7 +143,7 @@ end elseif layout == 'C' SPARSE_LAYOUT_COLUMN_MAJOR else - throw(ArgumentError("Unknown layout $layout")) + throw(ArgumentError(lazy"Unknown layout $layout")) end end @@ -148,7 +155,7 @@ end elseif verbose == "extended" SPARSE_VERBOSE_EXTENDED else - throw(ArgumentError("Unknown verbose mode $verbose")) + throw(ArgumentError(lazy"Unknown verbose mode $verbose")) end end @@ -158,7 +165,7 @@ end elseif memory == "aggressive" SPARSE_MEMORY_AGGRESSIVE else - throw(ArgumentError("Unknown memory usage $memory")) + throw(ArgumentError(lazy"Unknown memory usage $memory")) end end @@ -170,11 +177,13 @@ end # check the correctness of transA (transB etc) argument of MKLSparse calls check_trans(t::Char) = (t in ('C', 'N', 'T')) || - throw(ArgumentError("trans: is '$t', must be 'N', 'T', or 'C'")) + throw(ArgumentError(lazy"trans: is '$t', must be 'N', 'T', or 'C'")) # check matrix sizes for the multiplication-like operation C <- tA[A] * tB[B] function check_mat_op_sizes(C, A, tA, B, tB; - dense_layout::sparse_layout_t = SPARSE_LAYOUT_COLUMN_MAJOR) + dense_layout::sparse_layout_t = SPARSE_LAYOUT_COLUMN_MAJOR, + check_result_rows::Bool = true, + check_result_columns::Bool = true) mklsize(M::AbstractMatrix, tM::Char) = (tM == 'N') == (dense_layout == SPARSE_LAYOUT_COLUMN_MAJOR) ? size(M) : reverse(size(M)) mklsize(M::AbstractSparseMatrix, tM::Char) = @@ -192,8 +201,8 @@ function check_mat_op_sizes(C, A, tA, B, tB; mA, nA = mklsize(A, tA) mB, nB = mklsize(B, tB) - mC, nC = mklsize(C, 'N') - if nA != mB || mC != mA || nC != nB + mC, nC = !isnothing(C) ? mklsize(C, 'N') : (mA, nB) + if nA != mB || (check_result_rows && mC != mA) || (check_result_columns && nC != nB) str = string("arrays had inconsistent dimensions for C = A", opsym(tA), " * B", opsym(tB), ": ", sizestr(C), " = ", sizestr(A), opsym(tA), " * ", sizestr(B), opsym(tB)) throw(DimensionMismatch(str))