Skip to content

Commit b0b116e

Browse files
committed
Merge pull request dmlc#25 from slundberg/pull-request/0287171a
Add support for transposed sparse matrices.
2 parents 1e82de5 + 0287171 commit b0b116e

File tree

3 files changed

+33
-5
lines changed

3 files changed

+33
-5
lines changed

src/xgboost_lib.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,12 @@ type DMatrix
3131
finalizer(sp, JLFree)
3232
sp
3333
end
34-
function DMatrix{K<:Real, V<:Integer}(data::SparseMatrixCSC{K, V}; kwargs...)
35-
handle = XGDMatrixCreateFromCSC(convert(SparseMatrixCSC{Float32, Int64}, data))
36-
sp = new(handle, _setinfo)
34+
function DMatrix{K<:Real, V<:Integer}(data::SparseMatrixCSC{K, V}, transposed::Bool=false; kwargs...)
35+
handle = (transposed ? XGDMatrixCreateFromCSCT(data) : XGDMatrixCreateFromCSC(data))
3736
for itm in kwargs
3837
_setinfo(handle, string(itm[1]), itm[2])
3938
end
39+
sp = new(handle, _setinfo)
4040
finalizer(sp, JLFree)
4141
sp
4242
end

src/xgboost_wrapper_h.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,27 @@ function XGDMatrixCreateFromFile(fname::ASCIIString, slient::Int32)
2424
return handle[]
2525
end
2626

27-
function XGDMatrixCreateFromCSC(data::SparseMatrixCSC{Float32, Int64})
27+
function XGDMatrixCreateFromCSC(data::SparseMatrixCSC)
2828
handle = Ref{Ptr{Void}}()
2929
@xgboost_ccall(
3030
:XGDMatrixCreateFromCSC,
3131
(Ptr{UInt64}, Ptr{UInt32}, Ptr{Float32}, UInt64, UInt64, Ref{Ptr{Void}}),
3232
convert(Array{UInt64, 1}, data.colptr - 1),
33-
convert(Array{UInt32, 1}, data.rowval - 1), data.nzval,
33+
convert(Array{UInt32, 1}, data.rowval - 1), convert(Array{Float32, 1}, data.nzval),
34+
convert(UInt64, size(data.colptr)[1]),
35+
convert(UInt64, nnz(data)),
36+
handle
37+
)
38+
return handle[]
39+
end
40+
41+
function XGDMatrixCreateFromCSCT(data::SparseMatrixCSC)
42+
handle = Ref{Ptr{Void}}()
43+
@xgboost_ccall(
44+
:XGDMatrixCreateFromCSR,
45+
(Ptr{UInt64}, Ptr{UInt32}, Ptr{Float32}, UInt64, UInt64, Ref{Ptr{Void}}),
46+
convert(Array{UInt64, 1}, data.colptr - 1),
47+
convert(Array{UInt32, 1}, data.rowval - 1), convert(Array{Float32, 1}, data.nzval),
3448
convert(UInt64, size(data.colptr)[1]),
3549
convert(UInt64, nnz(data)),
3650
handle

test/runtests.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,20 @@ using FactCheck
33

44
include("utils.jl")
55

6+
facts("Sparse matrices") do
7+
X = sparse(randn(100,10) .* bitrand(100,10))
8+
y = randn(100)
9+
DMatrix(X, label=y)
10+
11+
X = sparse(convert(Array{Float32,2}, randn(10,100) .* bitrand(10,100)))
12+
y = randn(100)
13+
DMatrix(X, true)
14+
15+
X = sparse(randn(100,10) .* bitrand(100,10))
16+
y = randn(100)
17+
DMatrix(X)
18+
end
19+
620
facts("DMatrix loading") do
721
dtrain = DMatrix("../data/agaricus.txt.train")
822
train_X, train_Y = readlibsvm("../data/agaricus.txt.train", (6513, 126))

0 commit comments

Comments
 (0)