Skip to content

Commit

Permalink
[xgb] Add GPU support for XGBoost
Browse files Browse the repository at this point in the history
Upgrade XGBoost to 1.6.1

Change-Id: Id6a6cc911e4a382ccaa94a962d89fad6208c48fc
  • Loading branch information
frankfliu committed May 28, 2022
1 parent 75e6c25 commit bbc6991
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 3 deletions.
10 changes: 9 additions & 1 deletion engines/ml/xgboost/build.gradle
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import groovy.xml.QName

group "ai.djl.ml.xgboost"
boolean isGpu = project.hasProperty("gpu")
def XGB_FLAVOR = isGpu ? "-gpu" : ""

configurations {
exclusion
Expand All @@ -9,7 +11,7 @@ configurations {
dependencies {
api project(":api")
api "commons-logging:commons-logging:${commons_logging_version}"
api("ml.dmlc:xgboost4j_2.12:${xgboost_version}") {
api("ml.dmlc:xgboost4j${XGB_FLAVOR}_2.12:${xgboost_version}") {
// get rid of the unused XGBoost Dependencies
exclude group: "org.apache.hadoop", module: "hadoop-hdfs"
exclude group: "org.apache.hadoop", module: "hadoop-common"
Expand All @@ -23,9 +25,13 @@ dependencies {
exclude group: "org.scala-lang", module: "scala-reflect"
exclude group: "org.scala-lang", module: "scala-library"
}
if (isGpu) {
api "ai.rapids:cudf:${rapis_version}:cuda11"
}

exclusion project(":api")
exclusion "commons-logging:commons-logging:${commons_logging_version}"
exclusion "ai.rapids:cudf:${rapis_version}:cuda11"
testImplementation(project(":testing"))
testImplementation("org.testng:testng:${testng_version}") {
exclude group: "junit", module: "junit"
Expand All @@ -43,6 +49,7 @@ jar {
"ml/dmlc/xgboost4j/java/NativeLibLoader*",
"ml/dmlc/xgboost4j/java/XGBoost*",
"ml/dmlc/xgboost4j/java/util/*",
"ml/dmlc/xgboost4j/gpu/java/*",
"ml/dmlc/xgboost4j/LabeledPoint.*",
"xgboost4j-version.properties"
}
Expand All @@ -53,6 +60,7 @@ jar {
publishing {
publications {
maven(MavenPublication) {
artifactId "${project.name}${XGB_FLAVOR}"
pom {
name = "DJL Engine Adapter for XGBoost"
description = "Deep Java Library (DJL) Engine Adapter for XGBoost"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@
import ai.djl.Device;
import ai.djl.Model;
import ai.djl.engine.Engine;
import ai.djl.engine.StandardCapabilities;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.GradientCollector;
import java.io.IOException;
import java.io.InputStream;
import java.util.Properties;
import ml.dmlc.xgboost4j.java.JniUtils;

/**
Expand Down Expand Up @@ -71,12 +75,27 @@ public int getRank() {
/** {@inheritDoc} */
@Override
public String getVersion() {
return "1.3.1";
try (InputStream is =
XgbEngine.class.getResourceAsStream("/xgboost4j-version.properties")) {
Properties prop = new Properties();
prop.load(is);
return prop.getProperty("version");
} catch (IOException e) {
throw new AssertionError("Failed to load xgboost4j-version.properties", e);
}
}

/** {@inheritDoc} */
@Override
public boolean hasCapability(String capability) {
if (StandardCapabilities.CUDA.equals(capability)) {
try {
Class.forName("ml.dmlc.xgboost4j.gpu.java.CudfColumn");
return true;
} catch (ClassNotFoundException ignore) {
return false;
}
}
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.util.JsonUtils;
import com.google.gson.JsonArray;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.util.Arrays;
import ml.dmlc.xgboost4j.java.ColumnBatch;
import ml.dmlc.xgboost4j.java.JniUtils;

/** {@code XgbNDManager} is the XGBoost implementation of {@link NDManager}. */
Expand Down Expand Up @@ -80,6 +83,30 @@ public Engine getEngine() {
return Engine.getEngine(XgbEngine.ENGINE_NAME);
}

/**
* Creates {@link XgbNDArray} from column array interface.
*
* @param columnBatch – the XGBoost ColumnBatch to provide the cuda array interface of feature
* columns
* @param missing – missing value
* @param nthread – threads number
* @return a new instance of {@link NDArray}
*/
public NDArray create(ColumnBatch columnBatch, float missing, int nthread) {
columnBatch.getFeatureArrayInterface();
String json = columnBatch.getFeatureArrayInterface();
JsonArray array = JsonUtils.GSON.fromJson(json, JsonArray.class);
JsonArray shapeJson = array.get(0).getAsJsonObject().get("shape").getAsJsonArray();
long[] shapes = new long[shapeJson.size()];
for (int i = 0; i < shapes.length; ++i) {
shapes[i] = shapeJson.get(i).getAsLong();
}

Shape shape = new Shape(shapes);
long handle = JniUtils.createDMatrix(columnBatch, missing, nthread);
return new XgbNDArray(this, alternativeManager, handle, shape, SparseFormat.DENSE);
}

/** {@inheritDoc} */
@Override
public NDArray create(Buffer data, Shape shape, DataType dataType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,19 @@ public static long createDMatrix(Buffer buf, Shape shape, float missing) {
return handles[0];
}

public static long createDMatrix(ColumnBatch columnBatch, float missing, int nthread) {
long[] handles = new long[1];
String json = columnBatch.getFeatureArrayInterface();
if (json == null || json.isEmpty()) {
throw new IllegalArgumentException(
"Expecting non-empty feature columns' array interface");
}
checkCall(
XGBoostJNI.XGDMatrixCreateFromArrayInterfaceColumns(
json, missing, nthread, handles));
return handles[0];
}

public static long createDMatrixCSR(long[] indptr, int[] indices, float[] array) {
long[] handles = new long[1];
checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(indptr, indices, array, 0, handles));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
Expand Down Expand Up @@ -45,6 +46,30 @@ public void downloadXGBoostModel() throws IOException {
modelDir.resolve("regression.json").toString());
}

@Test
public void testVersion() {
Engine engine = Engine.getEngine("XGBoost");
Assert.assertEquals("1.6.1", engine.getVersion());
}

/*
This test depends on rapis, which doesn't work for CPU
@Test
public void testCudf() {
Engine engine = Engine.getEngine("XGBoost");
if (engine.hasCapability(StandardCapabilities.CUDA)) {
Float[] data = new Float[]{1.f, null, 5.f, 7.f, 9.f};
Table x = new Table.TestBuilder().column(data).build();
ColumnBatch columnBatch = new CudfColumnBatch(x, null, null, null);
try (XgbNDManager manager = (XgbNDManager)engine.newBaseManager()) {
NDArray array = manager.create(columnBatch,0, 1);
Assert.assertEquals(array.getShape().get(0), 5);
}
}
}
*/

@Test
public void testLoad() throws MalformedModelException, IOException, TranslateException {
// TODO: skip for Windows since it is not supported by XGBoost
Expand Down
3 changes: 2 additions & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ paddlepaddle_version=2.2.2
sentencepiece_version=0.1.95
tokenizers_version=0.11.0
fasttext_version=0.9.2
xgboost_version=1.6.0
xgboost_version=1.6.1
rapis_version=22.04.0

commons_cli_version=1.5.0
commons_compress_version=1.21
Expand Down

0 comments on commit bbc6991

Please sign in to comment.