Skip to content

Commit

Permalink
[enhancement](nereids) Speedup partition pruner (#38191) (#38405)
Browse files Browse the repository at this point in the history
1. fast return when partition predicate is true/false/null
2. fast compute table's hash code
3. fast merge two ranges when equals
  • Loading branch information
924060929 authored Jul 26, 2024
1 parent 2f6b2db commit ee65195
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 20 deletions.
10 changes: 6 additions & 4 deletions fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import com.google.common.base.Objects;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.base.Suppliers;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.apache.commons.collections.CollectionUtils;
Expand All @@ -71,6 +72,7 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;

/**
Expand Down Expand Up @@ -272,7 +274,7 @@ public void setSelectivity() {
protected Function fn;

// Cached value of IsConstant(), set during analyze() and valid if isAnalyzed_ is true.
private boolean isConstant;
private Supplier<Boolean> isConstant = Suppliers.memoize(() -> false);

// Flag to indicate whether to wrap this expr's toSql() in parenthesis. Set by parser.
// Needed for properly capturing expr precedences in the SQL string.
Expand Down Expand Up @@ -455,7 +457,7 @@ protected void analysisDone() {
Preconditions.checkState(!isAnalyzed);
// We need to compute the const-ness as the last step, since analysis may change
// the result, e.g. by resolving function.
isConstant = isConstantImpl();
isConstant = Suppliers.memoize(this::isConstantImpl);
isAnalyzed = true;
}

Expand Down Expand Up @@ -1348,7 +1350,7 @@ public boolean isLiteral() {
*/
public final boolean isConstant() {
if (isAnalyzed) {
return isConstant;
return isConstant.get();
}
return isConstantImpl();
}
Expand Down Expand Up @@ -2567,7 +2569,7 @@ public boolean refToCountStar() {
// In this case, agg output must be materialized whether outer query block required or not.
if (f.getFunctionName().getFunction().equals("count")) {
for (Expr expr : funcExpr.children) {
if (expr.isConstant && !(expr instanceof LiteralExpr)) {
if (expr.isConstant() && !(expr instanceof LiteralExpr)) {
return true;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1935,9 +1935,7 @@ public boolean equals(Object o) {

@Override
public int hashCode() {
return Objects.hash(state, indexIdToMeta, indexNameToId, keysType, partitionInfo, idToPartition,
nameToPartition, defaultDistributionInfo, tempPartitions, bfColumns, bfFpp, colocateGroup,
hasSequenceCol, sequenceType, indexes, baseIndexId, tableProperty);
return (int) baseIndexId;
}

public Column getBaseColumn(String columnName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
* Generic tree structure. Only concrete subclasses of this can be instantiated.
*/
public class TreeNode<NodeType extends TreeNode<NodeType>> {
protected ArrayList<NodeType> children = Lists.newArrayList();
protected ArrayList<NodeType> children = Lists.newArrayListWithCapacity(2);

public NodeType getChild(int i) {
return hasChild(i) ? children.get(i) : null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,10 @@ private EvaluateRangeResult mergeRanges(
Map<Slot, ColumnRange> leftRanges = left.columnRanges;
Map<Slot, ColumnRange> rightRanges = right.columnRanges;

if (leftRanges.equals(rightRanges)) {
return new EvaluateRangeResult(originResult, leftRanges, ImmutableList.of(left, right));
}

Set<Slot> slots = ImmutableSet.<Slot>builder()
.addAll(leftRanges.keySet())
.addAll(rightRanges.keySet())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.util.Utils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
Expand Down Expand Up @@ -125,14 +126,19 @@ public static List<Long> prune(List<Slot> partitionSlots, Expression partitionPr
"partitionPruningExpandThreshold",
10, sessionVariable -> sessionVariable.partitionPruningExpandThreshold);

partitionPredicate = OrToIn.INSTANCE.rewriteTree(
partitionPredicate, new ExpressionRewriteContext(cascadesContext));
if (BooleanLiteral.TRUE.equals(partitionPredicate)) {
return Utils.fastToImmutableList(idToPartitions.keySet());
} else if (Boolean.FALSE.equals(partitionPredicate) || partitionPredicate.isNullLiteral()) {
return ImmutableList.of();
}

List<OnePartitionEvaluator> evaluators = Lists.newArrayListWithCapacity(idToPartitions.size());
for (Entry<Long, PartitionItem> kv : idToPartitions.entrySet()) {
evaluators.add(toPartitionEvaluator(
kv.getKey(), kv.getValue(), partitionSlots, cascadesContext, expandThreshold));
}

partitionPredicate = OrToIn.INSTANCE.rewriteTree(
partitionPredicate, new ExpressionRewriteContext(cascadesContext));
PartitionPruner partitionPruner = new PartitionPruner(evaluators, partitionPredicate);
//TODO: we keep default partition because it's too hard to prune it, we return false in canPrune().
return partitionPruner.prune();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,25 @@

package org.apache.doris.nereids.trees.expressions.literal;

import org.apache.doris.analysis.LiteralExpr;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DateV2Type;

import com.google.common.base.Suppliers;

import java.time.LocalDateTime;
import java.util.function.Supplier;

/**
* date v2 literal for nereids
*/
public class DateV2Literal extends DateLiteral {
private final Supplier<org.apache.doris.analysis.DateLiteral> legacyLiteral = Suppliers.memoize(() ->
new org.apache.doris.analysis.DateLiteral(year, month, day, Type.DATEV2)
);

public DateV2Literal(String s) throws AnalysisException {
super(DateV2Type.INSTANCE, s);
Expand All @@ -41,8 +46,8 @@ public DateV2Literal(long year, long month, long day) {
}

@Override
public LiteralExpr toLegacyLiteral() {
return new org.apache.doris.analysis.DateLiteral(year, month, day, Type.DATEV2);
public org.apache.doris.analysis.DateLiteral toLegacyLiteral() {
return legacyLiteral.get();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import io.netty.util.concurrent.FastThreadLocal;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.json.JSONObject;
Expand All @@ -90,7 +91,7 @@
// Use `volatile` to make the reference change atomic.
public class ConnectContext {
private static final Logger LOG = LogManager.getLogger(ConnectContext.class);
protected static ThreadLocal<ConnectContext> threadLocalInfo = new ThreadLocal<>();
protected static FastThreadLocal<ConnectContext> threadLocalInfo = new FastThreadLocal<>();

private static final String SSL_PROTOCOL = "TLS";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3048,20 +3048,21 @@ public int getNthOptimizedPlan() {
public Set<String> getDisableNereidsRuleNames() {
String checkPrivilege = RuleType.CHECK_PRIVILEGES.name();
String checkRowPolicy = RuleType.CHECK_ROW_POLICY.name();
return Arrays.stream(disableNereidsRules.split(",[\\s]*"))
.map(rule -> rule.toUpperCase(Locale.ROOT))
.filter(rule -> !StringUtils.equalsIgnoreCase(rule, checkPrivilege)
return Arrays.stream(disableNereidsRules.split(","))
.map(rule -> rule.trim().toUpperCase(Locale.ROOT))
.filter(rule -> !rule.isEmpty()
&& !StringUtils.equalsIgnoreCase(rule, checkPrivilege)
&& !StringUtils.equalsIgnoreCase(rule, checkRowPolicy))
.collect(ImmutableSet.toImmutableSet());
}

public BitSet getDisableNereidsRules() {
BitSet bitSet = new BitSet();
for (String ruleName : disableNereidsRules.split(",[\\s]*")) {
for (String ruleName : disableNereidsRules.split(",")) {
ruleName = ruleName.trim().toUpperCase(Locale.ROOT);
if (ruleName.isEmpty()) {
continue;
}
ruleName = ruleName.toUpperCase(Locale.ROOT);
RuleType ruleType = RuleType.valueOf(ruleName);
if (ruleType == RuleType.CHECK_PRIVILEGES || ruleType == RuleType.CHECK_ROW_POLICY) {
continue;
Expand Down

0 comments on commit ee65195

Please sign in to comment.