Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature](Nereids): support infer join when comapring mv #28988

Merged
merged 2 commits into from
Dec 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.StructInfoNode;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.rules.exploration.mv.ComparisonResult;
import org.apache.doris.nereids.rules.exploration.mv.LogicalCompatibilityContext;
import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughJoin;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
Expand All @@ -44,20 +42,14 @@
import org.apache.doris.nereids.util.PlanUtils;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -272,11 +264,11 @@ private void makeFilterConflictRules(JoinEdge joinEdge) {
filterEdges.forEach(e -> {
if (LongBitmap.isSubset(e.getReferenceNodes(), leftSubNodes)
&& !PushDownFilterThroughJoin.COULD_PUSH_THROUGH_LEFT.contains(joinEdge.getJoinType())) {
e.addRejectEdge(joinEdge);
e.addLeftRejectEdge(joinEdge);
}
if (LongBitmap.isSubset(e.getReferenceNodes(), rightSubNodes)
&& !PushDownFilterThroughJoin.COULD_PUSH_THROUGH_RIGHT.contains(joinEdge.getJoinType())) {
e.addRejectEdge(joinEdge);
e.addRightRejectEdge(joinEdge);
}
});
}
Expand All @@ -293,23 +285,23 @@ private void makeJoinConflictRules(JoinEdge edgeB) {
JoinEdge childA = joinEdges.get(i);
if (!JoinType.isAssoc(childA.getJoinType(), edgeB.getJoinType())) {
leftRequired = LongBitmap.newBitmapUnion(leftRequired, childA.getLeftSubNodes(joinEdges));
childA.addRejectEdge(edgeB);
childA.addLeftRejectEdge(edgeB);
}
if (!JoinType.isLAssoc(childA.getJoinType(), edgeB.getJoinType())) {
leftRequired = LongBitmap.newBitmapUnion(leftRequired, childA.getRightSubNodes(joinEdges));
childA.addRejectEdge(edgeB);
childA.addLeftRejectEdge(edgeB);
}
}

for (int i = rightSubTreeEdges.nextSetBit(0); i >= 0; i = rightSubTreeEdges.nextSetBit(i + 1)) {
JoinEdge childA = joinEdges.get(i);
if (!JoinType.isAssoc(edgeB.getJoinType(), childA.getJoinType())) {
rightRequired = LongBitmap.newBitmapUnion(rightRequired, childA.getRightSubNodes(joinEdges));
childA.addRejectEdge(edgeB);
childA.addRightRejectEdge(edgeB);
}
if (!JoinType.isRAssoc(edgeB.getJoinType(), childA.getJoinType())) {
rightRequired = LongBitmap.newBitmapUnion(rightRequired, childA.getLeftSubNodes(joinEdges));
childA.addRejectEdge(edgeB);
childA.addRightRejectEdge(edgeB);
}
}
edgeB.setLeftExtendedNodes(leftRequired);
Expand Down Expand Up @@ -597,157 +589,6 @@ public int edgeSize() {
return joinEdges.size() + filterEdges.size();
}

/**
* compare hypergraph
*
* @param viewHG the compared hyper graph
* @return Comparison result
*/
public ComparisonResult isLogicCompatible(HyperGraph viewHG, LogicalCompatibilityContext ctx) {
// 1 try to construct a map which can be mapped from edge to edge
Map<Edge, Edge> queryToView = constructMapWithNode(viewHG, ctx.getQueryToViewNodeIDMapping());

// 2. compare them by expression and extract residual expr
ComparisonResult.Builder builder = new ComparisonResult.Builder();
ComparisonResult edgeCompareRes = compareEdgesWithExpr(queryToView, ctx.getQueryToViewEdgeExpressionMapping());
if (edgeCompareRes.isInvalid()) {
return ComparisonResult.INVALID;
}
builder.addComparisonResult(edgeCompareRes);

// 3. pull join edge of view is no sense, so reject them
if (!queryToView.values().containsAll(viewHG.joinEdges)) {
return ComparisonResult.INVALID;
}

// 4. process residual edges
List<Expression> residualQueryJoin =
processOrphanEdges(Sets.difference(Sets.newHashSet(joinEdges), queryToView.keySet()));
if (residualQueryJoin == null) {
return ComparisonResult.INVALID;
}
builder.addQueryExpressions(residualQueryJoin);

List<Expression> residualQueryFilter =
processOrphanEdges(Sets.difference(Sets.newHashSet(filterEdges), queryToView.keySet()));
if (residualQueryFilter == null) {
return ComparisonResult.INVALID;
}
builder.addQueryExpressions(residualQueryFilter);

List<Expression> residualViewFilter =
processOrphanEdges(
Sets.difference(Sets.newHashSet(viewHG.filterEdges), Sets.newHashSet(queryToView.values())));
if (residualViewFilter == null) {
return ComparisonResult.INVALID;
}
builder.addViewExpressions(residualViewFilter);

return builder.build();
}

private List<Expression> processOrphanEdges(Set<Edge> edges) {
List<Expression> expressions = new ArrayList<>();
for (Edge edge : edges) {
if (!edge.canPullUp()) {
return null;
}
expressions.addAll(edge.getExpressions());
}
return expressions;
}

private Map<Edge, Edge> constructMapWithNode(HyperGraph viewHG, Map<Integer, Integer> nodeMap) {
// TODO use hash map to reduce loop
Map<Edge, Edge> joinEdgeMap = joinEdges.stream().map(qe -> {
Optional<JoinEdge> viewEdge = viewHG.joinEdges.stream()
.filter(ve -> compareEdgeWithNode(qe, ve, nodeMap)).findFirst();
return Pair.of(qe, viewEdge);
}).filter(e -> e.second.isPresent()).collect(ImmutableMap.toImmutableMap(p -> p.first, p -> p.second.get()));
Map<Edge, Edge> filterEdgeMap = filterEdges.stream().map(qe -> {
Optional<FilterEdge> viewEdge = viewHG.filterEdges.stream()
.filter(ve -> compareEdgeWithNode(qe, ve, nodeMap)).findFirst();
return Pair.of(qe, viewEdge);
}).filter(e -> e.second.isPresent()).collect(ImmutableMap.toImmutableMap(p -> p.first, p -> p.second.get()));
return ImmutableMap.<Edge, Edge>builder().putAll(joinEdgeMap).putAll(filterEdgeMap).build();
}

private boolean compareEdgeWithNode(Edge t, Edge o, Map<Integer, Integer> nodeMap) {
if (t instanceof FilterEdge && o instanceof FilterEdge) {
return compareEdgeWithFilter((FilterEdge) t, (FilterEdge) o, nodeMap);
} else if (t instanceof JoinEdge && o instanceof JoinEdge) {
return compareJoinEdge((JoinEdge) t, (JoinEdge) o, nodeMap);
}
return false;
}

private boolean compareEdgeWithFilter(FilterEdge t, FilterEdge o, Map<Integer, Integer> nodeMap) {
long tChild = t.getReferenceNodes();
long oChild = o.getReferenceNodes();
return compareNodeMap(tChild, oChild, nodeMap);
}

private boolean compareJoinEdge(JoinEdge t, JoinEdge o, Map<Integer, Integer> nodeMap) {
long tLeft = t.getLeftExtendedNodes();
long tRight = t.getRightExtendedNodes();
long oLeft = o.getLeftExtendedNodes();
long oRight = o.getRightExtendedNodes();
if (!t.getJoinType().equals(o.getJoinType()) && !t.getJoinType().swap().equals(o.getJoinType())) {
return false;
}
boolean matched = false;
if (t.getJoinType().swap().equals(o.getJoinType())) {
matched |= compareNodeMap(tRight, oLeft, nodeMap) && compareNodeMap(tLeft, oRight, nodeMap);
}
matched |= compareNodeMap(tLeft, oLeft, nodeMap) && compareNodeMap(tRight, oRight, nodeMap);
return matched;
}

private boolean compareNodeMap(long bitmap1, long bitmap2, Map<Integer, Integer> nodeIDMap) {
long newBitmap1 = LongBitmap.newBitmap();
for (int i : LongBitmap.getIterator(bitmap1)) {
int mappedI = nodeIDMap.getOrDefault(i, 0);
newBitmap1 = LongBitmap.set(newBitmap1, mappedI);
}
return bitmap2 == newBitmap1;
}

private ComparisonResult compareEdgesWithExpr(Map<Edge, Edge> queryToViewedgeMap,
Map<Expression, Expression> queryToView) {
ComparisonResult.Builder builder = new ComparisonResult.Builder();
for (Entry<Edge, Edge> e : queryToViewedgeMap.entrySet()) {
ComparisonResult res = compareEdgeWithExpr(e.getKey(), e.getValue(), queryToView);
if (res.isInvalid()) {
return ComparisonResult.INVALID;
}
builder.addComparisonResult(res);
}
return builder.build();
}

private ComparisonResult compareEdgeWithExpr(Edge query, Edge view, Map<Expression, Expression> queryToView) {
Set<? extends Expression> queryExprSet = query.getExpressionSet();
Set<? extends Expression> viewExprSet = view.getExpressionSet();

Set<Expression> equalViewExpr = new HashSet<>();
List<Expression> residualQueryExpr = new ArrayList<>();
for (Expression queryExpr : queryExprSet) {
if (queryToView.containsKey(queryExpr) && viewExprSet.contains(queryToView.get(queryExpr))) {
equalViewExpr.add(queryToView.get(queryExpr));
} else {
residualQueryExpr.add(queryExpr);
}
}
List<Expression> residualViewExpr = ImmutableList.copyOf(Sets.difference(viewExprSet, equalViewExpr));
if (!residualViewExpr.isEmpty() && !view.canPullUp()) {
return ComparisonResult.INVALID;
}
if (!residualQueryExpr.isEmpty() && !query.canPullUp()) {
return ComparisonResult.INVALID;
}
return new ComparisonResult(residualQueryExpr, residualViewExpr);
}

/**
* For the given hyperGraph, make a textual representation in the form
* of a dotty graph. You can save this to a file and then use Graphviz
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.google.common.collect.ImmutableSet;

import java.util.BitSet;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

Expand Down Expand Up @@ -53,7 +54,8 @@ public abstract class Edge {
// record all sub nodes behind in this operator. It's T function in paper
private final long subTreeNodes;

private long rejectNodes = 0;
private final Set<JoinEdge> leftRejectEdges;
private final Set<JoinEdge> rightRejectEdges;

/**
* Create simple edge.
Expand All @@ -69,14 +71,36 @@ public abstract class Edge {
this.leftExtendedNodes = leftRequiredNodes;
this.rightExtendedNodes = rightRequiredNodes;
this.subTreeNodes = subTreeNodes;
this.leftRejectEdges = new HashSet<>();
this.rightRejectEdges = new HashSet<>();
}

public boolean isSimple() {
return LongBitmap.getCardinality(leftExtendedNodes) == 1 && LongBitmap.getCardinality(rightExtendedNodes) == 1;
}

public void addRejectEdge(Edge edge) {
rejectNodes = LongBitmap.newBitmapUnion(edge.getReferenceNodes(), rejectNodes);
public void addLeftRejectEdge(JoinEdge edge) {
leftRejectEdges.add(edge);
}

public void addRightRejectEdge(JoinEdge edge) {
rightRejectEdges.add(edge);
}

public void addLeftRejectEdges(Set<JoinEdge> edge) {
leftRejectEdges.addAll(edge);
}

public void addRightRejectEdges(Set<JoinEdge> edge) {
rightRejectEdges.addAll(edge);
}

public Set<JoinEdge> getLeftRejectEdge() {
return ImmutableSet.copyOf(leftRejectEdges);
}

public Set<JoinEdge> getRightRejectEdge() {
return ImmutableSet.copyOf(rightRejectEdges);
}

public void addLeftExtendNode(long left) {
Expand Down Expand Up @@ -183,16 +207,6 @@ public Set<? extends Expression> getExpressionSet() {
return ImmutableSet.copyOf(getExpressions());
}

public boolean canPullUp() {
// Only inner join and filter with none rejectNodes can be pull up
return rejectNodes == 0
&& !(this instanceof JoinEdge && !((JoinEdge) this).getJoinType().isInnerJoin());
}

public long getRejectNodes() {
return rejectNodes;
}

public Expression getExpression(int i) {
return getExpressions().get(i);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,18 @@ public JoinEdge(LogicalJoin<? extends Plan, ? extends Plan> join, int index,
this.join = join;
}

/**
* swap the edge
*/
public JoinEdge swap() {
JoinEdge swapEdge = new
JoinEdge(join.swap(), getIndex(), getRightChildEdges(),
getLeftChildEdges(), getSubTreeNodes(), getRightRequiredNodes(), getLeftRequiredNodes());
swapEdge.addLeftRejectEdges(getLeftRejectEdge());
swapEdge.addRightRejectEdges(getRightRejectEdge());
return swapEdge;
}

public JoinType getJoinType() {
return join.getJoinType();
}
Expand Down
Loading
Loading