diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java index 4328f080d47..b9c7f0f57df 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java @@ -35,10 +35,9 @@ import io.delta.kernel.defaults.internal.data.vector.DefaultBooleanVector; import io.delta.kernel.defaults.internal.data.vector.DefaultConstantVector; import static io.delta.kernel.defaults.internal.DefaultEngineErrors.unsupportedExpressionException; +import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.*; import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.booleanWrapperVector; import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.childAt; -import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.compare; -import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.evalNullability; import static io.delta.kernel.defaults.internal.expressions.ImplicitCastExpression.canCastTo; /** @@ -421,44 +420,37 @@ ColumnVector visitAlwaysFalse(AlwaysFalse alwaysFalse) { @Override ColumnVector visitComparator(Predicate predicate) { PredicateChildrenEvalResult argResults = evalBinaryExpressionChildren(predicate); - - int numRows = argResults.rowCount; - boolean[] result = new boolean[numRows]; - boolean[] nullability = evalNullability(argResults.leftResult, argResults.rightResult); - int[] compareResult = compare(argResults.leftResult, argResults.rightResult); switch (predicate.getName()) { case "=": - for (int rowId = 0; rowId < numRows; rowId++) { - result[rowId] = compareResult[rowId] == 0; - } - break; + return comparatorVector( + argResults.leftResult, + argResults.rightResult, + (compareResult) -> (compareResult == 0)); case ">": - for (int rowId = 0; rowId < numRows; rowId++) { - result[rowId] = compareResult[rowId] > 0; - } - break; + return comparatorVector( + argResults.leftResult, + argResults.rightResult, + (compareResult) -> (compareResult > 0)); case ">=": - for (int rowId = 0; rowId < numRows; rowId++) { - result[rowId] = compareResult[rowId] >= 0; - } - break; + return comparatorVector( + argResults.leftResult, + argResults.rightResult, + (compareResult) -> (compareResult >= 0)); case "<": - for (int rowId = 0; rowId < numRows; rowId++) { - result[rowId] = compareResult[rowId] < 0; - } - break; + return comparatorVector( + argResults.leftResult, + argResults.rightResult, + (compareResult) -> (compareResult < 0)); case "<=": - for (int rowId = 0; rowId < numRows; rowId++) { - result[rowId] = compareResult[rowId] <= 0; - } - break; + return comparatorVector( + argResults.leftResult, + argResults.rightResult, + (compareResult) -> (compareResult <= 0)); default: // We should never reach this based on the ExpressionVisitor throw new IllegalStateException( String.format("%s is not a recognized comparator", predicate.getName())); } - - return new DefaultBooleanVector(numRows, Optional.of(nullability), result); } @Override diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java index ef931b4f109..46047ed989d 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java @@ -19,6 +19,7 @@ import java.util.Comparator; import java.util.List; import java.util.function.Function; +import java.util.function.IntPredicate; import java.util.stream.Collectors; import io.delta.kernel.data.ArrayValue; @@ -33,6 +34,20 @@ * Utility methods used by the default expression evaluator. */ class DefaultExpressionUtils { + + static final Comparator BIGDECIMAL_COMPARATOR = Comparator.naturalOrder(); + static final Comparator STRING_COMPARATOR = Comparator.naturalOrder(); + static final Comparator BINARY_COMPARTOR = (leftOp, rightOp) -> { + int i = 0; + while (i < leftOp.length && i < rightOp.length) { + if (leftOp[i] != rightOp[i]) { + return Byte.compare(leftOp[i], rightOp[i]); + } + i++; + } + return Integer.compare(leftOp.length, rightOp.length); + }; + private DefaultExpressionUtils() {} /** @@ -87,138 +102,91 @@ public boolean getBoolean(int rowId) { } /** - * Utility method to compare the left and right according to the natural ordering - * and return an integer array where each row contains the comparison result (-1, 0, 1) for - * corresponding rows in the input vectors compared. + * Utility method to create a column vector that lazily evaluate the + * comparator ex. (ie. ==, >=, <=......) for left and right + * column vector according to the natural ordering of numbers *

* Only primitive data types are supported. */ - static int[] compare(ColumnVector left, ColumnVector right) { + static ColumnVector comparatorVector( + ColumnVector left, + ColumnVector right, + IntPredicate booleanComparator) { checkArgument( - left.getSize() == right.getSize(), - "Left and right operand have different vector sizes."); - DataType dataType = left.getDataType(); + left.getSize() == right.getSize(), + "Left and right operand have different vector sizes."); - int numRows = left.getSize(); - int[] result = new int[numRows]; + DataType dataType = left.getDataType(); + IntPredicate vectorValueComparator; if (dataType instanceof BooleanType) { - compareBoolean(left, right, result); + vectorValueComparator = rowId -> booleanComparator.test( + Boolean.compare(left.getBoolean(rowId), right.getBoolean(rowId))); } else if (dataType instanceof ByteType) { - compareByte(left, right, result); + vectorValueComparator = rowId -> booleanComparator.test( + Byte.compare(left.getByte(rowId), right.getByte(rowId))); } else if (dataType instanceof ShortType) { - compareShort(left, right, result); + vectorValueComparator = rowId -> booleanComparator.test( + Short.compare(left.getShort(rowId), right.getShort(rowId))); } else if (dataType instanceof IntegerType || dataType instanceof DateType) { - compareInt(left, right, result); + vectorValueComparator = rowId -> booleanComparator.test( + Integer.compare(left.getInt(rowId), right.getInt(rowId))); } else if (dataType instanceof LongType || dataType instanceof TimestampType || dataType instanceof TimestampNTZType) { - compareLong(left, right, result); + vectorValueComparator = rowId -> booleanComparator.test( + Long.compare(left.getLong(rowId), right.getLong(rowId))); } else if (dataType instanceof FloatType) { - compareFloat(left, right, result); + vectorValueComparator = rowId -> booleanComparator.test( + Float.compare(left.getFloat(rowId), right.getFloat(rowId))); } else if (dataType instanceof DoubleType) { - compareDouble(left, right, result); + vectorValueComparator = rowId -> booleanComparator.test( + Double.compare(left.getDouble(rowId), right.getDouble(rowId))); } else if (dataType instanceof DecimalType) { - compareDecimal(left, right, result); + vectorValueComparator = rowId -> booleanComparator.test( + BIGDECIMAL_COMPARATOR.compare( + left.getDecimal(rowId), right.getDecimal(rowId))); } else if (dataType instanceof StringType) { - compareString(left, right, result); + vectorValueComparator = rowId -> booleanComparator.test( + STRING_COMPARATOR.compare( + left.getString(rowId), right.getString(rowId))); } else if (dataType instanceof BinaryType) { - compareBinary(left, right, result); + vectorValueComparator = rowId -> booleanComparator.test( + BINARY_COMPARTOR.compare( + left.getBinary(rowId), right.getBinary(rowId))); } else { throw new UnsupportedOperationException(dataType + " can not be compared."); } - return result; - } - - static void compareBoolean(ColumnVector left, ColumnVector right, int[] result) { - for (int rowId = 0; rowId < left.getSize(); rowId++) { - if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { - result[rowId] = Boolean.compare(left.getBoolean(rowId), right.getBoolean(rowId)); - } - } - } - static void compareByte(ColumnVector left, ColumnVector right, int[] result) { - for (int rowId = 0; rowId < left.getSize(); rowId++) { - if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { - result[rowId] = Byte.compare(left.getByte(rowId), right.getByte(rowId)); - } - } - } - - static void compareShort(ColumnVector left, ColumnVector right, int[] result) { - for (int rowId = 0; rowId < left.getSize(); rowId++) { - if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { - result[rowId] = Short.compare(left.getShort(rowId), right.getShort(rowId)); - } - } - } - - static void compareInt(ColumnVector left, ColumnVector right, int[] result) { - for (int rowId = 0; rowId < left.getSize(); rowId++) { - if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { - result[rowId] = Integer.compare(left.getInt(rowId), right.getInt(rowId)); - } - } - } - - static void compareLong(ColumnVector left, ColumnVector right, int[] result) { - for (int rowId = 0; rowId < left.getSize(); rowId++) { - if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { - result[rowId] = Long.compare(left.getLong(rowId), right.getLong(rowId)); - } - } - } + return new ColumnVector() { - static void compareFloat(ColumnVector left, ColumnVector right, int[] result) { - for (int rowId = 0; rowId < left.getSize(); rowId++) { - if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { - result[rowId] = Float.compare(left.getFloat(rowId), right.getFloat(rowId)); + @Override + public DataType getDataType() { + return BooleanType.BOOLEAN; } - } - } - static void compareDouble(ColumnVector left, ColumnVector right, int[] result) { - for (int rowId = 0; rowId < left.getSize(); rowId++) { - if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { - result[rowId] = Double.compare(left.getDouble(rowId), right.getDouble(rowId)); + @Override + public void close() { + Utils.closeCloseables(left, right); } - } - } - static void compareString(ColumnVector left, ColumnVector right, int[] result) { - Comparator comparator = Comparator.naturalOrder(); - for (int rowId = 0; rowId < left.getSize(); rowId++) { - if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { - result[rowId] = comparator.compare(left.getString(rowId), right.getString(rowId)); + @Override + public int getSize() { + return left.getSize(); } - } - } - static void compareDecimal(ColumnVector left, ColumnVector right, int[] result) { - Comparator comparator = Comparator.naturalOrder(); - for (int rowId = 0; rowId < left.getSize(); rowId++) { - if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { - result[rowId] = comparator.compare(left.getDecimal(rowId), right.getDecimal(rowId)); + @Override + public boolean isNullAt(int rowId) { + return left.isNullAt(rowId) || right.isNullAt(rowId); } - } - } - static void compareBinary(ColumnVector left, ColumnVector right, int[] result) { - Comparator comparator = (leftOp, rightOp) -> { - int i = 0; - while (i < leftOp.length && i < rightOp.length) { - if (leftOp[i] != rightOp[i]) { - return Byte.compare(leftOp[i], rightOp[i]); + @Override + public boolean getBoolean(int rowId) { + if (isNullAt(rowId)) { + return false; } - i++; + return vectorValueComparator.test(rowId); } - return Integer.compare(leftOp.length, rightOp.length); }; - for (int rowId = 0; rowId < left.getSize(); rowId++) { - if (!left.isNullAt(rowId) && !right.isNullAt(rowId)) { - result[rowId] = comparator.compare(left.getBinary(rowId), right.getBinary(rowId)); - } - } } static Expression childAt(Expression expression, int index) {