diff --git a/api/src/main/java/ai/djl/util/passthrough/PassthroughNDArray.java b/api/src/main/java/ai/djl/util/passthrough/PassthroughNDArray.java new file mode 100644 index 00000000000..eb7d3e699e0 --- /dev/null +++ b/api/src/main/java/ai/djl/util/passthrough/PassthroughNDArray.java @@ -0,0 +1,64 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.util.passthrough; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDArrayAdapter; +import java.nio.ByteBuffer; + +/** + * An {@link NDArray} that stores an arbitrary Java object. + * + *

This class is mainly for use in extensions and hybrid engines. Despite it's name, it will + * often not contain actual {@link NDArray}s but just any object necessary to conform to the DJL + * predictor API. + */ +public class PassthroughNDArray extends NDArrayAdapter { + + private Object object; + + /** + * Constructs a {@link PassthroughNDArray} storing an object. + * + * @param object the object to store + */ + public PassthroughNDArray(Object object) { + super(null, null, null, null, null); + this.object = object; + } + + /** + * Returns the object stored. + * + * @return the object stored + */ + public Object getObject() { + return object; + } + + /** {@inheritDoc} */ + @Override + public ByteBuffer toByteBuffer() { + throw new UnsupportedOperationException("Operation not supported for FastText"); + } + + /** {@inheritDoc} */ + @Override + public void intern(NDArray replaced) { + throw new UnsupportedOperationException("Operation not supported for FastText"); + } + + /** {@inheritDoc} */ + @Override + public void detach() {} +} diff --git a/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java b/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java new file mode 100644 index 00000000000..493b2fa0e9a --- /dev/null +++ b/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java @@ -0,0 +1,209 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.util.passthrough; + +import ai.djl.Device; +import ai.djl.engine.Engine; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.NDResource; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.util.PairList; +import java.nio.Buffer; +import java.nio.ByteBuffer; +import java.nio.charset.Charset; +import java.nio.file.Path; + +/** An {@link NDManager} that does nothing, for use in extensions and hybrid engines. */ +public final class PassthroughNDManager implements NDManager { + + private static final String UNSUPPORTED = "Not supported by PassthroughNDManager"; + public static final PassthroughNDManager INSTANCE = new PassthroughNDManager(); + + private PassthroughNDManager() {} + + @Override + public Device defaultDevice() { + return Device.cpu(); + } + + @Override + public ByteBuffer allocateDirect(int capacity) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray from(NDArray array) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray create(String[] data, Charset charset, Shape shape) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray create(Shape shape, DataType dataType) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray createCSR(Buffer data, long[] indptr, long[] indices, Shape shape) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray createRowSparse(Buffer data, Shape dataShape, long[] indices, Shape shape) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray createCoo(Buffer data, long[][] indices, Shape shape) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDList load(Path path) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public void setName(String name) {} + + @Override + public String getName() { + return "PassthroughNDManager"; + } + + @Override + public NDArray zeros(Shape shape, DataType dataType) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray ones(Shape shape, DataType dataType) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray full(Shape shape, float value, DataType dataType) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray arange(float start, float stop, float step, DataType dataType) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray eye(int rows, int cols, int k, DataType dataType) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray linspace(float start, float stop, int num, boolean endpoint) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray randomInteger(long low, long high, Shape shape, DataType dataType) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray randomUniform(float low, float high, Shape shape, DataType dataType) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataType) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray truncatedNormal(float loc, float scale, Shape shape, DataType dataType) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray randomMultinomial(int n, NDArray pValues) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray randomMultinomial(int n, NDArray pValues, Shape shape) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public boolean isOpen() { + return true; + } + + @Override + public NDManager getParentManager() { + return this; + } + + @Override + public NDManager newSubManager() { + return this; + } + + @Override + public NDManager newSubManager(Device device) { + return this; + } + + @Override + public Device getDevice() { + return Device.cpu(); + } + + @Override + public void attachInternal(String resourceId, AutoCloseable resource) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public void tempAttachInternal( + NDManager originalManager, String resourceId, NDResource resource) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public void detachInternal(String resourceId) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public void invoke( + String operation, NDArray[] src, NDArray[] dest, PairList params) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDList invoke(String operation, NDList src, PairList params) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public Engine getEngine() { + return null; + } + + @Override + public void close() {} +} diff --git a/api/src/main/java/ai/djl/util/passthrough/PassthroughTranslator.java b/api/src/main/java/ai/djl/util/passthrough/PassthroughTranslator.java new file mode 100644 index 00000000000..55f6692a2b9 --- /dev/null +++ b/api/src/main/java/ai/djl/util/passthrough/PassthroughTranslator.java @@ -0,0 +1,38 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.util.passthrough; + +import ai.djl.ndarray.NDList; +import ai.djl.translate.NoBatchifyTranslator; +import ai.djl.translate.TranslatorContext; + +/** + * A translator that stores and removes data from a {@link PassthroughNDArray}. + * + * @param translator input type + * @param translator output type + */ +public class PassthroughTranslator implements NoBatchifyTranslator { + + @Override + public NDList processInput(TranslatorContext ctx, I input) throws Exception { + return new NDList(new PassthroughNDArray(input)); + } + + @Override + @SuppressWarnings("unchecked") + public O processOutput(TranslatorContext ctx, NDList list) { + PassthroughNDArray wrapper = (PassthroughNDArray) list.singletonOrThrow(); + return (O) wrapper.getObject(); + } +} diff --git a/api/src/main/java/ai/djl/util/passthrough/package-info.java b/api/src/main/java/ai/djl/util/passthrough/package-info.java new file mode 100644 index 00000000000..62a0fd37ce9 --- /dev/null +++ b/api/src/main/java/ai/djl/util/passthrough/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains passthrough DJL classes for use in extensions and hybrid engines. */ +package ai.djl.util.passthrough; diff --git a/extensions/fasttext/README.md b/extensions/fasttext/README.md index 64f582f8b07..f39e64b5730 100644 --- a/extensions/fasttext/README.md +++ b/extensions/fasttext/README.md @@ -5,8 +5,10 @@ This module contains the NLP support with fastText implementation. fastText module's implementation in DJL is not considered as an Engine, it doesn't support Trainer and Predictor. -The training and inference functionality is directly provided through [FtModel](https://javadoc.io/doc/ai.djl.fasttext/fasttext-engine/latest/ai/djl/fasttext/FtModel.html) -class. You can find examples [here](https://github.com/deepjavalibrary/djl/blob/master/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java). +Training is only supported by using [TrainFastText](https://javadoc.io/doc/ai.djl.fasttext/fasttext-engine/latest/ai/djl/fasttext/TrainFastText.html). +This produces a special block which can perform inference on its own or by using a model and predictor. +Pre-trained FastText models can also be loaded by using the standard DJL criteria. +You can find examples [here](https://github.com/deepjavalibrary/djl/blob/master/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java). Current implementation has the following limitations: diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtAbstractBlock.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/FtAbstractBlock.java new file mode 100644 index 00000000000..c7ffec04a95 --- /dev/null +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/FtAbstractBlock.java @@ -0,0 +1,63 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.fasttext; + +import ai.djl.fasttext.jni.FtWrapper; +import ai.djl.nn.AbstractSymbolBlock; +import java.nio.file.Path; + +/** + * A parent class containing shared behavior for {@link ai.djl.nn.SymbolBlock}s based on fasttext + * models. + */ +public abstract class FtAbstractBlock extends AbstractSymbolBlock implements AutoCloseable { + + protected FtWrapper fta; + + protected Path modelFile; + + /** + * Constructs a {@link FtAbstractBlock}. + * + * @param fta the {@link FtWrapper} containing the "fasttext model" + */ + public FtAbstractBlock(FtWrapper fta) { + this.fta = fta; + } + + /** + * Returns the fasttext model file for the block. + * + * @return the fasttext model file for the block + */ + public Path getModelFile() { + return modelFile; + } + + /** + * Embeds a word using fasttext. + * + * @param word the word to embed + * @return the embedding + * @see ai.djl.modality.nlp.embedding.WordEmbedding + */ + public float[] embedWord(String word) { + return fta.getWordVector(word); + } + + @Override + public void close() { + fta.unloadModel(); + fta.close(); + } +} diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtModel.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/FtModel.java index fed7a38b4d8..d7b4451b739 100644 --- a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtModel.java +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/FtModel.java @@ -15,19 +15,19 @@ import ai.djl.Device; import ai.djl.MalformedModelException; import ai.djl.Model; -import ai.djl.basicdataset.RawDataset; import ai.djl.fasttext.jni.FtWrapper; +import ai.djl.fasttext.zoo.nlp.textclassification.FtTextClassification; +import ai.djl.fasttext.zoo.nlp.word_embedding.FtWordEmbeddingBlock; import ai.djl.inference.Predictor; -import ai.djl.modality.Classifications; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Block; import ai.djl.training.Trainer; import ai.djl.training.TrainingConfig; -import ai.djl.training.TrainingResult; import ai.djl.translate.Translator; import ai.djl.util.PairList; +import ai.djl.util.passthrough.PassthroughNDManager; import java.io.FileNotFoundException; import java.io.IOException; import java.io.InputStream; @@ -41,11 +41,12 @@ /** * {@code FtModel} is the fastText implementation of {@link Model}. * - *

FtModel contains all the methods in Model to load and process a model. + *

FtModel contains all the methods in Model to load and process a model. However, it only + * supports training by using {@link TrainFastText}. */ public class FtModel implements Model { - FtWrapper fta; + FtAbstractBlock block; private Path modelDir; private String modelName; @@ -58,7 +59,6 @@ public class FtModel implements Model { */ public FtModel(String name) { this.modelName = name; - fta = FtWrapper.newInstance(); properties = new ConcurrentHashMap<>(); } @@ -80,6 +80,7 @@ public void load(Path modelPath, String prefix, Map options) } String modelFilePath = modelFile.toString(); + FtWrapper fta = FtWrapper.newInstance(); if (!fta.checkModel(modelFilePath)) { throw new MalformedModelException("Malformed FastText model file:" + modelFilePath); } @@ -90,7 +91,21 @@ public void load(Path modelPath, String prefix, Map options) properties.put(entry.getKey(), entry.getValue().toString()); } } - properties.put("model-type", fta.getModelType()); + String modelType = fta.getModelType(); + properties.put("model-type", modelType); + + if ("sup".equals(modelType)) { + String labelPrefix = + properties.getOrDefault( + "label-prefix", FtTextClassification.DEFAULT_LABEL_PREFIX); + block = new FtTextClassification(fta, labelPrefix); + modelDir = block.getModelFile(); + } else if ("cbow".equals(modelType) || "sg".equals(modelType)) { + block = new FtWordEmbeddingBlock(fta); + modelDir = block.getModelFile(); + } else { + throw new MalformedModelException("Unexpected FastText model type: " + modelType); + } } /** {@inheritDoc} */ @@ -130,49 +145,6 @@ private Path findModelFile(String prefix) { return modelFile; } - /** - * Returns top K number of classifications of the input text. - * - * @param text the input text to be classified - * @param topK the value of K - * @return classifications of the input text - */ - public Classifications classify(String text, int topK) { - String labelPrefix = properties.getOrDefault("label-prefix", "__label__"); - return fta.predictProba(text, topK, labelPrefix); - } - - /** - * Train the fastText model. - * - * @param config the training configuration to use - * @param dataset the training dataset - * @return the result of the training - * @throws IOException when IO operation fails in loading a resource - */ - public TrainingResult fit(FtTrainingConfig config, RawDataset dataset) - throws IOException { - Path outputDir = config.getOutputDir(); - if (Files.notExists(outputDir)) { - Files.createDirectory(outputDir); - } - String fitModelName = config.getModelName(); - Path modelFile = outputDir.resolve(fitModelName).toAbsolutePath(); - - String[] args = config.toCommand(dataset.getData().toString()); - - fta.runCmd(args); - setModelFile(modelFile); - - TrainingResult result = new TrainingResult(); - int epoch = config.getEpoch(); - if (epoch <= 0) { - epoch = 5; - } - result.setEpoch(epoch); - return result; - } - /** {@inheritDoc} */ @Override public void save(Path modelDir, String newModelName) {} @@ -185,14 +157,17 @@ public Path getModelPath() { /** {@inheritDoc} */ @Override - public Block getBlock() { - throw new UnsupportedOperationException("Fasttext doesn't support Block."); + public FtAbstractBlock getBlock() { + return block; } /** {@inheritDoc} */ @Override public void setBlock(Block block) { - throw new UnsupportedOperationException("Fasttext doesn't support setting the Block."); + if (!(block instanceof FtAbstractBlock)) { + throw new IllegalArgumentException("Expected a FtAbstractBlock Block"); + } + this.block = (FtAbstractBlock) block; } /** {@inheritDoc} */ @@ -205,7 +180,7 @@ public String getName() { @Override public Trainer newTrainer(TrainingConfig trainingConfig) { throw new UnsupportedOperationException( - "FastText only supports training using FtModel.fit"); + "FastText only supports training using the FtAbstractBlocks"); } /** {@inheritDoc} */ @@ -263,7 +238,7 @@ public InputStream getArtifactAsStream(String name) { /** {@inheritDoc} */ @Override public NDManager getNDManager() { - return null; + return PassthroughNDManager.INSTANCE; } /** {@inheritDoc} */ @@ -278,15 +253,10 @@ public String getProperty(String key) { return properties.get(key); } - void setModelFile(Path modelFile) { - this.modelDir = modelFile; - } - /** {@inheritDoc} */ @Override public void close() { - fta.unloadModel(); - fta.close(); + block.close(); } /** {@inheritDoc} */ diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/TrainFastText.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/TrainFastText.java new file mode 100644 index 00000000000..6250063932a --- /dev/null +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/TrainFastText.java @@ -0,0 +1,38 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.fasttext; + +import ai.djl.basicdataset.RawDataset; +import ai.djl.fasttext.zoo.nlp.textclassification.FtTextClassification; +import java.io.IOException; +import java.nio.file.Path; + +/** A utility to aggregate options for training with fasttext. */ +public final class TrainFastText { + + private TrainFastText() {} + + /** + * Trains a fastText {@link ai.djl.Application.NLP#TEXT_CLASSIFICATION} model. + * + * @param config the training configuration to use + * @param dataset the training dataset + * @return the result of the training + * @throws IOException when IO operation fails in loading a resource + * @see FtTextClassification#fit(FtTrainingConfig, RawDataset) + */ + public static FtTextClassification textClassification( + FtTrainingConfig config, RawDataset dataset) throws IOException { + return FtTextClassification.fit(config, dataset); + } +} diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FastTextLibrary.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FastTextLibrary.java index 42fecd3e219..d874d3ad1aa 100644 --- a/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FastTextLibrary.java +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FastTextLibrary.java @@ -12,6 +12,8 @@ */ package ai.djl.fasttext.jni; +import java.util.ArrayList; + /** A class containing utilities to interact with the SentencePiece Engine's JNI layer. */ @SuppressWarnings("MissingJavadocMethod") final class FastTextLibrary { @@ -32,8 +34,13 @@ private FastTextLibrary() {} native String getModelType(long handle); + @SuppressWarnings("PMD.LooseCoupling") native int predictProba( - long handle, String text, int topK, String[] classes, float[] probabilities); + long handle, + String text, + int topK, + ArrayList classes, + ArrayList probabilities); native float[] getWordVector(long handle, String word); diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FtWrapper.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FtWrapper.java index 7bbfcb1fe58..03d9499698d 100644 --- a/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FtWrapper.java +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FtWrapper.java @@ -59,25 +59,26 @@ public String getModelType() { } public Classifications predictProba(String text, int topK, String labelPrefix) { - String[] labels = new String[topK]; - float[] probs = new float[topK]; + int cap = topK != -1 ? topK : 10; + ArrayList labels = new ArrayList<>(cap); + ArrayList probs = new ArrayList<>(cap); int size = FastTextLibrary.LIB.predictProba(getHandle(), text, topK, labels, probs); List classes = new ArrayList<>(size); List probabilities = new ArrayList<>(size); for (int i = 0; i < size; ++i) { - String label = labels[i]; + String label = labels.get(i); if (label.startsWith(labelPrefix)) { label = label.substring(labelPrefix.length()); } classes.add(label); - probabilities.add((double) probs[i]); + probabilities.add((double) probs.get(i)); } return new Classifications(classes, probabilities); } - public float[] getDataVector(String word) { + public float[] getWordVector(String word) { return FastTextLibrary.LIB.getWordVector(getHandle(), word); } diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/textclassification/FtTextClassification.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/textclassification/FtTextClassification.java new file mode 100644 index 00000000000..c19b1370bfb --- /dev/null +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/textclassification/FtTextClassification.java @@ -0,0 +1,144 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.fasttext.zoo.nlp.textclassification; + +import ai.djl.basicdataset.RawDataset; +import ai.djl.fasttext.FtAbstractBlock; +import ai.djl.fasttext.FtTrainingConfig; +import ai.djl.fasttext.jni.FtWrapper; +import ai.djl.fasttext.zoo.nlp.word_embedding.FtWordEmbeddingBlock; +import ai.djl.modality.Classifications; +import ai.djl.ndarray.NDList; +import ai.djl.training.ParameterStore; +import ai.djl.training.TrainingResult; +import ai.djl.util.PairList; +import ai.djl.util.passthrough.PassthroughNDArray; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; + +/** A {@link FtAbstractBlock} for {@link ai.djl.Application.NLP#TEXT_CLASSIFICATION}. */ +public class FtTextClassification extends FtAbstractBlock { + + public static final String DEFAULT_LABEL_PREFIX = "__label__"; + + private String labelPrefix; + + private TrainingResult trainingResult; + + /** + * Constructs a {@link FtTextClassification}. + * + * @param fta the {@link FtWrapper} containing the "fasttext model" + * @param labelPrefix the prefix to use for labels + */ + public FtTextClassification(FtWrapper fta, String labelPrefix) { + super(fta); + this.labelPrefix = labelPrefix; + } + + /** + * Trains the fastText model. + * + * @param config the training configuration to use + * @param dataset the training dataset + * @return the result of the training + * @throws IOException when IO operation fails in loading a resource + */ + public static FtTextClassification fit(FtTrainingConfig config, RawDataset dataset) + throws IOException { + Path outputDir = config.getOutputDir(); + if (Files.notExists(outputDir)) { + Files.createDirectory(outputDir); + } + String fitModelName = config.getModelName(); + FtWrapper fta = FtWrapper.newInstance(); + Path modelFile = outputDir.resolve(fitModelName).toAbsolutePath(); + + String[] args = config.toCommand(dataset.getData().toString()); + + fta.runCmd(args); + + TrainingResult result = new TrainingResult(); + int epoch = config.getEpoch(); + if (epoch <= 0) { + epoch = 5; + } + result.setEpoch(epoch); + + FtTextClassification block = new FtTextClassification(fta, config.getLabelPrefix()); + block.modelFile = modelFile; + block.trainingResult = result; + return block; + } + + /** + * Returns the fasttext label prefix. + * + * @return the fasttext label prefix + */ + public String getLabelPrefix() { + return labelPrefix; + } + + /** + * Returns the results of training, or null if not trained. + * + * @return the results of training, or null if not trained + */ + public TrainingResult getTrainingResult() { + return trainingResult; + } + + @Override + protected NDList forwardInternal( + ParameterStore parameterStore, + NDList inputs, + boolean training, + PairList params) { + PassthroughNDArray inputWrapper = (PassthroughNDArray) inputs.singletonOrThrow(); + String input = (String) inputWrapper.getObject(); + Classifications result = fta.predictProba(input, -1, labelPrefix); + return new NDList(new PassthroughNDArray(result)); + } + + /** + * Converts the block into the equivalent {@link FtWordEmbeddingBlock}. + * + * @return the equivalent {@link FtWordEmbeddingBlock} + */ + public FtWordEmbeddingBlock toWordEmbedding() { + return new FtWordEmbeddingBlock(fta); + } + + /** + * Returns the classifications of the input text. + * + * @param text the input text to be classified + * @return classifications of the input text + */ + public Classifications classify(String text) { + return classify(text, -1); + } + + /** + * Returns top K classifications of the input text. + * + * @param text the input text to be classified + * @param topK the value of K + * @return classifications of the input text + */ + public Classifications classify(String text, int topK) { + return fta.predictProba(text, topK, labelPrefix); + } +} diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/textclassification/TextClassificationModelLoader.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/textclassification/TextClassificationModelLoader.java index 06c863301bd..4219420c585 100644 --- a/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/textclassification/TextClassificationModelLoader.java +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/textclassification/TextClassificationModelLoader.java @@ -24,6 +24,7 @@ import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.repository.zoo.ZooModel; import ai.djl.util.Progress; +import ai.djl.util.passthrough.PassthroughTranslator; import java.io.IOException; import java.nio.file.Path; @@ -66,6 +67,6 @@ public ZooModel loadModel(Criteria criteria) Model model = new FtModel(modelName); Path modelPath = mrl.getRepository().getResourceDirectory(artifact); model.load(modelPath, modelName, criteria.getOptions()); - return new ZooModel<>(model, null); + return new ZooModel<>(model, new PassthroughTranslator<>()); } } diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtWord2VecWordEmbedding.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/FtWord2VecWordEmbedding.java similarity index 64% rename from extensions/fasttext/src/main/java/ai/djl/fasttext/FtWord2VecWordEmbedding.java rename to extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/FtWord2VecWordEmbedding.java index 18e765092f2..230079d79cd 100644 --- a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtWord2VecWordEmbedding.java +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/FtWord2VecWordEmbedding.java @@ -10,27 +10,50 @@ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ -package ai.djl.fasttext; +package ai.djl.fasttext.zoo.nlp.word_embedding; +import ai.djl.Model; +import ai.djl.fasttext.FtAbstractBlock; +import ai.djl.fasttext.FtModel; import ai.djl.modality.nlp.Vocabulary; import ai.djl.modality.nlp.embedding.WordEmbedding; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; +import ai.djl.repository.zoo.ZooModel; /** An implementation of {@link WordEmbedding} for FastText word embeddings. */ public class FtWord2VecWordEmbedding implements WordEmbedding { - private FtModel model; + private FtAbstractBlock embedding; private Vocabulary vocabulary; /** * Constructs a {@link FtWord2VecWordEmbedding}. * - * @param model a loaded FastText model + * @param model a loaded FastText wordEmbedding model or a ZooModel containing one * @param vocabulary the {@link Vocabulary} to get indices from */ - public FtWord2VecWordEmbedding(FtModel model, Vocabulary vocabulary) { - this.model = model; + public FtWord2VecWordEmbedding(Model model, Vocabulary vocabulary) { + if (model instanceof ZooModel) { + model = ((ZooModel) model).getWrappedModel(); + } + + if (!(model instanceof FtModel)) { + throw new IllegalArgumentException("The FtWord2VecWordEmbedding requires an FtModel"); + } + + this.embedding = (FtAbstractBlock) model.getBlock(); + this.vocabulary = vocabulary; + } + + /** + * Constructs a {@link FtWord2VecWordEmbedding}. + * + * @param embedding the word embedding + * @param vocabulary the {@link Vocabulary} to get indices from + */ + public FtWord2VecWordEmbedding(FtAbstractBlock embedding, Vocabulary vocabulary) { + this.embedding = embedding; this.vocabulary = vocabulary; } @@ -56,7 +79,7 @@ public NDArray embedWord(NDArray index) { @Override public NDArray embedWord(NDManager manager, long index) { String word = vocabulary.getToken(index); - float[] buf = model.fta.getDataVector(word); + float[] buf = embedding.embedWord(word); return manager.create(buf); } diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/FtWordEmbeddingBlock.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/FtWordEmbeddingBlock.java new file mode 100644 index 00000000000..8f18558858d --- /dev/null +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/FtWordEmbeddingBlock.java @@ -0,0 +1,45 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.fasttext.zoo.nlp.word_embedding; + +import ai.djl.fasttext.FtAbstractBlock; +import ai.djl.fasttext.jni.FtWrapper; +import ai.djl.ndarray.NDList; +import ai.djl.training.ParameterStore; +import ai.djl.util.PairList; +import ai.djl.util.passthrough.PassthroughNDArray; + +/** A {@link FtAbstractBlock} for {@link ai.djl.Application.NLP#WORD_EMBEDDING}. */ +public class FtWordEmbeddingBlock extends FtAbstractBlock { + + /** + * Constructs a {@link FtWordEmbeddingBlock}. + * + * @param fta the {@link FtWrapper} for the "fasttext model". + */ + public FtWordEmbeddingBlock(FtWrapper fta) { + super(fta); + } + + @Override + protected NDList forwardInternal( + ParameterStore parameterStore, + NDList inputs, + boolean training, + PairList params) { + PassthroughNDArray inputWrapper = (PassthroughNDArray) inputs.singletonOrThrow(); + String input = (String) inputWrapper.getObject(); + float[] result = embedWord(input); + return new NDList(new PassthroughNDArray(result)); + } +} diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/package-info.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/package-info.java new file mode 100644 index 00000000000..4bbd761e559 --- /dev/null +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/package-info.java @@ -0,0 +1,18 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** + * Contains classes for the {@link ai.djl.Application.NLP#WORD_EMBEDDING} models in the {@link + * ai.djl.fasttext.zoo.FtModelZoo}. + */ +package ai.djl.fasttext.zoo.nlp.word_embedding; diff --git a/extensions/fasttext/src/main/native/ai_djl_fasttext_jni_FastTextLibrary.cc b/extensions/fasttext/src/main/native/ai_djl_fasttext_jni_FastTextLibrary.cc index c4cb4354cee..8a14c1078b2 100644 --- a/extensions/fasttext/src/main/native/ai_djl_fasttext_jni_FastTextLibrary.cc +++ b/extensions/fasttext/src/main/native/ai_djl_fasttext_jni_FastTextLibrary.cc @@ -97,9 +97,9 @@ JNIEXPORT jstring JNICALL Java_ai_djl_fasttext_jni_FastTextLibrary_getModelType( if (modelName == model_name::cbow) { return env->NewStringUTF("cbow"); } else if (modelName == model_name::sg) { - return env->NewStringUTF("cbow"); + return env->NewStringUTF("sg"); } else if (modelName == model_name::sup) { - return env->NewStringUTF("cbow"); + return env->NewStringUTF("sup"); } else { jclass jexception = env->FindClass("ai/djl/engine/EngineException"); env->ThrowNew(jexception, "Unrecognized model type"); @@ -108,7 +108,7 @@ JNIEXPORT jstring JNICALL Java_ai_djl_fasttext_jni_FastTextLibrary_getModelType( } JNIEXPORT jint JNICALL Java_ai_djl_fasttext_jni_FastTextLibrary_predictProba( - JNIEnv* env, jobject jthis, jlong jhandle, jstring jtext, jint top_k, jobjectArray jclasses, jfloatArray jprob) { + JNIEnv* env, jobject jthis, jlong jhandle, jstring jtext, jint top_k, jobject jclasses, jobject jprob) { auto* fasttext_ptr = reinterpret_cast(jhandle); std::string text = djl::utils::jni::GetStringFromJString(env, jtext); std::istringstream in(text); @@ -116,13 +116,15 @@ JNIEXPORT jint JNICALL Java_ai_djl_fasttext_jni_FastTextLibrary_predictProba( fasttext_ptr->predictLine(in, predictions, top_k, 0.0); int size = predictions.size(); - std::vector prob; + jclass java_lang_Float = static_cast(env->NewGlobalRef(env->FindClass("java/lang/Float"))); + jmethodID java_lang_Float_ = env->GetMethodID(java_lang_Float, "", "(F)V"); + jclass java_util_ArrayList = static_cast(env->NewGlobalRef(env->FindClass("java/util/ArrayList"))); + jmethodID java_util_ArrayList_add = env->GetMethodID(java_util_ArrayList, "add", "(Ljava/lang/Object;)Z"); for (int i = 0; i < size; ++i) { std::pair pair = predictions[i]; - env->SetObjectArrayElement(jclasses, i, env->NewStringUTF(pair.second.c_str())); - prob.push_back(pair.first); + env->CallBooleanMethod(jclasses, java_util_ArrayList_add, env->NewStringUTF(pair.second.c_str())); + env->CallBooleanMethod(jprob, java_util_ArrayList_add, env->NewObject(java_lang_Float, java_lang_Float_, pair.first)); } - env->SetFloatArrayRegion(jprob, 0, size, prob.data()); return size; } diff --git a/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java b/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java index 0ab6d1180d4..c0bb2fc3485 100644 --- a/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java +++ b/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java @@ -12,19 +12,26 @@ */ package ai.djl.fasttext; +import ai.djl.Application; import ai.djl.MalformedModelException; import ai.djl.ModelException; import ai.djl.basicdataset.nlp.CookingStackExchange; +import ai.djl.fasttext.zoo.nlp.textclassification.FtTextClassification; +import ai.djl.fasttext.zoo.nlp.word_embedding.FtWord2VecWordEmbedding; +import ai.djl.inference.Predictor; import ai.djl.modality.Classifications; import ai.djl.modality.nlp.DefaultVocabulary; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; +import ai.djl.repository.Artifact; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ModelZoo; import ai.djl.repository.zoo.ZooModel; import ai.djl.testing.TestRequirements; import ai.djl.training.TrainingResult; +import ai.djl.translate.TranslateException; import java.io.IOException; import java.io.InputStream; import java.net.URL; @@ -33,6 +40,8 @@ import java.nio.file.Paths; import java.nio.file.StandardCopyOption; import java.util.Collections; +import java.util.List; +import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.testng.Assert; @@ -43,30 +52,30 @@ public class CookingStackExchangeTest { private static final Logger logger = LoggerFactory.getLogger(CookingStackExchangeTest.class); @Test - public void testTrainTextClassification() throws IOException { + public void testTrainTextClassification() throws IOException, TranslateException { TestRequirements.notWindows(); // fastText is not supported on windows - try (FtModel model = new FtModel("cooking")) { - CookingStackExchange dataset = CookingStackExchange.builder().build(); - - // setup training configuration - FtTrainingConfig config = - FtTrainingConfig.builder() - .setOutputDir(Paths.get("build")) - .setModelName("cooking") - .optEpoch(5) - .optLoss(FtTrainingConfig.FtLoss.HS) - .build(); - - TrainingResult result = model.fit(config, dataset); - Assert.assertEquals(result.getEpoch(), 5); - Assert.assertTrue(Files.exists(Paths.get("build/cooking.bin"))); - } + CookingStackExchange dataset = CookingStackExchange.builder().build(); + + // setup training configuration + FtTrainingConfig config = + FtTrainingConfig.builder() + .setOutputDir(Paths.get("build")) + .setModelName("cooking") + .optEpoch(5) + .optLoss(FtTrainingConfig.FtLoss.HS) + .build(); + + FtTextClassification block = TrainFastText.textClassification(config, dataset); + TrainingResult result = block.getTrainingResult(); + Assert.assertEquals(result.getEpoch(), 5); + Assert.assertTrue(Files.exists(Paths.get("build/cooking.bin"))); } @Test public void testTextClassification() - throws IOException, MalformedModelException, ModelNotFoundException { + throws IOException, MalformedModelException, ModelNotFoundException, + TranslateException { TestRequirements.notWindows(); // fastText is not supported on windows Criteria criteria = @@ -75,11 +84,18 @@ public void testTextClassification() .optArtifactId("ai.djl.fasttext:cooking_stackexchange") .optOption("label-prefix", "__label") .build(); + Map> models = ModelZoo.listModels(criteria); + models.forEach( + (app, list) -> { + String appName = app.toString(); + list.forEach(artifact -> logger.info("{} {}", appName, artifact)); + }); try (ZooModel model = criteria.loadModel()) { String input = "Which baking dish is best to bake a banana bread ?"; - FtModel ftModel = (FtModel) model.getWrappedModel(); - Classifications result = ftModel.classify(input, 8); - Assert.assertEquals(result.item(0).getClassName(), "__bread"); + try (Predictor predictor = model.newPredictor()) { + Classifications result = predictor.predict(input); + Assert.assertEquals(result.item(0).getClassName(), "__bread"); + } } } @@ -95,10 +111,9 @@ public void testWord2Vec() throws IOException, MalformedModelException, ModelNot try (ZooModel model = criteria.loadModel(); NDManager manager = NDManager.newBaseManager()) { - FtModel ftModel = (FtModel) model.getWrappedModel(); FtWord2VecWordEmbedding fasttextWord2VecWordEmbedding = new FtWord2VecWordEmbedding( - ftModel, new DefaultVocabulary(Collections.singletonList("bread"))); + model, new DefaultVocabulary(Collections.singletonList("bread"))); long index = fasttextWord2VecWordEmbedding.preprocessWordToEmbed("bread"); NDArray embedding = fasttextWord2VecWordEmbedding.embedWord(manager, index); Assert.assertEquals(embedding.getShape(), new Shape(100)); @@ -125,7 +140,7 @@ public void testBlazingText() throws IOException, ModelException { model.load(modelFile); String text = "Convair was an american aircraft manufacturing company which later expanded into rockets and spacecraft ."; - Classifications result = model.classify(text, 5); + Classifications result = ((FtTextClassification) model.getBlock()).classify(text, 5); logger.info("{}", result); Assert.assertEquals(result.item(0).getClassName(), "Company");