Skip to content

Commit

Permalink
Enhance RequireTimeCondition to handle complex joins
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishek-chouhan authored and Abhishek Singh Chouhan committed Oct 24, 2024
1 parent c4b513e commit 7d3c8d3
Show file tree
Hide file tree
Showing 2 changed files with 225 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package org.apache.druid.sql.calcite.run;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
Expand All @@ -32,9 +33,13 @@
import org.apache.druid.java.util.common.UOE;
import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.query.DataSource;
import org.apache.druid.query.InlineDataSource;
import org.apache.druid.query.JoinDataSource;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.QueryToolChest;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.filter.BoundDimFilter;
import org.apache.druid.query.filter.DimFilter;
import org.apache.druid.query.filter.OrDimFilter;
Expand Down Expand Up @@ -87,7 +92,7 @@ public QueryResponse<Object[]> runQuery(final DruidQuery druidQuery)

if (plannerContext.getPlannerConfig().isRequireTimeCondition()
&& !(druidQuery.getDataSource() instanceof InlineDataSource)) {
if (Intervals.ONLY_ETERNITY.equals(findBaseDataSourceIntervals(query))) {
if (!queryHasTimeFilter(query)) {
throw new CannotBuildQueryException(
"requireTimeCondition is enabled, all queries must include a filter condition on the __time column"
);
Expand Down Expand Up @@ -164,6 +169,70 @@ private List<Interval> findBaseDataSourceIntervals(Query<?> query)
.orElseGet(query::getIntervals);
}

private boolean queryHasTimeFilter(Query<?> query)
{
DataSource dataSource = query.getDataSource();
if (dataSource instanceof InlineDataSource) {
return true;
}
if (dataSource instanceof JoinDataSource) {
return joinDataSourceQueryHasTimeFilter(query);
}
return isIntervalNonEternity(findBaseDataSourceIntervals(query));
}

/**
* Checks if a join query has a valid time filter by inspecting parts of the joins and any subqueries used within.
* If the left datasource is a Table datasource, we require a timefilter on the top level query and a time filter on
* the right datasource, else we check for time filter on the left and right datasources making up the join recursively
*/
private boolean joinDataSourceQueryHasTimeFilter(Query<?> query)
{
Preconditions.checkArgument(query.getDataSource() instanceof JoinDataSource);
JoinDataSource joinDataSource = (JoinDataSource) query.getDataSource();
if (joinDataSource.getLeft() instanceof TableDataSource) {
// Make sure we have a time filter on the base query since we have a concrete TableDataSource on the left
// And then make sure that right also has a time filter
if (isIntervalNonEternity(query.getIntervals())) {
return dataSourceHasTimeFilter(joinDataSource.getRight());
} else {
return false;
}
}
// Top level query does not have a time filter, check if all subqueries have a time filter
return joinDataSourceHasTimeFilter(joinDataSource);
}

private boolean joinDataSourceHasTimeFilter(JoinDataSource joinDataSource)
{
return dataSourceHasTimeFilter(joinDataSource.getLeft()) && dataSourceHasTimeFilter(joinDataSource.getRight());
}

private boolean dataSourceHasTimeFilter(DataSource dataSource)
{
if (dataSource instanceof InlineDataSource) {
return true;
}
if (dataSource instanceof QueryDataSource) {
return queryHasTimeFilter(((QueryDataSource) dataSource).getQuery());
}
if (dataSource instanceof JoinDataSource) {
return joinDataSourceHasTimeFilter((JoinDataSource) dataSource);
}
if (dataSource.getAnalysis().getBaseQuerySegmentSpec().isPresent()) {
return isIntervalNonEternity(dataSource.getAnalysis()
.getBaseQuerySegmentSpec()
.map(QuerySegmentSpec::getIntervals)
.get());
}
return false;
}

private boolean isIntervalNonEternity(List<Interval> intervals)
{
return !Intervals.ONLY_ETERNITY.equals(intervals);
}

@SuppressWarnings("unchecked")
private <T> QueryResponse<Object[]> execute(
Query<?> query, // Not final: may be reassigned with query ID added
Expand Down
155 changes: 155 additions & 0 deletions sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -13235,6 +13235,161 @@ public void testRequireTimeConditionSemiJoinNegative()
assertTrue(exception.getMessage().contains("__time column"));
}

@Test
public void testRequireTimeConditionSemiJoinNegative2()
{
msqIncompatible();
Throwable exception = assertThrows(CannotBuildQueryException.class, () -> {
testQuery(
PLANNER_CONFIG_REQUIRE_TIME_CONDITION,
"SELECT COUNT(*) FROM druid.foo\n"
+ "WHERE __time >= '2000-01-01' AND SUBSTRING(dim2, 1, 1) IN (\n"
+ " SELECT SUBSTRING(dim1, 1, 1) FROM druid.foo\n"
+ " WHERE dim1 <> ''\n"
+ ")",
CalciteTests.REGULAR_USER_AUTH_RESULT,
ImmutableList.of(),
ImmutableList.of()
);
});
assertTrue(exception.getMessage().contains("__time column"));
}

@Test
public void testRequireTimeConditionNestedJoinPositive() {
msqIncompatible();
skipVectorize();
testQuery(
PLANNER_CONFIG_REQUIRE_TIME_CONDITION,
"SELECT distinct T1.dim1, T2.dim2 FROM\n" +
" (SELECT dim1 from druid.foo WHERE dim1 <> '' AND __time >= '2000-01-01') AS T1,\n" +
" (SELECT dim2 from druid.foo WHERE dim1 <> '' AND __time >= '2000-02-01') AS T2\n" +
" WHERE T1.dim1=T2.dim2",
CalciteTests.REGULAR_USER_AUTH_RESULT,
ImmutableList.of(GroupByQuery.builder()
.setDataSource(JoinDataSource.create(
new QueryDataSource(newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Intervals.utc(
DateTimes.of("2000-01-01").getMillis(),
JodaUtils.MAX_INSTANT)))
.columns("dim1")
.filters(not(equality("dim1", "", ColumnType.STRING)))
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.context(QUERY_CONTEXT_DEFAULT)
.build()),
new QueryDataSource(newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Intervals.utc(
DateTimes.of("2000-02-01").getMillis(),
JodaUtils.MAX_INSTANT)))
.columns("dim2")
.filters(not(equality("dim1", "", ColumnType.STRING)))
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.context(QUERY_CONTEXT_DEFAULT)
.build()),
"j0.",
"(\"dim1\" == \"j0.dim2\")",
JoinType.INNER,
null,
ExprMacroTable.nil(),
CalciteTests.createJoinableFactoryWrapper()

))
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setDimensions(dimensions(new DefaultDimensionSpec("dim1", "d0"),
new DefaultDimensionSpec("j0.dim2", "d1")))
.setContext(QUERY_CONTEXT_DEFAULT)
.build()),
ImmutableList.of(new Object[]{"abc", "abc"})
);
}

@Test
public void testRequireTimeConditionNestedJoinNegative() {
msqIncompatible();
skipVectorize();
Throwable exception = assertThrows(CannotBuildQueryException.class, () -> {
testQuery(
PLANNER_CONFIG_REQUIRE_TIME_CONDITION,
"SELECT distinct T1.dim1, T2.dim2 FROM\n" +
" (SELECT dim1 from druid.foo WHERE dim1 <> '' AND __time >= '2000-01-01') AS T1,\n" +
" (SELECT dim2 from druid.foo WHERE dim1 <> '') AS T2\n" +
" WHERE T1.dim1=T2.dim2",
CalciteTests.REGULAR_USER_AUTH_RESULT,
ImmutableList.of(),
ImmutableList.of()
);
});
assertTrue(exception.getMessage().contains("__time column"));
}

@Test
public void testRequireTimeConditionJoinWithInlineDatasourceNegative() {
msqIncompatible();
skipVectorize();
Throwable exception = assertThrows(CannotBuildQueryException.class, () -> {
testQuery(
PLANNER_CONFIG_REQUIRE_TIME_CONDITION,
"SELECT distinct T1.dim1, T2.A FROM\n" +
" (SELECT dim1 from druid.foo WHERE dim1 <> '') AS T1,\n" +
" (SELECT * FROM (SELECT 2 + 2 AS A)) AS T2\n" +
" WHERE T1.dim1=T2.A",
CalciteTests.REGULAR_USER_AUTH_RESULT,
ImmutableList.of(),
ImmutableList.of()
);
});
assertTrue(exception.getMessage().contains("__time column"));
}

@Test
public void testRequireTimeConditionJoinWithInlineDatasourcePositive() {
msqIncompatible();
skipVectorize();
testQuery(
PLANNER_CONFIG_REQUIRE_TIME_CONDITION,
"SELECT distinct T1.dim1, T2.A FROM\n" +
" (SELECT dim1 from druid.foo WHERE dim1 <> '' AND __time >= '2000-01-01') AS T1,\n" +
" (SELECT * FROM (SELECT 2 + 2 AS A)) AS T2\n" +
" WHERE T1.dim1=T2.A",
CalciteTests.REGULAR_USER_AUTH_RESULT,
ImmutableList.of(GroupByQuery.builder()
.setDataSource(JoinDataSource.create(
new QueryDataSource(newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Intervals.utc(
DateTimes.of("2000-01-01").getMillis(),
JodaUtils.MAX_INSTANT)))
.columns("dim1")
.filters(not(equality("dim1", "", ColumnType.STRING)))
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.context(QUERY_CONTEXT_DEFAULT)
.build()),
InlineDataSource.fromIterable(
ImmutableList.of(new Object[]{4L}),
RowSignature.builder().add("A", ColumnType.LONG).build()
),
"j0.",
"1",
JoinType.INNER,
null,
ExprMacroTable.nil(),
CalciteTests.createJoinableFactoryWrapper()
))
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setDimensions(dimensions(new DefaultDimensionSpec("dim1", "d0"),
new DefaultDimensionSpec("j0.A", "d1", ColumnType.LONG)
))
.setDimFilter(equality("dim1", 4, ColumnType.LONG))
.setContext(QUERY_CONTEXT_DEFAULT)
.build()),
ImmutableList.of()
);
}

@Test
public void testFilterFloatDimension()
{
Expand Down

0 comments on commit 7d3c8d3

Please sign in to comment.