diff --git a/server/src/internalClusterTest/java/org/opensearch/script/ScriptCacheIT.java b/server/src/internalClusterTest/java/org/opensearch/script/ScriptCacheIT.java new file mode 100644 index 0000000000000..71a932cdd9149 --- /dev/null +++ b/server/src/internalClusterTest/java/org/opensearch/script/ScriptCacheIT.java @@ -0,0 +1,126 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.script; + +import org.opensearch.OpenSearchException; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.index.MockEngineFactoryPlugin; +import org.opensearch.index.mapper.MockFieldFilterPlugin; +import org.opensearch.node.NodeMocksPlugin; +import org.opensearch.plugins.Plugin; +import org.opensearch.rest.RestStatus; +import org.opensearch.search.MockSearchService; +import org.opensearch.test.MockHttpTransport; +import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.test.TestGeoShapeFieldMapperPlugin; +import org.opensearch.test.store.MockFSIndexStore; +import org.opensearch.test.transport.MockTransportService; + + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.function.Function; + +import static org.apache.logging.log4j.core.util.Throwables.getRootCause; + + +public class ScriptCacheIT extends OpenSearchIntegTestCase { + + protected Settings nodeSettings(int nodeOrdinal) { + Settings.Builder builder = Settings.builder() + .put(super.nodeSettings(nodeOrdinal)) + .put(ScriptService.SCRIPT_GENERAL_MAX_COMPILATIONS_RATE_SETTING.getKey(), ScriptService.USE_CONTEXT_RATE_KEY); + // Putting max_compilation_rate for each context to 0 per minute + for (String s:ScriptModule.CORE_CONTEXTS.keySet()) { + builder.put("script.context."+s+".max_compilations_rate", "0/1m"); + } + return builder.build(); + } + + // Overriding to remove MockScriptService.TestPlugin from the list of plugins + @Override + protected Collection> getMockPlugins(){ + final ArrayList> mocks = new ArrayList<>(); + if (randomBoolean()) { + if (randomBoolean() && addMockTransportService()) { + mocks.add(MockTransportService.TestPlugin.class); + } + if (randomBoolean()) { + mocks.add(MockFSIndexStore.TestPlugin.class); + } + if (randomBoolean()) { + mocks.add(NodeMocksPlugin.class); + } + if (addMockInternalEngine() && randomBoolean()) { + mocks.add(MockEngineFactoryPlugin.class); + } + if (randomBoolean()) { + mocks.add(MockSearchService.TestPlugin.class); + } + if (randomBoolean()) { + mocks.add(MockFieldFilterPlugin.class); + } + } + if (addMockTransportService()) { + mocks.add(getTestTransportPlugin()); + } + if (addMockHttpTransport()) { + mocks.add(MockHttpTransport.TestPlugin.class); + } + mocks.add(TestSeedPlugin.class); + mocks.add(AssertActionNamePlugin.class); + if (addMockGeoShapeFieldMapper()) { + mocks.add(TestGeoShapeFieldMapperPlugin.class); + } + return Collections.unmodifiableList(mocks); + } + + @Override + protected Collection> nodePlugins() { + return Collections.singleton(CustomScriptPlugin.class); + } + + public void testPainlessCompilationLimit429Error() throws Exception { + client().prepareIndex("test", "1").setId("1") + .setSource(XContentFactory.jsonBuilder().startObject().field("field", 1).endObject()).get(); + ensureGreen(); + Map params = new HashMap<>(); + params.put("field", "field"); + Script script = new Script(ScriptType.INLINE, "mockscript", "increase_field", params); + ExecutionException exception = expectThrows(ExecutionException.class, () -> + client().prepareUpdate("test", "1", "1").setScript(script).execute().get()); + Throwable rootCause = getRootCause(exception); + assertTrue(rootCause instanceof OpenSearchException); + assertEquals(RestStatus.TOO_MANY_REQUESTS, ((OpenSearchException) rootCause).status()); + } + + public static class CustomScriptPlugin extends MockScriptPlugin { + @Override + protected Map, Object>> pluginScripts() { + Map, Object>> scripts = new HashMap<>(); + scripts.put("increase_field", vars -> { + Map params = (Map) vars.get("params"); + String fieldname = (String) vars.get("field"); + Map ctx = (Map) vars.get("ctx"); + assertNotNull(ctx); + Map source = (Map) ctx.get("_source"); + Number currentValue = (Number) source.get(fieldname); + Number inc = (Number) params.getOrDefault("inc", 1); + source.put(fieldname, currentValue.longValue() + inc.longValue()); + return ctx; + }); + return scripts; + } + } +} diff --git a/server/src/main/java/org/opensearch/script/GeneralScriptException.java b/server/src/main/java/org/opensearch/script/GeneralScriptException.java index 8e2575379bddf..239bd0594effc 100644 --- a/server/src/main/java/org/opensearch/script/GeneralScriptException.java +++ b/server/src/main/java/org/opensearch/script/GeneralScriptException.java @@ -33,6 +33,7 @@ package org.opensearch.script; import org.opensearch.OpenSearchException; +import org.opensearch.OpenSearchWrapperException; import org.opensearch.common.io.stream.StreamInput; import java.io.IOException; @@ -48,7 +49,7 @@ * from various abstractions) */ @Deprecated -public class GeneralScriptException extends OpenSearchException { +public class GeneralScriptException extends OpenSearchException implements OpenSearchWrapperException { public GeneralScriptException(String msg) { super(msg); diff --git a/server/src/test/java/org/opensearch/script/ScriptCacheTests.java b/server/src/test/java/org/opensearch/script/ScriptCacheTests.java index 440d19d9eeceb..cb25441320efc 100644 --- a/server/src/test/java/org/opensearch/script/ScriptCacheTests.java +++ b/server/src/test/java/org/opensearch/script/ScriptCacheTests.java @@ -34,9 +34,29 @@ import org.opensearch.common.breaker.CircuitBreakingException; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.rest.RestStatus; import org.opensearch.test.OpenSearchTestCase; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; + public class ScriptCacheTests extends OpenSearchTestCase { + public void testCompileStatusOnLimitExceeded() { + final TimeValue expire = ScriptService.SCRIPT_GENERAL_CACHE_EXPIRE_SETTING.get(Settings.EMPTY); + final Integer size = ScriptService.SCRIPT_GENERAL_CACHE_SIZE_SETTING.get(Settings.EMPTY); + String settingName = ScriptService.SCRIPT_GENERAL_MAX_COMPILATIONS_RATE_SETTING.getKey(); + ScriptCache cache = new ScriptCache(size, expire, new ScriptCache.CompilationRate(0, TimeValue.timeValueMinutes(1)), settingName); + ScriptContext context = randomFrom(ScriptModule.CORE_CONTEXTS.values()); + Map, Object>> scripts = new HashMap<>(); + scripts.put("1+1", p -> null); // only care about compilation, not execution + ScriptEngine engine = new MockScriptEngine(Script.DEFAULT_SCRIPT_LANG, scripts, Collections.emptyMap()); + GeneralScriptException ex = expectThrows(GeneralScriptException.class, () -> + cache.compile(context, engine, "1+1", "1+1", ScriptType.INLINE, Collections.emptyMap())); + assertEquals(RestStatus.TOO_MANY_REQUESTS, ex.status()); + } + // even though circuit breaking is allowed to be configured per minute, we actually weigh this over five minutes // simply by multiplying by five, so even setting it to one, requires five compilations to break public void testCompilationCircuitBreaking() throws Exception {