Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML][Inference] adding license checks #49056

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -342,10 +342,16 @@ protected Setting<Boolean> roleSetting() {

@Override
public Map<String, Processor.Factory> getProcessors(Processor.Parameters parameters) {
if (this.enabled == false) {
return Collections.emptyMap();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs some due diligence. If there was an ingest pipeline in the cluster state containing an inference ingest processor and the cluster was restarted with xpack.ml.enabled: false what would happen?

The tentative plan for restoring full cluster snapshots in Cloud in the future is to disable all X-Pack plugins during the snapshot restore, which will lead to this exact situation.

If there's any doubt about what will happen it might be safer to allow the ingest processors to exist but just have them fail on every document they process (via the failure response from the infer model action) if the license is invalid.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@droberts195 aren't xpack.ml.enabled and the license checks to different things? Additionally, this is exactly what the enrich project has done (if enrich is disabled, do not provide the processors). I will reach out to core features to see what they think about this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@droberts195 pipelines are stored as maps in the cluster state and are not fully marshaled until the pipeline is created on the ingest node. So, clusterstate restoration and pipeline instantiation are two different steps.

If there was an ingest pipeline in the cluster state containing an inference ingest processor and the cluster was restarted with xpack.ml.enabled: false what would happen?

The cluster will start up fine, the pipeline will just fail to be instantiated on the node.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK great. Sounds like it's not a problem then and the code can stay as it is.

}

InferenceProcessor.Factory inferenceFactory = new InferenceProcessor.Factory(parameters.client,
parameters.ingestService.getClusterService(),
this.settings,
parameters.ingestService);
parameters.ingestService,
getLicenseState());
getLicenseState().addListener(inferenceFactory);
parameters.ingestService.addIngestClusterStateListener(inferenceFactory);
return Collections.singletonMap(InferenceProcessor.TYPE, inferenceFactory);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.ml.inference.loadingservice.Model;
Expand All @@ -24,20 +27,28 @@ public class TransportInferModelAction extends HandledTransportAction<InferModel

private final ModelLoadingService modelLoadingService;
private final Client client;
private final XPackLicenseState licenseState;

@Inject
public TransportInferModelAction(TransportService transportService,
ActionFilters actionFilters,
ModelLoadingService modelLoadingService,
Client client) {
Client client,
XPackLicenseState licenseState) {
super(InferModelAction.NAME, transportService, actionFilters, InferModelAction.Request::new);
this.modelLoadingService = modelLoadingService;
this.client = client;
this.licenseState = licenseState;
}

@Override
protected void doExecute(Task task, InferModelAction.Request request, ActionListener<InferModelAction.Response> listener) {

if (licenseState.isMachineLearningAllowed() == false) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do other actions (get, delete) already have this license check?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I don't think they should either. They don't really provide licensed value IMO. This is similar to how we treat anomaly jobs as well.

listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
return;
}

ActionListener<Model> getModelListener = ActionListener.wrap(
model -> {
TypedChainTaskExecutor<InferenceResults> typedChainTaskExecutor =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
import org.elasticsearch.ingest.Pipeline;
import org.elasticsearch.ingest.PipelineConfiguration;
import org.elasticsearch.ingest.Processor;
import org.elasticsearch.license.LicenseStateListener;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
Expand Down Expand Up @@ -144,20 +148,28 @@ public String getType() {
return TYPE;
}

public static final class Factory implements Processor.Factory, Consumer<ClusterState> {
public static final class Factory implements Processor.Factory, Consumer<ClusterState>, LicenseStateListener {

private static final Logger logger = LogManager.getLogger(Factory.class);

private final Client client;
private final IngestService ingestService;
private final XPackLicenseState licenseState;
private volatile int currentInferenceProcessors;
private volatile int maxIngestProcessors;
private volatile Version minNodeVersion = Version.CURRENT;
private volatile boolean inferenceAllowed;

public Factory(Client client, ClusterService clusterService, Settings settings, IngestService ingestService) {
public Factory(Client client,
ClusterService clusterService,
Settings settings,
IngestService ingestService,
XPackLicenseState licenseState) {
this.client = client;
this.maxIngestProcessors = MAX_INFERENCE_PROCESSORS.get(settings);
this.ingestService = ingestService;
this.licenseState = licenseState;
this.inferenceAllowed = licenseState.isMachineLearningAllowed();
clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_INFERENCE_PROCESSORS, this::setMaxIngestProcessors);
}

Expand Down Expand Up @@ -199,6 +211,10 @@ int numInferenceProcessors() {
public InferenceProcessor create(Map<String, Processor.Factory> processorFactories, String tag, Map<String, Object> config)
throws Exception {

if (inferenceAllowed == false) {
throw LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING);
}

if (this.maxIngestProcessors <= currentInferenceProcessors) {
throw new ElasticsearchStatusException("Max number of inference processors reached, total inference processors [{}]. " +
"Adjust the setting [{}]: [{}] if a greater number is desired.",
Expand Down Expand Up @@ -272,5 +288,10 @@ void checkSupportedVersion(InferenceConfig config) {
minNodeVersion));
}
}

@Override
public void licenseStateChanged() {
this.inferenceAllowed = licenseState.isMachineLearningAllowed();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the mechanism which updates licenseState?
Am I right that it is a mutable object and the license listener notifies licenseStateChanged method after the internal state of licenseState object changes?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is a singleton that gets mutated in place. See other examples in org.elasticsearch.xpack.ml.InvalidLicenseEnforcer

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,16 @@
package org.elasticsearch.license;

import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.action.ingest.PutPipelineAction;
import org.elasticsearch.action.ingest.PutPipelineRequest;
import org.elasticsearch.action.ingest.SimulatePipelineAction;
import org.elasticsearch.action.ingest.SimulatePipelineRequest;
import org.elasticsearch.action.ingest.SimulatePipelineResponse;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.license.License.OperationMode;
Expand All @@ -20,22 +27,30 @@
import org.elasticsearch.xpack.core.ml.action.DeleteJobAction;
import org.elasticsearch.xpack.core.ml.action.GetDatafeedsStatsAction;
import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
import org.elasticsearch.xpack.core.ml.action.PutDatafeedAction;
import org.elasticsearch.xpack.core.ml.action.PutJobAction;
import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction;
import org.elasticsearch.xpack.core.ml.action.StopDatafeedAction;
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.job.config.JobState;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase;
import org.junit.Before;

import java.nio.charset.StandardCharsets;
import java.util.Collections;

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;

public class MachineLearningLicensingTests extends BaseMlIntegTestCase {

Expand Down Expand Up @@ -453,6 +468,214 @@ public void testMachineLearningDeleteDatafeedActionNotRestricted() throws Except
listener.actionGet();
}

public void testMachineLearningCreateInferenceProcessorRestricted() {
String modelId = "modelprocessorlicensetest";
assertMLAllowed(true);
putInferenceModel(modelId);

String pipeline = "{" +
" \"processors\": [\n" +
" {\n" +
" \"inference\": {\n" +
" \"target_field\": \"regression_value\",\n" +
" \"model_id\": \"modelprocessorlicensetest\",\n" +
" \"inference_config\": {\"regression\": {}},\n" +
" \"field_mappings\": {\n" +
" \"col1\": \"col1\",\n" +
" \"col2\": \"col2\",\n" +
" \"col3\": \"col3\",\n" +
" \"col4\": \"col4\"\n" +
" }\n" +
" }\n" +
" }]}\n";
// test that license restricted apis do now work
PlainActionFuture<AcknowledgedResponse> putPipelineListener = PlainActionFuture.newFuture();
client().execute(PutPipelineAction.INSTANCE,
new PutPipelineRequest("test_infer_license_pipeline",
new BytesArray(pipeline.getBytes(StandardCharsets.UTF_8)),
XContentType.JSON),
putPipelineListener);
AcknowledgedResponse putPipelineResponse = putPipelineListener.actionGet();
assertTrue(putPipelineResponse.isAcknowledged());

String simulateSource = "{\n" +
" \"pipeline\": \n" +
pipeline +
" ,\n" +
" \"docs\": [\n" +
" {\"_source\": {\n" +
" \"col1\": \"female\",\n" +
" \"col2\": \"M\",\n" +
" \"col3\": \"none\",\n" +
" \"col4\": 10\n" +
" }}]\n" +
"}";
PlainActionFuture<SimulatePipelineResponse> simulatePipelineListener = PlainActionFuture.newFuture();
client().execute(SimulatePipelineAction.INSTANCE,
new SimulatePipelineRequest(new BytesArray(simulateSource.getBytes(StandardCharsets.UTF_8)), XContentType.JSON),
simulatePipelineListener);

assertThat(simulatePipelineListener.actionGet().getResults(), is(not(empty())));


// Pick a license that does not allow machine learning
License.OperationMode mode = randomInvalidLicenseType();
enableLicensing(mode);
assertMLAllowed(false);

// creating a new pipeline should fail
ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, () -> {
PlainActionFuture<AcknowledgedResponse> listener = PlainActionFuture.newFuture();
client().execute(PutPipelineAction.INSTANCE,
new PutPipelineRequest("test_infer_license_pipeline_failure",
new BytesArray(pipeline.getBytes(StandardCharsets.UTF_8)),
XContentType.JSON),
listener);
listener.actionGet();
});
assertThat(e.status(), is(RestStatus.FORBIDDEN));
assertThat(e.getMessage(), containsString("non-compliant"));
assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackField.MACHINE_LEARNING));

// Simulating the pipeline should fail
e = expectThrows(ElasticsearchSecurityException.class, () -> {
PlainActionFuture<SimulatePipelineResponse> listener = PlainActionFuture.newFuture();
client().execute(SimulatePipelineAction.INSTANCE,
new SimulatePipelineRequest(new BytesArray(simulateSource.getBytes(StandardCharsets.UTF_8)), XContentType.JSON),
listener);
listener.actionGet();
});
assertThat(e.status(), is(RestStatus.FORBIDDEN));
assertThat(e.getMessage(), containsString("non-compliant"));
assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackField.MACHINE_LEARNING));

// Pick a license that does allow machine learning
mode = randomValidLicenseType();
enableLicensing(mode);
assertMLAllowed(true);
// test that license restricted apis do now work
PlainActionFuture<AcknowledgedResponse> putPipelineListenerNewLicense = PlainActionFuture.newFuture();
client().execute(PutPipelineAction.INSTANCE,
new PutPipelineRequest("test_infer_license_pipeline",
new BytesArray(pipeline.getBytes(StandardCharsets.UTF_8)),
XContentType.JSON),
putPipelineListenerNewLicense);
AcknowledgedResponse putPipelineResponseNewLicense = putPipelineListenerNewLicense.actionGet();
assertTrue(putPipelineResponseNewLicense.isAcknowledged());

PlainActionFuture<SimulatePipelineResponse> simulatePipelineListenerNewLicense = PlainActionFuture.newFuture();
client().execute(SimulatePipelineAction.INSTANCE,
new SimulatePipelineRequest(new BytesArray(simulateSource.getBytes(StandardCharsets.UTF_8)), XContentType.JSON),
simulatePipelineListenerNewLicense);

assertThat(simulatePipelineListenerNewLicense.actionGet().getResults(), is(not(empty())));
}

public void testMachineLearningInferModelRestricted() throws Exception {
String modelId = "modelinfermodellicensetest";
assertMLAllowed(true);
putInferenceModel(modelId);


PlainActionFuture<InferModelAction.Response> inferModelSuccess = PlainActionFuture.newFuture();
client().execute(InferModelAction.INSTANCE, new InferModelAction.Request(
modelId,
Collections.singletonList(Collections.emptyMap()),
new RegressionConfig()
), inferModelSuccess);
assertThat(inferModelSuccess.actionGet().getInferenceResults(), is(not(empty())));

// Pick a license that does not allow machine learning
License.OperationMode mode = randomInvalidLicenseType();
enableLicensing(mode);
assertMLAllowed(false);

// inferring against a model should now fail
ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, () -> {
PlainActionFuture<InferModelAction.Response> listener = PlainActionFuture.newFuture();
client().execute(InferModelAction.INSTANCE, new InferModelAction.Request(
modelId,
Collections.singletonList(Collections.emptyMap()),
new RegressionConfig()
), listener);
listener.actionGet();
});
assertThat(e.status(), is(RestStatus.FORBIDDEN));
assertThat(e.getMessage(), containsString("non-compliant"));
assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackField.MACHINE_LEARNING));

// Pick a license that does allow machine learning
mode = randomValidLicenseType();
enableLicensing(mode);
assertMLAllowed(true);

PlainActionFuture<InferModelAction.Response> listener = PlainActionFuture.newFuture();
client().execute(InferModelAction.INSTANCE, new InferModelAction.Request(
modelId,
Collections.singletonList(Collections.emptyMap()),
new RegressionConfig()
), listener);
assertThat(listener.actionGet().getInferenceResults(), is(not(empty())));
}

private void putInferenceModel(String modelId) {
String config = "" +
"{\n" +
" \"model_id\": \"" + modelId + "\",\n" +
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
" \"description\": \"test model for classification\",\n" +
" \"version\": \"8.0.0\",\n" +
" \"created_by\": \"benwtrent\",\n" +
" \"created_time\": 0\n" +
"}";
String definition = "" +
"{" +
" \"trained_model\": {\n" +
" \"tree\": {\n" +
" \"feature_names\": [\n" +
" \"col1_male\",\n" +
" \"col1_female\",\n" +
" \"col2_encoded\",\n" +
" \"col3_encoded\",\n" +
" \"col4\"\n" +
" ],\n" +
" \"tree_structure\": [\n" +
" {\n" +
" \"node_index\": 0,\n" +
" \"split_feature\": 0,\n" +
" \"split_gain\": 12.0,\n" +
" \"threshold\": 10.0,\n" +
" \"decision_type\": \"lte\",\n" +
" \"default_left\": true,\n" +
" \"left_child\": 1,\n" +
" \"right_child\": 2\n" +
" },\n" +
" {\n" +
" \"node_index\": 1,\n" +
" \"leaf_value\": 1\n" +
" },\n" +
" {\n" +
" \"node_index\": 2,\n" +
" \"leaf_value\": 2\n" +
" }\n" +
" ],\n" +
" \"target_type\": \"regression\"\n" +
" }\n" +
" }," +
" \"model_id\": \"" + modelId + "\"\n" +
"}";
assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME)
.setId(modelId)
.setSource(config, XContentType.JSON)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.get().status(), equalTo(RestStatus.CREATED));
assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME)
.setId(TrainedModelDefinition.docId(modelId))
.setSource(definition, XContentType.JSON)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.get().status(), equalTo(RestStatus.CREATED));
}

private static OperationMode randomInvalidLicenseType() {
return randomFrom(License.OperationMode.GOLD, License.OperationMode.STANDARD, License.OperationMode.BASIC);
}
Expand Down
Loading