diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/run/NativeQueryMaker.java b/sql/src/main/java/org/apache/druid/sql/calcite/run/NativeQueryMaker.java index 7b1e1ec7091d..829289915744 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/run/NativeQueryMaker.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/run/NativeQueryMaker.java @@ -20,6 +20,7 @@ package org.apache.druid.sql.calcite.run; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.base.Preconditions; import com.google.common.collect.Iterables; import it.unimi.dsi.fastutil.objects.Object2IntMap; import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap; @@ -32,9 +33,13 @@ import org.apache.druid.java.util.common.UOE; import org.apache.druid.java.util.common.guava.Sequence; import org.apache.druid.java.util.common.guava.Sequences; +import org.apache.druid.query.DataSource; import org.apache.druid.query.InlineDataSource; +import org.apache.druid.query.JoinDataSource; import org.apache.druid.query.Query; +import org.apache.druid.query.QueryDataSource; import org.apache.druid.query.QueryToolChest; +import org.apache.druid.query.TableDataSource; import org.apache.druid.query.filter.BoundDimFilter; import org.apache.druid.query.filter.DimFilter; import org.apache.druid.query.filter.OrDimFilter; @@ -87,7 +92,7 @@ public QueryResponse runQuery(final DruidQuery druidQuery) if (plannerContext.getPlannerConfig().isRequireTimeCondition() && !(druidQuery.getDataSource() instanceof InlineDataSource)) { - if (Intervals.ONLY_ETERNITY.equals(findBaseDataSourceIntervals(query))) { + if (!queryHasTimeFilter(query)) { throw new CannotBuildQueryException( "requireTimeCondition is enabled, all queries must include a filter condition on the __time column" ); @@ -164,6 +169,70 @@ private List findBaseDataSourceIntervals(Query query) .orElseGet(query::getIntervals); } + private boolean queryHasTimeFilter(Query query) + { + DataSource dataSource = query.getDataSource(); + if (dataSource instanceof InlineDataSource) { + return true; + } + if (dataSource instanceof JoinDataSource) { + return joinDataSourceQueryHasTimeFilter(query); + } + return isIntervalNonEternity(findBaseDataSourceIntervals(query)); + } + + /** + * Checks if a join query has a valid time filter by inspecting parts of the joins and any subqueries used within. + * If the left datasource is a Table datasource, we require a timefilter on the top level query and a time filter on + * the right datasource, else we check for time filter on the left and right datasources making up the join recursively + */ + private boolean joinDataSourceQueryHasTimeFilter(Query query) + { + Preconditions.checkArgument(query.getDataSource() instanceof JoinDataSource); + JoinDataSource joinDataSource = (JoinDataSource) query.getDataSource(); + if (joinDataSource.getLeft() instanceof TableDataSource) { + // Make sure we have a time filter on the base query since we have a concrete TableDataSource on the left + // And then make sure that right also has a time filter + if (isIntervalNonEternity(query.getIntervals())) { + return dataSourceHasTimeFilter(joinDataSource.getRight()); + } else { + return false; + } + } + // Top level query does not have a time filter, check if all subqueries have a time filter + return joinDataSourceHasTimeFilter(joinDataSource); + } + + private boolean joinDataSourceHasTimeFilter(JoinDataSource joinDataSource) + { + return dataSourceHasTimeFilter(joinDataSource.getLeft()) && dataSourceHasTimeFilter(joinDataSource.getRight()); + } + + private boolean dataSourceHasTimeFilter(DataSource dataSource) + { + if (dataSource instanceof InlineDataSource) { + return true; + } + if (dataSource instanceof QueryDataSource) { + return queryHasTimeFilter(((QueryDataSource) dataSource).getQuery()); + } + if (dataSource instanceof JoinDataSource) { + return joinDataSourceHasTimeFilter((JoinDataSource) dataSource); + } + if (dataSource.getAnalysis().getBaseQuerySegmentSpec().isPresent()) { + return isIntervalNonEternity(dataSource.getAnalysis() + .getBaseQuerySegmentSpec() + .map(QuerySegmentSpec::getIntervals) + .get()); + } + return false; + } + + private boolean isIntervalNonEternity(List intervals) + { + return !Intervals.ONLY_ETERNITY.equals(intervals); + } + @SuppressWarnings("unchecked") private QueryResponse execute( Query query, // Not final: may be reassigned with query ID added diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index 26ee0685a13a..ce5d067b0b6b 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -13235,6 +13235,171 @@ public void testRequireTimeConditionSemiJoinNegative() assertTrue(exception.getMessage().contains("__time column")); } + @Test + public void testRequireTimeConditionSemiJoinNegative2() + { + msqIncompatible(); + Throwable exception = assertThrows(CannotBuildQueryException.class, () -> { + testQuery( + PLANNER_CONFIG_REQUIRE_TIME_CONDITION, + "SELECT COUNT(*) FROM druid.foo\n" + + "WHERE __time >= '2000-01-01' AND SUBSTRING(dim2, 1, 1) IN (\n" + + " SELECT SUBSTRING(dim1, 1, 1) FROM druid.foo\n" + + " WHERE dim1 <> ''\n" + + ")", + CalciteTests.REGULAR_USER_AUTH_RESULT, + ImmutableList.of(), + ImmutableList.of() + ); + }); + assertTrue(exception.getMessage().contains("__time column")); + } + + @Test + public void testRequireTimeConditionNestedJoinPositive() + { + msqIncompatible(); + skipVectorize(); + testQuery( + PLANNER_CONFIG_REQUIRE_TIME_CONDITION, + "SELECT distinct T1.dim1, T2.dim2 FROM\n" + + " (SELECT dim1 from druid.foo WHERE dim1 <> '' AND __time >= '2000-01-01') AS T1,\n" + + " (SELECT dim2 from druid.foo WHERE dim1 <> '' AND __time >= '2000-02-01') AS T2\n" + + " WHERE T1.dim1=T2.dim2", + CalciteTests.REGULAR_USER_AUTH_RESULT, + ImmutableList.of(GroupByQuery.builder() + .setDataSource(JoinDataSource.create( + new QueryDataSource(newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Intervals.utc( + DateTimes.of("2000-01-01").getMillis(), + JodaUtils.MAX_INSTANT + ))) + .columns("dim1") + .filters(not(equality("dim1", "", ColumnType.STRING))) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .context(QUERY_CONTEXT_DEFAULT) + .build()), + new QueryDataSource(newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Intervals.utc( + DateTimes.of("2000-02-01").getMillis(), + JodaUtils.MAX_INSTANT + ))) + .columns("dim2") + .filters(not(equality("dim1", "", ColumnType.STRING))) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .context(QUERY_CONTEXT_DEFAULT) + .build()), + "j0.", + "(\"dim1\" == \"j0.dim2\")", + JoinType.INNER, + null, + ExprMacroTable.nil(), + CalciteTests.createJoinableFactoryWrapper() + + )) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimensions(dimensions( + new DefaultDimensionSpec("dim1", "d0"), + new DefaultDimensionSpec("j0.dim2", "d1") + )) + .setContext(QUERY_CONTEXT_DEFAULT) + .build()), + ImmutableList.of(new Object[]{"abc", "abc"}) + ); + } + + @Test + public void testRequireTimeConditionNestedJoinNegative() + { + msqIncompatible(); + skipVectorize(); + Throwable exception = assertThrows(CannotBuildQueryException.class, () -> { + testQuery( + PLANNER_CONFIG_REQUIRE_TIME_CONDITION, + "SELECT distinct T1.dim1, T2.dim2 FROM\n" + + " (SELECT dim1 from druid.foo WHERE dim1 <> '' AND __time >= '2000-01-01') AS T1,\n" + + " (SELECT dim2 from druid.foo WHERE dim1 <> '') AS T2\n" + + " WHERE T1.dim1=T2.dim2", + CalciteTests.REGULAR_USER_AUTH_RESULT, + ImmutableList.of(), + ImmutableList.of() + ); + }); + assertTrue(exception.getMessage().contains("__time column")); + } + + @Test + public void testRequireTimeConditionJoinWithInlineDatasourceNegative() + { + msqIncompatible(); + skipVectorize(); + Throwable exception = assertThrows(CannotBuildQueryException.class, () -> { + testQuery( + PLANNER_CONFIG_REQUIRE_TIME_CONDITION, + "SELECT distinct T1.dim1, T2.A FROM\n" + + " (SELECT dim1 from druid.foo WHERE dim1 <> '') AS T1,\n" + + " (SELECT * FROM (SELECT 2 + 2 AS A)) AS T2\n" + + " WHERE T1.dim1=T2.A", + CalciteTests.REGULAR_USER_AUTH_RESULT, + ImmutableList.of(), + ImmutableList.of() + ); + }); + assertTrue(exception.getMessage().contains("__time column")); + } + + @Test + public void testRequireTimeConditionJoinWithInlineDatasourcePositive() + { + msqIncompatible(); + skipVectorize(); + testQuery( + PLANNER_CONFIG_REQUIRE_TIME_CONDITION, + "SELECT distinct T1.dim1, T2.A FROM\n" + + " (SELECT dim1 from druid.foo WHERE dim1 <> '' AND __time >= '2000-01-01') AS T1,\n" + + " (SELECT * FROM (SELECT 2 + 2 AS A)) AS T2\n" + + " WHERE T1.dim1=T2.A", + CalciteTests.REGULAR_USER_AUTH_RESULT, + ImmutableList.of(GroupByQuery.builder() + .setDataSource(JoinDataSource.create( + new QueryDataSource(newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Intervals.utc( + DateTimes.of("2000-01-01").getMillis(), + JodaUtils.MAX_INSTANT + ))) + .columns("dim1") + .filters(not(equality("dim1", "", ColumnType.STRING))) + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .context(QUERY_CONTEXT_DEFAULT) + .build()), + InlineDataSource.fromIterable( + ImmutableList.of(new Object[]{4L}), + RowSignature.builder().add("A", ColumnType.LONG).build() + ), + "j0.", + "1", + JoinType.INNER, + null, + ExprMacroTable.nil(), + CalciteTests.createJoinableFactoryWrapper() + )) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimensions(dimensions( + new DefaultDimensionSpec("dim1", "d0"), + new DefaultDimensionSpec("j0.A", "d1", ColumnType.LONG) + )) + .setDimFilter(equality("dim1", 4, ColumnType.LONG)) + .setContext(QUERY_CONTEXT_DEFAULT) + .build()), + ImmutableList.of() + ); + } + @Test public void testFilterFloatDimension() {