diff --git a/docs/querying/sql-operators.md b/docs/querying/sql-operators.md index 81c441c03367..da295821ecf9 100644 --- a/docs/querying/sql-operators.md +++ b/docs/querying/sql-operators.md @@ -79,7 +79,9 @@ Also see the [CONCAT function](sql-scalar.md#string-functions). |Operator|Description| |--------|-----------| |`x = y` |Equal to| +|`x IS NOT DISTINCT FROM y`|Equal to, considering `NULL` as a value. Never returns `NULL`.| |`x <> y`|Not equal to| +|`x IS DISTINCT FROM y`|Not equal to, considering `NULL` as a value. Never returns `NULL`.| |`x > y` |Greater than| |`x >= y`|Greater than or equal to| |`x < y` |Less than| diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java index 477c3e0e1982..a826f1928e6b 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java @@ -218,7 +218,7 @@ private static JoinAlgorithm deduceJoinAlgorithm(JoinAlgorithm preferredJoinAlgo JoinAlgorithm deducedJoinAlgorithm; if (JoinAlgorithm.BROADCAST.equals(preferredJoinAlgorithm)) { deducedJoinAlgorithm = JoinAlgorithm.BROADCAST; - } else if (isConditionEqualityOnLeftAndRightColumns(joinDataSource.getConditionAnalysis())) { + } else if (canUseSortMergeJoin(joinDataSource.getConditionAnalysis())) { deducedJoinAlgorithm = JoinAlgorithm.SORT_MERGE; } else { deducedJoinAlgorithm = JoinAlgorithm.BROADCAST; @@ -237,15 +237,21 @@ private static JoinAlgorithm deduceJoinAlgorithm(JoinAlgorithm preferredJoinAlgo } /** - * Checks if the join condition on two tables "table1" and "table2" is of the form + * Checks if the sortMerge algorithm can execute a particular join condition. + * + * Two checks: + * (1) join condition on two tables "table1" and "table2" is of the form * table1.columnA = table2.columnA && table1.columnB = table2.columnB && .... - * sortMerge algorithm can help these types of join conditions + * + * (2) join condition uses equals, not IS NOT DISTINCT FROM [sortMerge processor does not currently implement + * IS NOT DISTINCT FROM] */ - private static boolean isConditionEqualityOnLeftAndRightColumns(JoinConditionAnalysis joinConditionAnalysis) + private static boolean canUseSortMergeJoin(JoinConditionAnalysis joinConditionAnalysis) { - return joinConditionAnalysis.getEquiConditions() - .stream() - .allMatch(equality -> equality.getLeftExpr().isIdentifier()); + return joinConditionAnalysis + .getEquiConditions() + .stream() + .allMatch(equality -> equality.getLeftExpr().isIdentifier() && !equality.isIncludeNull()); } /** diff --git a/processing/src/main/java/org/apache/druid/math/expr/Exprs.java b/processing/src/main/java/org/apache/druid/math/expr/Exprs.java index e8ad020fe700..72ad9fabf825 100644 --- a/processing/src/main/java/org/apache/druid/math/expr/Exprs.java +++ b/processing/src/main/java/org/apache/druid/math/expr/Exprs.java @@ -19,11 +19,13 @@ package org.apache.druid.math.expr; -import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.UOE; +import org.apache.druid.segment.join.Equality; +import org.apache.druid.segment.join.JoinPrefixUtils; import java.util.ArrayList; import java.util.List; +import java.util.Objects; import java.util.Optional; import java.util.Stack; @@ -79,16 +81,56 @@ public static List decomposeAnd(final Expr expr) } /** - * Decomposes an equality expr into the left- and right-hand side. + * Decomposes an equality expr into an {@link Equality}. Used by join-related code to identify equi-joins. * * @return decomposed equality, or empty if the input expr was not an equality expr */ - public static Optional> decomposeEquals(final Expr expr) + public static Optional decomposeEquals(final Expr expr, final String rightPrefix) { + final Expr lhs; + final Expr rhs; + final boolean includeNull; + if (expr instanceof BinEqExpr) { - return Optional.of(Pair.of(((BinEqExpr) expr).left, ((BinEqExpr) expr).right)); + lhs = ((BinEqExpr) expr).left; + rhs = ((BinEqExpr) expr).right; + includeNull = false; + } else if (expr instanceof FunctionExpr + && ((FunctionExpr) expr).function instanceof Function.IsNotDistinctFromFunc) { + final List args = ((FunctionExpr) expr).args; + lhs = args.get(0); + rhs = args.get(1); + includeNull = true; + } else { + return Optional.empty(); + } + + if (isLeftExprAndRightColumn(lhs, rhs, rightPrefix)) { + // rhs is a right-hand column; lhs is an expression solely of the left-hand side. + return Optional.of( + new Equality( + lhs, + Objects.requireNonNull(rhs.getBindingIfIdentifier()).substring(rightPrefix.length()), + includeNull + ) + ); + } else if (isLeftExprAndRightColumn(rhs, lhs, rightPrefix)) { + return Optional.of( + new Equality( + rhs, + Objects.requireNonNull(lhs.getBindingIfIdentifier()).substring(rightPrefix.length()), + includeNull + ) + ); } else { return Optional.empty(); } } + + private static boolean isLeftExprAndRightColumn(final Expr a, final Expr b, final String rightPrefix) + { + return a.analyzeInputs().getRequiredBindings().stream().noneMatch(c -> JoinPrefixUtils.isPrefixedBy(c, rightPrefix)) + && b.getBindingIfIdentifier() != null + && JoinPrefixUtils.isPrefixedBy(b.getBindingIfIdentifier(), rightPrefix); + } } diff --git a/processing/src/main/java/org/apache/druid/math/expr/Function.java b/processing/src/main/java/org/apache/druid/math/expr/Function.java index 406ffac1ea7d..cabeb557792c 100644 --- a/processing/src/main/java/org/apache/druid/math/expr/Function.java +++ b/processing/src/main/java/org/apache/druid/math/expr/Function.java @@ -19,6 +19,7 @@ package org.apache.druid.math.expr; +import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableSet; import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.DateTimes; @@ -2225,6 +2226,108 @@ public ExprVectorProcessor asVectorProcessor(Expr.VectorInputBindingInspe } } + /** + * SQL function "x IS NOT DISTINCT FROM y". Very similar to "x = y", i.e. {@link BinEqExpr}, except this function + * never returns null, and this function considers NULL as a value, so NULL itself is not-distinct-from NULL. For + * example: `x == null` returns `null` in SQL-compatible null handling mode, but `notdistinctfrom(x, null)` is + * true if `x` is null. + */ + class IsNotDistinctFromFunc implements Function + { + @Override + public String name() + { + return "notdistinctfrom"; + } + + @Override + public ExprEval apply(List args, Expr.ObjectBinding bindings) + { + final ExprEval leftVal = args.get(0).eval(bindings); + final ExprEval rightVal = args.get(1).eval(bindings); + + if (leftVal.value() == null || rightVal.value() == null) { + return ExprEval.ofLongBoolean(leftVal.value() == null && rightVal.value() == null); + } + + // Code copied and adapted from BinaryBooleanOpExprBase and BinEqExpr. + // The code isn't shared due to differences in code structure: BinaryBooleanOpExprBase + BinEqExpr have logic + // interleaved between parent and child class, but we can't use BinaryBooleanOpExprBase as a parent here, because + // (a) this is a function, not an expr; and (b) our logic for handling and returning nulls is different from most + // binary exprs, where null in means null out. + final ExpressionType comparisonType = ExpressionTypeConversion.autoDetect(leftVal, rightVal); + switch (comparisonType.getType()) { + case STRING: + return ExprEval.ofLongBoolean(Objects.equals(leftVal.asString(), rightVal.asString())); + case LONG: + return ExprEval.ofLongBoolean(leftVal.asLong() == rightVal.asLong()); + case ARRAY: + final ExpressionType type = Preconditions.checkNotNull( + ExpressionTypeConversion.leastRestrictiveType(leftVal.type(), rightVal.type()), + "Cannot be null because ExprEval type is not nullable" + ); + return ExprEval.ofLongBoolean( + type.getNullableStrategy().compare(leftVal.castTo(type).asArray(), rightVal.castTo(type).asArray()) == 0 + ); + case DOUBLE: + default: + if (leftVal.isNumericNull() || rightVal.isNumericNull()) { + return ExprEval.ofLongBoolean(leftVal.isNumericNull() && rightVal.isNumericNull()); + } else { + return ExprEval.ofLongBoolean(leftVal.asDouble() == rightVal.asDouble()); + } + } + } + + @Override + public void validateArguments(List args) + { + validationHelperCheckArgumentCount(args, 2); + } + + @Nullable + @Override + public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List args) + { + return ExpressionType.LONG; + } + } + + /** + * SQL function "x IS DISTINCT FROM y". Very similar to "x <> y", i.e. {@link BinNeqExpr}, except this function + * never returns null. + * + * Implemented as a subclass of IsNotDistinctFromFunc to keep the code simple, and because we expect "notdistinctfrom" + * to be more common than "isdistinctfrom" in actual usage. + */ + class IsDistinctFromFunc extends IsNotDistinctFromFunc + { + @Override + public String name() + { + return "isdistinctfrom"; + } + + @Override + public ExprEval apply(List args, Expr.ObjectBinding bindings) + { + return ExprEval.ofLongBoolean(!super.apply(args, bindings).asBoolean()); + } + + @Override + public void validateArguments(List args) + { + validationHelperCheckArgumentCount(args, 2); + } + + @Nullable + @Override + public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List args) + { + return ExpressionType.LONG; + } + } + /** * SQL function "IS NOT FALSE". Different from "IS TRUE" in that it returns true for NULL as well. */ diff --git a/processing/src/main/java/org/apache/druid/query/filter/InDimFilter.java b/processing/src/main/java/org/apache/druid/query/filter/InDimFilter.java index 47c3d78a237b..fcbd6aa49605 100644 --- a/processing/src/main/java/org/apache/druid/query/filter/InDimFilter.java +++ b/processing/src/main/java/org/apache/druid/query/filter/InDimFilter.java @@ -659,9 +659,9 @@ public ValuesSet() /** * Create a ValuesSet from another Collection. The Collection will be reused if it is a {@link SortedSet} with - * an appropriate comparator. + * the {@link Comparators#naturalNullsFirst()} comparator. */ - public ValuesSet(final Collection values) + private ValuesSet(final Collection values) { if (values instanceof SortedSet && Comparators.naturalNullsFirst() .equals(((SortedSet) values).comparator())) { @@ -672,6 +672,36 @@ public ValuesSet(final Collection values) } } + /** + * Creates an empty ValuesSet. + */ + public static ValuesSet create() + { + return new ValuesSet(new TreeSet<>(Comparators.naturalNullsFirst())); + } + + /** + * Creates a ValuesSet wrapping the provided single value. + * + * @throws IllegalStateException if the provided collection cannot be wrapped since it has the wrong comparator + */ + public static ValuesSet of(@Nullable final String value) + { + final ValuesSet retVal = ValuesSet.create(); + retVal.add(value); + return retVal; + } + + /** + * Creates a ValuesSet copying the provided collection. + */ + public static ValuesSet copyOf(final Collection values) + { + final TreeSet copyOfValues = new TreeSet<>(Comparators.naturalNullsFirst()); + copyOfValues.addAll(values); + return new ValuesSet(copyOfValues); + } + public SortedSet toUtf8() { final TreeSet valuesUtf8 = new TreeSet<>(ByteBufferUtils.utf8Comparator()); diff --git a/processing/src/main/java/org/apache/druid/segment/join/Equality.java b/processing/src/main/java/org/apache/druid/segment/join/Equality.java index 6b839c1f0dc8..3e1e4ea31492 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/Equality.java +++ b/processing/src/main/java/org/apache/druid/segment/join/Equality.java @@ -32,11 +32,13 @@ public class Equality { private final Expr leftExpr; private final String rightColumn; + private final boolean includeNull; - public Equality(final Expr leftExpr, final String rightColumn) + public Equality(final Expr leftExpr, final String rightColumn, final boolean includeNull) { this.leftExpr = leftExpr; this.rightColumn = rightColumn; + this.includeNull = includeNull; } public Expr getLeftExpr() @@ -49,12 +51,22 @@ public String getRightColumn() return rightColumn; } + /** + * Whether null is treated as a value that can be equal to itself. True for conditions using "IS NOT DISTINCT FROM", + * false for conditions using regular equals. + */ + public boolean isIncludeNull() + { + return includeNull; + } + @Override public String toString() { return "Equality{" + "leftExpr=" + leftExpr + ", rightColumn='" + rightColumn + '\'' + + ", includeNull=" + includeNull + '}'; } } diff --git a/processing/src/main/java/org/apache/druid/segment/join/JoinConditionAnalysis.java b/processing/src/main/java/org/apache/druid/segment/join/JoinConditionAnalysis.java index 2a33da22d131..77d474720201 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/JoinConditionAnalysis.java +++ b/processing/src/main/java/org/apache/druid/segment/join/JoinConditionAnalysis.java @@ -20,7 +20,6 @@ package org.apache.druid.segment.join; import com.google.common.base.Preconditions; -import org.apache.druid.java.util.common.Pair; import org.apache.druid.math.expr.Expr; import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.math.expr.Exprs; @@ -121,40 +120,18 @@ public static JoinConditionAnalysis forExpression( final List exprs = Exprs.decomposeAnd(conditionExpr); for (Expr childExpr : exprs) { - final Optional> maybeDecomposed = Exprs.decomposeEquals(childExpr); + final Optional maybeEquality = Exprs.decomposeEquals(childExpr, rightPrefix); - if (!maybeDecomposed.isPresent()) { + if (!maybeEquality.isPresent()) { nonEquiConditions.add(childExpr); } else { - final Pair decomposed = maybeDecomposed.get(); - final Expr lhs = Objects.requireNonNull(decomposed.lhs); - final Expr rhs = Objects.requireNonNull(decomposed.rhs); - - if (isLeftExprAndRightColumn(lhs, rhs, rightPrefix)) { - // rhs is a right-hand column; lhs is an expression solely of the left-hand side. - equiConditions.add( - new Equality(lhs, Objects.requireNonNull(rhs.getBindingIfIdentifier()).substring(rightPrefix.length())) - ); - } else if (isLeftExprAndRightColumn(rhs, lhs, rightPrefix)) { - equiConditions.add( - new Equality(rhs, Objects.requireNonNull(lhs.getBindingIfIdentifier()).substring(rightPrefix.length())) - ); - } else { - nonEquiConditions.add(childExpr); - } + equiConditions.add(maybeEquality.get()); } } return new JoinConditionAnalysis(condition, rightPrefix, equiConditions, nonEquiConditions); } - private static boolean isLeftExprAndRightColumn(final Expr a, final Expr b, final String rightPrefix) - { - return a.analyzeInputs().getRequiredBindings().stream().noneMatch(c -> JoinPrefixUtils.isPrefixedBy(c, rightPrefix)) - && b.getBindingIfIdentifier() != null - && JoinPrefixUtils.isPrefixedBy(b.getBindingIfIdentifier(), rightPrefix); - } - /** * Return the condition expression. */ diff --git a/processing/src/main/java/org/apache/druid/segment/join/Joinable.java b/processing/src/main/java/org/apache/druid/segment/join/Joinable.java index b4f35fbc4b31..dab102d44932 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/Joinable.java +++ b/processing/src/main/java/org/apache/druid/segment/join/Joinable.java @@ -20,6 +20,7 @@ package org.apache.druid.segment.join; import org.apache.druid.java.util.common.io.Closer; +import org.apache.druid.query.filter.InDimFilter; import org.apache.druid.segment.ColumnSelectorFactory; import org.apache.druid.segment.ReferenceCountedObject; import org.apache.druid.segment.column.ColumnCapabilities; @@ -86,26 +87,29 @@ JoinMatcher makeJoinMatcher( ); /** - * Returns all non-null values from a particular column along with a flag to tell if they are all unique in the column. - * If the non-null values are greater than "maxNumValues" or if the column doesn't exists or doesn't supports this + * Returns all matchable values from a particular column along with a flag to tell if they are all unique in the column. + * If the matchable values are greater than "maxNumValues" or if the column doesn't exists or doesn't supports this * operation, returns an object with empty set for column values and false for uniqueness flag. - * The uniqueness flag will only be true if we've collected all non-null values in the column and found that they're + * The uniqueness flag will only be true if we've collected all matchable values in the column and found that they're * all unique. In all other cases it will be false. * - * The returned set may be passed to {@link org.apache.druid.query.filter.InDimFilter}. For efficiency, + * The returned set may be passed to {@link InDimFilter}. For efficiency, * implementations should prefer creating the returned set with * {@code new TreeSet(Comparators.naturalNullsFirst()}}. This avoids a copy in the filter's constructor. * * @param columnName name of the column - * @param maxNumValues maximum number of values to return + * @param includeNull whether null should be considered a matchable value. If true, this method returns all values + * that are present in the column. If false, this method returns all non-null values. + * @param maxNumValues maximum number of values to return. If exceeded, returns an empty set with the "allUnique" + * flag set to false. */ - ColumnValuesWithUniqueFlag getNonNullColumnValues(String columnName, int maxNumValues); + ColumnValuesWithUniqueFlag getMatchableColumnValues(String columnName, boolean includeNull, int maxNumValues); /** * Searches a column from this Joinable for a particular value, finds rows that match, * and returns values of a second column for those rows. * - * The returned set may be passed to {@link org.apache.druid.query.filter.InDimFilter}. For efficiency, + * The returned set may be passed to {@link InDimFilter}. For efficiency, * implementations should prefer creating the returned set with * {@code new TreeSet(Comparators.naturalNullsFirst()}}. This avoids a copy in the filter's constructor. * @@ -121,7 +125,7 @@ JoinMatcher makeJoinMatcher( * * In case either the search or retrieval column names are not found, this will return absent. */ - Optional> getCorrelatedColumnValues( + Optional getCorrelatedColumnValues( String searchColumnName, String searchColumnValue, String retrievalColumnName, diff --git a/processing/src/main/java/org/apache/druid/segment/join/JoinableFactoryWrapper.java b/processing/src/main/java/org/apache/druid/segment/join/JoinableFactoryWrapper.java index df4d14cf621c..0bebc5aab592 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/JoinableFactoryWrapper.java +++ b/processing/src/main/java/org/apache/druid/segment/join/JoinableFactoryWrapper.java @@ -170,7 +170,9 @@ static JoinClauseToFilterConversion convertJoinToFilter( } Joinable.ColumnValuesWithUniqueFlag columnValuesWithUniqueFlag = - clause.getJoinable().getNonNullColumnValues(condition.getRightColumn(), maxNumFilterValues); + clause.getJoinable() + .getMatchableColumnValues(condition.getRightColumn(), condition.isIncludeNull(), maxNumFilterValues); + // For an empty values set, isAllUnique flag will be true only if the column had no non-null values. if (columnValuesWithUniqueFlag.getColumnValues().isEmpty()) { if (columnValuesWithUniqueFlag.isAllUnique()) { diff --git a/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterAnalyzer.java b/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterAnalyzer.java index ddbaa34d4061..a4c06e79826c 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterAnalyzer.java +++ b/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterAnalyzer.java @@ -452,7 +452,7 @@ private static JoinFilterAnalysis rewriteSelectorFilter( for (JoinFilterColumnCorrelationAnalysis correlationAnalysis : correlationAnalyses) { if (correlationAnalysis.supportsPushDown()) { - Optional> correlatedValues = correlationAnalysis.getCorrelatedValuesMap().get( + Optional correlatedValues = correlationAnalysis.getCorrelatedValuesMap().get( Pair.of(filteringColumn, filteringValue) ); @@ -460,7 +460,7 @@ private static JoinFilterAnalysis rewriteSelectorFilter( return JoinFilterAnalysis.createNoPushdownFilterAnalysis(selectorFilter); } - Set newFilterValues = correlatedValues.get(); + InDimFilter.ValuesSet newFilterValues = correlatedValues.get(); // in nothing => match nothing if (newFilterValues.isEmpty()) { return new JoinFilterAnalysis( diff --git a/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterColumnCorrelationAnalysis.java b/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterColumnCorrelationAnalysis.java index 6071f404e499..8c8ec795d187 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterColumnCorrelationAnalysis.java +++ b/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterColumnCorrelationAnalysis.java @@ -21,6 +21,7 @@ import org.apache.druid.java.util.common.Pair; import org.apache.druid.math.expr.Expr; +import org.apache.druid.query.filter.InDimFilter; import javax.annotation.Nonnull; import java.util.ArrayList; @@ -43,7 +44,7 @@ public class JoinFilterColumnCorrelationAnalysis private final String joinColumn; @Nonnull private final List baseColumns; @Nonnull private final List baseExpressions; - private final Map, Optional>> correlatedValuesMap; + private final Map, Optional> correlatedValuesMap; public JoinFilterColumnCorrelationAnalysis( String joinColumn, @@ -75,7 +76,7 @@ public List getBaseExpressions() return baseExpressions; } - public Map, Optional>> getCorrelatedValuesMap() + public Map, Optional> getCorrelatedValuesMap() { return correlatedValuesMap; } diff --git a/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterCorrelations.java b/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterCorrelations.java index ed9fe0756251..39c5188984ef 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterCorrelations.java +++ b/processing/src/main/java/org/apache/druid/segment/join/filter/JoinFilterCorrelations.java @@ -22,6 +22,7 @@ import org.apache.druid.java.util.common.Pair; import org.apache.druid.math.expr.Expr; import org.apache.druid.query.filter.Filter; +import org.apache.druid.query.filter.InDimFilter; import org.apache.druid.segment.join.Equality; import org.apache.druid.segment.join.JoinConditionAnalysis; import org.apache.druid.segment.join.JoinableClause; @@ -155,7 +156,7 @@ public static JoinFilterCorrelations computeJoinFilterCorrelations( correlationForPrefix.getValue().getCorrelatedValuesMap().computeIfAbsent( Pair.of(rhsRewriteCandidate.getRhsColumn(), rhsRewriteCandidate.getValueForRewrite()), (rhsVal) -> { - Optional> correlatedValues = getCorrelatedValuesForPushDown( + Optional correlatedValues = getCorrelatedValuesForPushDown( rhsRewriteCandidate.getRhsColumn(), rhsRewriteCandidate.getValueForRewrite(), correlationForPrefix.getValue().getJoinColumn(), @@ -244,7 +245,7 @@ private static List eliminateCorrelationDup * @return A list of values of the correlatedJoinColumn that appear in rows where filterColumn = filterValue * Returns absent if we cannot determine the correlated values. */ - private static Optional> getCorrelatedValuesForPushDown( + private static Optional getCorrelatedValuesForPushDown( String filterColumn, String filterValue, String correlatedJoinColumn, diff --git a/processing/src/main/java/org/apache/druid/segment/join/lookup/LookupJoinable.java b/processing/src/main/java/org/apache/druid/segment/join/lookup/LookupJoinable.java index 813d412735c4..d74817227b07 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/lookup/LookupJoinable.java +++ b/processing/src/main/java/org/apache/druid/segment/join/lookup/LookupJoinable.java @@ -24,6 +24,7 @@ import com.google.common.collect.Sets; import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.io.Closer; +import org.apache.druid.query.filter.InDimFilter; import org.apache.druid.query.lookup.LookupExtractor; import org.apache.druid.segment.ColumnSelectorFactory; import org.apache.druid.segment.column.ColumnCapabilities; @@ -96,32 +97,38 @@ public JoinMatcher makeJoinMatcher( } @Override - public ColumnValuesWithUniqueFlag getNonNullColumnValues(String columnName, int maxNumValues) + public ColumnValuesWithUniqueFlag getMatchableColumnValues(String columnName, boolean includeNull, int maxNumValues) { if (LookupColumnSelectorFactory.KEY_COLUMN.equals(columnName) && extractor.canGetKeySet()) { final Set keys = extractor.keySet(); - final Set nullEquivalentValues = new HashSet<>(); - nullEquivalentValues.add(null); - if (NullHandling.replaceWithDefault()) { - nullEquivalentValues.add(NullHandling.defaultStringValue()); + final Set nonMatchingValues; + + if (includeNull) { + nonMatchingValues = Collections.emptySet(); + } else { + nonMatchingValues = new HashSet<>(); + nonMatchingValues.add(null); + if (NullHandling.replaceWithDefault()) { + nonMatchingValues.add(NullHandling.defaultStringValue()); + } } // size() of Sets.difference is slow; avoid it. - int nonNullKeys = keys.size(); + int matchingKeys = keys.size(); - for (String value : nullEquivalentValues) { + for (String value : nonMatchingValues) { if (keys.contains(value)) { - nonNullKeys--; + matchingKeys--; } } - if (nonNullKeys > maxNumValues) { + if (matchingKeys > maxNumValues) { return new ColumnValuesWithUniqueFlag(ImmutableSet.of(), false); - } else if (nonNullKeys == keys.size()) { + } else if (matchingKeys == keys.size()) { return new ColumnValuesWithUniqueFlag(keys, true); } else { - return new ColumnValuesWithUniqueFlag(Sets.difference(keys, nullEquivalentValues), true); + return new ColumnValuesWithUniqueFlag(Sets.difference(keys, nonMatchingValues), true); } } else { return new ColumnValuesWithUniqueFlag(ImmutableSet.of(), false); @@ -129,7 +136,7 @@ public ColumnValuesWithUniqueFlag getNonNullColumnValues(String columnName, int } @Override - public Optional> getCorrelatedColumnValues( + public Optional getCorrelatedColumnValues( String searchColumnName, String searchColumnValue, String retrievalColumnName, @@ -140,13 +147,13 @@ public Optional> getCorrelatedColumnValues( if (!ALL_COLUMNS.contains(searchColumnName) || !ALL_COLUMNS.contains(retrievalColumnName)) { return Optional.empty(); } - Set correlatedValues; + InDimFilter.ValuesSet correlatedValues; if (LookupColumnSelectorFactory.KEY_COLUMN.equals(searchColumnName)) { if (LookupColumnSelectorFactory.KEY_COLUMN.equals(retrievalColumnName)) { - correlatedValues = ImmutableSet.of(searchColumnValue); + correlatedValues = InDimFilter.ValuesSet.of(searchColumnValue); } else { // This should not happen in practice because the column to be joined on must be a key. - correlatedValues = Collections.singleton(extractor.apply(searchColumnValue)); + correlatedValues = InDimFilter.ValuesSet.of(extractor.apply(searchColumnValue)); } } else { if (!allowNonKeyColumnSearch) { @@ -154,11 +161,11 @@ public Optional> getCorrelatedColumnValues( } if (LookupColumnSelectorFactory.VALUE_COLUMN.equals(retrievalColumnName)) { // This should not happen in practice because the column to be joined on must be a key. - correlatedValues = ImmutableSet.of(searchColumnValue); + correlatedValues = InDimFilter.ValuesSet.of(searchColumnValue); } else { // Lookup extractor unapply only provides a list of strings, so we can't respect // maxCorrelationSetSize easily. This should be handled eventually. - correlatedValues = ImmutableSet.copyOf(extractor.unapply(searchColumnValue)); + correlatedValues = InDimFilter.ValuesSet.copyOf(extractor.unapply(searchColumnValue)); } } return Optional.of(correlatedValues); diff --git a/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTable.java b/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTable.java index d88d81a87f40..221fa67fc0f7 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTable.java +++ b/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTable.java @@ -95,6 +95,7 @@ default ColumnSelectorFactory makeColumnSelectorFactory(ReadableOffset offset, b * see {@link org.apache.druid.segment.join.JoinableFactory#computeJoinCacheKey} * * @return the byte array for cache key + * * @throws {@link IAE} if caching is not supported */ default byte[] computeCacheKey() @@ -125,8 +126,10 @@ interface Index /** * Returns whether keys are unique in this index. If this returns true, then {@link #find(Object)} will only ever * return a zero- or one-element list. + * + * @param includeNull whether null is considered a valid key */ - boolean areKeysUnique(); + boolean areKeysUnique(boolean includeNull); /** * Returns the list of row numbers corresponding to "key" in this index. @@ -134,14 +137,14 @@ interface Index * If "key" is some type other than the natural type {@link #keyType()}, it will be converted before checking * the index. */ - IntSortedSet find(Object key); + IntSortedSet find(@Nullable Object key); /** * Returns the row number corresponding to "key" in this index, or {@link #NOT_FOUND} if the key does not exist * in the index. * - * It is only valid to call this method if {@link #keyType()} is {@link ValueType#LONG} and {@link #areKeysUnique()} - * returns true. + * It is only valid to call this method if {@link #keyType()} is {@link ValueType#LONG} and + * {@link #areKeysUnique(boolean)} returns true. * * @throws UnsupportedOperationException if preconditions are not met */ diff --git a/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTableJoinMatcher.java b/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTableJoinMatcher.java index a433a0ae5522..f96e4260f8e3 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTableJoinMatcher.java +++ b/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTableJoinMatcher.java @@ -125,7 +125,7 @@ public class IndexedTableJoinMatcher implements JoinMatcher .map(pair -> makeConditionMatcher(pair.lhs, leftSelectorFactory, pair.rhs)) .collect(Collectors.toList()); - this.singleRowMatching = indexes.stream().allMatch(pair -> pair.lhs.areKeysUnique()); + this.singleRowMatching = indexes.stream().allMatch(pair -> pair.lhs.areKeysUnique(pair.rhs.isIncludeNull())); } else { throw new IAE( "Cannot build hash-join matcher on non-equi-join condition: %s", @@ -169,7 +169,7 @@ private static ConditionMatcher makeConditionMatcher( return ColumnProcessors.makeProcessor( condition.getLeftExpr(), index.keyType(), - new ConditionMatcherFactory(index), + new ConditionMatcherFactory(index, condition.isIncludeNull()), selectorFactory ); } @@ -374,21 +374,23 @@ static class ConditionMatcherFactory implements ColumnProcessorFactory (int) dimension id -> (IntSortedSet) row numbers @SuppressWarnings("MismatchedQueryAndUpdateOfCollection") // updated via computeIfAbsent private final LruLoadingHashMap dimensionCaches; - ConditionMatcherFactory(IndexedTable.Index index) + ConditionMatcherFactory(IndexedTable.Index index, boolean includeNull) { this.keyType = index.keyType(); this.index = index; + this.includeNull = includeNull; this.dimensionCaches = new LruLoadingHashMap<>( MAX_NUM_CACHE, selector -> { int cardinality = selector.getValueCardinality(); - IntFunction loader = dimensionId -> getRowNumbers(selector, dimensionId); + IntFunction loader = dimensionId -> getRowNumbers(selector.lookupName(dimensionId)); return cardinality <= CACHE_MAX_SIZE ? new Int2IntSortedSetLookupTable(cardinality, loader) : new Int2IntSortedSetLruCache(CACHE_MAX_SIZE, loader); @@ -396,10 +398,13 @@ static class ConditionMatcherFactory implements ColumnProcessorFactory index.find(selector.getFloat()); + } else if (includeNull) { + return () -> selector.isNull() ? index.find(null) : index.find(selector.getFloat()); } else { return () -> selector.isNull() ? IntSortedSets.EMPTY_SET : index.find(selector.getFloat()); } @@ -475,6 +482,8 @@ public ConditionMatcher makeDoubleProcessor(BaseDoubleColumnValueSelector select { if (NullHandling.replaceWithDefault()) { return () -> index.find(selector.getDouble()); + } else if (includeNull) { + return () -> selector.isNull() ? index.find(null) : index.find(selector.getDouble()); } else { return () -> selector.isNull() ? IntSortedSets.EMPTY_SET : index.find(selector.getDouble()); } @@ -487,6 +496,8 @@ public ConditionMatcher makeLongProcessor(BaseLongColumnValueSelector selector) return makePrimitiveLongMatcher(selector); } else if (NullHandling.replaceWithDefault()) { return () -> index.find(selector.getLong()); + } else if (includeNull) { + return () -> selector.isNull() ? index.find(null) : index.find(selector.getLong()); } else { return () -> selector.isNull() ? IntSortedSets.EMPTY_SET : index.find(selector.getLong()); } @@ -543,6 +554,27 @@ public IntSortedSet match() return index.find(selector.getLong()); } }; + } else if (includeNull) { + return new ConditionMatcher() + { + @Override + public int matchSingleRow() + { + if (selector.isNull()) { + final IntSortedSet rowNumbers = index.find(null); + + return rowNumbers == null ? NO_CONDITION_MATCH : rowNumbers.firstInt(); + } else { + return index.findUniqueLong(selector.getLong()); + } + } + + @Override + public IntSortedSet match() + { + return selector.isNull() ? index.find(null) : index.find(selector.getLong()); + } + }; } else { return new ConditionMatcher() { diff --git a/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTableJoinable.java b/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTableJoinable.java index cf7ced874360..4e9c5b5b3524 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTableJoinable.java +++ b/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTableJoinable.java @@ -23,8 +23,8 @@ import it.unimi.dsi.fastutil.ints.IntBidirectionalIterator; import it.unimi.dsi.fastutil.ints.IntSortedSet; import org.apache.druid.common.config.NullHandling; -import org.apache.druid.java.util.common.guava.Comparators; import org.apache.druid.java.util.common.io.Closer; +import org.apache.druid.query.filter.InDimFilter; import org.apache.druid.segment.ColumnSelectorFactory; import org.apache.druid.segment.DimensionHandlerUtils; import org.apache.druid.segment.column.ColumnCapabilities; @@ -38,8 +38,6 @@ import java.util.List; import java.util.Objects; import java.util.Optional; -import java.util.Set; -import java.util.TreeSet; public class IndexedTableJoinable implements Joinable { @@ -94,35 +92,34 @@ public JoinMatcher makeJoinMatcher( } @Override - public ColumnValuesWithUniqueFlag getNonNullColumnValues(String columnName, final int maxNumValues) + public ColumnValuesWithUniqueFlag getMatchableColumnValues(String columnName, boolean includeNull, int maxNumValues) { final int columnPosition = table.rowSignature().indexOf(columnName); + final InDimFilter.ValuesSet matchableValues = InDimFilter.ValuesSet.create(); if (columnPosition < 0) { - return new ColumnValuesWithUniqueFlag(ImmutableSet.of(), false); + return new ColumnValuesWithUniqueFlag(matchableValues /* empty set */, false); } try (final IndexedTable.Reader reader = table.columnReader(columnPosition)) { - // Use a SortedSet so InDimFilter doesn't need to create its own - final Set allValues = createValuesSet(); boolean allUnique = true; for (int i = 0; i < table.numRows(); i++) { final String s = DimensionHandlerUtils.convertObjectToString(reader.read(i)); - if (!NullHandling.isNullOrEquivalent(s)) { - if (!allValues.add(s)) { + if (includeNull || !NullHandling.isNullOrEquivalent(s)) { + if (!matchableValues.add(s)) { // Duplicate found allUnique = false; } - if (allValues.size() > maxNumValues) { + if (matchableValues.size() > maxNumValues) { return new ColumnValuesWithUniqueFlag(ImmutableSet.of(), false); } } } - return new ColumnValuesWithUniqueFlag(allValues, allUnique); + return new ColumnValuesWithUniqueFlag(matchableValues, allUnique); } catch (IOException e) { throw new RuntimeException(e); @@ -130,7 +127,7 @@ public ColumnValuesWithUniqueFlag getNonNullColumnValues(String columnName, fina } @Override - public Optional> getCorrelatedColumnValues( + public Optional getCorrelatedColumnValues( String searchColumnName, String searchColumnValue, String retrievalColumnName, @@ -145,7 +142,7 @@ public Optional> getCorrelatedColumnValues( return Optional.empty(); } try (final Closer closer = Closer.create()) { - Set correlatedValues = createValuesSet(); + InDimFilter.ValuesSet correlatedValues = InDimFilter.ValuesSet.create(); if (table.keyColumns().contains(searchColumnName)) { IndexedTable.Index index = table.columnIndex(filterColumnPosition); IndexedTable.Reader reader = table.columnReader(correlatedColumnPosition); @@ -195,12 +192,4 @@ public Optional acquireReferences() { return table.acquireReferences(); } - - /** - * Create a Set that InDimFilter will accept without incurring a copy. - */ - private static Set createValuesSet() - { - return new TreeSet<>(Comparators.naturalNullsFirst()); - } } diff --git a/processing/src/main/java/org/apache/druid/segment/join/table/MapIndex.java b/processing/src/main/java/org/apache/druid/segment/join/table/MapIndex.java index 25464deffb4f..973d9951a0fc 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/table/MapIndex.java +++ b/processing/src/main/java/org/apache/druid/segment/join/table/MapIndex.java @@ -26,6 +26,7 @@ import org.apache.druid.segment.DimensionHandlerUtils; import org.apache.druid.segment.column.ColumnType; +import javax.annotation.Nullable; import java.util.Map; /** @@ -33,25 +34,53 @@ */ public class MapIndex implements IndexedTable.Index { + /** + * Type of keys in {@link #index}. + */ private final ColumnType keyType; + + /** + * Index of all nonnull keys -> rows with those keys. + */ private final Map index; - private final boolean keysUnique; + + /** + * Rows containing a null key. + */ + @Nullable + private final IntSortedSet nullIndex; + + /** + * Whether nonnull keys are unique, i.e. everything in {@link #index} has exactly 1 element. + */ + private final boolean nonNullKeysUnique; + + /** + * Whether {@link #index} is a {@link Long2ObjectMap}. + */ private final boolean isLong2ObjectMap; /** * Creates a new instance based on a particular map. * - * @param keyType type of keys in "index" - * @param index a map of keys to matching row numbers - * @param keysUnique whether the keys are unique (if true: all IntLists in the index must be exactly 1 element) + * @param keyType type of keys in "index" + * @param index a map of keys to matching row numbers + * @param nonNullKeysUnique whether nonnull keys are unique (if true: all IntLists in the index must be exactly 1 + * element, except possibly the one corresponding to null) * * @see RowBasedIndexBuilder#build() the main caller */ - MapIndex(final ColumnType keyType, final Map index, final boolean keysUnique) + MapIndex( + final ColumnType keyType, + final Map index, + final IntSortedSet nullIndex, + final boolean nonNullKeysUnique + ) { this.keyType = Preconditions.checkNotNull(keyType, "keyType"); this.index = Preconditions.checkNotNull(index, "index"); - this.keysUnique = keysUnique; + this.nullIndex = nullIndex; + this.nonNullKeysUnique = nonNullKeysUnique; this.isLong2ObjectMap = index instanceof Long2ObjectMap; } @@ -62,23 +91,35 @@ public ColumnType keyType() } @Override - public boolean areKeysUnique() + public boolean areKeysUnique(final boolean includeNull) { - return keysUnique; + if (includeNull) { + return nonNullKeysUnique && find(null).size() < 2; + } else { + return nonNullKeysUnique; + } } @Override - public IntSortedSet find(Object key) + public IntSortedSet find(@Nullable Object key) { - final Object convertedKey = DimensionHandlerUtils.convertObjectToType(key, keyType, false); + final IntSortedSet found; + + if (key == null) { + found = nullIndex; + } else { + final Object convertedKey = DimensionHandlerUtils.convertObjectToType(key, keyType, false); - if (convertedKey != null) { - final IntSortedSet found = index.get(convertedKey); - if (found != null) { - return found; + if (convertedKey != null) { + found = index.get(convertedKey); } else { - return IntSortedSets.EMPTY_SET; + // Don't look up null in the index, since this convertedKey is null because it's a failed cast, not a true null. + found = null; } + } + + if (found != null) { + return found; } else { return IntSortedSets.EMPTY_SET; } @@ -87,7 +128,7 @@ public IntSortedSet find(Object key) @Override public int findUniqueLong(long key) { - if (isLong2ObjectMap && keysUnique) { + if (isLong2ObjectMap && nonNullKeysUnique) { final IntSortedSet rows = ((Long2ObjectMap) (Map) index).get(key); assert rows == null || rows.size() == 1; return rows != null ? rows.firstInt() : NOT_FOUND; diff --git a/processing/src/main/java/org/apache/druid/segment/join/table/RowBasedIndexBuilder.java b/processing/src/main/java/org/apache/druid/segment/join/table/RowBasedIndexBuilder.java index dc4b618dadae..5574f607e83d 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/table/RowBasedIndexBuilder.java +++ b/processing/src/main/java/org/apache/druid/segment/join/table/RowBasedIndexBuilder.java @@ -50,9 +50,10 @@ public class RowBasedIndexBuilder private static final long INT_ARRAY_SMALL_SIZE_OK = 250_000; private int currentRow = 0; - private int nullKeys = 0; + private int nonNullKeys = 0; private final ColumnType keyType; private final Map index; + private IntSortedSet nullIndex; private long minLongKey = Long.MAX_VALUE; private long maxLongKey = Long.MIN_VALUE; @@ -79,22 +80,30 @@ public RowBasedIndexBuilder(ColumnType keyType) */ public RowBasedIndexBuilder add(@Nullable final Object key) { - final Object castKey = DimensionHandlerUtils.convertObjectToType(key, keyType); + if (key == null) { + // Use "nullIndex" instead of "index" because "index" may be specialized as Long2ObjectMap, which cannot + // accept null keys. + if (nullIndex == null) { + nullIndex = new IntAVLTreeSet(); + } - if (castKey != null) { - final IntSortedSet rowNums = index.computeIfAbsent(castKey, k -> new IntAVLTreeSet()); - rowNums.add(currentRow); + nullIndex.add(currentRow); + } else { + final Object castKey = DimensionHandlerUtils.convertObjectToType(key, keyType); - // Track min, max long value so we can decide later on if it's appropriate to use an array-backed implementation. - if (keyType.is(ValueType.LONG) && (long) castKey < minLongKey) { - minLongKey = (long) castKey; - } + if (castKey != null) { + index.computeIfAbsent(castKey, k -> new IntAVLTreeSet()).add(currentRow); + nonNullKeys++; - if (keyType.is(ValueType.LONG) && (long) castKey > maxLongKey) { - maxLongKey = (long) castKey; + // Track min, max long value so we can decide later on if it's appropriate to use an array-backed implementation. + if (keyType.is(ValueType.LONG) && (long) castKey < minLongKey) { + minLongKey = (long) castKey; + } + + if (keyType.is(ValueType.LONG) && (long) castKey > maxLongKey) { + maxLongKey = (long) castKey; + } } - } else { - nullKeys++; } currentRow++; @@ -107,9 +116,9 @@ public RowBasedIndexBuilder add(@Nullable final Object key) */ public IndexedTable.Index build() { - final boolean keysUnique = index.size() == currentRow - nullKeys; + final boolean nonNullKeysUnique = index.size() == nonNullKeys; - if (keyType.is(ValueType.LONG) && keysUnique && index.size() > 0) { + if (keyType.is(ValueType.LONG) && nonNullKeysUnique && !index.isEmpty() && nullIndex == null) { // May be a good candidate for UniqueLongArrayIndex. Check the range of values as compared to min and max. long range; @@ -155,6 +164,6 @@ public IndexedTable.Index build() } } - return new MapIndex(keyType, index, keysUnique); + return new MapIndex(keyType, index, nullIndex, nonNullKeysUnique); } } diff --git a/processing/src/main/java/org/apache/druid/segment/join/table/UniqueLongArrayIndex.java b/processing/src/main/java/org/apache/druid/segment/join/table/UniqueLongArrayIndex.java index 5c5fd959de33..034ff03f9001 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/table/UniqueLongArrayIndex.java +++ b/processing/src/main/java/org/apache/druid/segment/join/table/UniqueLongArrayIndex.java @@ -24,13 +24,19 @@ import org.apache.druid.segment.DimensionHandlerUtils; import org.apache.druid.segment.column.ColumnType; +import javax.annotation.Nullable; + /** * An {@link IndexedTable.Index} backed by an int array. * - * This is for long-typed keys whose values all fall in a "reasonable" range. + * This is for nonnull long-typed keys whose values all fall in a "reasonable" range. Built by + * {@link RowBasedIndexBuilder#build()} when these conditions are met. */ public class UniqueLongArrayIndex implements IndexedTable.Index { + /** + * Array index is the key, value is the row number. + */ private final int[] index; private final long minKey; @@ -55,14 +61,19 @@ public ColumnType keyType() } @Override - public boolean areKeysUnique() + public boolean areKeysUnique(final boolean includeNull) { return true; } @Override - public IntSortedSet find(Object key) + public IntSortedSet find(@Nullable Object key) { + if (key == null) { + // This index class never contains null keys. + return IntSortedSets.EMPTY_SET; + } + final Long longKey = DimensionHandlerUtils.convertObjectToLong(key); if (longKey != null) { diff --git a/processing/src/test/java/org/apache/druid/math/expr/EvalTest.java b/processing/src/test/java/org/apache/druid/math/expr/EvalTest.java index c49959aa1ae3..2f68840955fa 100644 --- a/processing/src/test/java/org/apache/druid/math/expr/EvalTest.java +++ b/processing/src/test/java/org/apache/druid/math/expr/EvalTest.java @@ -99,6 +99,11 @@ public void testDoubleEval() Assert.assertTrue(evalDouble("2.0 == 2.0", bindings) > 0.0); Assert.assertTrue(evalDouble("2.0 != 1.0", bindings) > 0.0); + Assert.assertEquals(1L, evalLong("notdistinctfrom(2.0, 2.0)", bindings)); + Assert.assertEquals(1L, evalLong("isdistinctfrom(2.0, 1.0)", bindings)); + Assert.assertEquals(0L, evalLong("notdistinctfrom(2.0, 1.0)", bindings)); + Assert.assertEquals(0L, evalLong("isdistinctfrom(2.0, 2.0)", bindings)); + Assert.assertEquals(0L, evalLong("istrue(0.0)", bindings)); Assert.assertEquals(1L, evalLong("isfalse(0.0)", bindings)); Assert.assertEquals(1L, evalLong("nottrue(0.0)", bindings)); @@ -131,6 +136,11 @@ public void testDoubleEval() Assert.assertEquals(1L, evalLong("2.0 == 2.0", bindings)); Assert.assertEquals(1L, evalLong("2.0 != 1.0", bindings)); + Assert.assertEquals(1L, evalLong("notdistinctfrom(2.0, 2.0)", bindings)); + Assert.assertEquals(1L, evalLong("isdistinctfrom(2.0, 1.0)", bindings)); + Assert.assertEquals(0L, evalLong("notdistinctfrom(2.0, 1.0)", bindings)); + Assert.assertEquals(0L, evalLong("isdistinctfrom(2.0, 2.0)", bindings)); + Assert.assertEquals(0L, evalLong("istrue(0.0)", bindings)); Assert.assertEquals(1L, evalLong("isfalse(0.0)", bindings)); Assert.assertEquals(1L, evalLong("nottrue(0.0)", bindings)); @@ -186,6 +196,8 @@ public void testLongEval() Assert.assertTrue(evalLong("9223372036854775807 <= 9223372036854775807", bindings) > 0); Assert.assertTrue(evalLong("9223372036854775807 == 9223372036854775807", bindings) > 0); Assert.assertTrue(evalLong("9223372036854775807 != 9223372036854775806", bindings) > 0); + Assert.assertTrue(evalLong("notdistinctfrom(9223372036854775807, 9223372036854775807)", bindings) > 0); + Assert.assertTrue(evalLong("isdistinctfrom(9223372036854775807, 9223372036854775806)", bindings) > 0); assertEquals(9223372036854775807L, evalLong("9223372036854775806 + 1", bindings)); assertEquals(9223372036854775806L, evalLong("9223372036854775807 - 1", bindings)); @@ -221,6 +233,92 @@ public void testLongEval() assertEquals("x", eval("nvl(if(x == 9223372036854775806, '', 'x'), 'NULL')", bindings).asString()); } + @Test + public void testIsNotDistinctFrom() + { + assertEquals( + 1L, + new Function.IsNotDistinctFromFunc() + .apply( + ImmutableList.of( + new NullLongExpr(), + new NullLongExpr() + ), + InputBindings.nilBindings() + ) + .value() + ); + + assertEquals( + 0L, + new Function.IsNotDistinctFromFunc() + .apply( + ImmutableList.of( + new LongExpr(0L), + new NullLongExpr() + ), + InputBindings.nilBindings() + ) + .value() + ); + + assertEquals( + 1L, + new Function.IsNotDistinctFromFunc() + .apply( + ImmutableList.of( + new LongExpr(0L), + new LongExpr(0L) + ), + InputBindings.nilBindings() + ) + .value() + ); + } + + @Test + public void testIsDistinctFrom() + { + assertEquals( + 0L, + new Function.IsDistinctFromFunc() + .apply( + ImmutableList.of( + new NullLongExpr(), + new NullLongExpr() + ), + InputBindings.nilBindings() + ) + .value() + ); + + assertEquals( + 1L, + new Function.IsDistinctFromFunc() + .apply( + ImmutableList.of( + new LongExpr(0L), + new NullLongExpr() + ), + InputBindings.nilBindings() + ) + .value() + ); + + assertEquals( + 0L, + new Function.IsDistinctFromFunc() + .apply( + ImmutableList.of( + new LongExpr(0L), + new LongExpr(0L) + ), + InputBindings.nilBindings() + ) + .value() + ); + } + @Test public void testIsFalse() { @@ -1151,6 +1249,8 @@ public void testArrayComparison() Assert.assertEquals(1L, eval("['a','b',null,'c'] >= stringArray", bindings).value()); Assert.assertEquals(1L, eval("['a','b',null,'c'] == stringArray", bindings).value()); Assert.assertEquals(0L, eval("['a','b',null,'c'] != stringArray", bindings).value()); + Assert.assertEquals(1L, eval("notdistinctfrom(['a','b',null,'c'], stringArray)", bindings).value()); + Assert.assertEquals(0L, eval("isdistinctfrom(['a','b',null,'c'], stringArray)", bindings).value()); Assert.assertEquals(1L, eval("['a','b',null,'c'] <= stringArray", bindings).value()); Assert.assertEquals(0L, eval("['a','b',null,'c'] < stringArray", bindings).value()); @@ -1158,6 +1258,8 @@ public void testArrayComparison() Assert.assertEquals(1L, eval("[1,null,2,3] >= longArray", bindings).value()); Assert.assertEquals(1L, eval("[1,null,2,3] == longArray", bindings).value()); Assert.assertEquals(0L, eval("[1,null,2,3] != longArray", bindings).value()); + Assert.assertEquals(1L, eval("notdistinctfrom([1,null,2,3], longArray)", bindings).value()); + Assert.assertEquals(0L, eval("isdistinctfrom([1,null,2,3], longArray)", bindings).value()); Assert.assertEquals(1L, eval("[1,null,2,3] <= longArray", bindings).value()); Assert.assertEquals(0L, eval("[1,null,2,3] < longArray", bindings).value()); @@ -1165,6 +1267,8 @@ public void testArrayComparison() Assert.assertEquals(1L, eval("[1.1,2.2,3.3,null] >= doubleArray", bindings).value()); Assert.assertEquals(1L, eval("[1.1,2.2,3.3,null] == doubleArray", bindings).value()); Assert.assertEquals(0L, eval("[1.1,2.2,3.3,null] != doubleArray", bindings).value()); + Assert.assertEquals(1L, eval("notdistinctfrom([1.1,2.2,3.3,null], doubleArray)", bindings).value()); + Assert.assertEquals(0L, eval("isdistinctfrom([1.1,2.2,3.3,null], doubleArray)", bindings).value()); Assert.assertEquals(1L, eval("[1.1,2.2,3.3,null] <= doubleArray", bindings).value()); Assert.assertEquals(0L, eval("[1.1,2.2,3.3,null] < doubleArray", bindings).value()); } diff --git a/processing/src/test/java/org/apache/druid/math/expr/ExprsTest.java b/processing/src/test/java/org/apache/druid/math/expr/ExprsTest.java index a5b6844d5a23..aeb23d257d50 100644 --- a/processing/src/test/java/org/apache/druid/math/expr/ExprsTest.java +++ b/processing/src/test/java/org/apache/druid/math/expr/ExprsTest.java @@ -20,8 +20,9 @@ package org.apache.druid.math.expr; import com.google.common.collect.ImmutableList; -import org.apache.druid.java.util.common.Pair; +import org.apache.druid.segment.join.Equality; import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; import org.junit.Assert; import org.junit.Test; @@ -73,27 +74,52 @@ public void test_decomposeAnd_basic() @Test public void test_decomposeEquals_notAnEquals() { - final Optional> optionalPair = Exprs.decomposeEquals(new IdentifierExpr("foo")); - Assert.assertFalse(optionalPair.isPresent()); + final Optional result = Exprs.decomposeEquals(new IdentifierExpr("foo"), "j."); + Assert.assertFalse(result.isPresent()); } @Test public void test_decomposeEquals_basic() { - final Optional> optionalPair = Exprs.decomposeEquals( + final Optional result = Exprs.decomposeEquals( new BinEqExpr( "==", new IdentifierExpr("foo"), - new IdentifierExpr("bar") - ) + new IdentifierExpr("j.bar") + ), + "j." + ); + + Assert.assertTrue(result.isPresent()); + + final Equality equality = result.get(); + MatcherAssert.assertThat(equality.getLeftExpr(), CoreMatchers.instanceOf(IdentifierExpr.class)); + Assert.assertEquals("foo", ((IdentifierExpr) equality.getLeftExpr()).getIdentifier()); + Assert.assertEquals("bar", equality.getRightColumn()); + Assert.assertFalse(equality.isIncludeNull()); + } + + @Test + public void test_decomposeEquals_notDistinctFrom() + { + final Optional result = Exprs.decomposeEquals( + new FunctionExpr( + new Function.IsNotDistinctFromFunc(), + "notdistinctfrom", + ImmutableList.of( + new IdentifierExpr("foo"), + new IdentifierExpr("j.bar") + ) + ), + "j." ); - Assert.assertTrue(optionalPair.isPresent()); + Assert.assertTrue(result.isPresent()); - final Pair pair = optionalPair.get(); - Assert.assertThat(pair.lhs, CoreMatchers.instanceOf(IdentifierExpr.class)); - Assert.assertThat(pair.rhs, CoreMatchers.instanceOf(IdentifierExpr.class)); - Assert.assertEquals("foo", ((IdentifierExpr) pair.lhs).getIdentifier()); - Assert.assertEquals("bar", ((IdentifierExpr) pair.rhs).getIdentifier()); + final Equality equality = result.get(); + MatcherAssert.assertThat(equality.getLeftExpr(), CoreMatchers.instanceOf(IdentifierExpr.class)); + Assert.assertEquals("foo", ((IdentifierExpr) equality.getLeftExpr()).getIdentifier()); + Assert.assertEquals("bar", equality.getRightColumn()); + Assert.assertTrue(equality.isIncludeNull()); } } diff --git a/processing/src/test/java/org/apache/druid/query/filter/InDimFilterTest.java b/processing/src/test/java/org/apache/druid/query/filter/InDimFilterTest.java index 41f2480621da..9508de7bcac6 100644 --- a/processing/src/test/java/org/apache/druid/query/filter/InDimFilterTest.java +++ b/processing/src/test/java/org/apache/druid/query/filter/InDimFilterTest.java @@ -92,7 +92,7 @@ public void testGetValuesWithValuesSetOfNonEmptyStringsUseTheGivenSet() @Test public void testGetValuesWithValuesSetIncludingEmptyString() { - final InDimFilter.ValuesSet values = new InDimFilter.ValuesSet(ImmutableSet.of("v1", "", "v3")); + final InDimFilter.ValuesSet values = InDimFilter.ValuesSet.copyOf(ImmutableSet.of("v1", "", "v3")); final InDimFilter filter = new InDimFilter("dim", values); if (NullHandling.replaceWithDefault()) { Assert.assertSame(values, filter.getValues()); diff --git a/processing/src/test/java/org/apache/druid/segment/join/BaseHashJoinSegmentStorageAdapterTest.java b/processing/src/test/java/org/apache/druid/segment/join/BaseHashJoinSegmentStorageAdapterTest.java index d1de10fdaa54..e5ab7d1be8aa 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/BaseHashJoinSegmentStorageAdapterTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/BaseHashJoinSegmentStorageAdapterTest.java @@ -183,6 +183,25 @@ protected JoinableClause factToRegion(final JoinType joinType) ); } + protected JoinableClause factToRegionIncludeNull(final JoinType joinType) + { + return new JoinableClause( + FACT_TO_REGION_PREFIX, + new IndexedTableJoinable(regionsTable), + joinType, + JoinConditionAnalysis.forExpression( + StringUtils.format( + "notdistinctfrom(\"%sregionIsoCode\", regionIsoCode) && " + + "notdistinctfrom(\"%scountryIsoCode\", countryIsoCode)", + FACT_TO_REGION_PREFIX, + FACT_TO_REGION_PREFIX + ), + FACT_TO_REGION_PREFIX, + ExprMacroTable.nil() + ) + ); + } + protected JoinableClause regionToCountry(final JoinType joinType) { return new JoinableClause( diff --git a/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapterTest.java b/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapterTest.java index 5f7a10b9705e..20d032aba381 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapterTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapterTest.java @@ -1269,6 +1269,69 @@ public void test_makeCursors_factToRegionToCountryLeft() ); } + @Test + public void test_makeCursors_factToRegionToCountryInnerIncludeNull() + { + List joinableClauses = ImmutableList.of( + factToRegionIncludeNull(JoinType.INNER), + regionToCountry(JoinType.LEFT) + ); + JoinFilterPreAnalysis joinFilterPreAnalysis = makeDefaultConfigPreAnalysis( + null, + joinableClauses, + VirtualColumns.EMPTY + ); + JoinTestHelper.verifyCursors( + new HashJoinSegmentStorageAdapter( + factSegment.asStorageAdapter(), + joinableClauses, + joinFilterPreAnalysis + ).makeCursors( + null, + Intervals.ETERNITY, + VirtualColumns.EMPTY, + Granularities.ALL, + false, + null + ), + ImmutableList.of( + "page", + FACT_TO_REGION_PREFIX + "regionName", + REGION_TO_COUNTRY_PREFIX + "countryName" + ), + ImmutableList.of( + new Object[]{"Talk:Oswald Tilghman", "Nulland", null}, + new Object[]{"Rallicula", "Nulland", null}, + new Object[]{"Peremptory norm", "New South Wales", "Australia"}, + new Object[]{"Apamea abruzzorum", "Nulland", null}, + new Object[]{"Atractus flammigerus", "Nulland", null}, + new Object[]{"Agama mossambica", "Nulland", null}, + new Object[]{"Mathis Bolly", "Mexico City", "Mexico"}, + new Object[]{"유희왕 GX", "Seoul", "Republic of Korea"}, + new Object[]{"青野武", "Tōkyō", "Japan"}, + new Object[]{"Golpe de Estado en Chile de 1973", "Santiago Metropolitan", "Chile"}, + new Object[]{"President of India", "California", "United States"}, + new Object[]{"Diskussion:Sebastian Schulz", "Hesse", "Germany"}, + new Object[]{"Saison 9 de Secret Story", "Val d'Oise", "France"}, + new Object[]{"Glasgow", "Kingston upon Hull", "United Kingdom"}, + new Object[]{"Didier Leclair", "Ontario", "Canada"}, + new Object[]{"Les Argonautes", "Quebec", "Canada"}, + new Object[]{"Otjiwarongo Airport", "California", "United States"}, + new Object[]{"Sarah Michelle Gellar", "Ontario", "Canada"}, + new Object[]{"DirecTV", "North Carolina", "United States"}, + new Object[]{"Carlo Curti", "California", "United States"}, + new Object[]{"Giusy Ferreri discography", "Provincia di Varese", "Italy"}, + new Object[]{"Roma-Bangkok", "Provincia di Varese", "Italy"}, + new Object[]{"Wendigo", "Departamento de San Salvador", "El Salvador"}, + new Object[]{"Алиса в Зазеркалье", "Finnmark Fylke", "Norway"}, + new Object[]{"Gabinete Ministerial de Rafael Correa", "Provincia del Guayas", "Ecuador"}, + new Object[]{"Old Anatolian Turkish", "Virginia", "United States"}, + new Object[]{"Cream Soda", "Ainigriv", "States United"}, + new Object[]{"History of Fourems", "Fourems Province", "Fourems"} + ) + ); + } + @Test public void test_makeCursors_factToCountryAlwaysTrue() { @@ -1850,7 +1913,7 @@ public void test_makeCursors_errorOnNonKeyBasedJoin() { expectedException.expect(IllegalArgumentException.class); expectedException.expectMessage("Cannot build hash-join matcher on non-key-based condition: " - + "Equality{leftExpr=x, rightColumn='countryName'}"); + + "Equality{leftExpr=x, rightColumn='countryName', includeNull=false}"); List joinableClauses = ImmutableList.of( new JoinableClause( FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX, diff --git a/processing/src/test/java/org/apache/druid/segment/join/JoinFilterAnalyzerTest.java b/processing/src/test/java/org/apache/druid/segment/join/JoinFilterAnalyzerTest.java index b6fd9f4f0e0f..1b7f250f8479 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/JoinFilterAnalyzerTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/JoinFilterAnalyzerTest.java @@ -2043,7 +2043,8 @@ public void test_filterPushDown_factToRegionThreeRHSColumnsAllDirectAndFilterOnR // filter rewrites. expectedException.expect(IAE.class); expectedException.expectMessage( - "Cannot build hash-join matcher on non-key-based condition: Equality{leftExpr=user, rightColumn='regionName'}" + "Cannot build hash-join matcher on non-key-based condition: " + + "Equality{leftExpr=user, rightColumn='regionName', includeNull=false}" ); JoinTestHelper.verifyCursors( diff --git a/processing/src/test/java/org/apache/druid/segment/join/lookup/LookupJoinableTest.java b/processing/src/test/java/org/apache/druid/segment/join/lookup/LookupJoinableTest.java index ce1dc7fc8b49..0fa492211f07 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/lookup/LookupJoinableTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/lookup/LookupJoinableTest.java @@ -22,6 +22,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import org.apache.druid.common.config.NullHandling; +import org.apache.druid.query.filter.InDimFilter; import org.apache.druid.query.lookup.LookupExtractor; import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ValueType; @@ -36,6 +37,7 @@ import org.mockito.Mockito; import org.mockito.junit.MockitoJUnitRunner; +import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.List; @@ -129,7 +131,7 @@ public void getColumnCapabilitiesForUnknownColumnShouldReturnNull() @Test public void getCorrelatedColummnValuesMissingSearchColumnShouldReturnEmptySet() { - Optional> correlatedValues = + Optional correlatedValues = target.getCorrelatedColumnValues( UNKNOWN_COLUMN, SEARCH_KEY_VALUE, @@ -144,7 +146,7 @@ public void getCorrelatedColummnValuesMissingSearchColumnShouldReturnEmptySet() @Test public void getCorrelatedColummnValuesMissingRetrievalColumnShouldReturnEmptySet() { - Optional> correlatedValues = + Optional correlatedValues = target.getCorrelatedColumnValues( LookupColumnSelectorFactory.KEY_COLUMN, SEARCH_KEY_VALUE, @@ -159,7 +161,7 @@ public void getCorrelatedColummnValuesMissingRetrievalColumnShouldReturnEmptySet @Test public void getCorrelatedColumnValuesForSearchKeyAndRetrieveKeyColumnShouldReturnSearchValue() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( LookupColumnSelectorFactory.KEY_COLUMN, SEARCH_KEY_VALUE, LookupColumnSelectorFactory.KEY_COLUMN, @@ -172,7 +174,7 @@ public void getCorrelatedColumnValuesForSearchKeyAndRetrieveKeyColumnShouldRetur @Test public void getCorrelatedColumnValuesForSearchKeyAndRetrieveValueColumnShouldReturnExtractedValue() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( LookupColumnSelectorFactory.KEY_COLUMN, SEARCH_KEY_VALUE, LookupColumnSelectorFactory.VALUE_COLUMN, @@ -185,7 +187,7 @@ public void getCorrelatedColumnValuesForSearchKeyAndRetrieveValueColumnShouldRet @Test public void getCorrelatedColumnValuesForSearchKeyMissingAndRetrieveValueColumnShouldReturnExtractedValue() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( LookupColumnSelectorFactory.KEY_COLUMN, SEARCH_KEY_NULL_VALUE, LookupColumnSelectorFactory.VALUE_COLUMN, @@ -198,7 +200,7 @@ public void getCorrelatedColumnValuesForSearchKeyMissingAndRetrieveValueColumnSh @Test public void getCorrelatedColumnValuesForSearchValueAndRetrieveValueColumnAndNonKeyColumnSearchDisabledShouldReturnSearchValue() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( LookupColumnSelectorFactory.VALUE_COLUMN, SEARCH_VALUE_VALUE, LookupColumnSelectorFactory.VALUE_COLUMN, @@ -219,7 +221,7 @@ public void getCorrelatedColumnValuesForSearchValueAndRetrieveValueColumnAndNonK @Test public void getCorrelatedColumnValuesForSearchValueAndRetrieveValueColumnShouldReturnSearchValue() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( LookupColumnSelectorFactory.VALUE_COLUMN, SEARCH_VALUE_VALUE, LookupColumnSelectorFactory.VALUE_COLUMN, @@ -232,7 +234,7 @@ public void getCorrelatedColumnValuesForSearchValueAndRetrieveValueColumnShouldR @Test public void getCorrelatedColumnValuesForSearchValueAndRetrieveKeyColumnShouldReturnUnAppliedValue() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( LookupColumnSelectorFactory.VALUE_COLUMN, SEARCH_VALUE_VALUE, LookupColumnSelectorFactory.KEY_COLUMN, @@ -250,7 +252,7 @@ public void getCorrelatedColumnValuesForSearchValueAndRetrieveKeyColumnShouldRet */ public void getCorrelatedColumnValuesForSearchValueAndRetrieveKeyColumnWithMaxLimitSetShouldHonorMaxLimit() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( LookupColumnSelectorFactory.VALUE_COLUMN, SEARCH_VALUE_VALUE, LookupColumnSelectorFactory.KEY_COLUMN, @@ -263,7 +265,7 @@ public void getCorrelatedColumnValuesForSearchValueAndRetrieveKeyColumnWithMaxLi @Test public void getCorrelatedColumnValuesForSearchUnknownValueAndRetrieveKeyColumnShouldReturnNoCorrelatedValues() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( LookupColumnSelectorFactory.VALUE_COLUMN, SEARCH_VALUE_UNKNOWN, LookupColumnSelectorFactory.KEY_COLUMN, @@ -274,10 +276,11 @@ public void getCorrelatedColumnValuesForSearchUnknownValueAndRetrieveKeyColumnSh } @Test - public void getNonNullColumnValuesIfAllUniqueForValueColumnShouldReturnEmpty() + public void getMatchableColumnValuesIfAllUniqueForValueColumnShouldReturnEmpty() { - final Joinable.ColumnValuesWithUniqueFlag values = target.getNonNullColumnValues( + final Joinable.ColumnValuesWithUniqueFlag values = target.getMatchableColumnValues( LookupColumnSelectorFactory.VALUE_COLUMN, + false, Integer.MAX_VALUE ); @@ -285,24 +288,41 @@ public void getNonNullColumnValuesIfAllUniqueForValueColumnShouldReturnEmpty() } @Test - public void getNonNullColumnValuesIfAllUniqueForKeyColumnShouldReturnValues() + public void getMatchableColumnValuesIfAllUniqueForKeyColumnShouldReturnValues() { - final Joinable.ColumnValuesWithUniqueFlag values = target.getNonNullColumnValues( + final Joinable.ColumnValuesWithUniqueFlag values = target.getMatchableColumnValues( LookupColumnSelectorFactory.KEY_COLUMN, + false, Integer.MAX_VALUE ); Assert.assertEquals( - NullHandling.replaceWithDefault() ? ImmutableSet.of("foo", "bar") : ImmutableSet.of("foo", "bar", ""), + NullHandling.sqlCompatible() ? ImmutableSet.of("foo", "bar", "") : ImmutableSet.of("foo", "bar"), values.getColumnValues() ); } @Test - public void getNonNullColumnValuesIfAllUniqueForKeyColumnWithLowMaxValuesShouldReturnEmpty() + public void getMatchableColumnValuesWithIncludeNullIfAllUniqueForKeyColumnShouldReturnValues() { - final Joinable.ColumnValuesWithUniqueFlag values = target.getNonNullColumnValues( + final Joinable.ColumnValuesWithUniqueFlag values = target.getMatchableColumnValues( LookupColumnSelectorFactory.KEY_COLUMN, + true, + Integer.MAX_VALUE + ); + + Assert.assertEquals( + InDimFilter.ValuesSet.copyOf(Arrays.asList("foo", "bar", "", null)), + values.getColumnValues() + ); + } + + @Test + public void getMatchableColumnValuesIfAllUniqueForKeyColumnWithLowMaxValuesShouldReturnEmpty() + { + final Joinable.ColumnValuesWithUniqueFlag values = target.getMatchableColumnValues( + LookupColumnSelectorFactory.KEY_COLUMN, + false, 1 ); diff --git a/processing/src/test/java/org/apache/druid/segment/join/table/BroadcastSegmentIndexedTableTest.java b/processing/src/test/java/org/apache/druid/segment/join/table/BroadcastSegmentIndexedTableTest.java index 22ff1f3c5c27..2d929f079bfd 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/table/BroadcastSegmentIndexedTableTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/table/BroadcastSegmentIndexedTableTest.java @@ -266,14 +266,10 @@ private void checkIndexAndReader(String columnName, Object[] vals, Object[] nonm // lets try a few values out for (Object val : vals) { final IntSortedSet valIndex = valueIndex.find(val); - if (val == null) { - Assert.assertEquals(0, valIndex.size()); - } else { - Assert.assertTrue(valIndex.size() > 0); - final IntBidirectionalIterator rowIterator = valIndex.iterator(); - while (rowIterator.hasNext()) { - Assert.assertEquals(val, reader.read(rowIterator.nextInt())); - } + Assert.assertTrue(valIndex.size() > 0); + final IntBidirectionalIterator rowIterator = valIndex.iterator(); + while (rowIterator.hasNext()) { + Assert.assertEquals(val, reader.read(rowIterator.nextInt())); } } for (Object val : nonmatchingVals) { diff --git a/processing/src/test/java/org/apache/druid/segment/join/table/FrameBasedIndexedTableTest.java b/processing/src/test/java/org/apache/druid/segment/join/table/FrameBasedIndexedTableTest.java index a0fbd4fcc1d5..ed59f5f80652 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/table/FrameBasedIndexedTableTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/table/FrameBasedIndexedTableTest.java @@ -325,14 +325,10 @@ private void checkIndexAndReader(String columnName, Object[] vals, Object[] nonm for (Object val : vals) { final IntSortedSet valIndex = valueIndex.find(val); - if (val == null) { - Assert.assertEquals(0, valIndex.size()); - } else { - Assert.assertTrue(valIndex.size() > 0); - final IntBidirectionalIterator rowIterator = valIndex.iterator(); - while (rowIterator.hasNext()) { - Assert.assertEquals(val, reader.read(rowIterator.nextInt())); - } + Assert.assertTrue(valIndex.size() > 0); + final IntBidirectionalIterator rowIterator = valIndex.iterator(); + while (rowIterator.hasNext()) { + Assert.assertEquals(val, reader.read(rowIterator.nextInt())); } } for (Object val : nonmatchingVals) { diff --git a/processing/src/test/java/org/apache/druid/segment/join/table/IndexedTableJoinMatcherTest.java b/processing/src/test/java/org/apache/druid/segment/join/table/IndexedTableJoinMatcherTest.java index 57b3896648bf..b1a47355f541 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/table/IndexedTableJoinMatcherTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/table/IndexedTableJoinMatcherTest.java @@ -88,7 +88,7 @@ public void tearDown() throws Exception public void testMatchToUniqueLongIndex() { IndexedTableJoinMatcher.ConditionMatcherFactory conditionMatcherFactory = - new IndexedTableJoinMatcher.ConditionMatcherFactory(longPlusOneIndex()); + new IndexedTableJoinMatcher.ConditionMatcherFactory(longPlusOneIndex(), false); final IndexedTableJoinMatcher.ConditionMatcher processor = conditionMatcherFactory.makeLongProcessor(selector); Assert.assertEquals(ImmutableList.of(2), ImmutableList.copyOf(processor.match())); @@ -98,7 +98,7 @@ public void testMatchToUniqueLongIndex() public void testMatchSingleRowToUniqueLongIndex() { IndexedTableJoinMatcher.ConditionMatcherFactory conditionMatcherFactory = - new IndexedTableJoinMatcher.ConditionMatcherFactory(longPlusOneIndex()); + new IndexedTableJoinMatcher.ConditionMatcherFactory(longPlusOneIndex(), false); final IndexedTableJoinMatcher.ConditionMatcher processor = conditionMatcherFactory.makeLongProcessor(selector); Assert.assertEquals(2, processor.matchSingleRow()); @@ -108,7 +108,7 @@ public void testMatchSingleRowToUniqueLongIndex() public void testMatchToNonUniqueLongIndex() { IndexedTableJoinMatcher.ConditionMatcherFactory conditionMatcherFactory = - new IndexedTableJoinMatcher.ConditionMatcherFactory(longAlwaysOneTwoThreeIndex()); + new IndexedTableJoinMatcher.ConditionMatcherFactory(longAlwaysOneTwoThreeIndex(), false); final IndexedTableJoinMatcher.ConditionMatcher processor = conditionMatcherFactory.makeLongProcessor(selector); Assert.assertEquals(ImmutableList.of(1, 2, 3), ImmutableList.copyOf(processor.match())); @@ -118,7 +118,7 @@ public void testMatchToNonUniqueLongIndex() public void testMatchSingleRowToNonUniqueLongIndex() { IndexedTableJoinMatcher.ConditionMatcherFactory conditionMatcherFactory = - new IndexedTableJoinMatcher.ConditionMatcherFactory(longAlwaysOneTwoThreeIndex()); + new IndexedTableJoinMatcher.ConditionMatcherFactory(longAlwaysOneTwoThreeIndex(), false); final IndexedTableJoinMatcher.ConditionMatcher processor = conditionMatcherFactory.makeLongProcessor(selector); Assert.assertThrows(UnsupportedOperationException.class, processor::matchSingleRow); @@ -128,7 +128,7 @@ public void testMatchSingleRowToNonUniqueLongIndex() public void testMatchToUniqueStringIndex() { IndexedTableJoinMatcher.ConditionMatcherFactory conditionMatcherFactory = - new IndexedTableJoinMatcher.ConditionMatcherFactory(certainStringToThreeIndex()); + new IndexedTableJoinMatcher.ConditionMatcherFactory(certainStringToThreeIndex(), false); final IndexedTableJoinMatcher.ConditionMatcher processor = conditionMatcherFactory.makeLongProcessor(selector); Assert.assertEquals(ImmutableList.of(3), ImmutableList.copyOf(processor.match())); @@ -138,7 +138,7 @@ public void testMatchToUniqueStringIndex() public void testMatchSingleRowToUniqueStringIndex() { IndexedTableJoinMatcher.ConditionMatcherFactory conditionMatcherFactory = - new IndexedTableJoinMatcher.ConditionMatcherFactory(certainStringToThreeIndex()); + new IndexedTableJoinMatcher.ConditionMatcherFactory(certainStringToThreeIndex(), false); final IndexedTableJoinMatcher.ConditionMatcher processor = conditionMatcherFactory.makeLongProcessor(selector); Assert.assertEquals(3, processor.matchSingleRow()); @@ -170,7 +170,7 @@ public void tearDown() throws Exception public void testMatch() { final IndexedTableJoinMatcher.ConditionMatcherFactory conditionMatcherFactory = - new IndexedTableJoinMatcher.ConditionMatcherFactory(longPlusOneIndex()); + new IndexedTableJoinMatcher.ConditionMatcherFactory(longPlusOneIndex(), false); final IndexedTableJoinMatcher.ConditionMatcher processor = conditionMatcherFactory.makeComplexProcessor(selector); @@ -182,7 +182,7 @@ public void testMatch() public void testMatchSingleRow() { final IndexedTableJoinMatcher.ConditionMatcherFactory conditionMatcherFactory = - new IndexedTableJoinMatcher.ConditionMatcherFactory(longPlusOneIndex()); + new IndexedTableJoinMatcher.ConditionMatcherFactory(longPlusOneIndex(), false); final IndexedTableJoinMatcher.ConditionMatcher processor = conditionMatcherFactory.makeComplexProcessor(selector); @@ -212,7 +212,7 @@ public void testMatchMultiValuedRowCardinalityUnknownShouldThrowException() thro .getValueCardinality(); IndexedTableJoinMatcher.ConditionMatcherFactory conditionMatcherFactory = - new IndexedTableJoinMatcher.ConditionMatcherFactory(stringToLengthIndex()); + new IndexedTableJoinMatcher.ConditionMatcherFactory(stringToLengthIndex(), false); IndexedTableJoinMatcher.ConditionMatcher dimensionProcessor = conditionMatcherFactory.makeDimensionProcessor( dimensionSelector, false @@ -233,7 +233,7 @@ public void testMatchMultiValuedRowCardinalityKnownShouldThrowException() throws Mockito.doReturn(3).when(dimensionSelector).getValueCardinality(); IndexedTableJoinMatcher.ConditionMatcherFactory conditionMatcherFactory = - new IndexedTableJoinMatcher.ConditionMatcherFactory(stringToLengthIndex()); + new IndexedTableJoinMatcher.ConditionMatcherFactory(stringToLengthIndex(), false); IndexedTableJoinMatcher.ConditionMatcher dimensionProcessor = conditionMatcherFactory.makeDimensionProcessor( dimensionSelector, false @@ -256,7 +256,7 @@ public void testMatchEmptyRowCardinalityUnknown() throws Exception .getValueCardinality(); IndexedTableJoinMatcher.ConditionMatcherFactory conditionMatcherFactory = - new IndexedTableJoinMatcher.ConditionMatcherFactory(stringToLengthIndex()); + new IndexedTableJoinMatcher.ConditionMatcherFactory(stringToLengthIndex(), false); IndexedTableJoinMatcher.ConditionMatcher dimensionProcessor = conditionMatcherFactory.makeDimensionProcessor( dimensionSelector, false @@ -278,7 +278,7 @@ public void testMatchEmptyRowCardinalityKnown() throws Exception Mockito.doReturn(0).when(dimensionSelector).getValueCardinality(); IndexedTableJoinMatcher.ConditionMatcherFactory conditionMatcherFactory = - new IndexedTableJoinMatcher.ConditionMatcherFactory(stringToLengthIndex()); + new IndexedTableJoinMatcher.ConditionMatcherFactory(stringToLengthIndex(), false); IndexedTableJoinMatcher.ConditionMatcher dimensionProcessor = conditionMatcherFactory.makeDimensionProcessor( dimensionSelector, false @@ -324,7 +324,7 @@ public void getsCorrectResultWhenSelectorCardinalityHigh() private static IndexedTableJoinMatcher.ConditionMatcher makeConditionMatcher(int valueCardinality) { IndexedTableJoinMatcher.ConditionMatcherFactory conditionMatcherFactory = - new IndexedTableJoinMatcher.ConditionMatcherFactory(stringToLengthIndex()); + new IndexedTableJoinMatcher.ConditionMatcherFactory(stringToLengthIndex(), false); return conditionMatcherFactory.makeDimensionProcessor( new TestDimensionSelector(KEY, valueCardinality), false @@ -503,7 +503,7 @@ public ColumnType keyType() } @Override - public boolean areKeysUnique() + public boolean areKeysUnique(boolean includeNull) { return false; } @@ -533,7 +533,7 @@ public ColumnType keyType() } @Override - public boolean areKeysUnique() + public boolean areKeysUnique(boolean includeNull) { return true; } @@ -567,7 +567,7 @@ public ColumnType keyType() } @Override - public boolean areKeysUnique() + public boolean areKeysUnique(boolean includeNull) { return true; } @@ -603,7 +603,7 @@ public ColumnType keyType() } @Override - public boolean areKeysUnique() + public boolean areKeysUnique(boolean includeNull) { return false; } diff --git a/processing/src/test/java/org/apache/druid/segment/join/table/IndexedTableJoinableTest.java b/processing/src/test/java/org/apache/druid/segment/join/table/IndexedTableJoinableTest.java index 93f2d5df13eb..09ddec72b6bc 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/table/IndexedTableJoinableTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/table/IndexedTableJoinableTest.java @@ -28,6 +28,7 @@ import org.apache.druid.query.InlineDataSource; import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.dimension.DimensionSpec; +import org.apache.druid.query.filter.InDimFilter; import org.apache.druid.segment.ColumnSelectorFactory; import org.apache.druid.segment.ColumnValueSelector; import org.apache.druid.segment.ConstantDimensionSelector; @@ -43,9 +44,9 @@ import org.junit.Before; import org.junit.Test; +import java.util.Arrays; import java.util.Collections; import java.util.Optional; -import java.util.Set; public class IndexedTableJoinableTest { @@ -89,7 +90,8 @@ public ColumnCapabilities getColumnCapabilities(String columnName) ImmutableList.of( new Object[]{"foo", 1L, 1L}, new Object[]{"bar", 2L, 1L}, - new Object[]{"baz", null, 1L} + new Object[]{"baz", null, 1L}, + new Object[]{null, 3L, 1L} ), RowSignature.builder() .add(KEY_COLUMN, ColumnType.STRING) @@ -187,7 +189,7 @@ public void makeJoinMatcherWithDimensionSelectorOnString() .makeDimensionSelector(DefaultDimensionSpec.of("str")); // getValueCardinality - Assert.assertEquals(4, selector.getValueCardinality()); + Assert.assertEquals(5, selector.getValueCardinality()); // nameLookupPossibleInAdvance Assert.assertTrue(selector.nameLookupPossibleInAdvance()); @@ -197,6 +199,7 @@ public void makeJoinMatcherWithDimensionSelectorOnString() Assert.assertEquals("bar", selector.lookupName(1)); Assert.assertEquals("baz", selector.lookupName(2)); Assert.assertNull(selector.lookupName(3)); + Assert.assertNull(selector.lookupName(4)); // lookupId Assert.assertNull(selector.idLookup()); @@ -205,13 +208,14 @@ public void makeJoinMatcherWithDimensionSelectorOnString() @Test public void getCorrelatedColummnValuesMissingSearchColumnShouldReturnEmpty() { - Optional> correlatedValues = + Optional correlatedValues = target.getCorrelatedColumnValues( UNKNOWN_COLUMN, "foo", VALUE_COLUMN, MAX_CORRELATION_SET_SIZE, - false); + false + ); Assert.assertEquals(Optional.empty(), correlatedValues); } @@ -219,13 +223,14 @@ public void getCorrelatedColummnValuesMissingSearchColumnShouldReturnEmpty() @Test public void getCorrelatedColummnValuesMissingRetrievalColumnShouldReturnEmpty() { - Optional> correlatedValues = + Optional correlatedValues = target.getCorrelatedColumnValues( KEY_COLUMN, "foo", UNKNOWN_COLUMN, MAX_CORRELATION_SET_SIZE, - false); + false + ); Assert.assertEquals(Optional.empty(), correlatedValues); } @@ -233,149 +238,179 @@ public void getCorrelatedColummnValuesMissingRetrievalColumnShouldReturnEmpty() @Test public void getCorrelatedColumnValuesForSearchKeyAndRetrieveKeyColumnShouldReturnSearchValue() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( KEY_COLUMN, SEARCH_KEY_VALUE, KEY_COLUMN, MAX_CORRELATION_SET_SIZE, - false); + false + ); Assert.assertEquals(Optional.of(ImmutableSet.of(SEARCH_KEY_VALUE)), correlatedValues); } @Test public void getCorrelatedColumnValuesForSearchKeyAndRetrieveKeyColumnAboveLimitShouldReturnEmpty() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( KEY_COLUMN, SEARCH_KEY_VALUE, KEY_COLUMN, 0, - false); + false + ); Assert.assertEquals(Optional.empty(), correlatedValues); } @Test public void getCorrelatedColumnValuesForSearchKeyAndRetrieveValueColumnShouldReturnExtractedValue() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( KEY_COLUMN, SEARCH_KEY_VALUE, VALUE_COLUMN, MAX_CORRELATION_SET_SIZE, - false); + false + ); Assert.assertEquals(Optional.of(ImmutableSet.of(SEARCH_VALUE_VALUE)), correlatedValues); } @Test public void getCorrelatedColumnValuesForSearchKeyMissingAndRetrieveValueColumnShouldReturnExtractedValue() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( KEY_COLUMN, SEARCH_KEY_NULL_VALUE, VALUE_COLUMN, MAX_CORRELATION_SET_SIZE, - false); + false + ); Assert.assertEquals(Optional.of(Collections.singleton(null)), correlatedValues); } @Test public void getCorrelatedColumnValuesForSearchValueAndRetrieveValueColumnAndNonKeyColumnSearchDisabledShouldReturnEmpty() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( VALUE_COLUMN, SEARCH_VALUE_VALUE, VALUE_COLUMN, MAX_CORRELATION_SET_SIZE, - false); + false + ); Assert.assertEquals(Optional.empty(), correlatedValues); correlatedValues = target.getCorrelatedColumnValues( VALUE_COLUMN, SEARCH_VALUE_VALUE, KEY_COLUMN, 10, - false); + false + ); Assert.assertEquals(Optional.empty(), correlatedValues); } @Test public void getCorrelatedColumnValuesForSearchValueAndRetrieveValueColumnShouldReturnSearchValue() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( VALUE_COLUMN, SEARCH_VALUE_VALUE, VALUE_COLUMN, MAX_CORRELATION_SET_SIZE, - true); + true + ); Assert.assertEquals(Optional.of(ImmutableSet.of(SEARCH_VALUE_VALUE)), correlatedValues); } @Test public void getCorrelatedColumnValuesForSearchValueAndRetrieveKeyColumnShouldReturnUnAppliedValue() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( VALUE_COLUMN, SEARCH_VALUE_VALUE, KEY_COLUMN, 10, - true); + true + ); Assert.assertEquals(Optional.of(ImmutableSet.of(SEARCH_KEY_VALUE)), correlatedValues); } @Test public void getCorrelatedColumnValuesForSearchValueAndRetrieveKeyColumnWithMaxLimitSetShouldHonorMaxLimit() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( VALUE_COLUMN, SEARCH_VALUE_VALUE, KEY_COLUMN, 0, - true); + true + ); Assert.assertEquals(Optional.empty(), correlatedValues); } @Test public void getCorrelatedColumnValuesForSearchUnknownValueAndRetrieveKeyColumnShouldReturnNoCorrelatedValues() { - Optional> correlatedValues = target.getCorrelatedColumnValues( + Optional correlatedValues = target.getCorrelatedColumnValues( VALUE_COLUMN, SEARCH_VALUE_UNKNOWN, KEY_COLUMN, 10, - true); + true + ); Assert.assertEquals(Optional.of(ImmutableSet.of()), correlatedValues); } @Test - public void getNonNullColumnValuesIfAllUniqueForValueColumnShouldReturnValues() + public void getMatchableColumnValuesIfAllUniqueForValueColumnShouldReturnValues() { - final Joinable.ColumnValuesWithUniqueFlag values = target.getNonNullColumnValues(VALUE_COLUMN, Integer.MAX_VALUE); + final Joinable.ColumnValuesWithUniqueFlag values = + target.getMatchableColumnValues(VALUE_COLUMN, false, Integer.MAX_VALUE); - Assert.assertEquals(ImmutableSet.of("1", "2"), values.getColumnValues()); + Assert.assertEquals(ImmutableSet.of("1", "2", "3"), values.getColumnValues()); } @Test - public void getNonNullColumnValuesIfAllUniqueForNonexistentColumnShouldReturnEmpty() + public void getMatchableColumnValuesIfAllUniqueForNonexistentColumnShouldReturnEmpty() { - final Joinable.ColumnValuesWithUniqueFlag values = target.getNonNullColumnValues("nonexistent", Integer.MAX_VALUE); + final Joinable.ColumnValuesWithUniqueFlag values = + target.getMatchableColumnValues("nonexistent", false, Integer.MAX_VALUE); Assert.assertEquals(ImmutableSet.of(), values.getColumnValues()); } @Test - public void getNonNullColumnValuesIfAllUniqueForKeyColumnShouldReturnValues() + public void getMatchableColumnValuesIfAllUniqueForKeyColumnShouldReturnValues() { - final Joinable.ColumnValuesWithUniqueFlag values = target.getNonNullColumnValues(KEY_COLUMN, Integer.MAX_VALUE); + final Joinable.ColumnValuesWithUniqueFlag values = + target.getMatchableColumnValues(KEY_COLUMN, false, Integer.MAX_VALUE); Assert.assertEquals( ImmutableSet.of("foo", "bar", "baz"), values.getColumnValues() ); + + Assert.assertTrue(values.isAllUnique()); + } + + @Test + public void getMatchableColumnValuesWithIncludeNullIfAllUniqueForKeyColumnShouldReturnValues() + { + final Joinable.ColumnValuesWithUniqueFlag values = + target.getMatchableColumnValues(KEY_COLUMN, true, Integer.MAX_VALUE); + + Assert.assertEquals( + InDimFilter.ValuesSet.copyOf(Arrays.asList(null, "foo", "bar", "baz")), + values.getColumnValues() + ); + + Assert.assertTrue(values.isAllUnique()); } @Test - public void getNonNullColumnValuesIfAllUniqueForAllSameColumnShouldReturnEmpty() + public void getMatchableColumnValuesIfAllUniqueForAllSameColumnShouldReturnEmpty() { - final Joinable.ColumnValuesWithUniqueFlag values = target.getNonNullColumnValues(ALL_SAME_COLUMN, Integer.MAX_VALUE); + final Joinable.ColumnValuesWithUniqueFlag values = + target.getMatchableColumnValues(ALL_SAME_COLUMN, false, Integer.MAX_VALUE); Assert.assertEquals( ImmutableSet.of("1"), @@ -385,9 +420,10 @@ public void getNonNullColumnValuesIfAllUniqueForAllSameColumnShouldReturnEmpty() } @Test - public void getNonNullColumnValuesIfAllUniqueForKeyColumnWithLowMaxValuesShouldReturnEmpty() + public void getMatchableColumnValuesIfAllUniqueForKeyColumnWithLowMaxValuesShouldReturnEmpty() { - final Joinable.ColumnValuesWithUniqueFlag values = target.getNonNullColumnValues(KEY_COLUMN, 1); + final Joinable.ColumnValuesWithUniqueFlag values = + target.getMatchableColumnValues(KEY_COLUMN, false, 1); Assert.assertEquals(ImmutableSet.of(), values.getColumnValues()); } diff --git a/processing/src/test/java/org/apache/druid/segment/join/table/RowBasedIndexBuilderTest.java b/processing/src/test/java/org/apache/druid/segment/join/table/RowBasedIndexBuilderTest.java index d399b971dafa..d6cd74f55451 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/table/RowBasedIndexBuilderTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/table/RowBasedIndexBuilderTest.java @@ -23,6 +23,7 @@ import it.unimi.dsi.fastutil.ints.IntSortedSet; import org.apache.druid.segment.column.ColumnType; import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; import org.junit.Assert; import org.junit.Rule; import org.junit.Test; @@ -35,6 +36,35 @@ public class RowBasedIndexBuilderTest @Test public void test_stringKey_uniqueKeys() + { + final RowBasedIndexBuilder builder = + new RowBasedIndexBuilder(ColumnType.STRING) + .add("abc") + .add("") + .add("1") + .add("def"); + + final IndexedTable.Index index = builder.build(); + + MatcherAssert.assertThat(index, CoreMatchers.instanceOf(MapIndex.class)); + Assert.assertEquals(ColumnType.STRING, index.keyType()); + Assert.assertTrue(index.areKeysUnique(false)); + Assert.assertTrue(index.areKeysUnique(true)); + + Assert.assertEquals(intSet(0), index.find("abc")); + Assert.assertEquals(intSet(1), index.find("")); + Assert.assertEquals(intSet(2), index.find(1L)); + Assert.assertEquals(intSet(2), index.find("1")); + Assert.assertEquals(intSet(3), index.find("def")); + Assert.assertEquals(intSet(), index.find(null)); + Assert.assertEquals(intSet(), index.find("nonexistent")); + + expectedException.expect(UnsupportedOperationException.class); + index.findUniqueLong(0L); + } + + @Test + public void test_stringKey_uniqueKeysWithNull() { final RowBasedIndexBuilder builder = new RowBasedIndexBuilder(ColumnType.STRING) @@ -46,16 +76,48 @@ public void test_stringKey_uniqueKeys() final IndexedTable.Index index = builder.build(); - Assert.assertThat(index, CoreMatchers.instanceOf(MapIndex.class)); + MatcherAssert.assertThat(index, CoreMatchers.instanceOf(MapIndex.class)); Assert.assertEquals(ColumnType.STRING, index.keyType()); - Assert.assertTrue(index.areKeysUnique()); + Assert.assertTrue(index.areKeysUnique(false)); + Assert.assertTrue(index.areKeysUnique(true)); Assert.assertEquals(intSet(0), index.find("abc")); Assert.assertEquals(intSet(1), index.find("")); Assert.assertEquals(intSet(3), index.find(1L)); Assert.assertEquals(intSet(3), index.find("1")); Assert.assertEquals(intSet(4), index.find("def")); - Assert.assertEquals(intSet(), index.find(null)); + Assert.assertEquals(intSet(2), index.find(null)); + Assert.assertEquals(intSet(), index.find("nonexistent")); + + expectedException.expect(UnsupportedOperationException.class); + index.findUniqueLong(0L); + } + + @Test + public void test_stringKey_duplicateNullKey() + { + final RowBasedIndexBuilder builder = + new RowBasedIndexBuilder(ColumnType.STRING) + .add("abc") + .add("") + .add(null) + .add("1") + .add(null) + .add("def"); + + final IndexedTable.Index index = builder.build(); + + MatcherAssert.assertThat(index, CoreMatchers.instanceOf(MapIndex.class)); + Assert.assertEquals(ColumnType.STRING, index.keyType()); + Assert.assertTrue(index.areKeysUnique(false)); + Assert.assertFalse(index.areKeysUnique(true)); + + Assert.assertEquals(intSet(0), index.find("abc")); + Assert.assertEquals(intSet(1), index.find("")); + Assert.assertEquals(intSet(3), index.find(1L)); + Assert.assertEquals(intSet(3), index.find("1")); + Assert.assertEquals(intSet(5), index.find("def")); + Assert.assertEquals(intSet(2, 4), index.find(null)); Assert.assertEquals(intSet(), index.find("nonexistent")); expectedException.expect(UnsupportedOperationException.class); @@ -76,16 +138,17 @@ public void test_stringKey_duplicateKeys() final IndexedTable.Index index = builder.build(); - Assert.assertThat(index, CoreMatchers.instanceOf(MapIndex.class)); + MatcherAssert.assertThat(index, CoreMatchers.instanceOf(MapIndex.class)); Assert.assertEquals(ColumnType.STRING, index.keyType()); - Assert.assertFalse(index.areKeysUnique()); + Assert.assertFalse(index.areKeysUnique(false)); + Assert.assertFalse(index.areKeysUnique(true)); Assert.assertEquals(intSet(0, 3), index.find("abc")); Assert.assertEquals(intSet(1), index.find("")); Assert.assertEquals(intSet(4), index.find(1L)); Assert.assertEquals(intSet(4), index.find("1")); Assert.assertEquals(intSet(5), index.find("def")); - Assert.assertEquals(intSet(), index.find(null)); + Assert.assertEquals(intSet(2), index.find(null)); Assert.assertEquals(intSet(), index.find("nonexistent")); expectedException.expect(UnsupportedOperationException.class); @@ -103,14 +166,44 @@ public void test_longKey_uniqueKeys() final IndexedTable.Index index = builder.build(); - Assert.assertThat(index, CoreMatchers.instanceOf(UniqueLongArrayIndex.class)); + MatcherAssert.assertThat(index, CoreMatchers.instanceOf(UniqueLongArrayIndex.class)); Assert.assertEquals(ColumnType.LONG, index.keyType()); - Assert.assertTrue(index.areKeysUnique()); + Assert.assertTrue(index.areKeysUnique(false)); Assert.assertEquals(intSet(0), index.find(1L)); Assert.assertEquals(intSet(1), index.find(5L)); Assert.assertEquals(intSet(2), index.find(2L)); Assert.assertEquals(intSet(), index.find(3L)); + Assert.assertEquals(intSet(), index.find(null)); + + Assert.assertEquals(0, index.findUniqueLong(1L)); + Assert.assertEquals(1, index.findUniqueLong(5L)); + Assert.assertEquals(2, index.findUniqueLong(2L)); + Assert.assertEquals(IndexedTable.Index.NOT_FOUND, index.findUniqueLong(3L)); + } + + @Test + public void test_longKey_uniqueKeysWithNull() + { + final RowBasedIndexBuilder builder = + new RowBasedIndexBuilder(ColumnType.LONG) + .add(1) + .add(5) + .add(2) + .add(null); + + final IndexedTable.Index index = builder.build(); + + MatcherAssert.assertThat(index, CoreMatchers.instanceOf(MapIndex.class)); + Assert.assertEquals(ColumnType.LONG, index.keyType()); + Assert.assertTrue(index.areKeysUnique(false)); + Assert.assertTrue(index.areKeysUnique(true)); + + Assert.assertEquals(intSet(0), index.find(1L)); + Assert.assertEquals(intSet(1), index.find(5L)); + Assert.assertEquals(intSet(2), index.find(2L)); + Assert.assertEquals(intSet(), index.find(3L)); + Assert.assertEquals(intSet(3), index.find(null)); Assert.assertEquals(0, index.findUniqueLong(1L)); Assert.assertEquals(1, index.findUniqueLong(5L)); @@ -129,14 +222,15 @@ public void test_longKey_uniqueKeys_farApart() final IndexedTable.Index index = builder.build(); - Assert.assertThat(index, CoreMatchers.instanceOf(MapIndex.class)); + MatcherAssert.assertThat(index, CoreMatchers.instanceOf(MapIndex.class)); Assert.assertEquals(ColumnType.LONG, index.keyType()); - Assert.assertTrue(index.areKeysUnique()); + Assert.assertTrue(index.areKeysUnique(false)); Assert.assertEquals(intSet(0), index.find(1L)); Assert.assertEquals(intSet(1), index.find(10_000_000L)); Assert.assertEquals(intSet(2), index.find(2L)); Assert.assertEquals(intSet(), index.find(3L)); + Assert.assertEquals(intSet(), index.find(null)); Assert.assertEquals(0, index.findUniqueLong(1L)); Assert.assertEquals(1, index.findUniqueLong(10_000_000L)); @@ -156,9 +250,9 @@ public void test_longKey_duplicateKeys() final IndexedTable.Index index = builder.build(); - Assert.assertThat(index, CoreMatchers.instanceOf(MapIndex.class)); + MatcherAssert.assertThat(index, CoreMatchers.instanceOf(MapIndex.class)); Assert.assertEquals(ColumnType.LONG, index.keyType()); - Assert.assertFalse(index.areKeysUnique()); + Assert.assertFalse(index.areKeysUnique(false)); Assert.assertEquals(intSet(0, 2), index.find("1")); Assert.assertEquals(intSet(0, 2), index.find(1)); diff --git a/processing/src/test/java/org/apache/druid/segment/join/table/RowBasedIndexedTableTest.java b/processing/src/test/java/org/apache/druid/segment/join/table/RowBasedIndexedTableTest.java index 789bac28ea9a..aef371bcf556 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/table/RowBasedIndexedTableTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/table/RowBasedIndexedTableTest.java @@ -131,7 +131,7 @@ public void test_columnIndex_regionsRegionIsoCode() { final IndexedTable.Index index = regionsTable.columnIndex(INDEX_REGIONS_REGION_ISO_CODE); - Assert.assertEquals(ImmutableSet.of(), index.find(null)); + Assert.assertEquals(ImmutableSet.of(21), index.find(null)); Assert.assertEquals(ImmutableSet.of(0), index.find("11")); Assert.assertEquals(ImmutableSet.of(1), index.find(13)); Assert.assertEquals(ImmutableSet.of(12), index.find("QC")); diff --git a/processing/src/test/java/org/apache/druid/segment/nested/NestedFieldColumnIndexSupplierTest.java b/processing/src/test/java/org/apache/druid/segment/nested/NestedFieldColumnIndexSupplierTest.java index fffb9068a3ef..a691f4470a98 100644 --- a/processing/src/test/java/org/apache/druid/segment/nested/NestedFieldColumnIndexSupplierTest.java +++ b/processing/src/test/java/org/apache/druid/segment/nested/NestedFieldColumnIndexSupplierTest.java @@ -442,7 +442,7 @@ public void testSingleTypeStringColumnPredicateIndex() throws IOException Assert.assertNotNull(predicateIndex); DruidPredicateFactory predicateFactory = new InDimFilter.InFilterDruidPredicateFactory( null, - new InDimFilter.ValuesSet(ImmutableSet.of("b", "z")) + InDimFilter.ValuesSet.copyOf(ImmutableSet.of("b", "z")) ); // 10 rows @@ -607,7 +607,7 @@ public void testSingleValueStringWithNullPredicateIndex() throws IOException Assert.assertNotNull(predicateIndex); DruidPredicateFactory predicateFactory = new InDimFilter.InFilterDruidPredicateFactory( null, - new InDimFilter.ValuesSet(ImmutableSet.of("b", "z")) + InDimFilter.ValuesSet.copyOf(ImmutableSet.of("b", "z")) ); // 10 rows @@ -728,7 +728,7 @@ public void testSingleTypeLongColumnPredicateIndex() throws IOException Assert.assertNotNull(predicateIndex); DruidPredicateFactory predicateFactory = new InDimFilter.InFilterDruidPredicateFactory( null, - new InDimFilter.ValuesSet(ImmutableSet.of("1", "3")) + InDimFilter.ValuesSet.copyOf(ImmutableSet.of("1", "3")) ); // 10 rows @@ -880,7 +880,7 @@ public void testSingleValueLongWithNullPredicateIndex() throws IOException Assert.assertNotNull(predicateIndex); DruidPredicateFactory predicateFactory = new InDimFilter.InFilterDruidPredicateFactory( null, - new InDimFilter.ValuesSet(ImmutableSet.of("3", "100")) + InDimFilter.ValuesSet.copyOf(ImmutableSet.of("3", "100")) ); // 10 rows @@ -1025,7 +1025,7 @@ public void testSingleTypeDoubleColumnPredicateIndex() throws IOException Assert.assertNotNull(predicateIndex); DruidPredicateFactory predicateFactory = new InDimFilter.InFilterDruidPredicateFactory( null, - new InDimFilter.ValuesSet(ImmutableSet.of("1.2", "3.3", "5.0")) + InDimFilter.ValuesSet.copyOf(ImmutableSet.of("1.2", "3.3", "5.0")) ); // 10 rows @@ -1162,7 +1162,7 @@ public void testSingleValueDoubleWithNullPredicateIndex() throws IOException Assert.assertNotNull(predicateIndex); DruidPredicateFactory predicateFactory = new InDimFilter.InFilterDruidPredicateFactory( null, - new InDimFilter.ValuesSet(ImmutableSet.of("1.2", "3.3")) + InDimFilter.ValuesSet.copyOf(ImmutableSet.of("1.2", "3.3")) ); // 10 rows @@ -1277,7 +1277,7 @@ public void testVariantPredicateIndex() throws IOException Assert.assertNotNull(predicateIndex); DruidPredicateFactory predicateFactory = new InDimFilter.InFilterDruidPredicateFactory( null, - new InDimFilter.ValuesSet(ImmutableSet.of("b", "z", "9.9", "300")) + InDimFilter.ValuesSet.copyOf(ImmutableSet.of("b", "z", "9.9", "300")) ); // 10 rows @@ -1485,7 +1485,7 @@ public double skipValuePredicateIndexScale() // circuit early and return nothing DruidPredicateFactory predicateFactory = new InDimFilter.InFilterDruidPredicateFactory( null, - new InDimFilter.ValuesSet(ImmutableSet.of("0")) + InDimFilter.ValuesSet.copyOf(ImmutableSet.of("0")) ); Assert.assertNull(singleTypeStringSupplier.as(DruidPredicateIndexes.class).forPredicate(predicateFactory)); Assert.assertNull(singleTypeLongSupplier.as(DruidPredicateIndexes.class).forPredicate(predicateFactory)); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/Expressions.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/Expressions.java index ef572c8b6219..438c666227e6 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/Expressions.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/Expressions.java @@ -552,6 +552,8 @@ private static DimFilter toSimpleLeafFilter( return kind == SqlKind.IS_NOT_NULL ? new NotDimFilter(equalFilter) : equalFilter; } else if (kind == SqlKind.EQUALS || kind == SqlKind.NOT_EQUALS + || kind == SqlKind.IS_NOT_DISTINCT_FROM + || kind == SqlKind.IS_DISTINCT_FROM || kind == SqlKind.GREATER_THAN || kind == SqlKind.GREATER_THAN_OR_EQUAL || kind == SqlKind.LESS_THAN @@ -577,6 +579,8 @@ private static DimFilter toSimpleLeafFilter( switch (kind) { case EQUALS: case NOT_EQUALS: + case IS_NOT_DISTINCT_FROM: + case IS_DISTINCT_FROM: flippedKind = kind; break; case GREATER_THAN: @@ -688,9 +692,13 @@ private static DimFilter toSimpleLeafFilter( // Always use BoundDimFilters, to simplify filter optimization later (it helps to remember the comparator). switch (flippedKind) { case EQUALS: + case IS_NOT_DISTINCT_FROM: + // OK to treat EQUALS, IS_NOT_DISTINCT_FROM the same since we know stringVal is nonnull. filter = Bounds.equalTo(boundRefKey, stringVal); break; case NOT_EQUALS: + case IS_DISTINCT_FROM: + // OK to treat NOT_EQUALS, IS_DISTINCT_FROM the same since we know stringVal is nonnull. filter = new NotDimFilter(Bounds.equalTo(boundRefKey, stringVal)); break; case GREATER_THAN: @@ -724,9 +732,11 @@ private static DimFilter toSimpleLeafFilter( // Always use RangeFilter, to simplify filter optimization later switch (flippedKind) { case EQUALS: + case IS_NOT_DISTINCT_FROM: filter = Ranges.equalTo(rangeRefKey, val); break; case NOT_EQUALS: + case IS_DISTINCT_FROM: filter = new NotDimFilter(Ranges.equalTo(rangeRefKey, val)); break; case GREATER_THAN: diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOverlapOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOverlapOperatorConversion.java index 8e09ea0c7340..24fac69d11d2 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOverlapOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOverlapOperatorConversion.java @@ -41,9 +41,7 @@ import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; import javax.annotation.Nullable; -import java.util.Arrays; import java.util.List; -import java.util.stream.Collectors; public class ArrayOverlapOperatorConversion extends BaseExpressionDimFilterOperatorConversion { @@ -138,9 +136,14 @@ public DimFilter toDruidFilter( ); } } else { + final InDimFilter.ValuesSet valuesSet = InDimFilter.ValuesSet.create(); + for (final Object arrayElement : arrayElements) { + valuesSet.add(Evals.asString(arrayElement)); + } + return new InDimFilter( simpleExtractionExpr.getSimpleExtraction().getColumn(), - new InDimFilter.ValuesSet(Arrays.stream(arrayElements).map(Evals::asString).collect(Collectors.toList())), + valuesSet, simpleExtractionExpr.getSimpleExtraction().getExtractionFn(), null ); diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java index 16748b0b6ab5..e392ad8a47bf 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/DruidOperatorTable.java @@ -371,6 +371,8 @@ public class DruidOperatorTable implements SqlOperatorTable .add(new UnaryPrefixOperatorConversion(SqlStdOperatorTable.UNARY_MINUS, "-")) .add(new UnaryFunctionOperatorConversion(SqlStdOperatorTable.IS_NULL, "isnull")) .add(new UnaryFunctionOperatorConversion(SqlStdOperatorTable.IS_NOT_NULL, "notnull")) + .add(new DirectOperatorConversion(SqlStdOperatorTable.IS_DISTINCT_FROM, "isdistinctfrom")) + .add(new DirectOperatorConversion(SqlStdOperatorTable.IS_NOT_DISTINCT_FROM, "notdistinctfrom")) .add(new DirectOperatorConversion(SqlStdOperatorTable.IS_FALSE, "isfalse")) .add(new DirectOperatorConversion(SqlStdOperatorTable.IS_TRUE, "istrue")) .add(new DirectOperatorConversion(SqlStdOperatorTable.IS_NOT_FALSE, "notfalse")) diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java index 94c13f6a94c0..6dc8ff00531b 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java @@ -38,11 +38,12 @@ import org.apache.calcite.rex.RexSlot; import org.apache.calcite.rex.RexUtil; import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.tools.RelBuilder; import org.apache.calcite.util.ImmutableBitSet; -import org.apache.druid.java.util.common.Pair; +import org.apache.druid.error.DruidException; import org.apache.druid.query.LookupDataSource; import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.rel.DruidJoinQueryRel; @@ -53,7 +54,6 @@ import java.util.ArrayList; import java.util.HashSet; import java.util.List; -import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.Stack; @@ -242,7 +242,7 @@ private Optional analyzeCondition( ) { final List subConditions = decomposeAnd(condition); - final List> equalitySubConditions = new ArrayList<>(); + final List equalitySubConditions = new ArrayList<>(); final List literalSubConditions = new ArrayList<>(); final int numLeftFields = leftRowType.getFieldCount(); final Set rightColumns = new HashSet<>(); @@ -271,10 +271,12 @@ private Optional analyzeCondition( RexNode firstOperand; RexNode secondOperand; + SqlKind comparisonKind; if (subCondition.isA(SqlKind.INPUT_REF)) { firstOperand = rexBuilder.makeLiteral(true); secondOperand = subCondition; + comparisonKind = SqlKind.EQUALS; if (!SqlTypeName.BOOLEAN_TYPES.contains(secondOperand.getType().getSqlTypeName())) { plannerContext.setPlanningError( @@ -285,11 +287,12 @@ private Optional analyzeCondition( return Optional.empty(); } - } else if (subCondition.isA(SqlKind.EQUALS)) { + } else if (subCondition.isA(SqlKind.EQUALS) || subCondition.isA(SqlKind.IS_NOT_DISTINCT_FROM)) { final List operands = ((RexCall) subCondition).getOperands(); Preconditions.checkState(operands.size() == 2, "Expected 2 operands, got[%s]", operands.size()); firstOperand = operands.get(0); secondOperand = operands.get(1); + comparisonKind = subCondition.getKind(); } else { // If it's not EQUALS or a BOOLEAN input ref, it's not supported. plannerContext.setPlanningError( @@ -300,11 +303,11 @@ private Optional analyzeCondition( } if (isLeftExpression(firstOperand, numLeftFields) && isRightInputRef(secondOperand, numLeftFields)) { - equalitySubConditions.add(Pair.of(firstOperand, (RexInputRef) secondOperand)); + equalitySubConditions.add(new RexEquality(firstOperand, (RexInputRef) secondOperand, comparisonKind)); rightColumns.add((RexInputRef) secondOperand); } else if (isRightInputRef(firstOperand, numLeftFields) && isLeftExpression(secondOperand, numLeftFields)) { - equalitySubConditions.add(Pair.of(secondOperand, (RexInputRef) firstOperand)); + equalitySubConditions.add(new RexEquality(secondOperand, (RexInputRef) firstOperand, subCondition.getKind())); rightColumns.add((RexInputRef) firstOperand); } else { // Cannot handle this condition. @@ -336,7 +339,8 @@ && isLeftExpression(secondOperand, numLeftFields)) { numLeftFields, equalitySubConditions, literalSubConditions - )); + ) + ); } @VisibleForTesting @@ -375,7 +379,6 @@ private static boolean isRightInputRef(final RexNode rexNode, final int numLeftF return rexNode.isA(SqlKind.INPUT_REF) && ((RexInputRef) rexNode).getIndex() >= numLeftFields; } - @VisibleForTesting static class ConditionAnalysis { /** @@ -387,17 +390,16 @@ static class ConditionAnalysis /** * Each equality subcondition is an equality of the form f(LeftRel) = g(RightRel). */ - private final List> equalitySubConditions; + private final List equalitySubConditions; /** * Each literal subcondition is... a literal. */ private final List literalSubConditions; - ConditionAnalysis( int numLeftFields, - List> equalitySubConditions, + List equalitySubConditions, List literalSubConditions ) { @@ -417,9 +419,10 @@ public ConditionAnalysis pushThroughLeftProject(final Project leftProject) equalitySubConditions .stream() .map( - equality -> Pair.of( - RelOptUtil.pushPastProject(equality.lhs, leftProject), - (RexInputRef) RexUtil.shift(equality.rhs, rhsShift) + equality -> new RexEquality( + RelOptUtil.pushPastProject(equality.left, leftProject), + (RexInputRef) RexUtil.shift(equality.right, rhsShift), + equality.kind ) ) .collect(Collectors.toList()), @@ -436,15 +439,16 @@ public ConditionAnalysis pushThroughRightProject(final Project rightProject) equalitySubConditions .stream() .map( - equality -> Pair.of( - equality.lhs, + equality -> new RexEquality( + equality.left, (RexInputRef) RexUtil.shift( RelOptUtil.pushPastProject( - RexUtil.shift(equality.rhs, -numLeftFields), + RexUtil.shift(equality.right, -numLeftFields), rightProject ), numLeftFields - ) + ), + equality.kind ) ) .collect(Collectors.toList()), @@ -454,8 +458,8 @@ public ConditionAnalysis pushThroughRightProject(final Project rightProject) public boolean onlyUsesMappingsFromRightProject(final Project rightProject) { - for (Pair equality : equalitySubConditions) { - final int rightIndex = equality.rhs.getIndex() - numLeftFields; + for (final RexEquality equality : equalitySubConditions) { + final int rightIndex = equality.right.getIndex() - numLeftFields; if (!rightProject.getProjects().get(rightIndex).isA(SqlKind.INPUT_REF)) { return false; @@ -473,7 +477,7 @@ public RexNode getCondition(final RexBuilder rexBuilder) literalSubConditions, equalitySubConditions .stream() - .map(equality -> rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, equality.lhs, equality.rhs)) + .map(equality -> equality.makeCall(rexBuilder)) .collect(Collectors.toList()) ), false @@ -481,31 +485,55 @@ public RexNode getCondition(final RexBuilder rexBuilder) } @Override - public boolean equals(Object o) + public String toString() { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - ConditionAnalysis that = (ConditionAnalysis) o; - return Objects.equals(equalitySubConditions, that.equalitySubConditions) && - Objects.equals(literalSubConditions, that.literalSubConditions); + return "ConditionAnalysis{" + + "numLeftFields=" + numLeftFields + + ", equalitySubConditions=" + equalitySubConditions + + ", literalSubConditions=" + literalSubConditions + + '}'; } + } - @Override - public int hashCode() + /** + * Like {@link org.apache.druid.segment.join.Equality} but uses {@link RexNode} instead of + * {@link org.apache.druid.math.expr.Expr}. + */ + static class RexEquality + { + private final RexNode left; + private final RexInputRef right; + private final SqlKind kind; + + public RexEquality(RexNode left, RexInputRef right, SqlKind kind) + { + this.left = left; + this.right = right; + this.kind = kind; + } + + public RexNode makeCall(final RexBuilder builder) { - return Objects.hash(equalitySubConditions, literalSubConditions); + final SqlOperator operator; + + if (kind == SqlKind.EQUALS) { + operator = SqlStdOperatorTable.EQUALS; + } else if (kind == SqlKind.IS_NOT_DISTINCT_FROM) { + operator = SqlStdOperatorTable.IS_NOT_DISTINCT_FROM; + } else { + throw DruidException.defensive("Unexpected operator kind[%s]", kind); + } + + return builder.makeCall(operator, left, right); } @Override public String toString() { - return "ConditionAnalysis{" + - "equalitySubConditions=" + equalitySubConditions + - ", literalSubConditions=" + literalSubConditions + + return "RexEquality{" + + "left=" + left + + ", right=" + right + + ", kind=" + kind + '}'; } } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java index d0d7935a334d..a0e96a876379 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java @@ -3619,6 +3619,105 @@ public void testLeftJoinWithNotNullFilter(Map queryContext) ); } + @Test + @Parameters(source = QueryContextForJoinProvider.class) + public void testInnerJoin(Map queryContext) + { + testQuery( + "SELECT s.dim1, t.dim1\n" + + "FROM foo as s\n" + + "INNER JOIN foo as t " + + "ON s.dim1 = t.dim1", + queryContext, + ImmutableList.of( + newScanQueryBuilder() + .dataSource( + join( + new TableDataSource(CalciteTests.DATASOURCE1), + new QueryDataSource(newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns(ImmutableList.of("dim1")) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .context(QUERY_CONTEXT_DEFAULT) + .build()), + "j0.", + "(\"dim1\" == \"j0.dim1\")", + JoinType.INNER + ) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns("dim1", "j0.dim1") + .context(queryContext) + .build() + ), + sortIfSortBased( + NullHandling.sqlCompatible() + ? ImmutableList.of( + new Object[]{"", ""}, + new Object[]{"10.1", "10.1"}, + new Object[]{"2", "2"}, + new Object[]{"1", "1"}, + new Object[]{"def", "def"}, + new Object[]{"abc", "abc"} + ) + : ImmutableList.of( + new Object[]{"10.1", "10.1"}, + new Object[]{"2", "2"}, + new Object[]{"1", "1"}, + new Object[]{"def", "def"}, + new Object[]{"abc", "abc"} + ), + 0 + ) + ); + } + + @Test + @Parameters(source = QueryContextForJoinProvider.class) + public void testJoinWithExplicitIsNotDistinctFromCondition(Map queryContext) + { + // Like "testInnerJoin", but uses IS NOT DISTINCT FROM instead of equals. + + testQuery( + "SELECT s.dim1, t.dim1\n" + + "FROM foo as s\n" + + "INNER JOIN foo as t " + + "ON s.dim1 IS NOT DISTINCT FROM t.dim1", + queryContext, + ImmutableList.of( + newScanQueryBuilder() + .dataSource( + join( + new TableDataSource(CalciteTests.DATASOURCE1), + new QueryDataSource(newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns(ImmutableList.of("dim1")) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .context(QUERY_CONTEXT_DEFAULT) + .build()), + "j0.", + "notdistinctfrom(\"dim1\",\"j0.dim1\")", + JoinType.INNER + ) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns("dim1", "j0.dim1") + .context(queryContext) + .build() + ), + ImmutableList.of( + new Object[]{"", ""}, + new Object[]{"10.1", "10.1"}, + new Object[]{"2", "2"}, + new Object[]{"1", "1"}, + new Object[]{"def", "def"}, + new Object[]{"abc", "abc"} + ) + ); + } + @Test @Parameters(source = QueryContextForJoinProvider.class) public void testInnerJoinSubqueryWithSelectorFilter(Map queryContext) @@ -4416,6 +4515,51 @@ public void testCountDistinctOfLookupUsingJoinOperator(Map query ); } + @Test + @Parameters(source = QueryContextForJoinProvider.class) + public void testJoinWithImplicitIsNotDistinctFromCondition(Map queryContext) + { + // Like "testInnerJoin", but uses an implied is-not-distinct-from instead of equals. + cannotVectorize(); + + testQuery( + "SELECT x.m1, y.m1\n" + + "FROM foo x INNER JOIN foo y ON (x.m1 = y.m1) OR (x.m1 IS NULL AND y.m1 IS NULL)", + queryContext, + ImmutableList.of( + newScanQueryBuilder() + .dataSource( + join( + new TableDataSource(CalciteTests.DATASOURCE1), + new QueryDataSource( + newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns("m1") + .context(queryContext) + .build() + ), + "j0.", + "notdistinctfrom(\"m1\",\"j0.m1\")", + JoinType.INNER + ) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns("j0.m1", "m1") + .context(queryContext) + .build() + ), + ImmutableList.of( + new Object[]{1.0f, 1.0f}, + new Object[]{2.0f, 2.0f}, + new Object[]{3.0f, 3.0f}, + new Object[]{4.0f, 4.0f}, + new Object[]{5.0f, 5.0f}, + new Object[]{6.0f, 6.0f} + ) + ); + } + @Test @Parameters(source = QueryContextForJoinProvider.class) public void testJoinWithNonEquiCondition(Map queryContext) diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index 042b17368278..6516358b1b59 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -84,6 +84,7 @@ import org.apache.druid.query.filter.EqualityFilter; import org.apache.druid.query.filter.InDimFilter; import org.apache.druid.query.filter.LikeDimFilter; +import org.apache.druid.query.filter.NotDimFilter; import org.apache.druid.query.filter.NullFilter; import org.apache.druid.query.filter.RangeFilter; import org.apache.druid.query.filter.RegexDimFilter; @@ -5687,25 +5688,34 @@ public void testCountStarWithBoundFilterSimplifyOr() } @Test - public void testUnplannableTwoExactCountDistincts() + public void testUnplannableExactCountDistinctOnSketch() { - // Requires GROUPING SETS + GROUPING to be translated by AggregateExpandDistinctAggregatesRule. - + // COUNT DISTINCT on a sketch cannot be exact. assertQueryIsUnplannable( PLANNER_CONFIG_NO_HLL, - "SELECT dim2, COUNT(distinct dim1), COUNT(distinct dim2) FROM druid.foo GROUP BY dim2", - "SQL query requires 'IS NOT DISTINCT FROM' operator that is not supported." + "SELECT COUNT(distinct unique_dim1) FROM druid.foo", + "SQL requires a group-by on a column of type COMPLEX that is unsupported." ); } @Test - public void testUnplannableExactCountDistinctOnSketch() + public void testIsNotDistinctFromLiteral() { - // COUNT DISTINCT on a sketch cannot be exact. - assertQueryIsUnplannable( - PLANNER_CONFIG_NO_HLL, - "SELECT COUNT(distinct unique_dim1) FROM druid.foo", - "SQL requires a group-by on a column of type COMPLEX that is unsupported." + testQuery( + "SELECT COUNT(*) FROM druid.foo WHERE (dim1 >= 'a' and dim1 < 'b') OR dim1 IS NOT DISTINCT FROM 'ab'", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .filters(range("dim1", ColumnType.STRING, "a", "b", false, true)) + .aggregators(aggregators(new CountAggregatorFactory("a0"))) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{1L} + ) ); } @@ -6726,7 +6736,117 @@ public void testExactCountDistinctWithGroupingAndOtherAggregators() } @Test - public void testMultipleExactCountDistinctWithGroupingAndOtherAggregators() + public void testMultipleExactCountDistinctWithGroupingAndOtherAggregatorsUsingJoin() + { + // When HLL is disabled, do multiple exact count distincts through joins of nested queries. + + testQuery( + PLANNER_CONFIG_NO_HLL, + "SELECT dim2, COUNT(*), COUNT(distinct dim1), COUNT(distinct cnt) FROM druid.foo GROUP BY dim2", + CalciteTests.REGULAR_USER_AUTH_RESULT, + ImmutableList.of( + newScanQueryBuilder() + .dataSource( + join( + join( + new QueryDataSource( + GroupByQuery + .builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setGranularity(Granularities.ALL) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setDimensions(new DefaultDimensionSpec("dim2", "d0", ColumnType.STRING)) + .setAggregatorSpecs(new CountAggregatorFactory("a0")) + .build() + ), + new QueryDataSource( + GroupByQuery + .builder() + .setDataSource( + new QueryDataSource( + GroupByQuery + .builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setGranularity(Granularities.ALL) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setDimensions( + new DefaultDimensionSpec("dim2", "d0", ColumnType.STRING), + new DefaultDimensionSpec("dim1", "d1", ColumnType.STRING) + ) + .build() + ) + ) + .setGranularity(Granularities.ALL) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setDimensions(new DefaultDimensionSpec("d0", "_d0", ColumnType.STRING)) + .setAggregatorSpecs( + new FilteredAggregatorFactory( + new CountAggregatorFactory("a0"), + new NotDimFilter(isNull("d1", null)) + ) + ) + .build() + ), + "j0.", + "notdistinctfrom(\"d0\",\"j0._d0\")", + JoinType.INNER + ), + new QueryDataSource( + GroupByQuery.builder() + .setGranularity(Granularities.ALL) + .setDataSource( + new QueryDataSource( + GroupByQuery + .builder() + .setDataSource(CalciteTests.DATASOURCE1) + .setGranularity(Granularities.ALL) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setDimensions( + new DefaultDimensionSpec("dim2", "d0", ColumnType.STRING), + new DefaultDimensionSpec("cnt", "d1", ColumnType.LONG) + ) + .build() + ) + ) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setDimensions(new DefaultDimensionSpec("d0", "_d0", ColumnType.STRING)) + .setAggregatorSpecs( + NullHandling.sqlCompatible() + ? new FilteredAggregatorFactory( + new CountAggregatorFactory("a0"), + new NotDimFilter(isNull("d1", null)) + ) + : new CountAggregatorFactory("a0") + ) + .build() + ), + "_j0.", + "notdistinctfrom(\"d0\",\"_j0._d0\")", + JoinType.INNER + ) + ) + .context(QUERY_CONTEXT_DEFAULT) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns(ImmutableList.of("_j0.a0", "a0", "d0", "j0.a0")) + .build() + ), + NullHandling.sqlCompatible() + ? ImmutableList.of( + new Object[]{null, 2L, 2L, 1L}, + new Object[]{"", 1L, 1L, 1L}, + new Object[]{"a", 2L, 2L, 1L}, + new Object[]{"abc", 1L, 1L, 1L} + ) + : ImmutableList.of( + new Object[]{"", 3L, 3L, 1L}, + new Object[]{"a", 2L, 1L, 1L}, + new Object[]{"abc", 1L, 1L, 1L} + ) + ); + } + + @Test + public void testMultipleExactCountDistinctWithGroupingUsingGroupingSets() { notMsqCompatible(); requireMergeBuffers(4); @@ -12803,6 +12923,42 @@ public void testLookupWithNull() ); } + @Test + public void testLookupWithIsNotDistinctFromNull() + { + List expected; + if (useDefault) { + expected = ImmutableList.builder().add( + new Object[]{NULL_STRING, NULL_STRING}, + new Object[]{NULL_STRING, NULL_STRING}, + new Object[]{NULL_STRING, NULL_STRING} + ).build(); + } else { + expected = ImmutableList.builder().add( + new Object[]{NULL_STRING, NULL_STRING}, + new Object[]{NULL_STRING, NULL_STRING} + ).build(); + } + testQuery( + "SELECT dim2 ,lookup(dim2,'lookyloo') from foo where dim2 is not distinct from null", + ImmutableList.of( + new Druids.ScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .virtualColumns( + expressionVirtualColumn("v0", "null", ColumnType.STRING) + ) + .columns("v0") + .legacy(false) + .filters(isNull("dim2")) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + expected + ); + } + @Test public void testRoundFunc() { diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/DecoupledPlanningCalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/DecoupledPlanningCalciteQueryTest.java index c1552b0cfb85..7375ff385d17 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/DecoupledPlanningCalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/DecoupledPlanningCalciteQueryTest.java @@ -187,7 +187,7 @@ public void testGroupByWithSortOnPostAggregationNoTopNContext() @Override @Ignore - public void testUnplannableTwoExactCountDistincts() + public void testMultipleExactCountDistinctWithGroupingAndOtherAggregatorsUsingJoin() { }