diff --git a/engines/ml/xgboost/build.gradle b/engines/ml/xgboost/build.gradle index 19affcc3a62..9d31691b71a 100644 --- a/engines/ml/xgboost/build.gradle +++ b/engines/ml/xgboost/build.gradle @@ -1,6 +1,8 @@ import groovy.xml.QName group "ai.djl.ml.xgboost" +boolean isGpu = project.hasProperty("gpu") +def XGB_FLAVOR = isGpu ? "-gpu" : "" configurations { exclusion @@ -9,7 +11,7 @@ configurations { dependencies { api project(":api") api "commons-logging:commons-logging:${commons_logging_version}" - api("ml.dmlc:xgboost4j_2.12:${xgboost_version}") { + api("ml.dmlc:xgboost4j${XGB_FLAVOR}_2.12:${xgboost_version}") { // get rid of the unused XGBoost Dependencies exclude group: "org.apache.hadoop", module: "hadoop-hdfs" exclude group: "org.apache.hadoop", module: "hadoop-common" @@ -23,9 +25,13 @@ dependencies { exclude group: "org.scala-lang", module: "scala-reflect" exclude group: "org.scala-lang", module: "scala-library" } + if (isGpu) { + api "ai.rapids:cudf:${rapis_version}:cuda11" + } exclusion project(":api") exclusion "commons-logging:commons-logging:${commons_logging_version}" + exclusion "ai.rapids:cudf:${rapis_version}:cuda11" testImplementation(project(":testing")) testImplementation("org.testng:testng:${testng_version}") { exclude group: "junit", module: "junit" @@ -43,6 +49,7 @@ jar { "ml/dmlc/xgboost4j/java/NativeLibLoader*", "ml/dmlc/xgboost4j/java/XGBoost*", "ml/dmlc/xgboost4j/java/util/*", + "ml/dmlc/xgboost4j/gpu/java/*", "ml/dmlc/xgboost4j/LabeledPoint.*", "xgboost4j-version.properties" } @@ -53,6 +60,7 @@ jar { publishing { publications { maven(MavenPublication) { + artifactId "${project.name}${XGB_FLAVOR}" pom { name = "DJL Engine Adapter for XGBoost" description = "Deep Java Library (DJL) Engine Adapter for XGBoost" diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngine.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngine.java index 9aa03932a6b..e4a4ebc7045 100644 --- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngine.java +++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngine.java @@ -15,9 +15,13 @@ import ai.djl.Device; import ai.djl.Model; import ai.djl.engine.Engine; +import ai.djl.engine.StandardCapabilities; import ai.djl.ndarray.NDManager; import ai.djl.nn.SymbolBlock; import ai.djl.training.GradientCollector; +import java.io.IOException; +import java.io.InputStream; +import java.util.Properties; import ml.dmlc.xgboost4j.java.JniUtils; /** @@ -71,12 +75,27 @@ public int getRank() { /** {@inheritDoc} */ @Override public String getVersion() { - return "1.3.1"; + try (InputStream is = + XgbEngine.class.getResourceAsStream("/xgboost4j-version.properties")) { + Properties prop = new Properties(); + prop.load(is); + return prop.getProperty("version"); + } catch (IOException e) { + throw new AssertionError("Failed to load xgboost4j-version.properties", e); + } } /** {@inheritDoc} */ @Override public boolean hasCapability(String capability) { + if (StandardCapabilities.CUDA.equals(capability)) { + try { + Class.forName("ml.dmlc.xgboost4j.gpu.java.CudfColumn"); + return true; + } catch (ClassNotFoundException ignore) { + return false; + } + } return false; } diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java index 87b6af06a30..00b5133e226 100644 --- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java +++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java @@ -20,11 +20,14 @@ import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.ndarray.types.SparseFormat; +import ai.djl.util.JsonUtils; +import com.google.gson.JsonArray; import java.nio.Buffer; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.FloatBuffer; import java.util.Arrays; +import ml.dmlc.xgboost4j.java.ColumnBatch; import ml.dmlc.xgboost4j.java.JniUtils; /** {@code XgbNDManager} is the XGBoost implementation of {@link NDManager}. */ @@ -80,6 +83,30 @@ public Engine getEngine() { return Engine.getEngine(XgbEngine.ENGINE_NAME); } + /** + * Creates {@link XgbNDArray} from column array interface. + * + * @param columnBatch – the XGBoost ColumnBatch to provide the cuda array interface of feature + * columns + * @param missing – missing value + * @param nthread – threads number + * @return a new instance of {@link NDArray} + */ + public NDArray create(ColumnBatch columnBatch, float missing, int nthread) { + columnBatch.getFeatureArrayInterface(); + String json = columnBatch.getFeatureArrayInterface(); + JsonArray array = JsonUtils.GSON.fromJson(json, JsonArray.class); + JsonArray shapeJson = array.get(0).getAsJsonObject().get("shape").getAsJsonArray(); + long[] shapes = new long[shapeJson.size()]; + for (int i = 0; i < shapes.length; ++i) { + shapes[i] = shapeJson.get(i).getAsLong(); + } + + Shape shape = new Shape(shapes); + long handle = JniUtils.createDMatrix(columnBatch, missing, nthread); + return new XgbNDArray(this, alternativeManager, handle, shape, SparseFormat.DENSE); + } + /** {@inheritDoc} */ @Override public NDArray create(Buffer data, Shape shape, DataType dataType) { diff --git a/engines/ml/xgboost/src/main/java/ml/dmlc/xgboost4j/java/JniUtils.java b/engines/ml/xgboost/src/main/java/ml/dmlc/xgboost4j/java/JniUtils.java index ce63ad8bf82..e660150acff 100644 --- a/engines/ml/xgboost/src/main/java/ml/dmlc/xgboost4j/java/JniUtils.java +++ b/engines/ml/xgboost/src/main/java/ml/dmlc/xgboost4j/java/JniUtils.java @@ -52,6 +52,19 @@ public static long createDMatrix(Buffer buf, Shape shape, float missing) { return handles[0]; } + public static long createDMatrix(ColumnBatch columnBatch, float missing, int nthread) { + long[] handles = new long[1]; + String json = columnBatch.getFeatureArrayInterface(); + if (json == null || json.isEmpty()) { + throw new IllegalArgumentException( + "Expecting non-empty feature columns' array interface"); + } + checkCall( + XGBoostJNI.XGDMatrixCreateFromArrayInterfaceColumns( + json, missing, nthread, handles)); + return handles[0]; + } + public static long createDMatrixCSR(long[] indptr, int[] indices, float[] array) { long[] handles = new long[1]; checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(indptr, indices, array, 0, handles)); diff --git a/engines/ml/xgboost/src/test/java/ai/djl/ml/xgboost/XgbModelTest.java b/engines/ml/xgboost/src/test/java/ai/djl/ml/xgboost/XgbModelTest.java index 4310b342643..c29853e90c7 100644 --- a/engines/ml/xgboost/src/test/java/ai/djl/ml/xgboost/XgbModelTest.java +++ b/engines/ml/xgboost/src/test/java/ai/djl/ml/xgboost/XgbModelTest.java @@ -15,6 +15,7 @@ import ai.djl.MalformedModelException; import ai.djl.Model; +import ai.djl.engine.Engine; import ai.djl.inference.Predictor; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; @@ -45,6 +46,30 @@ public void downloadXGBoostModel() throws IOException { modelDir.resolve("regression.json").toString()); } + @Test + public void testVersion() { + Engine engine = Engine.getEngine("XGBoost"); + Assert.assertEquals("1.6.1", engine.getVersion()); + } + + /* + This test depends on rapis, which doesn't work for CPU + @Test + public void testCudf() { + Engine engine = Engine.getEngine("XGBoost"); + if (engine.hasCapability(StandardCapabilities.CUDA)) { + Float[] data = new Float[]{1.f, null, 5.f, 7.f, 9.f}; + Table x = new Table.TestBuilder().column(data).build(); + ColumnBatch columnBatch = new CudfColumnBatch(x, null, null, null); + + try (XgbNDManager manager = (XgbNDManager)engine.newBaseManager()) { + NDArray array = manager.create(columnBatch,0, 1); + Assert.assertEquals(array.getShape().get(0), 5); + } + } + } + */ + @Test public void testLoad() throws MalformedModelException, IOException, TranslateException { // TODO: skip for Windows since it is not supported by XGBoost diff --git a/gradle.properties b/gradle.properties index 0ff6242d621..23e2cf9d237 100644 --- a/gradle.properties +++ b/gradle.properties @@ -19,7 +19,8 @@ paddlepaddle_version=2.2.2 sentencepiece_version=0.1.95 tokenizers_version=0.11.0 fasttext_version=0.9.2 -xgboost_version=1.6.0 +xgboost_version=1.6.1 +rapis_version=22.04.0 commons_cli_version=1.5.0 commons_compress_version=1.21