diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/StringSqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/StringSqlAggregator.java index af7b713317a7..b391100ff3a1 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/StringSqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/StringSqlAggregator.java @@ -25,6 +25,7 @@ import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlCallBinding; import org.apache.calcite.sql.SqlFunctionCategory; @@ -109,36 +110,40 @@ public Aggregation toDruidAggregation( return null; } + RexNode separatorNode = Expressions.fromFieldAccess( + rexBuilder.getTypeFactory(), + rowSignature, + project, + aggregateCall.getArgList().get(1) + ); + if (!separatorNode.isA(SqlKind.LITERAL)) { + // separator must be a literal + return null; + } + final String separator; - if (arguments.size() > 1) { - separator = RexLiteral.stringValue( - Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(1) - ) - ); - } else { - separator = ""; + separator = RexLiteral.stringValue(separatorNode); + + if (separator == null) { + // separator must not be null + return null; } - final HumanReadableBytes maxSizeBytes; + Integer maxSizeBytes = null; if (arguments.size() > 2) { - maxSizeBytes = HumanReadableBytes.valueOf( - RexLiteral.intValue( - Expressions.fromFieldAccess( - rexBuilder.getTypeFactory(), - rowSignature, - project, - aggregateCall.getArgList().get(2) - ) - ) + RexNode maxBytes = Expressions.fromFieldAccess( + rexBuilder.getTypeFactory(), + rowSignature, + project, + aggregateCall.getArgList().get(2) ); - } else { - maxSizeBytes = null; + if (!maxBytes.isA(SqlKind.LITERAL)) { + // maxBytes must be a literal + return null; + } + maxSizeBytes = ((Number) RexLiteral.value(maxBytes)).intValue(); } final DruidExpression arg = arguments.get(0); @@ -176,7 +181,7 @@ public Aggregation toDruidAggregation( StringUtils.format("array_set_add_all(\"__acc\", \"%s\")", name), null, finalizer, - maxSizeBytes != null ? new HumanReadableBytes(maxSizeBytes.getBytes()) : null, + maxSizeBytes != null ? new HumanReadableBytes(maxSizeBytes) : null, macroTable ), dimFilter @@ -199,7 +204,7 @@ public Aggregation toDruidAggregation( StringUtils.format("array_concat(\"__acc\", \"%s\")", name), null, finalizer, - maxSizeBytes != null ? new HumanReadableBytes(maxSizeBytes.getBytes()) : null, + maxSizeBytes != null ? new HumanReadableBytes(maxSizeBytes) : null, macroTable ), dimFilter