Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML][Inference] Adding inference ingest processor #47859

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
public class InferModelAction extends ActionType<InferModelAction.Response> {

public static final InferModelAction INSTANCE = new InferModelAction();
public static final String NAME = "cluster:admin/xpack/ml/infer";
public static final String NAME = "cluster:admin/xpack/ml/inference/infer";

private InferModelAction() {
super(NAME, Response::new);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,38 @@
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.Version;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

public class ClassificationConfig implements InferenceConfig {

public static final String NAME = "classification";

public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
private static final Version MIN_SUPPORTED_VERSION = Version.V_8_0_0;

public static ClassificationConfig EMPTY_PARAMS = new ClassificationConfig(0);

private final int numTopClasses;

public static ClassificationConfig fromMap(Map<String, Object> map) {
Map<String, Object> options = new HashMap<>(map);
Integer numTopClasses = (Integer)options.remove(NUM_TOP_CLASSES.getPreferredName());
if (options.isEmpty() == false) {
throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet());
}
return new ClassificationConfig(numTopClasses);
}

public ClassificationConfig(Integer numTopClasses) {
this.numTopClasses = numTopClasses == null ? 0 : numTopClasses;
}
Expand Down Expand Up @@ -78,4 +92,9 @@ public boolean isTargetTypeSupported(TargetType targetType) {
return TargetType.CLASSIFICATION.equals(targetType);
}

@Override
public Version getMinimalSupportedVersion() {
return MIN_SUPPORTED_VERSION;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;

Expand All @@ -13,4 +14,5 @@ public interface InferenceConfig extends NamedXContentObject, NamedWriteable {

boolean isTargetTypeSupported(TargetType targetType);

Version getMinimalSupportedVersion();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment explaining the need for this?

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,27 @@
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;

public class RegressionConfig implements InferenceConfig {

public static final String NAME = "regression";
private static final Version MIN_SUPPORTED_VERSION = Version.V_8_0_0;

public static RegressionConfig fromMap(Map<String, Object> map) {
if (map.isEmpty() == false) {
throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", map.keySet());
przemekwitek marked this conversation as resolved.
Show resolved Hide resolved
}
return new RegressionConfig();
}

public RegressionConfig() {
}
Expand Down Expand Up @@ -61,4 +72,9 @@ public boolean isTargetTypeSupported(TargetType targetType) {
return TargetType.REGRESSION.equals(targetType);
}

@Override
public Version getMinimalSupportedVersion() {
return MIN_SUPPORTED_VERSION;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;

import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
Expand All @@ -26,6 +27,11 @@ public boolean isTargetTypeSupported(TargetType targetType) {
return true;
}

@Override
public Version getMinimalSupportedVersion() {
return Version.CURRENT;
}

@Override
public String getWriteableName() {
return "null";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ public final class Messages {
public static final String INFERENCE_FAILED_TO_SERIALIZE_MODEL =
"Failed to serialize the trained model [{0}] for storage";
public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}]";
public static final String INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION =
"Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]";

public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again";
public static final String JOB_AUDIT_CREATED = "Job created";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,35 @@
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;

import java.util.Collections;

import static org.hamcrest.Matchers.equalTo;

public class ClassificationConfigTests extends AbstractWireSerializingTestCase<ClassificationConfig> {

public static ClassificationConfig randomClassificationConfig() {
return new ClassificationConfig(randomBoolean() ? null : randomIntBetween(-1, 10));
}

public void testFromMap() {
ClassificationConfig expected = new ClassificationConfig(0);
assertThat(ClassificationConfig.fromMap(Collections.emptyMap()), equalTo(expected));

expected = new ClassificationConfig(3);
assertThat(ClassificationConfig.fromMap(Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 3)),
equalTo(expected));
}

public void testFromMapWithUnknownField() {
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
() -> ClassificationConfig.fromMap(Collections.singletonMap("some_key", 1)));
assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key]."));
}

@Override
protected ClassificationConfig createTestInstance() {
return randomClassificationConfig();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,31 @@
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;

import java.util.Collections;

import static org.hamcrest.Matchers.equalTo;

public class RegressionConfigTests extends AbstractWireSerializingTestCase<RegressionConfig> {

public static RegressionConfig randomRegressionConfig() {
return new RegressionConfig();
}

public void testFromMap() {
RegressionConfig expected = new RegressionConfig();
assertThat(RegressionConfig.fromMap(Collections.emptyMap()), equalTo(expected));
}

public void testFromMapWithUnknownField() {
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
() -> RegressionConfig.fromMap(Collections.singletonMap("some_key", 1)));
assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key]."));
}

@Override
protected RegressionConfig createTestInstance() {
return randomRegressionConfig();
Expand Down
Loading