diff --git a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java index 8e586f68ff..ca56dc054b 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java @@ -66,7 +66,7 @@ import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; import org.opensearch.sql.expression.function.FunctionName; -import org.opensearch.sql.expression.function.OpenSearchFunctions; +import org.opensearch.sql.expression.function.OpenSearchFunction; import org.opensearch.sql.expression.parse.ParseExpression; import org.opensearch.sql.expression.span.SpanExpression; import org.opensearch.sql.expression.window.aggregation.AggregateWindowFunction; @@ -273,8 +273,8 @@ public Expression visitScoreFunction(ScoreFunction node, AnalysisContext context // create a new function expression with boost argument and resolve it Function updatedRelevanceQueryUnresolvedExpr = new Function(relevanceQueryUnresolvedExpr.getFuncName(), updatedFuncArgs); - OpenSearchFunctions.OpenSearchFunction relevanceQueryExpr = - (OpenSearchFunctions.OpenSearchFunction) + OpenSearchFunction relevanceQueryExpr = + (OpenSearchFunction) updatedRelevanceQueryUnresolvedExpr.accept(this, context); relevanceQueryExpr.setScoreTracked(true); return relevanceQueryExpr; diff --git a/core/src/main/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizer.java b/core/src/main/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizer.java index 398f848f16..93f2bd233d 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizer.java @@ -18,7 +18,7 @@ import org.opensearch.sql.expression.conditional.cases.CaseClause; import org.opensearch.sql.expression.conditional.cases.WhenClause; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; -import org.opensearch.sql.expression.function.OpenSearchFunctions; +import org.opensearch.sql.expression.function.OpenSearchFunction; import org.opensearch.sql.planner.logical.LogicalAggregation; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.logical.LogicalPlanNodeVisitor; @@ -75,9 +75,9 @@ public Expression visitFunction(FunctionExpression node, AnalysisContext context (Expression) repository.compile(context.getFunctionProperties(), node.getFunctionName(), args); // Propagate scoreTracked for OpenSearch functions - if (optimizedFunctionExpression instanceof OpenSearchFunctions.OpenSearchFunction) { - ((OpenSearchFunctions.OpenSearchFunction) optimizedFunctionExpression) - .setScoreTracked(((OpenSearchFunctions.OpenSearchFunction) node).isScoreTracked()); + if (optimizedFunctionExpression instanceof OpenSearchFunction) { + ((OpenSearchFunction) optimizedFunctionExpression) + .setScoreTracked(((OpenSearchFunction) node).isScoreTracked()); } return optimizedFunctionExpression; } diff --git a/core/src/main/java/org/opensearch/sql/datasource/model/EmptyDataSourceService.java b/core/src/main/java/org/opensearch/sql/datasource/model/EmptyDataSourceService.java new file mode 100644 index 0000000000..1212f02ac0 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/datasource/model/EmptyDataSourceService.java @@ -0,0 +1,88 @@ +package org.opensearch.sql.datasource.model; + +import com.google.common.collect.ImmutableMap; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.storage.StorageEngine; +import org.opensearch.sql.storage.Table; + +import java.util.Map; +import java.util.Set; + +import static org.opensearch.sql.analysis.DataSourceSchemaIdentifierNameResolver.DEFAULT_DATASOURCE_NAME; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; + +public class EmptyDataSourceService { + private static DataSourceService emptyDataSourceService = new DataSourceService() { + @Override + public DataSource getDataSource(String dataSourceName) { + return new DataSource(DEFAULT_DATASOURCE_NAME, DataSourceType.OPENSEARCH, storageEngine()); + } + + @Override + public Set getDataSourceMetadata(boolean isDefaultDataSourceRequired) { + return Set.of(); + } + + @Override + public DataSourceMetadata getDataSourceMetadata(String name) { + return null; + } + + @Override + public void createDataSource(DataSourceMetadata metadata) { + + } + + @Override + public void updateDataSource(DataSourceMetadata dataSourceMetadata) { + + } + + @Override + public void deleteDataSource(String dataSourceName) { + + } + + @Override + public Boolean dataSourceExists(String dataSourceName) { + return null; + } + }; + + private static StorageEngine storageEngine() { + Table table = + new Table() { + @Override + public boolean exists() { + return true; + } + + @Override + public void create(Map schema) { + throw new UnsupportedOperationException("Create table is not supported"); + } + + @Override + public Map getFieldTypes() { + return null; + } + + @Override + public PhysicalPlan implement(LogicalPlan plan) { + throw new UnsupportedOperationException(); + } + + public Map getReservedFieldTypes() { + return ImmutableMap.of("_test", STRING); + } + }; + return (dataSourceSchemaName, tableName) -> table; + } + + public static DataSourceService getEmptyDataSourceService() { + return emptyDataSourceService; + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/DSL.java b/core/src/main/java/org/opensearch/sql/expression/DSL.java index 4341668b69..0f2e350567 100644 --- a/core/src/main/java/org/opensearch/sql/expression/DSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/DSL.java @@ -25,6 +25,7 @@ import org.opensearch.sql.expression.parse.RegexExpression; import org.opensearch.sql.expression.span.SpanExpression; import org.opensearch.sql.expression.window.ranking.RankingWindowFunction; +import static org.opensearch.sql.datasource.model.EmptyDataSourceService.getEmptyDataSourceService; public class DSL { @@ -119,10 +120,6 @@ public static NamedArgumentExpression namedArgument(String argName, Expression v return new NamedArgumentExpression(argName, value); } - public static NamedArgumentExpression namedArgument(String name, String value) { - return namedArgument(name, literal(value)); - } - public static GrokExpression grok( Expression sourceField, Expression pattern, Expression identifier) { return new GrokExpression(sourceField, pattern, identifier); @@ -827,54 +824,6 @@ public static FunctionExpression typeof(Expression value) { return compile(FunctionProperties.None, BuiltinFunctionName.TYPEOF, value); } - public static FunctionExpression match(Expression... args) { - return compile(FunctionProperties.None, BuiltinFunctionName.MATCH, args); - } - - public static FunctionExpression match_phrase(Expression... args) { - return compile(FunctionProperties.None, BuiltinFunctionName.MATCH_PHRASE, args); - } - - public static FunctionExpression match_phrase_prefix(Expression... args) { - return compile(FunctionProperties.None, BuiltinFunctionName.MATCH_PHRASE_PREFIX, args); - } - - public static FunctionExpression multi_match(Expression... args) { - return compile(FunctionProperties.None, BuiltinFunctionName.MULTI_MATCH, args); - } - - public static FunctionExpression simple_query_string(Expression... args) { - return compile(FunctionProperties.None, BuiltinFunctionName.SIMPLE_QUERY_STRING, args); - } - - public static FunctionExpression query(Expression... args) { - return compile(FunctionProperties.None, BuiltinFunctionName.QUERY, args); - } - - public static FunctionExpression query_string(Expression... args) { - return compile(FunctionProperties.None, BuiltinFunctionName.QUERY_STRING, args); - } - - public static FunctionExpression match_bool_prefix(Expression... args) { - return compile(FunctionProperties.None, BuiltinFunctionName.MATCH_BOOL_PREFIX, args); - } - - public static FunctionExpression wildcard_query(Expression... args) { - return compile(FunctionProperties.None, BuiltinFunctionName.WILDCARD_QUERY, args); - } - - public static FunctionExpression score(Expression... args) { - return compile(FunctionProperties.None, BuiltinFunctionName.SCORE, args); - } - - public static FunctionExpression scorequery(Expression... args) { - return compile(FunctionProperties.None, BuiltinFunctionName.SCOREQUERY, args); - } - - public static FunctionExpression score_query(Expression... args) { - return compile(FunctionProperties.None, BuiltinFunctionName.SCORE_QUERY, args); - } - public static FunctionExpression now(FunctionProperties functionProperties, Expression... args) { return compile(functionProperties, BuiltinFunctionName.NOW, args); } @@ -957,7 +906,7 @@ public static FunctionExpression utc_timestamp( private static T compile( FunctionProperties functionProperties, BuiltinFunctionName bfn, Expression... args) { return (T) - BuiltinFunctionRepository.getInstance() + BuiltinFunctionRepository.getInstance(getEmptyDataSourceService()) .compile(functionProperties, bfn.getName(), Arrays.asList(args)); } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java index 2e16d5f01f..31b1fa45d4 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java @@ -16,12 +16,15 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; import org.apache.commons.lang3.tuple.Pair; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.exception.ExpressionEvaluationException; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.aggregation.AggregatorFunction; @@ -46,7 +49,7 @@ public class BuiltinFunctionRepository { private final Map functionResolverMap; /** The singleton instance. */ - private static BuiltinFunctionRepository instance; + private final static Map instance = new HashMap<>(); /** * Construct a function repository with the given function registered. This is only used in test. @@ -64,25 +67,42 @@ public class BuiltinFunctionRepository { * * @return singleton instance */ - public static synchronized BuiltinFunctionRepository getInstance() { - if (instance == null) { - instance = new BuiltinFunctionRepository(new HashMap<>()); + public static synchronized BuiltinFunctionRepository getInstance(DataSourceService dataSourceService) { + Set dataSourceMetadataSet = + dataSourceService.getDataSourceMetadata(true); + Set dataSourceServiceHashSet = + dataSourceMetadataSet.stream().map(metadata -> metadata.hashCode()).collect(Collectors.toSet()); + + // Creates new Repository for every dataSourceService + if (!dataSourceServiceHashSet.stream().anyMatch(hash -> instance.containsKey(hash))) { + BuiltinFunctionRepository repository = new BuiltinFunctionRepository(new HashMap<>()); // Register all built-in functions - ArithmeticFunction.register(instance); - BinaryPredicateOperator.register(instance); - MathematicalFunction.register(instance); - UnaryPredicateOperator.register(instance); - AggregatorFunction.register(instance); - DateTimeFunction.register(instance); - IntervalClause.register(instance); - WindowFunctions.register(instance); - TextFunction.register(instance); - TypeCastOperator.register(instance); - SystemFunctions.register(instance); - OpenSearchFunctions.register(instance); + ArithmeticFunction.register(repository); + BinaryPredicateOperator.register(repository); + MathematicalFunction.register(repository); + UnaryPredicateOperator.register(repository); + AggregatorFunction.register(repository); + DateTimeFunction.register(repository); + IntervalClause.register(repository); + WindowFunctions.register(repository); + TextFunction.register(repository); + TypeCastOperator.register(repository); + SystemFunctions.register(repository); + // Temporary as part of https://github.com/opensearch-project/sql/issues/811 + // TODO: remove this resolver when Analyzers are moved to opensearch module + repository.register(new NestedFunctionResolver()); + + for (DataSourceMetadata metadata : dataSourceMetadataSet) { + dataSourceService + .getDataSource(metadata.getName()) + .getStorageEngine().getFunctions(). + forEach(function -> repository.register(function)); + instance.put(metadata.hashCode(), repository); + } + return repository; } - return instance; + return instance.get(dataSourceServiceHashSet.iterator().next()); } /** diff --git a/core/src/main/java/org/opensearch/sql/expression/function/NestedFunctionResolver.java b/core/src/main/java/org/opensearch/sql/expression/function/NestedFunctionResolver.java new file mode 100644 index 0000000000..152acf2e6c --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/expression/function/NestedFunctionResolver.java @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.FunctionExpression; +import org.opensearch.sql.expression.env.Environment; + +public class NestedFunctionResolver implements FunctionResolver{ + @Override + public Pair resolve(FunctionSignature unresolvedSignature) { + return Pair.of(unresolvedSignature, + (functionProperties, arguments) -> + new FunctionExpression(BuiltinFunctionName.NESTED.getName(), arguments) { + @Override + public ExprValue valueOf(Environment valueEnv) { + return valueEnv.resolve(getArguments().get(0)); + } + @Override + public ExprType type() { + return getArguments().get(0).type(); + } + }); + } + + @Override + public FunctionName getFunctionName() { + return BuiltinFunctionName.NESTED.getName(); + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunction.java b/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunction.java new file mode 100644 index 0000000000..082734256a --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunction.java @@ -0,0 +1,66 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import static org.opensearch.sql.data.type.ExprCoreType.BOOLEAN; + +import java.util.List; +import java.util.stream.Collectors; +import lombok.Getter; +import lombok.Setter; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.FunctionExpression; +import org.opensearch.sql.expression.NamedArgumentExpression; +import org.opensearch.sql.expression.env.Environment; + +public class OpenSearchFunction extends FunctionExpression { + private final FunctionName functionName; + private final List arguments; + + @Getter @Setter private boolean isScoreTracked; + + /** + * Required argument constructor. + * + * @param functionName name of the function + * @param arguments a list of expressions + */ + public OpenSearchFunction(FunctionName functionName, List arguments) { + super(functionName, arguments); + this.functionName = functionName; + this.arguments = arguments; + this.isScoreTracked = false; + } + + @Override + public ExprValue valueOf(Environment valueEnv) { + throw new UnsupportedOperationException( + String.format( + "OpenSearch defined function [%s] is only supported in WHERE and HAVING clause.", + functionName)); + } + + @Override + public ExprType type() { + return BOOLEAN; + } + + @Override + public String toString() { + List args = + arguments.stream() + .map( + arg -> + String.format( + "%s=%s", + ((NamedArgumentExpression) arg).getArgName(), + ((NamedArgumentExpression) arg).getValue().toString())) + .collect(Collectors.toList()); + return String.format("%s(%s)", functionName, String.join(", ", args)); + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java b/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java deleted file mode 100644 index 8d8928c16a..0000000000 --- a/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java +++ /dev/null @@ -1,176 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.expression.function; - -import static org.opensearch.sql.data.type.ExprCoreType.BOOLEAN; - -import java.util.List; -import java.util.stream.Collectors; -import lombok.Getter; -import lombok.Setter; -import lombok.experimental.UtilityClass; -import org.apache.commons.lang3.tuple.Pair; -import org.opensearch.sql.data.model.ExprValue; -import org.opensearch.sql.data.type.ExprType; -import org.opensearch.sql.expression.Expression; -import org.opensearch.sql.expression.FunctionExpression; -import org.opensearch.sql.expression.NamedArgumentExpression; -import org.opensearch.sql.expression.env.Environment; - -@UtilityClass -public class OpenSearchFunctions { - /** Add functions specific to OpenSearch to repository. */ - public void register(BuiltinFunctionRepository repository) { - repository.register(match_bool_prefix()); - repository.register(multi_match(BuiltinFunctionName.MULTI_MATCH)); - repository.register(multi_match(BuiltinFunctionName.MULTIMATCH)); - repository.register(multi_match(BuiltinFunctionName.MULTIMATCHQUERY)); - repository.register(match(BuiltinFunctionName.MATCH)); - repository.register(match(BuiltinFunctionName.MATCHQUERY)); - repository.register(match(BuiltinFunctionName.MATCH_QUERY)); - repository.register(simple_query_string()); - repository.register(query()); - repository.register(query_string()); - - // Register MATCHPHRASE as MATCH_PHRASE as well for backwards - // compatibility. - repository.register(match_phrase(BuiltinFunctionName.MATCH_PHRASE)); - repository.register(match_phrase(BuiltinFunctionName.MATCHPHRASE)); - repository.register(match_phrase(BuiltinFunctionName.MATCHPHRASEQUERY)); - repository.register(match_phrase_prefix()); - repository.register(wildcard_query(BuiltinFunctionName.WILDCARD_QUERY)); - repository.register(wildcard_query(BuiltinFunctionName.WILDCARDQUERY)); - repository.register(score(BuiltinFunctionName.SCORE)); - repository.register(score(BuiltinFunctionName.SCOREQUERY)); - repository.register(score(BuiltinFunctionName.SCORE_QUERY)); - // Functions supported in SELECT clause - repository.register(nested()); - } - - private static FunctionResolver match_bool_prefix() { - FunctionName name = BuiltinFunctionName.MATCH_BOOL_PREFIX.getName(); - return new RelevanceFunctionResolver(name); - } - - private static FunctionResolver match(BuiltinFunctionName match) { - FunctionName funcName = match.getName(); - return new RelevanceFunctionResolver(funcName); - } - - private static FunctionResolver match_phrase_prefix() { - FunctionName funcName = BuiltinFunctionName.MATCH_PHRASE_PREFIX.getName(); - return new RelevanceFunctionResolver(funcName); - } - - private static FunctionResolver match_phrase(BuiltinFunctionName matchPhrase) { - FunctionName funcName = matchPhrase.getName(); - return new RelevanceFunctionResolver(funcName); - } - - private static FunctionResolver multi_match(BuiltinFunctionName multiMatchName) { - return new RelevanceFunctionResolver(multiMatchName.getName()); - } - - private static FunctionResolver simple_query_string() { - FunctionName funcName = BuiltinFunctionName.SIMPLE_QUERY_STRING.getName(); - return new RelevanceFunctionResolver(funcName); - } - - private static FunctionResolver query() { - FunctionName funcName = BuiltinFunctionName.QUERY.getName(); - return new RelevanceFunctionResolver(funcName); - } - - private static FunctionResolver query_string() { - FunctionName funcName = BuiltinFunctionName.QUERY_STRING.getName(); - return new RelevanceFunctionResolver(funcName); - } - - private static FunctionResolver wildcard_query(BuiltinFunctionName wildcardQuery) { - FunctionName funcName = wildcardQuery.getName(); - return new RelevanceFunctionResolver(funcName); - } - - private static FunctionResolver nested() { - return new FunctionResolver() { - @Override - public Pair resolve( - FunctionSignature unresolvedSignature) { - return Pair.of( - unresolvedSignature, - (functionProperties, arguments) -> - new FunctionExpression(BuiltinFunctionName.NESTED.getName(), arguments) { - @Override - public ExprValue valueOf(Environment valueEnv) { - return valueEnv.resolve(getArguments().get(0)); - } - - @Override - public ExprType type() { - return getArguments().get(0).type(); - } - }); - } - - @Override - public FunctionName getFunctionName() { - return BuiltinFunctionName.NESTED.getName(); - } - }; - } - - private static FunctionResolver score(BuiltinFunctionName score) { - FunctionName funcName = score.getName(); - return new RelevanceFunctionResolver(funcName); - } - - public static class OpenSearchFunction extends FunctionExpression { - private final FunctionName functionName; - private final List arguments; - - @Getter @Setter private boolean isScoreTracked; - - /** - * Required argument constructor. - * - * @param functionName name of the function - * @param arguments a list of expressions - */ - public OpenSearchFunction(FunctionName functionName, List arguments) { - super(functionName, arguments); - this.functionName = functionName; - this.arguments = arguments; - this.isScoreTracked = false; - } - - @Override - public ExprValue valueOf(Environment valueEnv) { - throw new UnsupportedOperationException( - String.format( - "OpenSearch defined function [%s] is only supported in WHERE and HAVING clause.", - functionName)); - } - - @Override - public ExprType type() { - return BOOLEAN; - } - - @Override - public String toString() { - List args = - arguments.stream() - .map( - arg -> - String.format( - "%s=%s", - ((NamedArgumentExpression) arg).getArgName(), - ((NamedArgumentExpression) arg).getValue().toString())) - .collect(Collectors.toList()); - return String.format("%s(%s)", functionName, String.join(", ", args)); - } - } -} diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index 2f4d6e8ada..4c80556a02 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -41,6 +41,7 @@ import static org.opensearch.sql.data.type.ExprCoreType.LONG; import static org.opensearch.sql.data.type.ExprCoreType.STRING; import static org.opensearch.sql.data.type.ExprCoreType.TIMESTAMP; +import static org.opensearch.sql.datasource.model.EmptyDataSourceService.getEmptyDataSourceService; import static org.opensearch.sql.expression.DSL.literal; import static org.opensearch.sql.utils.MLCommonsConstants.ACTION; import static org.opensearch.sql.utils.MLCommonsConstants.ALGO; @@ -66,6 +67,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; + import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; import org.junit.jupiter.api.Disabled; @@ -87,13 +89,24 @@ import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.common.antlr.SyntaxCheckException; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.datasource.DataSourceService; import org.opensearch.sql.exception.ExpressionEvaluationException; import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.FunctionExpression; import org.opensearch.sql.expression.HighlightExpression; import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.ReferenceExpression; -import org.opensearch.sql.expression.function.OpenSearchFunctions; +import org.opensearch.sql.expression.env.Environment; +import org.opensearch.sql.expression.function.BuiltinFunctionName; +import org.opensearch.sql.expression.function.BuiltinFunctionRepository; +import org.opensearch.sql.expression.function.FunctionImplementation; +import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.expression.function.FunctionProperties; +import org.opensearch.sql.expression.function.OpenSearchFunction; import org.opensearch.sql.expression.window.WindowDefinition; import org.opensearch.sql.planner.logical.LogicalAD; import org.opensearch.sql.planner.logical.LogicalCloseCursor; @@ -274,104 +287,6 @@ public void filter_relation_with_multiple_tables() { AstDSL.equalTo(AstDSL.field("integer_value"), AstDSL.intLiteral(1)))); } - @Test - public void analyze_filter_visit_score_function() { - UnresolvedPlan unresolvedPlan = - AstDSL.filter( - AstDSL.relation("schema"), - new ScoreFunction( - AstDSL.function( - "match_phrase_prefix", - AstDSL.unresolvedArg("field", stringLiteral("field_value1")), - AstDSL.unresolvedArg("query", stringLiteral("search query")), - AstDSL.unresolvedArg("boost", stringLiteral("3"))), - AstDSL.doubleLiteral(1.0))); - assertAnalyzeEqual( - LogicalPlanDSL.filter( - LogicalPlanDSL.relation("schema", table), - DSL.match_phrase_prefix( - DSL.namedArgument("field", "field_value1"), - DSL.namedArgument("query", "search query"), - DSL.namedArgument("boost", "3.0"))), - unresolvedPlan); - - LogicalPlan logicalPlan = analyze(unresolvedPlan); - OpenSearchFunctions.OpenSearchFunction relevanceQuery = - (OpenSearchFunctions.OpenSearchFunction) ((LogicalFilter) logicalPlan).getCondition(); - assertEquals(true, relevanceQuery.isScoreTracked()); - } - - @Test - public void analyze_filter_visit_without_score_function() { - UnresolvedPlan unresolvedPlan = - AstDSL.filter( - AstDSL.relation("schema"), - AstDSL.function( - "match_phrase_prefix", - AstDSL.unresolvedArg("field", stringLiteral("field_value1")), - AstDSL.unresolvedArg("query", stringLiteral("search query")), - AstDSL.unresolvedArg("boost", stringLiteral("3")))); - assertAnalyzeEqual( - LogicalPlanDSL.filter( - LogicalPlanDSL.relation("schema", table), - DSL.match_phrase_prefix( - DSL.namedArgument("field", "field_value1"), - DSL.namedArgument("query", "search query"), - DSL.namedArgument("boost", "3"))), - unresolvedPlan); - - LogicalPlan logicalPlan = analyze(unresolvedPlan); - OpenSearchFunctions.OpenSearchFunction relevanceQuery = - (OpenSearchFunctions.OpenSearchFunction) ((LogicalFilter) logicalPlan).getCondition(); - assertEquals(false, relevanceQuery.isScoreTracked()); - } - - @Test - public void analyze_filter_visit_score_function_with_double_boost() { - UnresolvedPlan unresolvedPlan = - AstDSL.filter( - AstDSL.relation("schema"), - new ScoreFunction( - AstDSL.function( - "match_phrase_prefix", - AstDSL.unresolvedArg("field", stringLiteral("field_value1")), - AstDSL.unresolvedArg("query", stringLiteral("search query")), - AstDSL.unresolvedArg("slop", stringLiteral("3"))), - new Literal(3.0, DataType.DOUBLE))); - - assertAnalyzeEqual( - LogicalPlanDSL.filter( - LogicalPlanDSL.relation("schema", table), - DSL.match_phrase_prefix( - DSL.namedArgument("field", "field_value1"), - DSL.namedArgument("query", "search query"), - DSL.namedArgument("slop", "3"), - DSL.namedArgument("boost", "3.0"))), - unresolvedPlan); - - LogicalPlan logicalPlan = analyze(unresolvedPlan); - OpenSearchFunctions.OpenSearchFunction relevanceQuery = - (OpenSearchFunctions.OpenSearchFunction) ((LogicalFilter) logicalPlan).getCondition(); - assertEquals(true, relevanceQuery.isScoreTracked()); - } - - @Test - public void analyze_filter_visit_score_function_with_unsupported_boost_SemanticCheckException() { - UnresolvedPlan unresolvedPlan = - AstDSL.filter( - AstDSL.relation("schema"), - new ScoreFunction( - AstDSL.function( - "match_phrase_prefix", - AstDSL.unresolvedArg("field", stringLiteral("field_value1")), - AstDSL.unresolvedArg("query", stringLiteral("search query")), - AstDSL.unresolvedArg("boost", stringLiteral("3"))), - AstDSL.stringLiteral("3.0"))); - SemanticCheckException exception = - assertThrows(SemanticCheckException.class, () -> analyze(unresolvedPlan)); - assertEquals("Expected boost type 'DOUBLE' but got 'STRING'", exception.getMessage()); - } - @Test public void head_relation() { assertAnalyzeEqual( @@ -527,22 +442,40 @@ public void project_nested_field_arg() { new NamedExpression( "nested(message.info)", DSL.nested(DSL.ref("message.info", STRING)), null)); + UnresolvedPlan unresolvedPlan = AstDSL.projectWithArg( + AstDSL.relation("schema"), + AstDSL.defaultFieldsArgs(), + AstDSL.alias( + "nested(message.info)", + function("nested", qualifiedName("message", "info")), + null)); + assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.nested( LogicalPlanDSL.relation("schema", table), nestedArgs, projectList), DSL.named("nested(message.info)", DSL.nested(DSL.ref("message.info", STRING)))), - AstDSL.projectWithArg( - AstDSL.relation("schema"), - AstDSL.defaultFieldsArgs(), - AstDSL.alias( - "nested(message.info)", - function("nested", qualifiedName("message", "info")), - null))); + unresolvedPlan); assertTrue(isNestedFunction(DSL.nested(DSL.ref("message.info", STRING)))); assertFalse(isNestedFunction(DSL.literal("fieldA"))); - assertFalse(isNestedFunction(DSL.match(DSL.namedArgument("field", literal("message"))))); + } + + @Test + void nested_query() { + FunctionImplementation result = + BuiltinFunctionRepository.getInstance(dataSourceService) + .compile(new FunctionProperties(), FunctionName.of("nested"), List.of(DSL.ref("message.info", STRING))); + FunctionExpression expr = (FunctionExpression) result; + assertEquals( + String.format( + "FunctionExpression(functionName=%s, arguments=[message.info])", + BuiltinFunctionName.NESTED.getName()), + expr.toString()); + Environment nestedTuple = + ExprValueUtils.tupleValue(Map.of("message", Map.of("info", "result"))).bindingTuples(); + assertEquals(expr.valueOf(nestedTuple), ExprValueUtils.stringValue("result")); + assertEquals(expr.type(), STRING); } @Test @@ -1083,7 +1016,7 @@ public void select_all_from_subquery() { } /** - * Ensure Nested function falls back to legacy engine when used in GROUP BY clause. TODO Remove + * Ensure Nested function falls back to legacy engine when used in GROUP BY clause. TODO Remove * this test when support is added. */ @Test @@ -1529,21 +1462,39 @@ public void table_function() { } @Test - public void table_function_with_no_datasource() { + public void table_function_with_datasource_with_no_functions() { + DataSourceService dataSourceService = getEmptyDataSourceService(); + Analyzer analyzer = new Analyzer(super.expressionAnalyzer, dataSourceService, BuiltinFunctionRepository.getInstance(dataSourceService)); ExpressionEvaluationException exception = assertThrows( ExpressionEvaluationException.class, () -> - analyze( + analyzer.analyze( AstDSL.tableFunction( List.of("query_range"), unresolvedArg("query", stringLiteral("http_latency")), unresolvedArg("", intLiteral(12345)), unresolvedArg("", intLiteral(12345)), - unresolvedArg(null, intLiteral(14))))); + unresolvedArg(null, intLiteral(14))), new AnalysisContext())); assertEquals("unsupported function name: query_range", exception.getMessage()); } + @Test + public void unsupported_table_function() { + ExpressionEvaluationException exception = + assertThrows( + ExpressionEvaluationException.class, + () -> + analyze( + AstDSL.tableFunction( + List.of("unsupported"), + unresolvedArg("query", stringLiteral("http_latency")), + unresolvedArg("", intLiteral(12345)), + unresolvedArg("", intLiteral(12345)), + unresolvedArg(null, intLiteral(14))))); + assertEquals("unsupported function name: unsupported", exception.getMessage()); + } + @Test public void table_function_with_wrong_datasource() { ExpressionEvaluationException exception = diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java index f09bc5d380..dbcdaddf39 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java @@ -26,6 +26,7 @@ import org.opensearch.sql.analysis.symbol.SymbolTable; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.config.TestConfig; +import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.datasource.DataSourceService; import org.opensearch.sql.datasource.model.DataSource; @@ -33,13 +34,16 @@ import org.opensearch.sql.datasource.model.DataSourceType; import org.opensearch.sql.exception.ExpressionEvaluationException; import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.FunctionExpression; import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.expression.env.Environment; +import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; import org.opensearch.sql.expression.function.FunctionBuilder; import org.opensearch.sql.expression.function.FunctionName; import org.opensearch.sql.expression.function.FunctionResolver; import org.opensearch.sql.expression.function.FunctionSignature; +import org.opensearch.sql.expression.function.NestedFunctionResolver; import org.opensearch.sql.expression.function.TableFunctionImplementation; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlan; @@ -53,7 +57,17 @@ protected Map typeMapping() { } protected StorageEngine storageEngine() { - return (dataSourceSchemaName, tableName) -> table; + return new StorageEngine() { + @Override + public Collection getFunctions() { + return Collections.singletonList(new NestedFunctionResolver()); + } + + @Override + public Table getTable(DataSourceSchemaName dataSourceSchemaName, String tableName) { + return table; + } + }; } protected StorageEngine prometheusStorageEngine() { @@ -159,7 +173,7 @@ protected Environment typeEnv() { protected Analyzer analyzer( ExpressionAnalyzer expressionAnalyzer, DataSourceService dataSourceService) { - BuiltinFunctionRepository functionRepository = BuiltinFunctionRepository.getInstance(); + BuiltinFunctionRepository functionRepository = BuiltinFunctionRepository.getInstance(dataSourceService); return new Analyzer(expressionAnalyzer, dataSourceService, functionRepository); } @@ -172,7 +186,7 @@ protected AnalysisContext analysisContext(TypeEnvironment typeEnvironment) { } protected ExpressionAnalyzer expressionAnalyzer() { - return new ExpressionAnalyzer(BuiltinFunctionRepository.getInstance()); + return new ExpressionAnalyzer(BuiltinFunctionRepository.getInstance(dataSourceService())); } protected void assertAnalyzeEqual(LogicalPlan expected, UnresolvedPlan unresolvedPlan) { diff --git a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java index 9d30ebeaab..66f8db72c1 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java @@ -355,30 +355,6 @@ public void named_non_parse_expression() { assertAnalyzeEqual(DSL.ref("string_field", STRING), qualifiedName("string_field")); } - @Test - void match_bool_prefix_expression() { - assertAnalyzeEqual( - DSL.match_bool_prefix( - DSL.namedArgument("field", DSL.literal("field_value1")), - DSL.namedArgument("query", DSL.literal("sample query"))), - AstDSL.function( - "match_bool_prefix", - AstDSL.unresolvedArg("field", stringLiteral("field_value1")), - AstDSL.unresolvedArg("query", stringLiteral("sample query")))); - } - - @Test - void match_bool_prefix_wrong_expression() { - assertThrows( - SemanticCheckException.class, - () -> - analyze( - AstDSL.function( - "match_bool_prefix", - AstDSL.unresolvedArg("field", stringLiteral("fieldA")), - AstDSL.unresolvedArg("query", floatLiteral(1.2f))))); - } - @Test void visit_span() { assertAnalyzeEqual( @@ -401,364 +377,6 @@ void visit_in() { () -> analyze(AstDSL.in(field("integer_value"), Collections.emptyList()))); } - @Test - void multi_match_expression() { - assertAnalyzeEqual( - DSL.multi_match( - DSL.namedArgument( - "fields", - DSL.literal( - new ExprTupleValue( - new LinkedHashMap<>( - ImmutableMap.of("field_value1", ExprValueUtils.floatValue(1.F)))))), - DSL.namedArgument("query", DSL.literal("sample query"))), - AstDSL.function( - "multi_match", - AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of("field_value1", 1.F))), - AstDSL.unresolvedArg("query", stringLiteral("sample query")))); - } - - @Test - void multi_match_expression_with_params() { - assertAnalyzeEqual( - DSL.multi_match( - DSL.namedArgument( - "fields", - DSL.literal( - new ExprTupleValue( - new LinkedHashMap<>( - ImmutableMap.of("field_value1", ExprValueUtils.floatValue(1.F)))))), - DSL.namedArgument("query", DSL.literal("sample query")), - DSL.namedArgument("analyzer", DSL.literal("keyword"))), - AstDSL.function( - "multi_match", - AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of("field_value1", 1.F))), - AstDSL.unresolvedArg("query", stringLiteral("sample query")), - AstDSL.unresolvedArg("analyzer", stringLiteral("keyword")))); - } - - @Test - void multi_match_expression_two_fields() { - assertAnalyzeEqual( - DSL.multi_match( - DSL.namedArgument( - "fields", - DSL.literal( - new ExprTupleValue( - new LinkedHashMap<>( - ImmutableMap.of( - "field_value1", ExprValueUtils.floatValue(1.F), - "field_value2", ExprValueUtils.floatValue(.3F)))))), - DSL.namedArgument("query", DSL.literal("sample query"))), - AstDSL.function( - "multi_match", - AstDSL.unresolvedArg( - "fields", - new RelevanceFieldList(ImmutableMap.of("field_value1", 1.F, "field_value2", .3F))), - AstDSL.unresolvedArg("query", stringLiteral("sample query")))); - } - - @Test - void simple_query_string_expression() { - assertAnalyzeEqual( - DSL.simple_query_string( - DSL.namedArgument( - "fields", - DSL.literal( - new ExprTupleValue( - new LinkedHashMap<>( - ImmutableMap.of("field_value1", ExprValueUtils.floatValue(1.F)))))), - DSL.namedArgument("query", DSL.literal("sample query"))), - AstDSL.function( - "simple_query_string", - AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of("field_value1", 1.F))), - AstDSL.unresolvedArg("query", stringLiteral("sample query")))); - } - - @Test - void simple_query_string_expression_with_params() { - assertAnalyzeEqual( - DSL.simple_query_string( - DSL.namedArgument( - "fields", - DSL.literal( - new ExprTupleValue( - new LinkedHashMap<>( - ImmutableMap.of("field_value1", ExprValueUtils.floatValue(1.F)))))), - DSL.namedArgument("query", DSL.literal("sample query")), - DSL.namedArgument("analyzer", DSL.literal("keyword"))), - AstDSL.function( - "simple_query_string", - AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of("field_value1", 1.F))), - AstDSL.unresolvedArg("query", stringLiteral("sample query")), - AstDSL.unresolvedArg("analyzer", stringLiteral("keyword")))); - } - - @Test - void simple_query_string_expression_two_fields() { - assertAnalyzeEqual( - DSL.simple_query_string( - DSL.namedArgument( - "fields", - DSL.literal( - new ExprTupleValue( - new LinkedHashMap<>( - ImmutableMap.of( - "field_value1", ExprValueUtils.floatValue(1.F), - "field_value2", ExprValueUtils.floatValue(.3F)))))), - DSL.namedArgument("query", DSL.literal("sample query"))), - AstDSL.function( - "simple_query_string", - AstDSL.unresolvedArg( - "fields", - new RelevanceFieldList(ImmutableMap.of("field_value1", 1.F, "field_value2", .3F))), - AstDSL.unresolvedArg("query", stringLiteral("sample query")))); - } - - @Test - void query_expression() { - assertAnalyzeEqual( - DSL.query(DSL.namedArgument("query", DSL.literal("field:query"))), - AstDSL.function("query", AstDSL.unresolvedArg("query", stringLiteral("field:query")))); - } - - @Test - void query_string_expression() { - assertAnalyzeEqual( - DSL.query_string( - DSL.namedArgument( - "fields", - DSL.literal( - new ExprTupleValue( - new LinkedHashMap<>( - ImmutableMap.of("field_value1", ExprValueUtils.floatValue(1.F)))))), - DSL.namedArgument("query", DSL.literal("query_value"))), - AstDSL.function( - "query_string", - AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of("field_value1", 1.F))), - AstDSL.unresolvedArg("query", stringLiteral("query_value")))); - } - - @Test - void query_string_expression_with_params() { - assertAnalyzeEqual( - DSL.query_string( - DSL.namedArgument( - "fields", - DSL.literal( - new ExprTupleValue( - new LinkedHashMap<>( - ImmutableMap.of("field_value1", ExprValueUtils.floatValue(1.F)))))), - DSL.namedArgument("query", DSL.literal("query_value")), - DSL.namedArgument("escape", DSL.literal("false"))), - AstDSL.function( - "query_string", - AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of("field_value1", 1.F))), - AstDSL.unresolvedArg("query", stringLiteral("query_value")), - AstDSL.unresolvedArg("escape", stringLiteral("false")))); - } - - @Test - void query_string_expression_two_fields() { - assertAnalyzeEqual( - DSL.query_string( - DSL.namedArgument( - "fields", - DSL.literal( - new ExprTupleValue( - new LinkedHashMap<>( - ImmutableMap.of( - "field_value1", ExprValueUtils.floatValue(1.F), - "field_value2", ExprValueUtils.floatValue(.3F)))))), - DSL.namedArgument("query", DSL.literal("query_value"))), - AstDSL.function( - "query_string", - AstDSL.unresolvedArg( - "fields", - new RelevanceFieldList(ImmutableMap.of("field_value1", 1.F, "field_value2", .3F))), - AstDSL.unresolvedArg("query", stringLiteral("query_value")))); - } - - @Test - void wildcard_query_expression() { - assertAnalyzeEqual( - DSL.wildcard_query( - DSL.namedArgument("field", DSL.literal("test")), - DSL.namedArgument("query", DSL.literal("query_value*"))), - AstDSL.function( - "wildcard_query", - unresolvedArg("field", stringLiteral("test")), - unresolvedArg("query", stringLiteral("query_value*")))); - } - - @Test - void wildcard_query_expression_all_params() { - assertAnalyzeEqual( - DSL.wildcard_query( - DSL.namedArgument("field", DSL.literal("test")), - DSL.namedArgument("query", DSL.literal("query_value*")), - DSL.namedArgument("boost", DSL.literal("1.5")), - DSL.namedArgument("case_insensitive", DSL.literal("true")), - DSL.namedArgument("rewrite", DSL.literal("scoring_boolean"))), - AstDSL.function( - "wildcard_query", - unresolvedArg("field", stringLiteral("test")), - unresolvedArg("query", stringLiteral("query_value*")), - unresolvedArg("boost", stringLiteral("1.5")), - unresolvedArg("case_insensitive", stringLiteral("true")), - unresolvedArg("rewrite", stringLiteral("scoring_boolean")))); - } - - @Test - public void match_phrase_prefix_all_params() { - assertAnalyzeEqual( - DSL.match_phrase_prefix( - DSL.namedArgument("field", "field_value1"), - DSL.namedArgument("query", "search query"), - DSL.namedArgument("slop", "3"), - DSL.namedArgument("boost", "1.5"), - DSL.namedArgument("analyzer", "standard"), - DSL.namedArgument("max_expansions", "4"), - DSL.namedArgument("zero_terms_query", "NONE")), - AstDSL.function( - "match_phrase_prefix", - unresolvedArg("field", stringLiteral("field_value1")), - unresolvedArg("query", stringLiteral("search query")), - unresolvedArg("slop", stringLiteral("3")), - unresolvedArg("boost", stringLiteral("1.5")), - unresolvedArg("analyzer", stringLiteral("standard")), - unresolvedArg("max_expansions", stringLiteral("4")), - unresolvedArg("zero_terms_query", stringLiteral("NONE")))); - } - - @Test - void score_function_expression() { - assertAnalyzeEqual( - DSL.score( - DSL.namedArgument( - "RelevanceQuery", - DSL.match_phrase_prefix( - DSL.namedArgument("field", "field_value1"), - DSL.namedArgument("query", "search query"), - DSL.namedArgument("slop", "3")))), - AstDSL.function( - "score", - unresolvedArg( - "RelevanceQuery", - AstDSL.function( - "match_phrase_prefix", - unresolvedArg("field", stringLiteral("field_value1")), - unresolvedArg("query", stringLiteral("search query")), - unresolvedArg("slop", stringLiteral("3")))))); - } - - @Test - void score_function_with_boost() { - assertAnalyzeEqual( - DSL.score( - DSL.namedArgument( - "RelevanceQuery", - DSL.match_phrase_prefix( - DSL.namedArgument("field", "field_value1"), - DSL.namedArgument("query", "search query"), - DSL.namedArgument("boost", "3.0"))), - DSL.namedArgument("boost", "2")), - AstDSL.function( - "score", - unresolvedArg( - "RelevanceQuery", - AstDSL.function( - "match_phrase_prefix", - unresolvedArg("field", stringLiteral("field_value1")), - unresolvedArg("query", stringLiteral("search query")), - unresolvedArg("boost", stringLiteral("3.0")))), - unresolvedArg("boost", stringLiteral("2")))); - } - - @Test - void score_query_function_expression() { - assertAnalyzeEqual( - DSL.score_query( - DSL.namedArgument( - "RelevanceQuery", - DSL.wildcard_query( - DSL.namedArgument("field", "field_value1"), - DSL.namedArgument("query", "search query")))), - AstDSL.function( - "score_query", - unresolvedArg( - "RelevanceQuery", - AstDSL.function( - "wildcard_query", - unresolvedArg("field", stringLiteral("field_value1")), - unresolvedArg("query", stringLiteral("search query")))))); - } - - @Test - void score_query_function_with_boost() { - assertAnalyzeEqual( - DSL.score_query( - DSL.namedArgument( - "RelevanceQuery", - DSL.wildcard_query( - DSL.namedArgument("field", "field_value1"), - DSL.namedArgument("query", "search query"))), - DSL.namedArgument("boost", "2.0")), - AstDSL.function( - "score_query", - unresolvedArg( - "RelevanceQuery", - AstDSL.function( - "wildcard_query", - unresolvedArg("field", stringLiteral("field_value1")), - unresolvedArg("query", stringLiteral("search query")))), - unresolvedArg("boost", stringLiteral("2.0")))); - } - - @Test - void scorequery_function_expression() { - assertAnalyzeEqual( - DSL.scorequery( - DSL.namedArgument( - "RelevanceQuery", - DSL.simple_query_string( - DSL.namedArgument("field", "field_value1"), - DSL.namedArgument("query", "search query"), - DSL.namedArgument("slop", "3")))), - AstDSL.function( - "scorequery", - unresolvedArg( - "RelevanceQuery", - AstDSL.function( - "simple_query_string", - unresolvedArg("field", stringLiteral("field_value1")), - unresolvedArg("query", stringLiteral("search query")), - unresolvedArg("slop", stringLiteral("3")))))); - } - - @Test - void scorequery_function_with_boost() { - assertAnalyzeEqual( - DSL.scorequery( - DSL.namedArgument( - "RelevanceQuery", - DSL.simple_query_string( - DSL.namedArgument("field", "field_value1"), - DSL.namedArgument("query", "search query"), - DSL.namedArgument("slop", "3"))), - DSL.namedArgument("boost", "2.0")), - AstDSL.function( - "scorequery", - unresolvedArg( - "RelevanceQuery", - AstDSL.function( - "simple_query_string", - unresolvedArg("field", stringLiteral("field_value1")), - unresolvedArg("query", stringLiteral("search query")), - unresolvedArg("slop", stringLiteral("3")))), - unresolvedArg("boost", stringLiteral("2.0")))); - } - @Test public void function_isnt_calculated_on_analyze() { assertTrue(analyze(function("now")) instanceof FunctionExpression); diff --git a/core/src/test/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizerTest.java b/core/src/test/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizerTest.java index 28bcb8793f..7e68bafd27 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizerTest.java @@ -10,6 +10,7 @@ import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.data.type.ExprCoreType.STRING; +import static org.opensearch.sql.datasource.model.EmptyDataSourceService.getEmptyDataSourceService; import com.google.common.collect.ImmutableList; import org.junit.jupiter.api.Test; @@ -132,7 +133,7 @@ Expression optimize(Expression expression) { } Expression optimize(Expression expression, LogicalPlan logicalPlan) { - BuiltinFunctionRepository functionRepository = BuiltinFunctionRepository.getInstance(); + BuiltinFunctionRepository functionRepository = BuiltinFunctionRepository.getInstance(getEmptyDataSourceService()); final ExpressionReferenceOptimizer optimizer = new ExpressionReferenceOptimizer(functionRepository, logicalPlan); return optimizer.optimize(DSL.named(expression), new AnalysisContext()); diff --git a/core/src/test/java/org/opensearch/sql/analysis/OpenSearchAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/OpenSearchAnalyzerTest.java new file mode 100644 index 0000000000..225d7371e9 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/analysis/OpenSearchAnalyzerTest.java @@ -0,0 +1,154 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.analysis; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.ast.dsl.AstDSL.alias; +import static org.opensearch.sql.ast.dsl.AstDSL.and; +import static org.opensearch.sql.ast.dsl.AstDSL.filter; +import static org.opensearch.sql.ast.dsl.AstDSL.function; +import static org.opensearch.sql.ast.dsl.AstDSL.highlight; +import static org.opensearch.sql.ast.dsl.AstDSL.project; +import static org.opensearch.sql.ast.dsl.AstDSL.qualifiedName; +import static org.opensearch.sql.ast.dsl.AstDSL.relation; +import static org.opensearch.sql.ast.dsl.AstDSL.stringLiteral; +import static org.opensearch.sql.ast.dsl.AstDSL.unresolvedArg; +import static org.opensearch.sql.data.type.ExprCoreType.STRUCT; + +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.ast.dsl.AstDSL; +import org.opensearch.sql.ast.expression.RelevanceFieldList; +import org.opensearch.sql.ast.expression.ScoreFunction; +import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.ast.tree.UnresolvedPlan; +import org.opensearch.sql.data.model.ExprTupleValue; +import org.opensearch.sql.exception.SemanticCheckException; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.function.BuiltinFunctionRepository; +import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.expression.function.OpenSearchFunction; +import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalPlan; + +@ExtendWith(MockitoExtension.class) +class OpenSearchAnalyzerTest extends AnalyzerTestBase { + + @Mock + private BuiltinFunctionRepository builtinFunctionRepository; + + @Override + protected ExpressionAnalyzer expressionAnalyzer() { + return new ExpressionAnalyzer(builtinFunctionRepository); + } + + @BeforeEach + private void setup() { + this.expressionAnalyzer = expressionAnalyzer(); + this.analyzer = analyzer(this.expressionAnalyzer, dataSourceService); + } + + @Test + public void analyze_filter_visit_score_function() { + + // setup + OpenSearchFunction scoreFunction = new OpenSearchFunction( + new FunctionName("match_phrase_prefix"), List.of()); + when(builtinFunctionRepository.compile(any(), any(), any())).thenReturn(scoreFunction); + + UnresolvedPlan unresolvedPlan = + AstDSL.filter( + AstDSL.relation("schema"), + new ScoreFunction( + AstDSL.function( + "match_phrase_prefix", + AstDSL.unresolvedArg("field", stringLiteral("field_value1")), + AstDSL.unresolvedArg("query", stringLiteral("search query")), + AstDSL.unresolvedArg("boost", stringLiteral("3"))), + AstDSL.doubleLiteral(1.0))); + + // test + LogicalPlan logicalPlan = analyze(unresolvedPlan); + OpenSearchFunction relevanceQuery = + (OpenSearchFunction) ((LogicalFilter) logicalPlan).getCondition(); + + // verify + assertEquals(true, relevanceQuery.isScoreTracked()); + } + + @Test + public void analyze_filter_visit_score_function_without_boost() { + + // setup + OpenSearchFunction scoreFunction = new OpenSearchFunction( + new FunctionName("match_phrase_prefix"), List.of()); + when(builtinFunctionRepository.compile(any(), any(), any())).thenReturn(scoreFunction); + + UnresolvedPlan unresolvedPlan = + AstDSL.filter( + AstDSL.relation("schema"), + new ScoreFunction( + AstDSL.function( + "match_phrase_prefix", + AstDSL.unresolvedArg("field", stringLiteral("field_value1")), + AstDSL.unresolvedArg("query", stringLiteral("search query"))), + AstDSL.doubleLiteral(1.0))); + + // test + LogicalPlan logicalPlan = analyze(unresolvedPlan); + OpenSearchFunction relevanceQuery = + (OpenSearchFunction) ((LogicalFilter) logicalPlan).getCondition(); + + // verify + assertEquals(true, relevanceQuery.isScoreTracked()); + } + + + @Test + public void analyze_filter_visit_score_function_with_unsupported_boost_SemanticCheckException() { + // setup + UnresolvedPlan unresolvedPlan = + AstDSL.filter( + AstDSL.relation("schema"), + new ScoreFunction( + AstDSL.function( + "match_phrase_prefix", + AstDSL.unresolvedArg("field", stringLiteral("field_value1")), + AstDSL.unresolvedArg("query", stringLiteral("search query")), + AstDSL.unresolvedArg("boost", stringLiteral("3"))), + AstDSL.stringLiteral("3.0"))); + + // Test + SemanticCheckException exception = + assertThrows(SemanticCheckException.class, () -> analyze(unresolvedPlan)); + + // Verify + assertEquals("Expected boost type 'DOUBLE' but got 'STRING'", exception.getMessage()); + } + + @Test + public void analyze_relevance_field_list() { + LinkedHashMap tuple = new LinkedHashMap(Map.of("Title", 1.0F, "Body", 4.2F, "Tags", 1.5F)); + + UnresolvedExpression unresolvedPlan = new RelevanceFieldList(tuple); + Expression relevanceFieldList = unresolvedPlan.accept(expressionAnalyzer, analysisContext); + assertEquals(STRUCT, relevanceFieldList.type()); + assertTrue(relevanceFieldList.valueOf() instanceof ExprTupleValue); + assertEquals(tuple.get("Tags"), relevanceFieldList.valueOf().tupleValue().get("Tags").floatValue()); + assertEquals(tuple.get("Body"), relevanceFieldList.valueOf().tupleValue().get("Body").floatValue()); + assertEquals(tuple.get("Title"), relevanceFieldList.valueOf().tupleValue().get("Title").floatValue()); + } +} diff --git a/core/src/test/java/org/opensearch/sql/expression/ExpressionTestBase.java b/core/src/test/java/org/opensearch/sql/expression/ExpressionTestBase.java index fd886cdda3..e3954a6a24 100644 --- a/core/src/test/java/org/opensearch/sql/expression/ExpressionTestBase.java +++ b/core/src/test/java/org/opensearch/sql/expression/ExpressionTestBase.java @@ -78,26 +78,4 @@ protected static Environment valueEnv() { } }; } - - protected Environment typeEnv() { - return typeEnv; - } - - protected Function, FunctionExpression> functionMapping( - BuiltinFunctionName builtinFunctionName) { - switch (builtinFunctionName) { - case ADD: - return (expressions) -> DSL.add(expressions.get(0), expressions.get(1)); - case SUBTRACT: - return (expressions) -> DSL.subtract(expressions.get(0), expressions.get(1)); - case MULTIPLY: - return (expressions) -> DSL.multiply(expressions.get(0), expressions.get(1)); - case DIVIDE: - return (expressions) -> DSL.divide(expressions.get(0), expressions.get(1)); - case MODULUS: - return (expressions) -> DSL.modulus(expressions.get(0), expressions.get(1)); - default: - throw new RuntimeException(); - } - } } diff --git a/core/src/test/java/org/opensearch/sql/expression/datetime/DateTimeTestBase.java b/core/src/test/java/org/opensearch/sql/expression/datetime/DateTimeTestBase.java index 023a3574aa..ecc665f653 100644 --- a/core/src/test/java/org/opensearch/sql/expression/datetime/DateTimeTestBase.java +++ b/core/src/test/java/org/opensearch/sql/expression/datetime/DateTimeTestBase.java @@ -6,6 +6,7 @@ package org.opensearch.sql.expression.datetime; import static org.opensearch.sql.data.model.ExprValueUtils.fromObjectValue; +import static org.opensearch.sql.datasource.model.EmptyDataSourceService.getEmptyDataSourceService; import java.time.Instant; import java.time.LocalDate; @@ -28,7 +29,7 @@ public class DateTimeTestBase extends ExpressionTestBase { protected final BuiltinFunctionRepository functionRepository = - BuiltinFunctionRepository.getInstance(); + BuiltinFunctionRepository.getInstance(getEmptyDataSourceService()); protected ExprValue eval(Expression expression) { return expression.valueOf(); diff --git a/core/src/test/java/org/opensearch/sql/expression/function/DefaultFunctionResolverTest.java b/core/src/test/java/org/opensearch/sql/expression/function/DefaultFunctionResolverTest.java index ad9e8a6661..a8412ef4f6 100644 --- a/core/src/test/java/org/opensearch/sql/expression/function/DefaultFunctionResolverTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/function/DefaultFunctionResolverTest.java @@ -20,6 +20,7 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.data.type.WideningTypeRule; import org.opensearch.sql.exception.ExpressionEvaluationException; @@ -125,4 +126,10 @@ void resolve_varargs_too_many_args_function_signature_not_match() { ExpressionEvaluationException.class, () -> resolver.resolve(functionSignature)); assertEquals("concat function expected 1-9 arguments, but got 10", exception.getMessage()); } + + @Test + void resolve_nested_function() { + functionName = FunctionName.of("nested"); + + } } diff --git a/core/src/test/java/org/opensearch/sql/expression/function/OpenSearchFunctionTest.java b/core/src/test/java/org/opensearch/sql/expression/function/OpenSearchFunctionTest.java new file mode 100644 index 0000000000..66711dbb15 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/function/OpenSearchFunctionTest.java @@ -0,0 +1,59 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.opensearch.sql.data.model.ExprStringValue; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.ExpressionTestBase; +import org.opensearch.sql.expression.FunctionExpression; +import org.opensearch.sql.expression.LiteralExpression; +import org.opensearch.sql.expression.NamedArgumentExpression; +import org.opensearch.sql.expression.NamedExpression; +import org.opensearch.sql.expression.ReferenceExpression; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Answers.CALLS_REAL_METHODS; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.withSettings; + +import static org.opensearch.sql.data.type.ExprCoreType.BOOLEAN; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; + +public class OpenSearchFunctionTest extends ExpressionTestBase { + private final NamedArgumentExpression field = + new NamedArgumentExpression("field", DSL.literal("message")); + private final NamedArgumentExpression query = + new NamedArgumentExpression("query", DSL.literal("search query")); + private final DataSourceMetadata defaultDataSourceMetadata = DataSourceMetadata.defaultOpenSearchDataSourceMetadata(); + + @Test + void test_opensearch_function() { + OpenSearchFunction function = new OpenSearchFunction( + new FunctionName("match"), + List.of(new NamedArgumentExpression("test", new LiteralExpression(new ExprStringValue("test"))))); + assertEquals(BOOLEAN, function.type()); + assertThrows(UnsupportedOperationException.class,() -> function.valueOf(null)); + assertEquals("match(test=\"test\")", function.toString()); + } + +// @Test +// void test_nested_function() { +//// OpenSearchFunction function = mock(OpenSearchFunction.class, withSettings().defaultAnswer(CALLS_REAL_METHODS)); +// OpenSearchFunction function = new OpenSearchFunction(new FunctionName("match"), List.of(new NamedArgumentExpression("a", new LiteralExpression(new ExprStringValue("a"))))); +// FunctionExpression expr = function; +//// assertEquals("match(field=\"message\", query=\"search query\")", expr.toString()); +// assertEquals(BOOLEAN, function.type()); +// assertThrows(UnsupportedOperationException.class,() -> function.valueOf(null)); +// assertEquals("match(a=\"a\")", function.toString()); +// } +} diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/storage/OpenSearchDataSourceMetadataStorage.java b/datasources/src/main/java/org/opensearch/sql/datasources/storage/OpenSearchDataSourceMetadataStorage.java index 73eb297fea..632aa78c0e 100644 --- a/datasources/src/main/java/org/opensearch/sql/datasources/storage/OpenSearchDataSourceMetadataStorage.java +++ b/datasources/src/main/java/org/opensearch/sql/datasources/storage/OpenSearchDataSourceMetadataStorage.java @@ -80,6 +80,9 @@ public OpenSearchDataSourceMetadataStorage( @Override public List getDataSourceMetadata() { + if (!this.clusterService.getClusterApplierService().isInitialClusterStateSet()) { + return Collections.emptyList(); + } if (!this.clusterService.state().routingTable().hasIndex(DATASOURCE_INDEX_NAME)) { createDataSourcesIndex(); return Collections.emptyList(); diff --git a/datasources/src/test/java/org/opensearch/sql/datasources/storage/OpenSearchDataSourceMetadataStorageTest.java b/datasources/src/test/java/org/opensearch/sql/datasources/storage/OpenSearchDataSourceMetadataStorageTest.java index cc663d56e6..3aa454ba59 100644 --- a/datasources/src/test/java/org/opensearch/sql/datasources/storage/OpenSearchDataSourceMetadataStorageTest.java +++ b/datasources/src/test/java/org/opensearch/sql/datasources/storage/OpenSearchDataSourceMetadataStorageTest.java @@ -209,6 +209,7 @@ public void testGetDataSourceMetadataWithBasicAuth() { @SneakyThrows @Test public void testGetDataSourceMetadataList() { + Mockito.when(clusterService.getClusterApplierService().isInitialClusterStateSet()).thenReturn(true); Mockito.when(clusterService.state().routingTable().hasIndex(DATASOURCE_INDEX_NAME)) .thenReturn(true); Mockito.when(client.search(ArgumentMatchers.any())).thenReturn(searchResponseActionFuture); @@ -233,6 +234,7 @@ public void testGetDataSourceMetadataList() { @SneakyThrows @Test public void testGetDataSourceMetadataListWithNoIndex() { + Mockito.when(clusterService.getClusterApplierService().isInitialClusterStateSet()).thenReturn(true); Mockito.when(clusterService.state().routingTable().hasIndex(DATASOURCE_INDEX_NAME)) .thenReturn(Boolean.FALSE); Mockito.when(client.admin().indices().create(ArgumentMatchers.any())) @@ -289,6 +291,16 @@ public void testCreateDataSourceMetadata() { Mockito.verify(client.threadPool().getThreadContext(), Mockito.times(2)).stashContext(); } + @Test + public void testCreateDataSourceMetadataUninitialized() { + Mockito.when(clusterService.getClusterApplierService().isInitialClusterStateSet()).thenReturn(false); + + List dataSourceMetadataList = + openSearchDataSourceMetadataStorage.getDataSourceMetadata(); + + Assertions.assertEquals(0, dataSourceMetadataList.size()); + } + @Test public void testCreateDataSourceMetadataWithOutCreatingIndex() { Mockito.when(clusterService.state().routingTable().hasIndex(DATASOURCE_INDEX_NAME)) diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/StandaloneIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/StandaloneIT.java index 8ef8787597..8eb6b1ed43 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/StandaloneIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/StandaloneIT.java @@ -7,6 +7,7 @@ package org.opensearch.sql.ppl; import static org.opensearch.sql.datasource.model.DataSourceMetadata.defaultOpenSearchDataSourceMetadata; +import static org.opensearch.sql.datasource.model.EmptyDataSourceService.getEmptyDataSourceService; import static org.opensearch.sql.protocol.response.format.JsonResponseFormatter.Style.PRETTY; import com.google.common.collect.ImmutableMap; @@ -181,7 +182,7 @@ public class StandaloneModule extends AbstractModule { private final DataSourceService dataSourceService; private final BuiltinFunctionRepository functionRepository = - BuiltinFunctionRepository.getInstance(); + BuiltinFunctionRepository.getInstance(getEmptyDataSourceService()); @Override protected void configure() {} diff --git a/integ-test/src/test/java/org/opensearch/sql/util/StandaloneModule.java b/integ-test/src/test/java/org/opensearch/sql/util/StandaloneModule.java index ad8afc47ca..b8c96ab081 100644 --- a/integ-test/src/test/java/org/opensearch/sql/util/StandaloneModule.java +++ b/integ-test/src/test/java/org/opensearch/sql/util/StandaloneModule.java @@ -36,6 +36,8 @@ import org.opensearch.sql.sql.antlr.SQLSyntaxParser; import org.opensearch.sql.storage.StorageEngine; +import static org.opensearch.sql.datasource.model.EmptyDataSourceService.getEmptyDataSourceService; + /** * A utility class which registers SQL engine singletons as `OpenSearchPluginModule` does. * It is needed to get access to those instances in test and validate their behavior. @@ -50,7 +52,7 @@ public class StandaloneModule extends AbstractModule { private final DataSourceService dataSourceService; private final BuiltinFunctionRepository functionRepository = - BuiltinFunctionRepository.getInstance(); + BuiltinFunctionRepository.getInstance(getEmptyDataSourceService()); @Override protected void configure() { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/functions/OpenSearchFunctions.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/functions/OpenSearchFunctions.java new file mode 100644 index 0000000000..1c4ceceeab --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/functions/OpenSearchFunctions.java @@ -0,0 +1,129 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.functions; + +import lombok.experimental.UtilityClass; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.FunctionExpression; +import org.opensearch.sql.expression.env.Environment; +import org.opensearch.sql.expression.function.BuiltinFunctionName; +import org.opensearch.sql.expression.function.FunctionBuilder; +import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.expression.function.FunctionResolver; +import org.opensearch.sql.expression.function.FunctionSignature; + +import java.util.Collection; +import java.util.List; + +@UtilityClass +public class OpenSearchFunctions { + public Collection getResolvers() { + return List.of( + match_bool_prefix(), + multi_match(BuiltinFunctionName.MULTI_MATCH), + multi_match(BuiltinFunctionName.MULTIMATCH), + multi_match(BuiltinFunctionName.MULTIMATCHQUERY), + match(BuiltinFunctionName.MATCH), + match(BuiltinFunctionName.MATCHQUERY), + match(BuiltinFunctionName.MATCH_QUERY), + simple_query_string(), + query(), + query_string(), + + // Register MATCHPHRASE as MATCH_PHRASE as well for backwards + // compatibility. + match_phrase(BuiltinFunctionName.MATCH_PHRASE), + match_phrase(BuiltinFunctionName.MATCHPHRASE), + match_phrase(BuiltinFunctionName.MATCHPHRASEQUERY), + match_phrase_prefix(), + wildcard_query(BuiltinFunctionName.WILDCARD_QUERY), + wildcard_query(BuiltinFunctionName.WILDCARDQUERY), + score(BuiltinFunctionName.SCORE), + score(BuiltinFunctionName.SCOREQUERY), + score(BuiltinFunctionName.SCORE_QUERY), + nested() + ); + } + + private static FunctionResolver match_bool_prefix() { + FunctionName name = BuiltinFunctionName.MATCH_BOOL_PREFIX.getName(); + return new RelevanceFunctionResolver(name); + } + + private static FunctionResolver match(BuiltinFunctionName match) { + FunctionName funcName = match.getName(); + return new RelevanceFunctionResolver(funcName); + } + + private static FunctionResolver match_phrase_prefix() { + FunctionName funcName = BuiltinFunctionName.MATCH_PHRASE_PREFIX.getName(); + return new RelevanceFunctionResolver(funcName); + } + + private static FunctionResolver match_phrase(BuiltinFunctionName matchPhrase) { + FunctionName funcName = matchPhrase.getName(); + return new RelevanceFunctionResolver(funcName); + } + + private static FunctionResolver multi_match(BuiltinFunctionName multiMatchName) { + return new RelevanceFunctionResolver(multiMatchName.getName()); + } + + private static FunctionResolver simple_query_string() { + FunctionName funcName = BuiltinFunctionName.SIMPLE_QUERY_STRING.getName(); + return new RelevanceFunctionResolver(funcName); + } + + private static FunctionResolver query() { + FunctionName funcName = BuiltinFunctionName.QUERY.getName(); + return new RelevanceFunctionResolver(funcName); + } + + private static FunctionResolver query_string() { + FunctionName funcName = BuiltinFunctionName.QUERY_STRING.getName(); + return new RelevanceFunctionResolver(funcName); + } + + private static FunctionResolver wildcard_query(BuiltinFunctionName wildcardQuery) { + FunctionName funcName = wildcardQuery.getName(); + return new RelevanceFunctionResolver(funcName); + } + + private static FunctionResolver nested() { + return new FunctionResolver() { + @Override + public Pair resolve( + FunctionSignature unresolvedSignature) { + return Pair.of(unresolvedSignature, + (functionProperties, arguments) -> + new FunctionExpression(BuiltinFunctionName.NESTED.getName(), arguments) { + @Override + public ExprValue valueOf(Environment valueEnv) { + return valueEnv.resolve(getArguments().get(0)); + } + + @Override + public ExprType type() { + return getArguments().get(0).type(); + } + }); + } + + @Override + public FunctionName getFunctionName() { + return BuiltinFunctionName.NESTED.getName(); + } + }; + } + + private static FunctionResolver score(BuiltinFunctionName score) { + FunctionName funcName = score.getName(); + return new RelevanceFunctionResolver(funcName); + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/function/RelevanceFunctionResolver.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/functions/RelevanceFunctionResolver.java similarity index 60% rename from core/src/main/java/org/opensearch/sql/expression/function/RelevanceFunctionResolver.java rename to opensearch/src/main/java/org/opensearch/sql/opensearch/functions/RelevanceFunctionResolver.java index ae882897d0..867bb38483 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/RelevanceFunctionResolver.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/functions/RelevanceFunctionResolver.java @@ -3,29 +3,33 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.sql.expression.function; +package org.opensearch.sql.opensearch.functions; -import java.util.List; import lombok.Getter; import lombok.RequiredArgsConstructor; import org.apache.commons.lang3.tuple.Pair; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.exception.SemanticCheckException; +import org.opensearch.sql.expression.function.FunctionBuilder; +import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.expression.function.FunctionResolver; +import org.opensearch.sql.expression.function.FunctionSignature; +import org.opensearch.sql.expression.function.OpenSearchFunction; -@RequiredArgsConstructor -public class RelevanceFunctionResolver implements FunctionResolver { +import java.util.List; - @Getter private final FunctionName functionName; +@RequiredArgsConstructor +public class RelevanceFunctionResolver + implements FunctionResolver { + @Getter + private final FunctionName functionName; @Override public Pair resolve(FunctionSignature unresolvedSignature) { if (!unresolvedSignature.getFunctionName().equals(functionName)) { - throw new SemanticCheckException( - String.format( - "Expected '%s' but got '%s'", - functionName.getFunctionName(), - unresolvedSignature.getFunctionName().getFunctionName())); + throw new SemanticCheckException(String.format("Expected '%s' but got '%s'", + functionName.getFunctionName(), unresolvedSignature.getFunctionName().getFunctionName())); } List paramTypes = unresolvedSignature.getParamTypeList(); // Check if all but the first parameter are of type STRING. @@ -37,15 +41,13 @@ public Pair resolve(FunctionSignature unreso } } - FunctionBuilder buildFunction = - (functionProperties, args) -> - new OpenSearchFunctions.OpenSearchFunction(functionName, args); + FunctionBuilder buildFunction = (functionProperties, args) + -> new OpenSearchFunction(functionName, args); return Pair.of(unresolvedSignature, buildFunction); } - /** - * Returns a helpful error message when expected parameter type does not match the specified - * parameter type. + /** Returns a helpful error message when expected parameter type does not match the + * specified parameter type. * * @param i 0-based index of the parameter in a function signature. * @param paramType the type of the ith parameter at run-time. @@ -53,8 +55,7 @@ public Pair resolve(FunctionSignature unreso * @return A user-friendly error message that informs of the type difference. */ private String getWrongParameterErrorMessage(int i, ExprType paramType, ExprType expectedType) { - return String.format( - "Expected type %s instead of %s for parameter #%d", + return String.format("Expected type %s instead of %s for parameter #%d", expectedType.typeName(), paramType.typeName(), i + 1); } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngine.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngine.java index c915fa549b..256a08e402 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngine.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngine.java @@ -8,11 +8,14 @@ import static org.opensearch.sql.utils.SystemIndexUtils.isSystemIndex; +import java.util.Collection; import lombok.Getter; import lombok.RequiredArgsConstructor; import org.opensearch.sql.DataSourceSchemaName; import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.expression.function.FunctionResolver; import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.opensearch.functions.OpenSearchFunctions; import org.opensearch.sql.opensearch.storage.system.OpenSearchSystemIndex; import org.opensearch.sql.storage.StorageEngine; import org.opensearch.sql.storage.Table; @@ -35,4 +38,9 @@ public Table getTable(DataSourceSchemaName dataSourceSchemaName, String name) { return new OpenSearchIndex(client, settings, name); } } + + @Override + public Collection getFunctions() { + return OpenSearchFunctions.getResolvers(); + } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanQueryBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanQueryBuilder.java index 590272a9f1..447c83b6fa 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanQueryBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanQueryBuilder.java @@ -21,7 +21,7 @@ import org.opensearch.sql.expression.FunctionExpression; import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.ReferenceExpression; -import org.opensearch.sql.expression.function.OpenSearchFunctions; +import org.opensearch.sql.expression.function.OpenSearchFunction; import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; import org.opensearch.sql.opensearch.storage.script.filter.FilterQueryBuilder; import org.opensearch.sql.opensearch.storage.script.sort.SortQueryBuilder; @@ -100,8 +100,8 @@ public boolean pushDownPageSize(LogicalPaginate paginate) { } private boolean trackScoresFromOpenSearchFunction(Expression condition) { - if (condition instanceof OpenSearchFunctions.OpenSearchFunction - && ((OpenSearchFunctions.OpenSearchFunction) condition).isScoreTracked()) { + if (condition instanceof OpenSearchFunction + && ((OpenSearchFunction) condition).isScoreTracked()) { return true; } if (condition instanceof FunctionExpression) { diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/config/TestConfig.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/config/TestConfig.java new file mode 100644 index 0000000000..cba43049a8 --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/config/TestConfig.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + +package org.opensearch.sql.opensearch.config; + +import com.google.common.collect.ImmutableMap; +import java.util.Map; +import org.opensearch.sql.DataSourceSchemaName; +import org.opensearch.sql.analysis.symbol.Namespace; +import org.opensearch.sql.analysis.symbol.Symbol; +import org.opensearch.sql.analysis.symbol.SymbolTable; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.expression.env.Environment; +import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.storage.StorageEngine; +import org.opensearch.sql.storage.Table; + +/** + * Configuration will be used for UT. + */ +public class TestConfig { + public static final String INT_TYPE_NULL_VALUE_FIELD = "int_null_value"; + public static final String INT_TYPE_MISSING_VALUE_FIELD = "int_missing_value"; + public static final String DOUBLE_TYPE_NULL_VALUE_FIELD = "double_null_value"; + public static final String DOUBLE_TYPE_MISSING_VALUE_FIELD = "double_missing_value"; + public static final String BOOL_TYPE_NULL_VALUE_FIELD = "null_value_boolean"; + public static final String BOOL_TYPE_MISSING_VALUE_FIELD = "missing_value_boolean"; + public static final String STRING_TYPE_NULL_VALUE_FIELD = "string_null_value"; + public static final String STRING_TYPE_MISSING_VALUE_FIELD = "string_missing_value"; + + public static Map typeMapping = new ImmutableMap.Builder() + .put("integer_value", ExprCoreType.INTEGER) + .put(INT_TYPE_NULL_VALUE_FIELD, ExprCoreType.INTEGER) + .put(INT_TYPE_MISSING_VALUE_FIELD, ExprCoreType.INTEGER) + .put("long_value", ExprCoreType.LONG) + .put("float_value", ExprCoreType.FLOAT) + .put("double_value", ExprCoreType.DOUBLE) + .put(DOUBLE_TYPE_NULL_VALUE_FIELD, ExprCoreType.DOUBLE) + .put(DOUBLE_TYPE_MISSING_VALUE_FIELD, ExprCoreType.DOUBLE) + .put("boolean_value", ExprCoreType.BOOLEAN) + .put(BOOL_TYPE_NULL_VALUE_FIELD, ExprCoreType.BOOLEAN) + .put(BOOL_TYPE_MISSING_VALUE_FIELD, ExprCoreType.BOOLEAN) + .put("string_value", ExprCoreType.STRING) + .put(STRING_TYPE_NULL_VALUE_FIELD, ExprCoreType.STRING) + .put(STRING_TYPE_MISSING_VALUE_FIELD, ExprCoreType.STRING) + .put("struct_value", ExprCoreType.STRUCT) + .put("array_value", ExprCoreType.ARRAY) + .put("timestamp_value", ExprCoreType.TIMESTAMP) + .put("field_value1", ExprCoreType.STRING) + .put("field_value2", ExprCoreType.STRING) + .put("message", ExprCoreType.STRING) + .put("message.info", ExprCoreType.STRING) + .put("message.info.id", ExprCoreType.STRING) + .put("comment", ExprCoreType.STRING) + .put("comment.data", ExprCoreType.STRING) + .build(); +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/expression/ExpressionTestBase.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/expression/ExpressionTestBase.java new file mode 100644 index 0000000000..545dcdb13e --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/expression/ExpressionTestBase.java @@ -0,0 +1,107 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.expression; + +import static org.opensearch.sql.opensearch.config.TestConfig.BOOL_TYPE_MISSING_VALUE_FIELD; +import static org.opensearch.sql.opensearch.config.TestConfig.BOOL_TYPE_NULL_VALUE_FIELD; +import static org.opensearch.sql.opensearch.config.TestConfig.DOUBLE_TYPE_MISSING_VALUE_FIELD; +import static org.opensearch.sql.opensearch.config.TestConfig.DOUBLE_TYPE_NULL_VALUE_FIELD; +import static org.opensearch.sql.opensearch.config.TestConfig.INT_TYPE_MISSING_VALUE_FIELD; +import static org.opensearch.sql.opensearch.config.TestConfig.INT_TYPE_NULL_VALUE_FIELD; +import static org.opensearch.sql.opensearch.config.TestConfig.STRING_TYPE_MISSING_VALUE_FIELD; +import static org.opensearch.sql.opensearch.config.TestConfig.STRING_TYPE_NULL_VALUE_FIELD; +import static org.opensearch.sql.data.model.ExprValueUtils.booleanValue; +import static org.opensearch.sql.data.model.ExprValueUtils.collectionValue; +import static org.opensearch.sql.data.model.ExprValueUtils.doubleValue; +import static org.opensearch.sql.data.model.ExprValueUtils.floatValue; +import static org.opensearch.sql.data.model.ExprValueUtils.integerValue; +import static org.opensearch.sql.data.model.ExprValueUtils.longValue; +import static org.opensearch.sql.data.model.ExprValueUtils.missingValue; +import static org.opensearch.sql.data.model.ExprValueUtils.nullValue; +import static org.opensearch.sql.data.model.ExprValueUtils.stringValue; +import static org.opensearch.sql.data.model.ExprValueUtils.tupleValue; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import java.util.List; +import java.util.function.Function; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.FunctionExpression; +import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.expression.env.Environment; +import org.opensearch.sql.expression.function.BuiltinFunctionName; +import org.opensearch.sql.expression.function.FunctionProperties; + +public class ExpressionTestBase { + + protected FunctionProperties functionProperties = new FunctionProperties(); + + protected Environment typeEnv; + + protected static Environment valueEnv() { + return var -> { + if (var instanceof ReferenceExpression) { + switch (((ReferenceExpression) var).getAttr()) { + case "integer_value": + return integerValue(1); + case "long_value": + return longValue(1L); + case "float_value": + return floatValue(1f); + case "double_value": + return doubleValue(1d); + case "boolean_value": + return booleanValue(true); + case "string_value": + return stringValue("str"); + case "struct_value": + return tupleValue(ImmutableMap.of("str", 1)); + case "array_value": + return collectionValue(ImmutableList.of(1)); + case BOOL_TYPE_NULL_VALUE_FIELD: + case INT_TYPE_NULL_VALUE_FIELD: + case DOUBLE_TYPE_NULL_VALUE_FIELD: + case STRING_TYPE_NULL_VALUE_FIELD: + return nullValue(); + case INT_TYPE_MISSING_VALUE_FIELD: + case BOOL_TYPE_MISSING_VALUE_FIELD: + case DOUBLE_TYPE_MISSING_VALUE_FIELD: + case STRING_TYPE_MISSING_VALUE_FIELD: + return missingValue(); + default: + throw new IllegalArgumentException("undefined reference"); + } + } else { + throw new IllegalArgumentException("var must be ReferenceExpression"); + } + }; + } + + protected Environment typeEnv() { + return typeEnv; + } + + protected Function, FunctionExpression> functionMapping( + BuiltinFunctionName builtinFunctionName) { + switch (builtinFunctionName) { + case ADD: + return (expressions) -> DSL.add(expressions.get(0), expressions.get(1)); + case SUBTRACT: + return (expressions) -> DSL.subtract(expressions.get(0), expressions.get(1)); + case MULTIPLY: + return (expressions) -> DSL.multiply(expressions.get(0), expressions.get(1)); + case DIVIDE: + return (expressions) -> DSL.divide(expressions.get(0), expressions.get(1)); + case MODULUS: + return (expressions) -> DSL.modulus(expressions.get(0), expressions.get(1)); + default: + throw new RuntimeException(); + } + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/expression/OpenSearchDSL.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/expression/OpenSearchDSL.java new file mode 100644 index 0000000000..a9da94ad03 --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/expression/OpenSearchDSL.java @@ -0,0 +1,311 @@ +package org.opensearch.sql.opensearch.expression; + +import com.google.common.collect.ImmutableMap; +import lombok.Getter; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.DataSourceSchemaName; +import org.opensearch.sql.ast.expression.SpanUnit; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.data.model.ExprShortValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasource.model.DataSource; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.FunctionExpression; +import org.opensearch.sql.expression.LiteralExpression; +import org.opensearch.sql.expression.NamedArgumentExpression; +import org.opensearch.sql.expression.NamedExpression; +import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.expression.aggregation.Aggregator; +import org.opensearch.sql.expression.aggregation.NamedAggregator; +import org.opensearch.sql.expression.conditional.cases.CaseClause; +import org.opensearch.sql.expression.conditional.cases.WhenClause; +import org.opensearch.sql.expression.function.BuiltinFunctionName; +import org.opensearch.sql.expression.function.BuiltinFunctionRepository; +import org.opensearch.sql.expression.function.FunctionImplementation; +import org.opensearch.sql.expression.function.FunctionProperties; +import org.opensearch.sql.expression.function.FunctionResolver; +import org.opensearch.sql.expression.parse.GrokExpression; +import org.opensearch.sql.expression.parse.ParseExpression; +import org.opensearch.sql.expression.parse.PatternsExpression; +import org.opensearch.sql.expression.parse.RegexExpression; +import org.opensearch.sql.expression.span.SpanExpression; +import org.opensearch.sql.expression.window.ranking.RankingWindowFunction; +import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.opensearch.config.TestConfig; +import org.opensearch.sql.opensearch.functions.OpenSearchFunctions; +import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.storage.StorageEngine; +import org.opensearch.sql.storage.Table; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.opensearch.sql.analysis.DataSourceSchemaIdentifierNameResolver.DEFAULT_DATASOURCE_NAME; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; + +public class OpenSearchDSL { + private OpenSearchDSL() {} + public static LiteralExpression literal(Byte value) { + return new LiteralExpression(ExprValueUtils.byteValue(value)); + } + + public static LiteralExpression literal(Short value) { + return new LiteralExpression(new ExprShortValue(value)); + } + + public static LiteralExpression literal(Integer value) { + return new LiteralExpression(ExprValueUtils.integerValue(value)); + } + + public static LiteralExpression literal(Long value) { + return new LiteralExpression(ExprValueUtils.longValue(value)); + } + + public static LiteralExpression literal(Float value) { + return new LiteralExpression(ExprValueUtils.floatValue(value)); + } + + public static LiteralExpression literal(Double value) { + return new LiteralExpression(ExprValueUtils.doubleValue(value)); + } + + public static LiteralExpression literal(String value) { + return new LiteralExpression(ExprValueUtils.stringValue(value)); + } + + public static LiteralExpression literal(Boolean value) { + return new LiteralExpression(ExprValueUtils.booleanValue(value)); + } + + public static LiteralExpression literal(ExprValue value) { + return new LiteralExpression(value); + } + + /** Wrap a number to {@link LiteralExpression}. */ + public static LiteralExpression literal(Number value) { + if (value instanceof Integer) { + return new LiteralExpression(ExprValueUtils.integerValue(value.intValue())); + } else if (value instanceof Long) { + return new LiteralExpression(ExprValueUtils.longValue(value.longValue())); + } else if (value instanceof Float) { + return new LiteralExpression(ExprValueUtils.floatValue(value.floatValue())); + } else { + return new LiteralExpression(ExprValueUtils.doubleValue(value.doubleValue())); + } + } + + public static ReferenceExpression ref(String ref, ExprType type) { + return new ReferenceExpression(ref, type); + } + + /** + * Wrap a named expression if not yet. The intent is that different languages may use Alias or not + * when building AST. This caused either named or unnamed expression is resolved by analyzer. To + * make unnamed expression acceptable for logical project, it is required to wrap it by named + * expression here before passing to logical project. + * + * @param expression expression + * @return expression if named already or expression wrapped by named expression. + */ + public static NamedExpression named(Expression expression) { + if (expression instanceof NamedExpression) { + return (NamedExpression) expression; + } + if (expression instanceof ParseExpression) { + return named( + ((ParseExpression) expression).getIdentifier().valueOf().stringValue(), expression); + } + return named(expression.toString(), expression); + } + + public static NamedExpression named(String name, Expression expression) { + return new NamedExpression(name, expression); + } + + public static NamedExpression named(String name, Expression expression, String alias) { + return new NamedExpression(name, expression, alias); + } + + public static NamedAggregator named(String name, Aggregator aggregator) { + return new NamedAggregator(name, aggregator); + } + + public static NamedArgumentExpression namedArgument(String argName, Expression value) { + return new NamedArgumentExpression(argName, value); + } + + public static NamedArgumentExpression namedArgument(String name, String value) { + return namedArgument(name, literal(value)); + } + + public static FunctionExpression nested(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.NESTED, expressions); + } + + public static FunctionExpression match(Expression... args) { + return compile(FunctionProperties.None, BuiltinFunctionName.MATCH, args); + } + + public static FunctionExpression match_phrase(Expression... args) { + return compile(FunctionProperties.None, BuiltinFunctionName.MATCH_PHRASE, args); + } + + public static FunctionExpression match_phrase_prefix(Expression... args) { + return compile(FunctionProperties.None, BuiltinFunctionName.MATCH_PHRASE_PREFIX, args); + } + + public static FunctionExpression multi_match(Expression... args) { + return compile(FunctionProperties.None, BuiltinFunctionName.MULTI_MATCH, args); + } + + public static FunctionExpression simple_query_string(Expression... args) { + return compile(FunctionProperties.None, BuiltinFunctionName.SIMPLE_QUERY_STRING, args); + } + + public static FunctionExpression query(Expression... args) { + return compile(FunctionProperties.None, BuiltinFunctionName.QUERY, args); + } + + public static FunctionExpression query_string(Expression... args) { + return compile(FunctionProperties.None, BuiltinFunctionName.QUERY_STRING, args); + } + + public static FunctionExpression match_bool_prefix(Expression... args) { + return compile(FunctionProperties.None, BuiltinFunctionName.MATCH_BOOL_PREFIX, args); + } + + public static FunctionExpression wildcard_query(Expression... args) { + return compile(FunctionProperties.None, BuiltinFunctionName.WILDCARD_QUERY, args); + } + + public static FunctionExpression score(Expression... args) { + return compile(FunctionProperties.None, BuiltinFunctionName.SCORE, args); + } + + public static FunctionExpression scorequery(Expression... args) { + return compile(FunctionProperties.None, BuiltinFunctionName.SCOREQUERY, args); + } + + public static FunctionExpression score_query(Expression... args) { + return compile(FunctionProperties.None, BuiltinFunctionName.SCORE_QUERY, args); + } + + // Make this add OpenSearchStorageEngine + @ExtendWith(MockitoExtension.class) + protected static StorageEngine storageEngine() { + return new StorageEngine() { + @Getter + @Mock + private OpenSearchClient client; + @Getter + @Mock + private Settings settings; + + @Override + public Table getTable(DataSourceSchemaName dataSourceSchemaName, String name) { + return table(); + } + + @Override + public Collection getFunctions() { + return OpenSearchFunctions.getResolvers(); + } + }; + } + + protected static Table table() { + return new Table() { + @Override + public boolean exists() { + return true; + } + + @Override + public void create(Map schema) { + throw new UnsupportedOperationException("Create table is not supported"); + } + + @Override + public Map getFieldTypes() { + return TestConfig.typeMapping; + } + + @Override + public PhysicalPlan implement(LogicalPlan plan) { + throw new UnsupportedOperationException(); + } + + public Map getReservedFieldTypes() { + return ImmutableMap.of("_test", STRING); + } + }; + } + + + private static class DefaultDataSourceService implements DataSourceService { + + private final DataSource opensearchDataSource = new DataSource(DEFAULT_DATASOURCE_NAME, + DataSourceType.OPENSEARCH, storageEngine()); + + + @Override + public Set getDataSourceMetadata(boolean isDefaultDataSourceRequired) { + return Stream.of(opensearchDataSource) + .map(ds -> new DataSourceMetadata(ds.getName(), + ds.getConnectorType(), Collections.emptyList(), + ImmutableMap.of())).collect(Collectors.toSet()); + } + + @Override + public DataSourceMetadata getDataSourceMetadata(String name) { + return null; + } + + @Override + public void createDataSource(DataSourceMetadata metadata) { + throw new UnsupportedOperationException("unsupported operation"); + } + + @Override + public DataSource getDataSource(String dataSourceName) { + return opensearchDataSource; + } + + @Override + public void updateDataSource(DataSourceMetadata dataSourceMetadata) { + + } + + @Override + public void deleteDataSource(String dataSourceName) { + } + + @Override + public Boolean dataSourceExists(String dataSourceName) { + return dataSourceName.equals(DEFAULT_DATASOURCE_NAME); + } + } + + @SuppressWarnings("unchecked") + private static T compile( + FunctionProperties functionProperties, BuiltinFunctionName bfn, Expression... args) { + return (T) + BuiltinFunctionRepository.getInstance(new DefaultDataSourceService()) + .compile(functionProperties, bfn.getName(), Arrays.asList(args)); + } +} diff --git a/core/src/test/java/org/opensearch/sql/expression/function/OpenSearchFunctionsTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/expression/function/OpenSearchFunctionsTest.java similarity index 70% rename from core/src/test/java/org/opensearch/sql/expression/function/OpenSearchFunctionsTest.java rename to opensearch/src/test/java/org/opensearch/sql/opensearch/expression/function/OpenSearchFunctionsTest.java index 168b73acc4..38e9513840 100644 --- a/core/src/test/java/org/opensearch/sql/expression/function/OpenSearchFunctionsTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/expression/function/OpenSearchFunctionsTest.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.sql.expression.function; +package org.opensearch.sql.opensearch.expression.function; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -17,75 +17,76 @@ import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; -import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; -import org.opensearch.sql.expression.ExpressionTestBase; +import org.opensearch.sql.expression.function.BuiltinFunctionName; +import org.opensearch.sql.opensearch.expression.ExpressionTestBase; import org.opensearch.sql.expression.FunctionExpression; import org.opensearch.sql.expression.NamedArgumentExpression; import org.opensearch.sql.expression.env.Environment; +import org.opensearch.sql.opensearch.expression.OpenSearchDSL; public class OpenSearchFunctionsTest extends ExpressionTestBase { private final NamedArgumentExpression field = - new NamedArgumentExpression("field", DSL.literal("message")); + new NamedArgumentExpression("field", OpenSearchDSL.literal("message")); private final NamedArgumentExpression fields = new NamedArgumentExpression( "fields", - DSL.literal( + OpenSearchDSL.literal( new ExprTupleValue( new LinkedHashMap<>( Map.of( "title", ExprValueUtils.floatValue(1.F), "body", ExprValueUtils.floatValue(.3F)))))); private final NamedArgumentExpression query = - new NamedArgumentExpression("query", DSL.literal("search query")); + new NamedArgumentExpression("query", OpenSearchDSL.literal("search query")); private final NamedArgumentExpression analyzer = - new NamedArgumentExpression("analyzer", DSL.literal("keyword")); + new NamedArgumentExpression("analyzer", OpenSearchDSL.literal("keyword")); private final NamedArgumentExpression autoGenerateSynonymsPhrase = - new NamedArgumentExpression("auto_generate_synonyms_phrase", DSL.literal("true")); + new NamedArgumentExpression("auto_generate_synonyms_phrase", OpenSearchDSL.literal("true")); private final NamedArgumentExpression fuzziness = - new NamedArgumentExpression("fuzziness", DSL.literal("AUTO")); + new NamedArgumentExpression("fuzziness", OpenSearchDSL.literal("AUTO")); private final NamedArgumentExpression maxExpansions = - new NamedArgumentExpression("max_expansions", DSL.literal("10")); + new NamedArgumentExpression("max_expansions", OpenSearchDSL.literal("10")); private final NamedArgumentExpression prefixLength = - new NamedArgumentExpression("prefix_length", DSL.literal("1")); + new NamedArgumentExpression("prefix_length", OpenSearchDSL.literal("1")); private final NamedArgumentExpression fuzzyTranspositions = - new NamedArgumentExpression("fuzzy_transpositions", DSL.literal("false")); + new NamedArgumentExpression("fuzzy_transpositions", OpenSearchDSL.literal("false")); private final NamedArgumentExpression fuzzyRewrite = - new NamedArgumentExpression("fuzzy_rewrite", DSL.literal("rewrite method")); + new NamedArgumentExpression("fuzzy_rewrite", OpenSearchDSL.literal("rewrite method")); private final NamedArgumentExpression lenient = - new NamedArgumentExpression("lenient", DSL.literal("true")); + new NamedArgumentExpression("lenient", OpenSearchDSL.literal("true")); private final NamedArgumentExpression operator = - new NamedArgumentExpression("operator", DSL.literal("OR")); + new NamedArgumentExpression("operator", OpenSearchDSL.literal("OR")); private final NamedArgumentExpression minimumShouldMatch = - new NamedArgumentExpression("minimum_should_match", DSL.literal("1")); + new NamedArgumentExpression("minimum_should_match", OpenSearchDSL.literal("1")); private final NamedArgumentExpression zeroTermsQueryAll = - new NamedArgumentExpression("zero_terms_query", DSL.literal("ALL")); + new NamedArgumentExpression("zero_terms_query", OpenSearchDSL.literal("ALL")); private final NamedArgumentExpression zeroTermsQueryNone = - new NamedArgumentExpression("zero_terms_query", DSL.literal("None")); + new NamedArgumentExpression("zero_terms_query", OpenSearchDSL.literal("None")); private final NamedArgumentExpression boost = - new NamedArgumentExpression("boost", DSL.literal("2.0")); + new NamedArgumentExpression("boost", OpenSearchDSL.literal("2.0")); private final NamedArgumentExpression slop = - new NamedArgumentExpression("slop", DSL.literal("3")); + new NamedArgumentExpression("slop", OpenSearchDSL.literal("3")); @Test void match() { - FunctionExpression expr = DSL.match(field, query); + FunctionExpression expr = OpenSearchDSL.match(field, query); assertEquals(BOOLEAN, expr.type()); - expr = DSL.match(field, query, analyzer); + expr = OpenSearchDSL.match(field, query, analyzer); assertEquals(BOOLEAN, expr.type()); - expr = DSL.match(field, query, analyzer, autoGenerateSynonymsPhrase); + expr = OpenSearchDSL.match(field, query, analyzer, autoGenerateSynonymsPhrase); assertEquals(BOOLEAN, expr.type()); - expr = DSL.match(field, query, analyzer, autoGenerateSynonymsPhrase, fuzziness); + expr = OpenSearchDSL.match(field, query, analyzer, autoGenerateSynonymsPhrase, fuzziness); assertEquals(BOOLEAN, expr.type()); - expr = DSL.match(field, query, analyzer, autoGenerateSynonymsPhrase, fuzziness, maxExpansions); + expr = OpenSearchDSL.match(field, query, analyzer, autoGenerateSynonymsPhrase, fuzziness, maxExpansions); assertEquals(BOOLEAN, expr.type()); expr = - DSL.match( + OpenSearchDSL.match( field, query, analyzer, @@ -96,7 +97,7 @@ void match() { assertEquals(BOOLEAN, expr.type()); expr = - DSL.match( + OpenSearchDSL.match( field, query, analyzer, @@ -108,7 +109,7 @@ void match() { assertEquals(BOOLEAN, expr.type()); expr = - DSL.match( + OpenSearchDSL.match( field, query, analyzer, @@ -121,7 +122,7 @@ void match() { assertEquals(BOOLEAN, expr.type()); expr = - DSL.match( + OpenSearchDSL.match( field, query, analyzer, @@ -135,7 +136,7 @@ void match() { assertEquals(BOOLEAN, expr.type()); expr = - DSL.match( + OpenSearchDSL.match( field, query, analyzer, @@ -150,7 +151,7 @@ void match() { assertEquals(BOOLEAN, expr.type()); expr = - DSL.match( + OpenSearchDSL.match( field, query, analyzer, @@ -165,7 +166,7 @@ void match() { assertEquals(BOOLEAN, expr.type()); expr = - DSL.match( + OpenSearchDSL.match( field, query, analyzer, @@ -181,7 +182,7 @@ void match() { assertEquals(BOOLEAN, expr.type()); expr = - DSL.match( + OpenSearchDSL.match( field, query, analyzer, @@ -198,7 +199,7 @@ void match() { assertEquals(BOOLEAN, expr.type()); expr = - DSL.match( + OpenSearchDSL.match( field, query, analyzer, @@ -225,14 +226,14 @@ void match_phrase() { List match_phrase_dsl_expressions() { return List.of( - DSL.match_phrase(field, query), - DSL.match_phrase(field, query, analyzer), - DSL.match_phrase(field, query, analyzer, zeroTermsQueryAll), - DSL.match_phrase(field, query, analyzer, zeroTermsQueryNone, slop)); + OpenSearchDSL.match_phrase(field, query), + OpenSearchDSL.match_phrase(field, query, analyzer), + OpenSearchDSL.match_phrase(field, query, analyzer, zeroTermsQueryAll), + OpenSearchDSL.match_phrase(field, query, analyzer, zeroTermsQueryNone, slop)); } List match_phrase_prefix_dsl_expressions() { - return List.of(DSL.match_phrase_prefix(field, query)); + return List.of(OpenSearchDSL.match_phrase_prefix(field, query)); } @Test @@ -244,7 +245,7 @@ public void match_phrase_prefix() { @Test void match_in_memory() { - FunctionExpression expr = DSL.match(field, query); + FunctionExpression expr = OpenSearchDSL.match(field, query); assertThrows( UnsupportedOperationException.class, () -> expr.valueOf(valueEnv()), @@ -253,13 +254,13 @@ void match_in_memory() { @Test void match_to_string() { - FunctionExpression expr = DSL.match(field, query); + FunctionExpression expr = OpenSearchDSL.match(field, query); assertEquals("match(field=\"message\", query=\"search query\")", expr.toString()); } @Test void multi_match() { - FunctionExpression expr = DSL.multi_match(fields, query); + FunctionExpression expr = OpenSearchDSL.multi_match(fields, query); assertEquals( String.format("multi_match(fields=%s, query=%s)", fields.getValue(), query.getValue()), expr.toString()); @@ -267,7 +268,7 @@ void multi_match() { @Test void simple_query_string() { - FunctionExpression expr = DSL.simple_query_string(fields, query); + FunctionExpression expr = OpenSearchDSL.simple_query_string(fields, query); assertEquals( String.format( "simple_query_string(fields=%s, query=%s)", fields.getValue(), query.getValue()), @@ -276,13 +277,13 @@ void simple_query_string() { @Test void query() { - FunctionExpression expr = DSL.query(query); + FunctionExpression expr = OpenSearchDSL.query(query); assertEquals(String.format("query(query=%s)", query.getValue()), expr.toString()); } @Test void query_string() { - FunctionExpression expr = DSL.query_string(fields, query); + FunctionExpression expr = OpenSearchDSL.query_string(fields, query); assertEquals( String.format("query_string(fields=%s, query=%s)", fields.getValue(), query.getValue()), expr.toString()); @@ -290,7 +291,7 @@ void query_string() { @Test void wildcard_query() { - FunctionExpression expr = DSL.wildcard_query(field, query); + FunctionExpression expr = OpenSearchDSL.wildcard_query(field, query); assertEquals( String.format("wildcard_query(field=%s, query=%s)", field.getValue(), query.getValue()), expr.toString()); @@ -298,7 +299,7 @@ void wildcard_query() { @Test void nested_query() { - FunctionExpression expr = DSL.nested(DSL.ref("message.info", STRING)); + FunctionExpression expr = OpenSearchDSL.nested(OpenSearchDSL.ref("message.info", STRING)); assertEquals( String.format( "FunctionExpression(functionName=%s, arguments=[message.info])", diff --git a/core/src/test/java/org/opensearch/sql/expression/function/RelevanceFunctionResolverTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/expression/function/RelevanceFunctionResolverTest.java similarity index 84% rename from core/src/test/java/org/opensearch/sql/expression/function/RelevanceFunctionResolverTest.java rename to opensearch/src/test/java/org/opensearch/sql/opensearch/expression/function/RelevanceFunctionResolverTest.java index c678ac6eb4..f2abe9cb92 100644 --- a/core/src/test/java/org/opensearch/sql/expression/function/RelevanceFunctionResolverTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/expression/function/RelevanceFunctionResolverTest.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.sql.expression.function; +package org.opensearch.sql.opensearch.expression.function; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -15,6 +15,10 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.opensearch.sql.exception.SemanticCheckException; +import org.opensearch.sql.expression.function.FunctionBuilder; +import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.expression.function.FunctionSignature; +import org.opensearch.sql.opensearch.functions.RelevanceFunctionResolver; class RelevanceFunctionResolverTest { private final FunctionName sampleFuncName = FunctionName.of("sample_function"); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/DefaultImplementorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/DefaultImplementorTest.java new file mode 100644 index 0000000000..fc6932fa3a --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/DefaultImplementorTest.java @@ -0,0 +1,149 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.planner; + +import static java.util.Collections.emptyList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; +import static org.opensearch.sql.expression.DSL.literal; +import static org.opensearch.sql.expression.DSL.named; +import static org.opensearch.sql.expression.DSL.ref; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.aggregation; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.eval; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.filter; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.limit; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.nested; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.project; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.rareTopN; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.remove; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.rename; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.sort; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.values; + +import com.google.common.collect.ImmutableMap; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.ast.tree.RareTopN.CommandType; +import org.opensearch.sql.ast.tree.Sort; +import org.opensearch.sql.data.model.ExprBooleanValue; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.opensearch.expression.OpenSearchDSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.NamedExpression; +import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.expression.aggregation.AvgAggregator; +import org.opensearch.sql.expression.aggregation.NamedAggregator; +import org.opensearch.sql.planner.DefaultImplementor; +import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.logical.LogicalPlanDSL; +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.planner.physical.PhysicalPlanDSL; + + +class DefaultImplementorTest { + + private final DefaultImplementor implementor = new DefaultImplementor<>(); + + @Test + public void visit_should_return_default_physical_operator() { + String indexName = "test"; + NamedExpression include = named("age", ref("age", INTEGER)); + ReferenceExpression exclude = ref("name", STRING); + ReferenceExpression dedupeField = ref("name", STRING); + Expression filterExpr = literal(ExprBooleanValue.of(true)); + List groupByExprs = Arrays.asList(OpenSearchDSL.named("age", ref("age", INTEGER))); + List aggExprs = Arrays.asList(ref("age", INTEGER)); + ReferenceExpression rareTopNField = ref("age", INTEGER); + List topByExprs = Arrays.asList(ref("age", INTEGER)); + List aggregators = + Arrays.asList(OpenSearchDSL.named("avg(age)", new AvgAggregator(aggExprs, ExprCoreType.DOUBLE))); + Map mappings = + ImmutableMap.of(ref("name", STRING), ref("lastname", STRING)); + Pair newEvalField = + ImmutablePair.of(ref("name1", STRING), ref("name", STRING)); + Pair sortField = + ImmutablePair.of(Sort.SortOption.DEFAULT_ASC, ref("name1", STRING)); + Integer limit = 1; + Integer offset = 1; + List> nestedArgs = + List.of( + Map.of( + "field", new ReferenceExpression("message.info", STRING), + "path", new ReferenceExpression("message", STRING))); + List nestedProjectList = + List.of( + new NamedExpression("message.info", OpenSearchDSL.nested(OpenSearchDSL.ref("message.info", STRING)), null)); + Set nestedOperatorArgs = Set.of("message.info"); + Map> groupedFieldsByPath = Map.of("message", List.of("message.info")); + + LogicalPlan plan = + project( + nested( + limit( + LogicalPlanDSL.dedupe( + rareTopN( + sort( + eval( + remove( + rename( + aggregation( + filter(values(emptyList()), filterExpr), + aggregators, + groupByExprs), + mappings), + exclude), + newEvalField), + sortField), + CommandType.TOP, + topByExprs, + rareTopNField), + dedupeField), + limit, + offset), + nestedArgs, + nestedProjectList), + include); + + PhysicalPlan actual = plan.accept(implementor, null); + + assertEquals( + PhysicalPlanDSL.project( + PhysicalPlanDSL.nested( + PhysicalPlanDSL.limit( + PhysicalPlanDSL.dedupe( + PhysicalPlanDSL.rareTopN( + PhysicalPlanDSL.sort( + PhysicalPlanDSL.eval( + PhysicalPlanDSL.remove( + PhysicalPlanDSL.rename( + PhysicalPlanDSL.agg( + PhysicalPlanDSL.filter( + PhysicalPlanDSL.values(emptyList()), + filterExpr), + aggregators, + groupByExprs), + mappings), + exclude), + newEvalField), + sortField), + CommandType.TOP, + topByExprs, + rareTopNField), + dedupeField), + limit, + offset), + nestedOperatorArgs, + groupedFieldsByPath), + include), + actual); + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/logical/LogicalPlanNodeVisitorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/logical/LogicalPlanNodeVisitorTest.java new file mode 100644 index 0000000000..b9831009c2 --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/logical/LogicalPlanNodeVisitorTest.java @@ -0,0 +1,195 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.planner.logical; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.mockito.Mockito.mock; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; +import static org.opensearch.sql.expression.DSL.named; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.opensearch.sql.ast.tree.RareTopN.CommandType; +import org.opensearch.sql.ast.tree.Sort.SortOption; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.LiteralExpression; +import org.opensearch.sql.expression.NamedExpression; +import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.expression.aggregation.Aggregator; +import org.opensearch.sql.expression.window.WindowDefinition; +import org.opensearch.sql.opensearch.expression.OpenSearchDSL; +import org.opensearch.sql.planner.logical.LogicalAggregation; +import org.opensearch.sql.planner.logical.LogicalCloseCursor; +import org.opensearch.sql.planner.logical.LogicalFetchCursor; +import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalNested; +import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.logical.LogicalPlanDSL; +import org.opensearch.sql.planner.logical.LogicalPlanNodeVisitor; +import org.opensearch.sql.planner.logical.LogicalRareTopN; +import org.opensearch.sql.planner.logical.LogicalRelation; +import org.opensearch.sql.planner.logical.LogicalRename; +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.storage.StorageEngine; +import org.opensearch.sql.storage.Table; +import org.opensearch.sql.storage.TableScanOperator; +import org.opensearch.sql.storage.read.TableScanBuilder; +import org.opensearch.sql.storage.write.TableWriteBuilder; +import org.opensearch.sql.storage.write.TableWriteOperator; + +/** Added for UT coverage */ +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +class LogicalPlanNodeVisitorTest { + + static Expression expression; + static ReferenceExpression ref; + static Aggregator aggregator; + static Table table; + + @BeforeAll + private static void initMocks() { + expression = mock(Expression.class); + ref = mock(ReferenceExpression.class); + aggregator = mock(Aggregator.class); + table = mock(Table.class); + } + +// @Test +// public void logical_plan_should_be_traversable() { +// LogicalPlan logicalPlan = +// LogicalPlanDSL.rename( +// LogicalPlanDSL.aggregation( +// LogicalPlanDSL.rareTopN( +// LogicalPlanDSL.filter(LogicalPlanDSL.relation("schema", table), expression), +// CommandType.TOP, +// ImmutableList.of(expression), +// expression), +// ImmutableList.of(OpenSearchDSL.named("avg", aggregator)), +// ImmutableList.of(OpenSearchDSL.named("group", expression))), +// ImmutableMap.of(ref, ref)); +// +// Integer result = logicalPlan.accept(new NodesCount(), null); +// assertEquals(5, result); +// } +// +// @SuppressWarnings("unchecked") +// private static Stream getLogicalPlansForVisitorTest() { +// LogicalPlan relation = LogicalPlanDSL.relation("schema", table); +// LogicalPlan tableScanBuilder = +// new TableScanBuilder() { +// @Override +// public TableScanOperator build() { +// return null; +// } +// }; +// TableWriteBuilder tableWriteBuilder = +// new TableWriteBuilder(null) { +// @Override +// public TableWriteOperator build(PhysicalPlan child) { +// return null; +// } +// }; +// LogicalPlan write = LogicalPlanDSL.write(null, table, Collections.emptyList()); +// LogicalPlan filter = LogicalPlanDSL.filter(relation, expression); +// LogicalPlan aggregation = +// LogicalPlanDSL.aggregation( +// filter, +// ImmutableList.of(OpenSearchDSL.named("avg", aggregator)), +// ImmutableList.of(OpenSearchDSL.named("group", expression))); +// +// +// List> nestedArgs = +// List.of( +// Map.of( +// "field", new ReferenceExpression("message.info", STRING), +// "path", new ReferenceExpression("message", STRING))); +// List projectList = +// List.of( +// new NamedExpression("message.info", OpenSearchDSL.nested(OpenSearchDSL.ref("message.info", STRING)), null)); +// +// LogicalNested nested = new LogicalNested(null, nestedArgs, projectList); +// +// LogicalFetchCursor cursor = new LogicalFetchCursor("n:test", mock(StorageEngine.class)); +// +// LogicalCloseCursor closeCursor = new LogicalCloseCursor(cursor); +// +// return Stream.of( +// relation, +// tableScanBuilder, +// write, +// tableWriteBuilder, +// filter, +// aggregation, +// nested, +// cursor, +// closeCursor) +// .map(Arguments::of); +// } +// +// @ParameterizedTest +// @MethodSource("getLogicalPlansForVisitorTest") +// public void abstract_plan_node_visitor_should_return_null(LogicalPlan plan) { +// assertNull(plan.accept(new LogicalPlanNodeVisitor() {}, null)); +// } + + private static class NodesCount extends LogicalPlanNodeVisitor { + @Override + public Integer visitRelation(LogicalRelation plan, Object context) { + return 1; + } + + @Override + public Integer visitFilter(LogicalFilter plan, Object context) { + return 1 + + plan.getChild().stream() + .map(child -> child.accept(this, context)) + .mapToInt(Integer::intValue) + .sum(); + } + + @Override + public Integer visitAggregation(LogicalAggregation plan, Object context) { + return 1 + + plan.getChild().stream() + .map(child -> child.accept(this, context)) + .mapToInt(Integer::intValue) + .sum(); + } + + @Override + public Integer visitRename(LogicalRename plan, Object context) { + return 1 + + plan.getChild().stream() + .map(child -> child.accept(this, context)) + .mapToInt(Integer::intValue) + .sum(); + } + + @Override + public Integer visitRareTopN(LogicalRareTopN plan, Object context) { + return 1 + + plan.getChild().stream() + .map(child -> child.accept(this, context)) + .mapToInt(Integer::intValue) + .sum(); + } + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/optimizer/LogicalPlanOptimizerTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/optimizer/LogicalPlanOptimizerTest.java new file mode 100644 index 0000000000..640739c07c --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/optimizer/LogicalPlanOptimizerTest.java @@ -0,0 +1,93 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.planner.optimizer; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.data.model.ExprValueUtils.integerValue; +import static org.opensearch.sql.data.model.ExprValueUtils.longValue; +import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.data.type.ExprCoreType.LONG; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.aggregation; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.filter; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.highlight; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.limit; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.nested; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.paginate; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.project; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.relation; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.sort; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.values; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.write; + +import com.google.common.collect.ImmutableList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.Spy; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.ast.tree.Sort; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.NamedExpression; +import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.opensearch.expression.OpenSearchDSL; +import org.opensearch.sql.planner.logical.LogicalPaginate; +import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.logical.LogicalPlanNodeVisitor; +import org.opensearch.sql.planner.logical.LogicalRelation; +import org.opensearch.sql.planner.optimizer.LogicalPlanOptimizer; +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.storage.Table; +import org.opensearch.sql.storage.read.TableScanBuilder; +import org.opensearch.sql.storage.write.TableWriteBuilder; + +@ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +class LogicalPlanOptimizerTest { + + @Mock private Table table; + + @Spy private TableScanBuilder tableScanBuilder; + + @BeforeEach + void setUp() { + lenient().when(table.createScanBuilder()).thenReturn(tableScanBuilder); + } + +// @Test +// void table_scan_builder_support_nested_push_down_can_apply_its_rule() { +// when(tableScanBuilder.pushDownNested(any())).thenReturn(true); +// +// assertEquals( +// tableScanBuilder, +// optimize( +// nested( +// relation("schema", table), +// List.of(Map.of("field", new ReferenceExpression("message.info", STRING))), +// List.of( +// new NamedExpression( +// "message.info", OpenSearchDSL.nested(OpenSearchDSL.ref("message.info", STRING)), null))))); +// } + + private LogicalPlan optimize(LogicalPlan plan) { + final LogicalPlanOptimizer optimizer = LogicalPlanOptimizer.create(); + return optimizer.optimize(plan); + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java index 483ea1290e..b7d541007c 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java @@ -57,6 +57,7 @@ import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; +import org.opensearch.sql.opensearch.expression.OpenSearchDSL; import org.opensearch.sql.opensearch.response.agg.CompositeAggregationParser; import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; import org.opensearch.sql.opensearch.response.agg.SingleValueParser; @@ -342,7 +343,7 @@ void test_push_down_nested() { List projectList = List.of( - new NamedExpression("message.info", DSL.nested(DSL.ref("message.info", STRING)), null) + new NamedExpression("message.info", OpenSearchDSL.nested(DSL.ref("message.info", STRING)), null) ); LogicalNested nested = new LogicalNested(null, args, projectList); @@ -375,8 +376,8 @@ void test_push_down_multiple_nested_with_same_path() { ); List projectList = List.of( - new NamedExpression("message.info", DSL.nested(DSL.ref("message.info", STRING)), null), - new NamedExpression("message.from", DSL.nested(DSL.ref("message.from", STRING)), null) + new NamedExpression("message.info", OpenSearchDSL.nested(DSL.ref("message.info", STRING)), null), + new NamedExpression("message.from", OpenSearchDSL.nested(DSL.ref("message.from", STRING)), null) ); LogicalNested nested = new LogicalNested(null, args, projectList); @@ -405,7 +406,7 @@ void test_push_down_nested_with_filter() { List projectList = List.of( - new NamedExpression("message.info", DSL.nested(DSL.ref("message.info", STRING)), null) + new NamedExpression("message.info", OpenSearchDSL.nested(DSL.ref("message.info", STRING)), null) ); LogicalNested nested = new LogicalNested(null, args, projectList); @@ -442,7 +443,7 @@ void testPushDownNestedWithNestedFilter() { List projectList = List.of( - new NamedExpression("message.info", DSL.nested(DSL.ref("message.info", STRING)), null) + new NamedExpression("message.info", OpenSearchDSL.nested(DSL.ref("message.info", STRING)), null) ); QueryBuilder innerFilterQuery = QueryBuilders.rangeQuery("myNum").gt(3); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngineTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngineTest.java index 1089e7e252..e2d37e625b 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngineTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngineTest.java @@ -52,4 +52,10 @@ public void getSystemTable() { () -> assertTrue(table instanceof OpenSearchSystemIndex) ); } + + @Test + public void getFunctions() { + OpenSearchStorageEngine engine = new OpenSearchStorageEngine(client, settings); + assertNotNull(engine.getFunctions()); + } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanOptimizationTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanOptimizationTest.java index e045bae3e3..7b552dc562 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanOptimizationTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanOptimizationTest.java @@ -73,8 +73,9 @@ import org.opensearch.sql.expression.HighlightExpression; import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.ReferenceExpression; -import org.opensearch.sql.expression.function.OpenSearchFunctions; +import org.opensearch.sql.expression.function.OpenSearchFunction; import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; +import org.opensearch.sql.opensearch.expression.OpenSearchDSL; import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; import org.opensearch.sql.opensearch.response.agg.CompositeAggregationParser; import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; @@ -164,14 +165,14 @@ void test_filter_on_opensearchfunction_with_trackedscores_push_down() { ), DSL.named("i", DSL.ref("intV", INTEGER)) ); - FunctionExpression queryString = DSL.query_string( + FunctionExpression queryString = OpenSearchDSL.query_string( DSL.namedArgument("fields", DSL.literal( new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( "intV", ExprValueUtils.floatValue(1.5F)))))), - DSL.namedArgument("query", "QUERY"), - DSL.namedArgument("boost", "12.5")); + OpenSearchDSL.namedArgument("query", "QUERY"), + OpenSearchDSL.namedArgument("boost", "12.5")); - ((OpenSearchFunctions.OpenSearchFunction) queryString).setScoreTracked(true); + ((OpenSearchFunction) queryString).setScoreTracked(true); LogicalPlan logicalPlan = project( filter( @@ -204,20 +205,20 @@ void test_filter_on_multiple_opensearchfunctions_with_trackedscores_push_down() ), DSL.named("i", DSL.ref("intV", INTEGER)) ); - FunctionExpression firstQueryString = DSL.query_string( + FunctionExpression firstQueryString = OpenSearchDSL.query_string( DSL.namedArgument("fields", DSL.literal( new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( "intV", ExprValueUtils.floatValue(1.5F)))))), - DSL.namedArgument("query", "QUERY"), - DSL.namedArgument("boost", "12.5")); - ((OpenSearchFunctions.OpenSearchFunction) firstQueryString).setScoreTracked(false); - FunctionExpression secondQueryString = DSL.query_string( + OpenSearchDSL.namedArgument("query", "QUERY"), + OpenSearchDSL.namedArgument("boost", "12.5")); + ((OpenSearchFunction) firstQueryString).setScoreTracked(false); + FunctionExpression secondQueryString = OpenSearchDSL.query_string( DSL.namedArgument("fields", DSL.literal( new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( "intV", ExprValueUtils.floatValue(1.5F)))))), - DSL.namedArgument("query", "QUERY"), - DSL.namedArgument("boost", "12.5")); - ((OpenSearchFunctions.OpenSearchFunction) secondQueryString).setScoreTracked(true); + OpenSearchDSL.namedArgument("query", "QUERY"), + OpenSearchDSL.namedArgument("boost", "12.5")); + ((OpenSearchFunction) secondQueryString).setScoreTracked(true); LogicalPlan logicalPlan = project( filter( @@ -243,12 +244,12 @@ void test_filter_on_opensearchfunction_without_trackedscores_push_down() { ), DSL.named("i", DSL.ref("intV", INTEGER)) ); - FunctionExpression queryString = DSL.query_string( + FunctionExpression queryString = OpenSearchDSL.query_string( DSL.namedArgument("fields", DSL.literal( new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( "intV", ExprValueUtils.floatValue(1.5F)))))), - DSL.namedArgument("query", "QUERY"), - DSL.namedArgument("boost", "12.5")); + OpenSearchDSL.namedArgument("query", "QUERY"), + OpenSearchDSL.namedArgument("boost", "12.5")); LogicalPlan logicalPlan = project( filter( @@ -413,7 +414,7 @@ void test_nested_push_down() { List projectList = List.of( - new NamedExpression("message.info", DSL.nested(DSL.ref("message.info", STRING)), null) + new NamedExpression("message.info", OpenSearchDSL.nested(DSL.ref("message.info", STRING)), null) ); LogicalNested nested = new LogicalNested(null, args, projectList); @@ -424,13 +425,13 @@ void test_nested_push_down() { indexScanBuilder( withNestedPushedDown(nested.getFields())), args, projectList), DSL.named("message.info", - DSL.nested(DSL.ref("message.info", STRING))) + OpenSearchDSL.nested(DSL.ref("message.info", STRING))) ), project( nested( relation("schema", table), args, projectList), DSL.named("message.info", - DSL.nested(DSL.ref("message.info", STRING))) + OpenSearchDSL.nested(DSL.ref("message.info", STRING))) ) ); } @@ -595,7 +596,7 @@ void test_nested_sort_filter_push_down() { DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))) ), Pair.of( - SortOption.DEFAULT_ASC, DSL.nested(DSL.ref("message.info", STRING)) + SortOption.DEFAULT_ASC, OpenSearchDSL.nested(DSL.ref("message.info", STRING)) ) ), DSL.named("intV", DSL.ref("intV", INTEGER)) @@ -611,14 +612,14 @@ void test_function_expression_sort_returns_optimized_logical_sort() { indexScanBuilder(), Pair.of( SortOption.DEFAULT_ASC, - DSL.match(DSL.namedArgument("field", literal("message"))) + OpenSearchDSL.match(DSL.namedArgument("field", literal("message"))) ) ), sort( relation("schema", table), Pair.of( SortOption.DEFAULT_ASC, - DSL.match(DSL.namedArgument("field", literal("message")) + OpenSearchDSL.match(DSL.namedArgument("field", literal("message")) ) ) ) diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java index eb07076257..8ec125fbb6 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java @@ -56,6 +56,7 @@ import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; import org.opensearch.sql.opensearch.data.type.OpenSearchTextType; +import org.opensearch.sql.opensearch.expression.OpenSearchDSL; import org.opensearch.sql.opensearch.storage.serialization.ExpressionSerializer; @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) @@ -320,7 +321,7 @@ void should_build_term_query_predicate_expression_with_nested_function() { + " }\n" + "}", buildQuery( - DSL.equal(DSL.nested( + DSL.equal(OpenSearchDSL.nested( DSL.ref("message.info", STRING), DSL.ref("message", STRING)), literal("string_value") @@ -352,7 +353,7 @@ void should_build_range_query_predicate_expression_with_nested_function() { + " }\n" + "}", buildQuery( - DSL.greater(DSL.nested( + DSL.greater(OpenSearchDSL.nested( DSL.ref("lottery.number.id", INTEGER)), literal(1234) ) ) @@ -365,7 +366,7 @@ void should_build_range_query_predicate_expression_with_nested_function() { void ensure_alternate_syntax_falls_back_to_legacy_engine() { assertThrows(SyntaxCheckException.class, () -> buildQuery( - DSL.nested( + OpenSearchDSL.nested( DSL.ref("message", STRING), DSL.equal(DSL.literal("message.info"), literal("a")) ) @@ -377,7 +378,7 @@ void ensure_alternate_syntax_falls_back_to_legacy_engine() { void nested_filter_wrong_right_side_type_in_predicate_throws_exception() { assertThrows(IllegalArgumentException.class, () -> buildQuery( - DSL.equal(DSL.nested( + DSL.equal(OpenSearchDSL.nested( DSL.ref("message.info", STRING), DSL.ref("message", STRING)), DSL.ref("string_value", STRING) @@ -390,7 +391,7 @@ void nested_filter_wrong_right_side_type_in_predicate_throws_exception() { void nested_filter_wrong_first_param_type_throws_exception() { assertThrows(IllegalArgumentException.class, () -> buildQuery( - DSL.equal(DSL.nested( + DSL.equal(OpenSearchDSL.nested( DSL.namedArgument("field", literal("message"))), literal("string_value") ) @@ -402,7 +403,7 @@ void nested_filter_wrong_first_param_type_throws_exception() { void nested_filter_wrong_second_param_type_throws_exception() { assertThrows(IllegalArgumentException.class, () -> buildQuery( - DSL.equal(DSL.nested( + DSL.equal(OpenSearchDSL.nested( DSL.ref("message.info", STRING), DSL.literal(2)), literal("string_value") @@ -415,7 +416,7 @@ void nested_filter_wrong_second_param_type_throws_exception() { void nested_filter_too_many_params_throws_exception() { assertThrows(IllegalArgumentException.class, () -> buildQuery( - DSL.equal(DSL.nested( + DSL.equal(OpenSearchDSL.nested( DSL.ref("message.info", STRING), DSL.ref("message", STRING), DSL.ref("message", STRING)), @@ -444,7 +445,7 @@ void should_build_match_query_with_default_parameters() { + " }\n" + "}", buildQuery( - DSL.match( + OpenSearchDSL.match( DSL.namedArgument("field", new ReferenceExpression("message", OpenSearchTextType.of())), DSL.namedArgument("query", literal("search query"))))); @@ -473,7 +474,7 @@ void should_build_match_query_with_custom_parameters() { + " }\n" + "}", buildQuery( - DSL.match( + OpenSearchDSL.match( DSL.namedArgument("field", new ReferenceExpression("message", OpenSearchTextType.of())), DSL.namedArgument("query", literal("search query")), @@ -493,7 +494,7 @@ void should_build_match_query_with_custom_parameters() { @Test void match_invalid_parameter() { - FunctionExpression expr = DSL.match( + FunctionExpression expr = OpenSearchDSL.match( DSL.namedArgument("field", new ReferenceExpression("message", OpenSearchTextType.of())), DSL.namedArgument("query", literal("search query")), @@ -504,7 +505,7 @@ void match_invalid_parameter() { @Test void match_disallow_duplicate_parameter() { - FunctionExpression expr = DSL.match( + FunctionExpression expr = OpenSearchDSL.match( DSL.namedArgument("field", literal("message")), DSL.namedArgument("query", literal("search query")), DSL.namedArgument("analyzer", literal("keyword")), @@ -515,7 +516,7 @@ void match_disallow_duplicate_parameter() { @Test void match_disallow_duplicate_query() { - FunctionExpression expr = DSL.match( + FunctionExpression expr = OpenSearchDSL.match( DSL.namedArgument("field", literal("message")), DSL.namedArgument("query", literal("search query")), DSL.namedArgument("analyzer", literal("keyword")), @@ -526,7 +527,7 @@ void match_disallow_duplicate_query() { @Test void match_disallow_duplicate_field() { - FunctionExpression expr = DSL.match( + FunctionExpression expr = OpenSearchDSL.match( DSL.namedArgument("field", literal("message")), DSL.namedArgument("query", literal("search query")), DSL.namedArgument("analyzer", literal("keyword")), @@ -537,7 +538,7 @@ void match_disallow_duplicate_field() { @Test void match_missing_field() { - FunctionExpression expr = DSL.match( + FunctionExpression expr = OpenSearchDSL.match( DSL.namedArgument("query", literal("search query")), DSL.namedArgument("analyzer", literal("keyword"))); var msg = assertThrows(SemanticCheckException.class, () -> buildQuery(expr)).getMessage(); @@ -546,7 +547,7 @@ void match_missing_field() { @Test void match_missing_query() { - FunctionExpression expr = DSL.match( + FunctionExpression expr = OpenSearchDSL.match( DSL.namedArgument("field", literal("field1")), DSL.namedArgument("analyzer", literal("keyword"))); var msg = assertThrows(SemanticCheckException.class, () -> buildQuery(expr)).getMessage(); @@ -567,7 +568,7 @@ void should_build_match_phrase_query_with_default_parameters() { + " }\n" + "}", buildQuery( - DSL.match_phrase( + OpenSearchDSL.match_phrase( DSL.namedArgument("field", new ReferenceExpression("message", OpenSearchTextType.of())), DSL.namedArgument("query", literal("search query"))))); @@ -592,7 +593,7 @@ void should_build_multi_match_query_with_default_parameters_single_field() { + " \"boost\" : 1.0,\n" + " }\n" + "}", - buildQuery(DSL.multi_match( + buildQuery(OpenSearchDSL.multi_match( DSL.namedArgument("fields", DSL.literal(new ExprTupleValue( new LinkedHashMap<>(ImmutableMap.of( "field1", ExprValueUtils.floatValue(1.F)))))), @@ -618,7 +619,7 @@ void should_build_multi_match_query_with_default_parameters_all_fields() { + " \"boost\" : 1.0,\n" + " }\n" + "}", - buildQuery(DSL.multi_match( + buildQuery(OpenSearchDSL.multi_match( DSL.namedArgument("fields", DSL.literal(new ExprTupleValue( new LinkedHashMap<>(ImmutableMap.of( "*", ExprValueUtils.floatValue(1.F)))))), @@ -642,7 +643,7 @@ void should_build_multi_match_query_with_default_parameters_no_fields() { + " \"boost\" : 1.0,\n" + " }\n" + "}", - buildQuery(DSL.multi_match( + buildQuery(OpenSearchDSL.multi_match( DSL.namedArgument("fields", DSL.literal(new ExprTupleValue( new LinkedHashMap<>(ImmutableMap.of())))), DSL.namedArgument("query", literal("search query"))))); @@ -667,7 +668,7 @@ void should_build_multi_match_query_with_default_parameters_multiple_fields() { + " \"boost\" : 1.0,\n" + " }\n" + "}"; - var actual = buildQuery(DSL.multi_match( + var actual = buildQuery(OpenSearchDSL.multi_match( DSL.namedArgument("fields", DSL.literal(new ExprTupleValue( new LinkedHashMap<>(ImmutableMap.of( "field1", ExprValueUtils.floatValue(1.F), @@ -705,7 +706,7 @@ void should_build_multi_match_query_with_custom_parameters() { + " }\n" + "}"; var actual = buildQuery( - DSL.multi_match( + OpenSearchDSL.multi_match( DSL.namedArgument("fields", DSL.literal( ExprValueUtils.tupleValue(ImmutableMap.of("field1", 1.F, "field2", .3F)))), DSL.namedArgument("query", literal("search query")), @@ -734,7 +735,7 @@ void should_build_multi_match_query_with_custom_parameters() { @Test void multi_match_invalid_parameter() { - FunctionExpression expr = DSL.multi_match( + FunctionExpression expr = OpenSearchDSL.multi_match( DSL.namedArgument("fields", DSL.literal( new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( "field1", ExprValueUtils.floatValue(1.F), @@ -760,7 +761,7 @@ void should_build_match_phrase_query_with_custom_parameters() { + " }\n" + "}", buildQuery( - DSL.match_phrase( + OpenSearchDSL.match_phrase( DSL.namedArgument("field", new ReferenceExpression("message", OpenSearchTextType.of())), DSL.namedArgument("boost", literal("1.2")), @@ -772,7 +773,7 @@ void should_build_match_phrase_query_with_custom_parameters() { @Test void wildcard_query_invalid_parameter() { - FunctionExpression expr = DSL.wildcard_query( + FunctionExpression expr = OpenSearchDSL.wildcard_query( DSL.namedArgument("field", new ReferenceExpression("field", OpenSearchTextType.of())), DSL.namedArgument("query", literal("search query*")), @@ -792,7 +793,7 @@ void wildcard_query_convert_sql_wildcard_to_lucene() { + " }\n" + " }\n" + "}", - buildQuery(DSL.wildcard_query( + buildQuery(OpenSearchDSL.wildcard_query( DSL.namedArgument("field", new ReferenceExpression("field", OpenSearchTextType.of())), DSL.namedArgument("query", literal("search query%"))))); @@ -805,7 +806,7 @@ void wildcard_query_convert_sql_wildcard_to_lucene() { + " }\n" + " }\n" + "}", - buildQuery(DSL.wildcard_query( + buildQuery(OpenSearchDSL.wildcard_query( DSL.namedArgument("field", new ReferenceExpression("field", OpenSearchTextType.of())), DSL.namedArgument("query", literal("search query_"))))); @@ -821,7 +822,7 @@ void wildcard_query_escape_wildcards_characters() { + " }\n" + " }\n" + "}", - buildQuery(DSL.wildcard_query( + buildQuery(OpenSearchDSL.wildcard_query( DSL.namedArgument("field", new ReferenceExpression("field", OpenSearchTextType.of())), DSL.namedArgument("query", literal("search query\\%"))))); @@ -834,7 +835,7 @@ void wildcard_query_escape_wildcards_characters() { + " }\n" + " }\n" + "}", - buildQuery(DSL.wildcard_query( + buildQuery(OpenSearchDSL.wildcard_query( DSL.namedArgument("field", new ReferenceExpression("field", OpenSearchTextType.of())), DSL.namedArgument("query", literal("search query\\_"))))); @@ -847,7 +848,7 @@ void wildcard_query_escape_wildcards_characters() { + " }\n" + " }\n" + "}", - buildQuery(DSL.wildcard_query( + buildQuery(OpenSearchDSL.wildcard_query( DSL.namedArgument("field", new ReferenceExpression("field", OpenSearchTextType.of())), DSL.namedArgument("query", literal("search query\\*"))))); @@ -860,7 +861,7 @@ void wildcard_query_escape_wildcards_characters() { + " }\n" + " }\n" + "}", - buildQuery(DSL.wildcard_query( + buildQuery(OpenSearchDSL.wildcard_query( DSL.namedArgument("field", new ReferenceExpression("field", OpenSearchTextType.of())), DSL.namedArgument("query", literal("search query\\?"))))); @@ -876,7 +877,7 @@ void should_build_wildcard_query_with_default_parameters() { + " }\n" + " }\n" + "}", - buildQuery(DSL.wildcard_query( + buildQuery(OpenSearchDSL.wildcard_query( DSL.namedArgument("field", new ReferenceExpression("field", OpenSearchTextType.of())), DSL.namedArgument("query", literal("search query*"))))); @@ -894,7 +895,7 @@ void should_build_wildcard_query_query_with_custom_parameters() { + " }\n" + " }\n" + "}", - buildQuery(DSL.wildcard_query( + buildQuery(OpenSearchDSL.wildcard_query( DSL.namedArgument("field", new ReferenceExpression("field", OpenSearchTextType.of())), DSL.namedArgument("query", literal("search query*")), @@ -905,7 +906,7 @@ void should_build_wildcard_query_query_with_custom_parameters() { @Test void query_invalid_parameter() { - FunctionExpression expr = DSL.query( + FunctionExpression expr = OpenSearchDSL.query( DSL.namedArgument("invalid_parameter", literal("invalid_value"))); assertThrows(SemanticCheckException.class, () -> buildQuery(expr), "Parameter invalid_parameter is invalid for query function."); @@ -913,7 +914,7 @@ void query_invalid_parameter() { @Test void query_invalid_fields_parameter_exception_message() { - FunctionExpression expr = DSL.query( + FunctionExpression expr = OpenSearchDSL.query( DSL.namedArgument("fields", literal("field1")), DSL.namedArgument("query", literal("search query"))); @@ -942,7 +943,7 @@ void should_build_query_query_with_default_parameters() { + " }\n" + "}"; - assertJsonEquals(expected, buildQuery(DSL.query( + assertJsonEquals(expected, buildQuery(OpenSearchDSL.query( DSL.namedArgument("query", literal("field1:query_value"))))); } @@ -972,7 +973,7 @@ void should_build_query_query_with_custom_parameters() { + " }\n" + "}"; var actual = buildQuery( - DSL.query( + OpenSearchDSL.query( DSL.namedArgument("query", literal("field1:query_value")), DSL.namedArgument("analyze_wildcard", literal("true")), DSL.namedArgument("analyzer", literal("keyword")), @@ -992,7 +993,7 @@ void should_build_query_query_with_custom_parameters() { @Test void query_string_invalid_parameter() { - FunctionExpression expr = DSL.query_string( + FunctionExpression expr = OpenSearchDSL.query_string( DSL.namedArgument("fields", DSL.literal( new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( "field1", ExprValueUtils.floatValue(1.F), @@ -1023,7 +1024,7 @@ void should_build_query_string_query_with_default_parameters_multiple_fields() { + " \"boost\" : 1.0\n" + " }\n" + "}"; - var actual = buildQuery(DSL.query_string( + var actual = buildQuery(OpenSearchDSL.query_string( DSL.namedArgument("fields", DSL.literal(new ExprTupleValue( new LinkedHashMap<>(ImmutableMap.of( "field1", ExprValueUtils.floatValue(1.F), @@ -1063,7 +1064,7 @@ void should_build_query_string_query_with_custom_parameters() { + " }\n" + "}"; var actual = buildQuery( - DSL.query_string( + OpenSearchDSL.query_string( DSL.namedArgument("fields", DSL.literal( ExprValueUtils.tupleValue(ImmutableMap.of("field1", 1.F, "field2", .3F)))), DSL.namedArgument("query", literal("query_value")), @@ -1109,7 +1110,7 @@ void should_build_query_string_query_with_default_parameters_single_field() { + " \"boost\" : 1.0,\n" + " }\n" + "}", - buildQuery(DSL.query_string( + buildQuery(OpenSearchDSL.query_string( DSL.namedArgument("fields", DSL.literal(new ExprTupleValue( new LinkedHashMap<>(ImmutableMap.of( "field1", ExprValueUtils.floatValue(1.F)))))), @@ -1138,7 +1139,7 @@ void should_build_simple_query_string_query_with_default_parameters_single_field + " \"boost\" : 1.0\n" + " }\n" + "}", - buildQuery(DSL.simple_query_string( + buildQuery(OpenSearchDSL.simple_query_string( DSL.namedArgument("fields", DSL.literal(new ExprTupleValue( new LinkedHashMap<>(ImmutableMap.of( "field1", ExprValueUtils.floatValue(1.F)))))), @@ -1161,7 +1162,7 @@ void should_build_simple_query_string_query_with_default_parameters_multiple_fie + " \"boost\" : 1.0\n" + " }\n" + "}"; - var actual = buildQuery(DSL.simple_query_string( + var actual = buildQuery(OpenSearchDSL.simple_query_string( DSL.namedArgument("fields", DSL.literal(new ExprTupleValue( new LinkedHashMap<>(ImmutableMap.of( "field1", ExprValueUtils.floatValue(1.F), @@ -1195,7 +1196,7 @@ void should_build_simple_query_string_query_with_custom_parameters() { + " }\n" + "}"; var actual = buildQuery( - DSL.simple_query_string( + OpenSearchDSL.simple_query_string( DSL.namedArgument("fields", DSL.literal( ExprValueUtils.tupleValue(ImmutableMap.of("field1", 1.F, "field2", .3F)))), DSL.namedArgument("query", literal("search query")), @@ -1220,7 +1221,7 @@ void should_build_simple_query_string_query_with_custom_parameters() { @Test void simple_query_string_invalid_parameter() { - FunctionExpression expr = DSL.simple_query_string( + FunctionExpression expr = OpenSearchDSL.simple_query_string( DSL.namedArgument("fields", DSL.literal( new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( "field1", ExprValueUtils.floatValue(1.F), @@ -1233,7 +1234,7 @@ void simple_query_string_invalid_parameter() { @Test void match_phrase_invalid_parameter() { - FunctionExpression expr = DSL.match_phrase( + FunctionExpression expr = OpenSearchDSL.match_phrase( DSL.namedArgument("field", new ReferenceExpression("message", OpenSearchTextType.of())), DSL.namedArgument("query", literal("search query")), @@ -1252,55 +1253,55 @@ void relevancy_func_invalid_arg_values() { "field2", ExprValueUtils.floatValue(.3F)))))); final var query = DSL.namedArgument("query", literal("search query")); - var slopTest = DSL.match_phrase(field, query, + var slopTest = OpenSearchDSL.match_phrase(field, query, DSL.namedArgument("slop", literal("1.5"))); var msg = assertThrows(RuntimeException.class, () -> buildQuery(slopTest)).getMessage(); assertEquals("Invalid slop value: '1.5'. Accepts only integer values.", msg); - var ztqTest = DSL.match_phrase(field, query, + var ztqTest = OpenSearchDSL.match_phrase(field, query, DSL.namedArgument("zero_terms_query", literal("meow"))); msg = assertThrows(RuntimeException.class, () -> buildQuery(ztqTest)).getMessage(); assertEquals( "Invalid zero_terms_query value: 'meow'. Available values are: NONE, ALL, NULL.", msg); - var boostTest = DSL.match(field, query, + var boostTest = OpenSearchDSL.match(field, query, DSL.namedArgument("boost", literal("pewpew"))); msg = assertThrows(RuntimeException.class, () -> buildQuery(boostTest)).getMessage(); assertEquals( "Invalid boost value: 'pewpew'. Accepts only floating point values greater than 0.", msg); - var boolTest = DSL.query_string(fields, query, + var boolTest = OpenSearchDSL.query_string(fields, query, DSL.namedArgument("escape", literal("42"))); msg = assertThrows(RuntimeException.class, () -> buildQuery(boolTest)).getMessage(); assertEquals( "Invalid escape value: '42'. Accepts only boolean values: 'true' or 'false'.", msg); - var typeTest = DSL.multi_match(fields, query, + var typeTest = OpenSearchDSL.multi_match(fields, query, DSL.namedArgument("type", literal("42"))); msg = assertThrows(RuntimeException.class, () -> buildQuery(typeTest)).getMessage(); assertTrue(msg.startsWith("Invalid type value: '42'. Available values are:")); - var operatorTest = DSL.simple_query_string(fields, query, + var operatorTest = OpenSearchDSL.simple_query_string(fields, query, DSL.namedArgument("default_operator", literal("42"))); msg = assertThrows(RuntimeException.class, () -> buildQuery(operatorTest)).getMessage(); assertTrue(msg.startsWith("Invalid default_operator value: '42'. Available values are:")); - var flagsTest = DSL.simple_query_string(fields, query, + var flagsTest = OpenSearchDSL.simple_query_string(fields, query, DSL.namedArgument("flags", literal("42"))); msg = assertThrows(RuntimeException.class, () -> buildQuery(flagsTest)).getMessage(); assertTrue(msg.startsWith("Invalid flags value: '42'. Available values are:")); - var fuzzinessTest = DSL.match_bool_prefix(field, query, + var fuzzinessTest = OpenSearchDSL.match_bool_prefix(field, query, DSL.namedArgument("fuzziness", literal("AUTO:"))); msg = assertThrows(RuntimeException.class, () -> buildQuery(fuzzinessTest)).getMessage(); assertTrue(msg.startsWith("Invalid fuzziness value: 'AUTO:'. Available values are:")); - var rewriteTest = DSL.match_bool_prefix(field, query, + var rewriteTest = OpenSearchDSL.match_bool_prefix(field, query, DSL.namedArgument("fuzzy_rewrite", literal("42"))); msg = assertThrows(RuntimeException.class, () -> buildQuery(rewriteTest)).getMessage(); assertTrue(msg.startsWith("Invalid fuzzy_rewrite value: '42'. Available values are:")); - var timezoneTest = DSL.query_string(fields, query, + var timezoneTest = OpenSearchDSL.query_string(fields, query, DSL.namedArgument("time_zone", literal("42"))); msg = assertThrows(RuntimeException.class, () -> buildQuery(timezoneTest)).getMessage(); assertTrue(msg.startsWith("Invalid time_zone value: '42'.")); @@ -1322,7 +1323,7 @@ void should_build_match_bool_prefix_query_with_default_parameters() { + " }\n" + "}", buildQuery( - DSL.match_bool_prefix( + OpenSearchDSL.match_bool_prefix( DSL.namedArgument("field", new ReferenceExpression("message", OpenSearchTextType.of())), DSL.namedArgument("query", literal("search query"))))); @@ -1330,7 +1331,7 @@ void should_build_match_bool_prefix_query_with_default_parameters() { @Test void multi_match_missing_fields_even_with_struct() { - FunctionExpression expr = DSL.multi_match( + FunctionExpression expr = OpenSearchDSL.multi_match( DSL.namedArgument("something-but-not-fields", DSL.literal( new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( "pewpew", ExprValueUtils.integerValue(42)))))), @@ -1342,7 +1343,7 @@ void multi_match_missing_fields_even_with_struct() { @Test void multi_match_missing_query_even_with_struct() { - FunctionExpression expr = DSL.multi_match( + FunctionExpression expr = OpenSearchDSL.multi_match( DSL.namedArgument("fields", DSL.literal( new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( "field1", ExprValueUtils.floatValue(1.F), @@ -1367,7 +1368,7 @@ void should_build_match_phrase_prefix_query_with_default_parameters() { + " }\n" + "}", buildQuery( - DSL.match_phrase_prefix( + OpenSearchDSL.match_phrase_prefix( DSL.namedArgument("field", new ReferenceExpression("message", OpenSearchTextType.of())), DSL.namedArgument("query", literal("search query"))))); @@ -1389,7 +1390,7 @@ void should_build_match_phrase_prefix_query_with_non_default_parameters() { + " }\n" + "}", buildQuery( - DSL.match_phrase_prefix( + OpenSearchDSL.match_phrase_prefix( DSL.namedArgument("field", new ReferenceExpression("message", OpenSearchTextType.of())), DSL.namedArgument("query", literal("search query")), diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchBoolPrefixQueryTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchBoolPrefixQueryTest.java index 6906619065..57bda5b4d8 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchBoolPrefixQueryTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchBoolPrefixQueryTest.java @@ -27,6 +27,7 @@ import org.opensearch.sql.expression.env.Environment; import org.opensearch.sql.expression.function.FunctionName; import org.opensearch.sql.opensearch.data.type.OpenSearchTextType; +import org.opensearch.sql.opensearch.expression.OpenSearchDSL; import org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance.MatchBoolPrefixQuery; @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) @@ -63,7 +64,7 @@ public void test_valid_when_two_arguments() { List arguments = List.of( DSL.namedArgument("field", new ReferenceExpression("field_value", OpenSearchTextType.of())), - DSL.namedArgument("query", "query_value")); + OpenSearchDSL.namedArgument("query", "query_value")); Assertions.assertNotNull(matchBoolPrefixQuery.build(new MatchExpression(arguments))); } @@ -76,7 +77,7 @@ public void test_SyntaxCheckException_when_no_arguments() { @Test public void test_SyntaxCheckException_when_one_argument() { - List arguments = List.of(DSL.namedArgument("field", "field_value")); + List arguments = List.of(OpenSearchDSL.namedArgument("field", "field_value")); assertThrows(SyntaxCheckException.class, () -> matchBoolPrefixQuery.build(new MatchExpression(arguments))); } @@ -86,8 +87,8 @@ public void test_SemanticCheckException_when_invalid_argument() { List arguments = List.of( DSL.namedArgument("field", new ReferenceExpression("field_value", OpenSearchTextType.of())), - DSL.namedArgument("query", "query_value"), - DSL.namedArgument("unsupported", "unsupported_value")); + OpenSearchDSL.namedArgument("query", "query_value"), + OpenSearchDSL.namedArgument("unsupported", "unsupported_value")); Assertions.assertThrows(SemanticCheckException.class, () -> matchBoolPrefixQuery.build(new MatchExpression(arguments))); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchPhrasePrefixQueryTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchPhrasePrefixQueryTest.java index 0defee0008..6e6ad8e5b1 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchPhrasePrefixQueryTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchPhrasePrefixQueryTest.java @@ -24,6 +24,7 @@ import org.opensearch.sql.expression.env.Environment; import org.opensearch.sql.expression.function.FunctionName; import org.opensearch.sql.opensearch.data.type.OpenSearchTextType; +import org.opensearch.sql.opensearch.expression.OpenSearchDSL; import org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance.MatchPhrasePrefixQuery; @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) @@ -52,8 +53,8 @@ public void test_SyntaxCheckException_when_invalid_parameter() { List arguments = List.of( DSL.namedArgument("field", new ReferenceExpression("test", OpenSearchTextType.of())), - DSL.namedArgument("query", "test2"), - DSL.namedArgument("unsupported", "3")); + OpenSearchDSL.namedArgument("query", "test2"), + OpenSearchDSL.namedArgument("unsupported", "3")); Assertions.assertThrows(SemanticCheckException.class, () -> matchPhrasePrefixQuery.build(new MatchPhraseExpression(arguments))); } @@ -63,8 +64,8 @@ public void test_analyzer_parameter() { List arguments = List.of( DSL.namedArgument("field", new ReferenceExpression("t1", OpenSearchTextType.of())), - DSL.namedArgument("query", "t2"), - DSL.namedArgument("analyzer", "standard") + OpenSearchDSL.namedArgument("query", "t2"), + OpenSearchDSL.namedArgument("analyzer", "standard") ); Assertions.assertNotNull(matchPhrasePrefixQuery.build(new MatchPhraseExpression(arguments))); } @@ -74,7 +75,7 @@ public void build_succeeds_with_two_arguments() { List arguments = List.of( DSL.namedArgument("field", new ReferenceExpression("test", OpenSearchTextType.of())), - DSL.namedArgument("query", "test2")); + OpenSearchDSL.namedArgument("query", "test2")); Assertions.assertNotNull(matchPhrasePrefixQuery.build(new MatchPhraseExpression(arguments))); } @@ -83,8 +84,8 @@ public void test_slop_parameter() { List arguments = List.of( DSL.namedArgument("field", new ReferenceExpression("t1", OpenSearchTextType.of())), - DSL.namedArgument("query", "t2"), - DSL.namedArgument("slop", "2") + OpenSearchDSL.namedArgument("query", "t2"), + OpenSearchDSL.namedArgument("slop", "2") ); Assertions.assertNotNull(matchPhrasePrefixQuery.build(new MatchPhraseExpression(arguments))); } @@ -94,8 +95,8 @@ public void test_zero_terms_query_parameter() { List arguments = List.of( DSL.namedArgument("field", new ReferenceExpression("t1", OpenSearchTextType.of())), - DSL.namedArgument("query", "t2"), - DSL.namedArgument("zero_terms_query", "ALL") + OpenSearchDSL.namedArgument("query", "t2"), + OpenSearchDSL.namedArgument("zero_terms_query", "ALL") ); Assertions.assertNotNull(matchPhrasePrefixQuery.build(new MatchPhraseExpression(arguments))); } @@ -105,8 +106,8 @@ public void test_zero_terms_query_parameter_lower_case() { List arguments = List.of( DSL.namedArgument("field", new ReferenceExpression("t1", OpenSearchTextType.of())), - DSL.namedArgument("query", "t2"), - DSL.namedArgument("zero_terms_query", "all") + OpenSearchDSL.namedArgument("query", "t2"), + OpenSearchDSL.namedArgument("zero_terms_query", "all") ); Assertions.assertNotNull(matchPhrasePrefixQuery.build(new MatchPhraseExpression(arguments))); } @@ -116,8 +117,8 @@ public void test_boost_parameter() { List arguments = List.of( DSL.namedArgument("field", new ReferenceExpression("test", OpenSearchTextType.of())), - DSL.namedArgument("query", "t2"), - DSL.namedArgument("boost", "0.1") + OpenSearchDSL.namedArgument("query", "t2"), + OpenSearchDSL.namedArgument("boost", "0.1") ); Assertions.assertNotNull(matchPhrasePrefixQuery.build(new MatchPhraseExpression(arguments))); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchPhraseQueryTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchPhraseQueryTest.java index 20ecb869ba..6f67d364d3 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchPhraseQueryTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchPhraseQueryTest.java @@ -24,6 +24,7 @@ import org.opensearch.sql.expression.env.Environment; import org.opensearch.sql.expression.function.FunctionName; import org.opensearch.sql.opensearch.data.type.OpenSearchTextType; +import org.opensearch.sql.opensearch.expression.OpenSearchDSL; import org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance.MatchPhraseQuery; @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) @@ -54,8 +55,8 @@ public void test_SyntaxCheckException_when_invalid_parameter() { List arguments = List.of( DSL.namedArgument("field", new ReferenceExpression("test", OpenSearchTextType.of())), - DSL.namedArgument("query", "test2"), - DSL.namedArgument("unsupported", "3")); + OpenSearchDSL.namedArgument("query", "test2"), + OpenSearchDSL.namedArgument("unsupported", "3")); Assertions.assertThrows(SemanticCheckException.class, () -> matchPhraseQuery.build(new MatchPhraseExpression(arguments))); } @@ -65,8 +66,8 @@ public void test_analyzer_parameter() { List arguments = List.of( DSL.namedArgument("field", new ReferenceExpression("t1", OpenSearchTextType.of())), - DSL.namedArgument("query", "t2"), - DSL.namedArgument("analyzer", "standard") + OpenSearchDSL.namedArgument("query", "t2"), + OpenSearchDSL.namedArgument("analyzer", "standard") ); Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression(arguments))); } @@ -76,7 +77,7 @@ public void build_succeeds_with_two_arguments() { List arguments = List.of( DSL.namedArgument("field", new ReferenceExpression("test", OpenSearchTextType.of())), - DSL.namedArgument("query", "test2")); + OpenSearchDSL.namedArgument("query", "test2")); Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression(arguments))); } @@ -85,8 +86,8 @@ public void test_slop_parameter() { List arguments = List.of( DSL.namedArgument("field", new ReferenceExpression("t1", OpenSearchTextType.of())), - DSL.namedArgument("query", "t2"), - DSL.namedArgument("slop", "2") + OpenSearchDSL.namedArgument("query", "t2"), + OpenSearchDSL.namedArgument("slop", "2") ); Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression(arguments))); } @@ -96,8 +97,8 @@ public void test_zero_terms_query_parameter() { List arguments = List.of( DSL.namedArgument("field", new ReferenceExpression("t1", OpenSearchTextType.of())), - DSL.namedArgument("query", "t2"), - DSL.namedArgument("zero_terms_query", "ALL") + OpenSearchDSL.namedArgument("query", "t2"), + OpenSearchDSL.namedArgument("zero_terms_query", "ALL") ); Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression(arguments))); } @@ -107,8 +108,8 @@ public void test_zero_terms_query_parameter_lower_case() { List arguments = List.of( DSL.namedArgument("field", new ReferenceExpression("t1", OpenSearchTextType.of())), - DSL.namedArgument("query", "t2"), - DSL.namedArgument("zero_terms_query", "all") + OpenSearchDSL.namedArgument("query", "t2"), + OpenSearchDSL.namedArgument("zero_terms_query", "all") ); Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression(arguments))); } @@ -136,8 +137,8 @@ public void test_SyntaxCheckException_when_invalid_parameter_match_phrase_syntax List arguments = List.of( DSL.namedArgument("field", new ReferenceExpression("test", OpenSearchTextType.of())), - DSL.namedArgument("query", "test2"), - DSL.namedArgument("unsupported", "3")); + OpenSearchDSL.namedArgument("query", "test2"), + OpenSearchDSL.namedArgument("unsupported", "3")); Assertions.assertThrows(SemanticCheckException.class, () -> matchPhraseQuery.build(new MatchPhraseExpression( arguments, matchPhraseWithUnderscoreName))); @@ -148,8 +149,8 @@ public void test_analyzer_parameter_match_phrase_syntax() { List arguments = List.of( DSL.namedArgument("field", new ReferenceExpression("t1", OpenSearchTextType.of())), - DSL.namedArgument("query", "t2"), - DSL.namedArgument("analyzer", "standard") + OpenSearchDSL.namedArgument("query", "t2"), + OpenSearchDSL.namedArgument("analyzer", "standard") ); Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression( arguments, matchPhraseWithUnderscoreName))); @@ -160,7 +161,7 @@ public void build_succeeds_with_two_arguments_match_phrase_syntax() { List arguments = List.of( DSL.namedArgument("field", new ReferenceExpression("test", OpenSearchTextType.of())), - DSL.namedArgument("query", "test2")); + OpenSearchDSL.namedArgument("query", "test2")); Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression( arguments, matchPhraseWithUnderscoreName))); } @@ -170,8 +171,8 @@ public void test_slop_parameter_match_phrase_syntax() { List arguments = List.of( DSL.namedArgument("field", new ReferenceExpression("t1", OpenSearchTextType.of())), - DSL.namedArgument("query", "t2"), - DSL.namedArgument("slop", "2") + OpenSearchDSL.namedArgument("query", "t2"), + OpenSearchDSL.namedArgument("slop", "2") ); Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression( arguments, matchPhraseWithUnderscoreName))); @@ -182,8 +183,8 @@ public void test_zero_terms_query_parameter_match_phrase_syntax() { List arguments = List.of( DSL.namedArgument("field", new ReferenceExpression("t1", OpenSearchTextType.of())), - DSL.namedArgument("query", "t2"), - DSL.namedArgument("zero_terms_query", "ALL") + OpenSearchDSL.namedArgument("query", "t2"), + OpenSearchDSL.namedArgument("zero_terms_query", "ALL") ); Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression( arguments, matchPhraseWithUnderscoreName))); @@ -194,8 +195,8 @@ public void test_zero_terms_query_parameter_lower_case_match_phrase_syntax() { List arguments = List.of( DSL.namedArgument("field", new ReferenceExpression("t1", OpenSearchTextType.of())), - DSL.namedArgument("query", "t2"), - DSL.namedArgument("zero_terms_query", "all") + OpenSearchDSL.namedArgument("query", "t2"), + OpenSearchDSL.namedArgument("zero_terms_query", "all") ); Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression( arguments, matchPhraseWithUnderscoreName))); @@ -224,8 +225,8 @@ public void test_SyntaxCheckException_when_invalid_parameter_matchphrase_syntax( List arguments = List.of( DSL.namedArgument("field", new ReferenceExpression("test", OpenSearchTextType.of())), - DSL.namedArgument("query", "test2"), - DSL.namedArgument("unsupported", "3")); + OpenSearchDSL.namedArgument("query", "test2"), + OpenSearchDSL.namedArgument("unsupported", "3")); Assertions.assertThrows(SemanticCheckException.class, () -> matchPhraseQuery.build(new MatchPhraseExpression( arguments, matchPhraseQueryName))); @@ -236,8 +237,8 @@ public void test_analyzer_parameter_matchphrase_syntax() { List arguments = List.of( DSL.namedArgument("field", new ReferenceExpression("t1", OpenSearchTextType.of())), - DSL.namedArgument("query", "t2"), - DSL.namedArgument("analyzer", "standard") + OpenSearchDSL.namedArgument("query", "t2"), + OpenSearchDSL.namedArgument("analyzer", "standard") ); Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression( arguments, matchPhraseQueryName))); @@ -248,7 +249,7 @@ public void build_succeeds_with_two_arguments_matchphrase_syntax() { List arguments = List.of( DSL.namedArgument("field", new ReferenceExpression("test", OpenSearchTextType.of())), - DSL.namedArgument("query", "test2")); + OpenSearchDSL.namedArgument("query", "test2")); Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression( arguments, matchPhraseQueryName))); } @@ -258,8 +259,8 @@ public void test_slop_parameter_matchphrase_syntax() { List arguments = List.of( DSL.namedArgument("field", new ReferenceExpression("t1", OpenSearchTextType.of())), - DSL.namedArgument("query", "t2"), - DSL.namedArgument("slop", "2") + OpenSearchDSL.namedArgument("query", "t2"), + OpenSearchDSL.namedArgument("slop", "2") ); Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression( arguments, matchPhraseQueryName))); @@ -270,8 +271,8 @@ public void test_zero_terms_query_parameter_matchphrase_syntax() { List arguments = List.of( DSL.namedArgument("field", new ReferenceExpression("t1", OpenSearchTextType.of())), - DSL.namedArgument("query", "t2"), - DSL.namedArgument("zero_terms_query", "ALL") + OpenSearchDSL.namedArgument("query", "t2"), + OpenSearchDSL.namedArgument("zero_terms_query", "ALL") ); Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression( arguments, matchPhraseQueryName))); @@ -282,8 +283,8 @@ public void test_zero_terms_query_parameter_lower_case_matchphrase_syntax() { List arguments = List.of( DSL.namedArgument("field", new ReferenceExpression("t1", OpenSearchTextType.of())), - DSL.namedArgument("query", "t2"), - DSL.namedArgument("zero_terms_query", "all") + OpenSearchDSL.namedArgument("query", "t2"), + OpenSearchDSL.namedArgument("zero_terms_query", "all") ); Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression( arguments, matchPhraseQueryName))); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MultiMatchTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MultiMatchTest.java index 93b0cdbc93..c447e7af8e 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MultiMatchTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MultiMatchTest.java @@ -31,6 +31,7 @@ import org.opensearch.sql.expression.NamedArgumentExpression; import org.opensearch.sql.expression.env.Environment; import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.opensearch.expression.OpenSearchDSL; import org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance.MultiMatchQuery; @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) @@ -197,7 +198,7 @@ public void test_SemanticCheckException_when_invalid_parameter_multiMatch() { List arguments = List.of( namedArgument("fields", fields_value), namedArgument("query", query_value), - DSL.namedArgument("unsupported", "unsupported_value")); + OpenSearchDSL.namedArgument("unsupported", "unsupported_value")); Assertions.assertThrows(SemanticCheckException.class, () -> multiMatchQuery.build(new MultiMatchExpression(arguments))); } @@ -207,7 +208,7 @@ public void test_SemanticCheckException_when_invalid_parameter_multi_match() { List arguments = List.of( namedArgument("fields", fields_value), namedArgument("query", query_value), - DSL.namedArgument("unsupported", "unsupported_value")); + OpenSearchDSL.namedArgument("unsupported", "unsupported_value")); Assertions.assertThrows(SemanticCheckException.class, () -> multiMatchQuery.build(new MultiMatchExpression(arguments, snakeCaseMultiMatchName))); } @@ -217,7 +218,7 @@ public void test_SemanticCheckException_when_invalid_parameter_multiMatchQuery() List arguments = List.of( namedArgument("fields", fields_value), namedArgument("query", query_value), - DSL.namedArgument("unsupported", "unsupported_value")); + OpenSearchDSL.namedArgument("unsupported", "unsupported_value")); Assertions.assertThrows(SemanticCheckException.class, () -> multiMatchQuery.build(new MultiMatchExpression(arguments, multiMatchQueryName))); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/WildcardQueryTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/WildcardQueryTest.java index 98bd7c5784..28807e424c 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/WildcardQueryTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/WildcardQueryTest.java @@ -26,6 +26,7 @@ import org.opensearch.sql.expression.env.Environment; import org.opensearch.sql.expression.function.FunctionName; import org.opensearch.sql.opensearch.data.type.OpenSearchTextType; +import org.opensearch.sql.opensearch.expression.OpenSearchDSL; import org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance.WildcardQuery; @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) @@ -38,10 +39,10 @@ static Stream> generateValidData() { List.of( namedArgument("field", new ReferenceExpression("title", OpenSearchTextType.of())), - namedArgument("query", "query_value*"), - namedArgument("boost", "0.7"), - namedArgument("case_insensitive", "false"), - namedArgument("rewrite", "constant_score_boolean") + OpenSearchDSL.namedArgument("query", "query_value*"), + OpenSearchDSL.namedArgument("boost", "0.7"), + OpenSearchDSL.namedArgument("case_insensitive", "false"), + OpenSearchDSL.namedArgument("rewrite", "constant_score_boolean") ) ); } @@ -73,8 +74,8 @@ public void test_SemanticCheckException_when_invalid_parameter() { List arguments = List.of( namedArgument("field", new ReferenceExpression("title", OpenSearchTextType.of())), - namedArgument("query", "query_value*"), - namedArgument("unsupported", "unsupported_value")); + OpenSearchDSL.namedArgument("query", "query_value*"), + OpenSearchDSL.namedArgument("unsupported", "unsupported_value")); Assertions.assertThrows(SemanticCheckException.class, () -> wildcardQueryQuery.build(new WildcardQueryExpression(arguments))); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/sort/SortQueryBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/sort/SortQueryBuilderTest.java index e84ed14e43..225f3e47d6 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/sort/SortQueryBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/sort/SortQueryBuilderTest.java @@ -18,6 +18,7 @@ import org.opensearch.sql.data.model.ExprShortValue; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.LiteralExpression; +import org.opensearch.sql.opensearch.expression.OpenSearchDSL; class SortQueryBuilderTest { @@ -32,7 +33,7 @@ void build_sortbuilder_from_reference() { void build_sortbuilder_from_nested_function() { assertNotNull( sortQueryBuilder.build( - DSL.nested(DSL.ref("message.info", STRING)), + OpenSearchDSL.nested(DSL.ref("message.info", STRING)), Sort.SortOption.DEFAULT_ASC ) ); @@ -42,7 +43,7 @@ void build_sortbuilder_from_nested_function() { void build_sortbuilder_from_nested_function_with_path_param() { assertNotNull( sortQueryBuilder.build( - DSL.nested(DSL.ref("message.info", STRING), DSL.ref("message", STRING)), + OpenSearchDSL.nested(DSL.ref("message.info", STRING), DSL.ref("message", STRING)), Sort.SortOption.DEFAULT_ASC ) ); @@ -53,7 +54,7 @@ void nested_with_too_many_args_throws_exception() { assertThrows( IllegalArgumentException.class, () -> sortQueryBuilder.build( - DSL.nested( + OpenSearchDSL.nested( DSL.ref("message.info", STRING), DSL.ref("message", STRING), DSL.ref("message", STRING) @@ -68,7 +69,7 @@ void nested_with_too_few_args_throws_exception() { assertThrows( IllegalArgumentException.class, () -> sortQueryBuilder.build( - DSL.nested(), + OpenSearchDSL.nested(), Sort.SortOption.DEFAULT_ASC ) ); @@ -79,7 +80,7 @@ void nested_with_invalid_arg_type_throws_exception() { assertThrows( IllegalArgumentException.class, () -> sortQueryBuilder.build( - DSL.nested( + OpenSearchDSL.nested( DSL.literal(1) ), Sort.SortOption.DEFAULT_ASC diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index 5e156c2f5d..33d3c888c6 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -175,7 +175,7 @@ public Collection createComponents( LocalClusterState.state().setPluginSettings((OpenSearchSettings) pluginSettings); ModulesBuilder modules = new ModulesBuilder(); - modules.add(new OpenSearchPluginModule()); + modules.add(new OpenSearchPluginModule(dataSourceService)); modules.add(b -> { b.bind(NodeClient.class).toInstance((NodeClient) client); b.bind(org.opensearch.sql.common.setting.Settings.class).toInstance(pluginSettings); diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/config/OpenSearchPluginModule.java b/plugin/src/main/java/org/opensearch/sql/plugin/config/OpenSearchPluginModule.java index f301a242fb..c6aaafd28b 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/config/OpenSearchPluginModule.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/config/OpenSearchPluginModule.java @@ -41,8 +41,11 @@ @RequiredArgsConstructor public class OpenSearchPluginModule extends AbstractModule { - private final BuiltinFunctionRepository functionRepository = - BuiltinFunctionRepository.getInstance(); + private final BuiltinFunctionRepository functionRepository; + + public OpenSearchPluginModule(DataSourceService dataSourceService) { + functionRepository = BuiltinFunctionRepository.getInstance(dataSourceService); + } @Override protected void configure() { diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java b/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java index 8a9d276673..7d097b4526 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java @@ -58,7 +58,7 @@ public TransportPPLQueryAction( super(PPLQueryAction.NAME, transportService, actionFilters, TransportPPLQueryRequest::new); ModulesBuilder modules = new ModulesBuilder(); - modules.add(new OpenSearchPluginModule()); + modules.add(new OpenSearchPluginModule(dataSourceService)); modules.add( b -> { b.bind(NodeClient.class).toInstance(client); diff --git a/prometheus/src/main/java/org/opensearch/sql/prometheus/storage/PrometheusStorageEngine.java b/prometheus/src/main/java/org/opensearch/sql/prometheus/storage/PrometheusStorageEngine.java index 738eb023b6..e9f0a2417a 100644 --- a/prometheus/src/main/java/org/opensearch/sql/prometheus/storage/PrometheusStorageEngine.java +++ b/prometheus/src/main/java/org/opensearch/sql/prometheus/storage/PrometheusStorageEngine.java @@ -62,6 +62,4 @@ private Table resolveInformationSchemaTable(DataSourceSchemaName dataSourceSchem String.format("Information Schema doesn't contain %s table", tableName)); } } - - }