Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance RequireTimeCondition to handle complex joins #17408

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package org.apache.druid.sql.calcite.run;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
Expand All @@ -32,9 +33,13 @@
import org.apache.druid.java.util.common.UOE;
import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.query.DataSource;
import org.apache.druid.query.InlineDataSource;
import org.apache.druid.query.JoinDataSource;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.QueryToolChest;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.filter.BoundDimFilter;
import org.apache.druid.query.filter.DimFilter;
import org.apache.druid.query.filter.OrDimFilter;
Expand Down Expand Up @@ -87,7 +92,7 @@ public QueryResponse<Object[]> runQuery(final DruidQuery druidQuery)

if (plannerContext.getPlannerConfig().isRequireTimeCondition()
&& !(druidQuery.getDataSource() instanceof InlineDataSource)) {
if (Intervals.ONLY_ETERNITY.equals(findBaseDataSourceIntervals(query))) {
if (!queryHasTimeFilter(query)) {
throw new CannotBuildQueryException(
"requireTimeCondition is enabled, all queries must include a filter condition on the __time column"
);
Expand Down Expand Up @@ -164,6 +169,70 @@ private List<Interval> findBaseDataSourceIntervals(Query<?> query)
.orElseGet(query::getIntervals);
}

private boolean queryHasTimeFilter(Query<?> query)
{
DataSource dataSource = query.getDataSource();
if (dataSource instanceof InlineDataSource) {
return true;
}
if (dataSource instanceof JoinDataSource) {
return joinDataSourceQueryHasTimeFilter(query);
}
return isIntervalNonEternity(findBaseDataSourceIntervals(query));
}

/**
* Checks if a join query has a valid time filter by inspecting parts of the joins and any subqueries used within.
* If the left datasource is a Table datasource, we require a timefilter on the top level query and a time filter on
* the right datasource, else we check for time filter on the left and right datasources making up the join recursively
*/
private boolean joinDataSourceQueryHasTimeFilter(Query<?> query)
{
Preconditions.checkArgument(query.getDataSource() instanceof JoinDataSource);
JoinDataSource joinDataSource = (JoinDataSource) query.getDataSource();
if (joinDataSource.getLeft() instanceof TableDataSource) {
// Make sure we have a time filter on the base query since we have a concrete TableDataSource on the left
// And then make sure that right also has a time filter
if (isIntervalNonEternity(query.getIntervals())) {
return dataSourceHasTimeFilter(joinDataSource.getRight());
} else {
return false;
}
}
// Top level query does not have a time filter, check if all subqueries have a time filter
return joinDataSourceHasTimeFilter(joinDataSource);
}

private boolean joinDataSourceHasTimeFilter(JoinDataSource joinDataSource)
{
return dataSourceHasTimeFilter(joinDataSource.getLeft()) && dataSourceHasTimeFilter(joinDataSource.getRight());
}

private boolean dataSourceHasTimeFilter(DataSource dataSource)
{
if (dataSource instanceof InlineDataSource) {
return true;
}
if (dataSource instanceof QueryDataSource) {
return queryHasTimeFilter(((QueryDataSource) dataSource).getQuery());
}
if (dataSource instanceof JoinDataSource) {
return joinDataSourceHasTimeFilter((JoinDataSource) dataSource);
}
if (dataSource.getAnalysis().getBaseQuerySegmentSpec().isPresent()) {
return isIntervalNonEternity(dataSource.getAnalysis()
.getBaseQuerySegmentSpec()
.map(QuerySegmentSpec::getIntervals)
.get());
}
return false;
}

private boolean isIntervalNonEternity(List<Interval> intervals)
{
return !Intervals.ONLY_ETERNITY.equals(intervals);
}

@SuppressWarnings("unchecked")
private <T> QueryResponse<Object[]> execute(
Query<?> query, // Not final: may be reassigned with query ID added
Expand Down
165 changes: 165 additions & 0 deletions sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -13235,6 +13235,171 @@ public void testRequireTimeConditionSemiJoinNegative()
assertTrue(exception.getMessage().contains("__time column"));
}

@Test
public void testRequireTimeConditionSemiJoinNegative2()
{
msqIncompatible();
Throwable exception = assertThrows(CannotBuildQueryException.class, () -> {
testQuery(
PLANNER_CONFIG_REQUIRE_TIME_CONDITION,
"SELECT COUNT(*) FROM druid.foo\n"
+ "WHERE __time >= '2000-01-01' AND SUBSTRING(dim2, 1, 1) IN (\n"
+ " SELECT SUBSTRING(dim1, 1, 1) FROM druid.foo\n"
+ " WHERE dim1 <> ''\n"
+ ")",
CalciteTests.REGULAR_USER_AUTH_RESULT,
ImmutableList.of(),
ImmutableList.of()
);
});
assertTrue(exception.getMessage().contains("__time column"));
}

@Test
public void testRequireTimeConditionNestedJoinPositive()
{
msqIncompatible();
skipVectorize();
testQuery(
PLANNER_CONFIG_REQUIRE_TIME_CONDITION,
"SELECT distinct T1.dim1, T2.dim2 FROM\n" +
" (SELECT dim1 from druid.foo WHERE dim1 <> '' AND __time >= '2000-01-01') AS T1,\n" +
" (SELECT dim2 from druid.foo WHERE dim1 <> '' AND __time >= '2000-02-01') AS T2\n" +
" WHERE T1.dim1=T2.dim2",
CalciteTests.REGULAR_USER_AUTH_RESULT,
ImmutableList.of(GroupByQuery.builder()
.setDataSource(JoinDataSource.create(
new QueryDataSource(newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Intervals.utc(
DateTimes.of("2000-01-01").getMillis(),
JodaUtils.MAX_INSTANT
)))
.columns("dim1")
.filters(not(equality("dim1", "", ColumnType.STRING)))
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.context(QUERY_CONTEXT_DEFAULT)
.build()),
new QueryDataSource(newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Intervals.utc(
DateTimes.of("2000-02-01").getMillis(),
JodaUtils.MAX_INSTANT
)))
.columns("dim2")
.filters(not(equality("dim1", "", ColumnType.STRING)))
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.context(QUERY_CONTEXT_DEFAULT)
.build()),
"j0.",
"(\"dim1\" == \"j0.dim2\")",
JoinType.INNER,
null,
ExprMacroTable.nil(),
CalciteTests.createJoinableFactoryWrapper()

))
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setDimensions(dimensions(
new DefaultDimensionSpec("dim1", "d0"),
new DefaultDimensionSpec("j0.dim2", "d1")
))
.setContext(QUERY_CONTEXT_DEFAULT)
.build()),
ImmutableList.of(new Object[]{"abc", "abc"})
);
}

@Test
public void testRequireTimeConditionNestedJoinNegative()
{
msqIncompatible();
skipVectorize();
Throwable exception = assertThrows(CannotBuildQueryException.class, () -> {
testQuery(
PLANNER_CONFIG_REQUIRE_TIME_CONDITION,
"SELECT distinct T1.dim1, T2.dim2 FROM\n" +
" (SELECT dim1 from druid.foo WHERE dim1 <> '' AND __time >= '2000-01-01') AS T1,\n" +
" (SELECT dim2 from druid.foo WHERE dim1 <> '') AS T2\n" +
" WHERE T1.dim1=T2.dim2",
CalciteTests.REGULAR_USER_AUTH_RESULT,
ImmutableList.of(),
ImmutableList.of()
);
});
assertTrue(exception.getMessage().contains("__time column"));
}

@Test
public void testRequireTimeConditionJoinWithInlineDatasourceNegative()
{
msqIncompatible();
skipVectorize();
Throwable exception = assertThrows(CannotBuildQueryException.class, () -> {
testQuery(
PLANNER_CONFIG_REQUIRE_TIME_CONDITION,
"SELECT distinct T1.dim1, T2.A FROM\n" +
" (SELECT dim1 from druid.foo WHERE dim1 <> '') AS T1,\n" +
" (SELECT * FROM (SELECT 2 + 2 AS A)) AS T2\n" +
" WHERE T1.dim1=T2.A",
CalciteTests.REGULAR_USER_AUTH_RESULT,
ImmutableList.of(),
ImmutableList.of()
);
});
assertTrue(exception.getMessage().contains("__time column"));
}

@Test
public void testRequireTimeConditionJoinWithInlineDatasourcePositive()
{
msqIncompatible();
skipVectorize();
testQuery(
PLANNER_CONFIG_REQUIRE_TIME_CONDITION,
"SELECT distinct T1.dim1, T2.A FROM\n" +
" (SELECT dim1 from druid.foo WHERE dim1 <> '' AND __time >= '2000-01-01') AS T1,\n" +
" (SELECT * FROM (SELECT 2 + 2 AS A)) AS T2\n" +
" WHERE T1.dim1=T2.A",
CalciteTests.REGULAR_USER_AUTH_RESULT,
ImmutableList.of(GroupByQuery.builder()
.setDataSource(JoinDataSource.create(
new QueryDataSource(newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Intervals.utc(
DateTimes.of("2000-01-01").getMillis(),
JodaUtils.MAX_INSTANT
)))
.columns("dim1")
.filters(not(equality("dim1", "", ColumnType.STRING)))
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.context(QUERY_CONTEXT_DEFAULT)
.build()),
InlineDataSource.fromIterable(
ImmutableList.of(new Object[]{4L}),
RowSignature.builder().add("A", ColumnType.LONG).build()
),
"j0.",
"1",
JoinType.INNER,
null,
ExprMacroTable.nil(),
CalciteTests.createJoinableFactoryWrapper()
))
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setDimensions(dimensions(
new DefaultDimensionSpec("dim1", "d0"),
new DefaultDimensionSpec("j0.A", "d1", ColumnType.LONG)
))
.setDimFilter(equality("dim1", 4, ColumnType.LONG))
.setContext(QUERY_CONTEXT_DEFAULT)
.build()),
ImmutableList.of()
);
}

@Test
public void testFilterFloatDimension()
{
Expand Down
Loading