Skip to content

Commit

Permalink
[Backport] Allow casted literal values in SQL functions accepting lit…
Browse files Browse the repository at this point in the history
…erals (Part 2) (#15322)

Backport of #15316 to 28.0.0.
  • Loading branch information
LakshSingla authored Nov 4, 2023
1 parent 80dc45e commit 283529f
Show file tree
Hide file tree
Showing 9 changed files with 438 additions and 10 deletions.
6 changes: 6 additions & 0 deletions codestyle/druid-forbidden-apis.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ java.util.Random#<init>() @ Use ThreadLocalRandom.current() or the constructor w
java.lang.Math#random() @ Use ThreadLocalRandom.current()
java.util.regex.Pattern#matches(java.lang.String,java.lang.CharSequence) @ Use String.startsWith(), endsWith(), contains(), or compile and cache a Pattern explicitly
org.apache.calcite.sql.type.OperandTypes#LITERAL @ LITERAL type checker throws when literals with CAST are passed. Use org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker instead.
org.apache.calcite.sql.type.OperandTypes#BOOLEAN_LITERAL @ Create a type checker like org.apache.calcite.sql.type.POSITIVE_INTEGER_LITERAL and use that instead
org.apache.calcite.sql.type.OperandTypes#ARRAY_BOOLEAN_LITERAL @ Create a type checker like org.apache.calcite.sql.type.POSITIVE_INTEGER_LITERAL and use that instead
org.apache.calcite.sql.type.OperandTypes#POSITIVE_INTEGER_LITERAL @ Use org.apache.calcite.sql.type.POSITIVE_INTEGER_LITERAL instead
org.apache.calcite.sql.type.OperandTypes#UNIT_INTERVAL_NUMERIC_LITERAL @ Create a type checker like org.apache.calcite.sql.type.POSITIVE_INTEGER_LITERAL and use that instead
org.apache.calcite.sql.type.OperandTypes#NUMERIC_UNIT_INTERVAL_NUMERIC_LITERAL @ Create a type checker like org.apache.calcite.sql.type.POSITIVE_INTEGER_LITERAL and use that instead
org.apache.calcite.sql.type.OperandTypes#NULLABLE_LITERAL @ Create an instance of org.apache.calcite.sql.type.CastedLiteralOperandTypeChecker that allows nulls and use that instead
org.apache.commons.io.FileUtils#getTempDirectory() @ Use org.junit.rules.TemporaryFolder for tests instead
org.apache.commons.io.FileUtils#deleteDirectory(java.io.File) @ Use org.apache.druid.java.util.common.FileUtils#deleteDirectory()
org.apache.commons.io.FileUtils#forceMkdir(java.io.File) @ Use org.apache.druid.java.util.common.FileUtils.mkdirp instead
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.CastedLiteralOperandTypeCheckers;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
Expand Down Expand Up @@ -156,25 +157,25 @@ private CompressedBigDecimalSqlAggFunction(String name)
OperandTypes.sequence(
"'" + name + "(column, size)'",
OperandTypes.ANY,
OperandTypes.POSITIVE_INTEGER_LITERAL
CastedLiteralOperandTypeCheckers.POSITIVE_INTEGER_LITERAL
),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.EXACT_NUMERIC)
),
OperandTypes.and(
OperandTypes.sequence(
"'" + name + "(column, size, scale)'",
OperandTypes.ANY,
OperandTypes.POSITIVE_INTEGER_LITERAL,
OperandTypes.POSITIVE_INTEGER_LITERAL
CastedLiteralOperandTypeCheckers.POSITIVE_INTEGER_LITERAL,
CastedLiteralOperandTypeCheckers.POSITIVE_INTEGER_LITERAL
),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.EXACT_NUMERIC, SqlTypeFamily.EXACT_NUMERIC)
),
OperandTypes.and(
OperandTypes.sequence(
"'" + name + "(column, size, scale, strictNumberParsing)'",
OperandTypes.ANY,
OperandTypes.POSITIVE_INTEGER_LITERAL,
OperandTypes.POSITIVE_INTEGER_LITERAL,
CastedLiteralOperandTypeCheckers.POSITIVE_INTEGER_LITERAL,
CastedLiteralOperandTypeCheckers.POSITIVE_INTEGER_LITERAL,
OperandTypes.BOOLEAN
),
OperandTypes.family(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

//CHECKSTYLE.OFF: PackageName - Must be in Calcite

package org.apache.calcite.sql.type;

import org.apache.calcite.sql.SqlCallBinding;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlUtil;
import org.apache.calcite.util.Static;
import org.apache.calcite.util.Util;

/**
* Like {@link LiteralOperandTypeChecker}, but also allows casted literals.
*
* "Casted literals" are like `CAST(100 AS INTEGER)`. While it doesn't make sense to cast a literal that the user
* themselves enter, it is important to add a broader validation to allow these literals because Calcite's JDBC driver
* doesn't allow the wildcards (?)to work without a cast, and there's no workaround it.
* <p>
* This makes sure that the functions using the literal operand type checker can be workaround the JDBC's restriction,
* without being marked as invalid SQL input
*/

public class CastedLiteralOperandTypeChecker implements SqlSingleOperandTypeChecker
{
public static SqlSingleOperandTypeChecker LITERAL = new CastedLiteralOperandTypeChecker(false);

private final boolean allowNull;

CastedLiteralOperandTypeChecker(boolean allowNull)
{
this.allowNull = allowNull;
}

@Override
public boolean checkSingleOperandType(
SqlCallBinding callBinding,
SqlNode node,
int iFormalOperand,
boolean throwOnFailure
)
{
Util.discard(iFormalOperand);

if (SqlUtil.isNullLiteral(node, true)) {
if (allowNull) {
return true;
}
if (throwOnFailure) {
throw callBinding.newError(
Static.RESOURCE.argumentMustNotBeNull(
callBinding.getOperator().getName()));
}
return false;
}
// The following line of code is the only difference between the OperandTypes.LITERAL and this type checker
if (!SqlUtil.isLiteral(node, true) && !SqlUtil.isLiteralChain(node)) {
if (throwOnFailure) {
throw callBinding.newError(
Static.RESOURCE.argumentMustBeLiteral(
callBinding.getOperator().getName()));
}
return false;
}

return true;
}

@Override
public String getAllowedSignatures(SqlOperator op, String opName)
{
return "<LITERAL>";
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

//CHECKSTYLE.OFF: PackageName - Must be in Calcite

package org.apache.calcite.sql.type;

import com.google.common.collect.ImmutableList;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlCallBinding;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.util.Static;
import org.apache.druid.error.DruidException;

import java.math.BigDecimal;

public class CastedLiteralOperandTypeCheckers
{
public static final SqlSingleOperandTypeChecker LITERAL = new CastedLiteralOperandTypeChecker(false);

/**
* Blatantly copied from {@link OperandTypes#POSITIVE_INTEGER_LITERAL}, however the reference to the {@link #LITERAL}
* is the one which accepts casted literals
*/
public static final SqlSingleOperandTypeChecker POSITIVE_INTEGER_LITERAL =
new FamilyOperandTypeChecker(
ImmutableList.of(SqlTypeFamily.INTEGER),
i -> false
)
{
@Override
public boolean checkSingleOperandType(
SqlCallBinding callBinding,
SqlNode operand,
int iFormalOperand,
SqlTypeFamily family,
boolean throwOnFailure
)
{
// This LITERAL refers to the above implementation, the one which allows casted literals
if (!LITERAL.checkSingleOperandType(
callBinding,
operand,
iFormalOperand,
throwOnFailure
)) {
return false;
}

if (!super.checkSingleOperandType(
callBinding,
operand,
iFormalOperand,
family,
throwOnFailure
)) {
return false;
}

final SqlLiteral arg = fetchPrimitiveLiteralFromCasts(operand);
final BigDecimal value = arg.getValueAs(BigDecimal.class);
if (value.compareTo(BigDecimal.ZERO) < 0
|| hasFractionalPart(value)) {
if (throwOnFailure) {
throw callBinding.newError(
Static.RESOURCE.argumentMustBePositiveInteger(
callBinding.getOperator().getName()));
}
return false;
}
if (value.compareTo(BigDecimal.valueOf(Integer.MAX_VALUE)) > 0) {
if (throwOnFailure) {
throw callBinding.newError(
Static.RESOURCE.numberLiteralOutOfRange(value.toString()));
}
return false;
}
return true;
}

/** Returns whether a number has any fractional part.
*
* @see BigDecimal#longValueExact() */
private boolean hasFractionalPart(BigDecimal bd)
{
return bd.precision() - bd.scale() <= 0;
}
};

public static boolean isLiteral(SqlNode node, boolean allowCast)
{
assert node != null;
if (node instanceof SqlLiteral) {
return true;
}
if (!allowCast) {
return false;
}
switch (node.getKind()) {
case CAST:
// "CAST(e AS type)" is literal if "e" is literal
return isLiteral(((SqlCall) node).operand(0), true);
case MAP_VALUE_CONSTRUCTOR:
case ARRAY_VALUE_CONSTRUCTOR:
return ((SqlCall) node).getOperandList().stream()
.allMatch(o -> isLiteral(o, true));
case DEFAULT:
return true; // DEFAULT is always NULL
default:
return false;
}
}

/**
* Fetches primitive literals from the casts, including NULL literal.
* It throws if the entered node isn't a primitive literal, which can be cast multiple times.
*
* Therefore, it would fail on the following types:
* 1. Nodes that are not of the form CAST(....(CAST LITERAL AS TYPE).....)
* 2. ARRAY and MAP literals. This won't be required since we are only using this method in the type checker for
* primitive types
*/
private static SqlLiteral fetchPrimitiveLiteralFromCasts(SqlNode node)
{
if (node == null) {
throw DruidException.defensive("'node' cannot be null");
}
if (node instanceof SqlLiteral) {
return (SqlLiteral) node;
}

switch (node.getKind()) {
case CAST:
return fetchPrimitiveLiteralFromCasts(((SqlCall) node).operand(0));
case DEFAULT:
return SqlLiteral.createNull(SqlParserPos.ZERO);
default:
throw DruidException.defensive("Expected a literal or a cast on the literal. Found [%s] instead", node.getKind());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.CastedLiteralOperandTypeCheckers;
import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
Expand Down Expand Up @@ -156,7 +157,7 @@ private static class ArrayConcatAggFunction extends SqlAggFunction
OperandTypes.sequence(
StringUtils.format("'%s(expr, maxSizeBytes)'", NAME),
OperandTypes.ARRAY,
OperandTypes.POSITIVE_INTEGER_LITERAL
CastedLiteralOperandTypeCheckers.POSITIVE_INTEGER_LITERAL
)
),
SqlFunctionCategory.USER_DEFINED_FUNCTION,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperatorBinding;
import org.apache.calcite.sql.type.CastedLiteralOperandTypeCheckers;
import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
Expand Down Expand Up @@ -179,7 +180,11 @@ private static class ArrayAggFunction extends SqlAggFunction
OperandTypes.or(
OperandTypes.ANY,
OperandTypes.and(
OperandTypes.sequence(StringUtils.format("'%s(expr, maxSizeBytes)'", NAME), OperandTypes.ANY, OperandTypes.POSITIVE_INTEGER_LITERAL),
OperandTypes.sequence(
StringUtils.format("'%s(expr, maxSizeBytes)'", NAME),
OperandTypes.ANY,
CastedLiteralOperandTypeCheckers.POSITIVE_INTEGER_LITERAL
),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
)
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperatorBinding;
import org.apache.calcite.sql.type.CastedLiteralOperandTypeCheckers;
import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
Expand Down Expand Up @@ -251,7 +252,7 @@ private static class StringAggFunction extends SqlAggFunction
StringUtils.format("'%s(expr, separator, maxSizeBytes)'", name),
OperandTypes.ANY,
OperandTypes.STRING,
OperandTypes.POSITIVE_INTEGER_LITERAL
CastedLiteralOperandTypeCheckers.POSITIVE_INTEGER_LITERAL
),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.STRING, SqlTypeFamily.NUMERIC)
)
Expand Down
Loading

0 comments on commit 283529f

Please sign in to comment.