Skip to content

Commit

Permalink
[fix](mtmv) Fix result wrong when query rewrite by mv if query contai…
Browse files Browse the repository at this point in the history
…ns null_unsafe equals expression (apache#39629)

Fix result wrong when query rewrite by mv if query contains null_unsafe
equals expression and the expression both side is slot
table orders data is as following:

    (null, 1, 'o', 9.5, '2023-12-08', 'a', 'b', 1, 'yy'),
    (1, null, 'o', 10.5, '2023-12-08', 'a', 'b', 1, 'yy'),
    (2, 1, null, 11.5, '2023-12-09', 'a', 'b', 1, 'yy'),
    (3, 1, 'o', null, '2023-12-10', 'a', 'b', 1, 'yy'),
    (3, 1, 'o', 33.5, null, 'a', 'b', 1, 'yy'),
    (4, 2, 'o', 43.2, '2023-12-11', null,'d',2, 'mm'),
    (5, 2, 'o', 56.2, '2023-12-12', 'c',null, 2, 'mi'),
    (5, 2, 'o', 1.2, '2023-12-12', 'c','d', null, 'mi');  

such as mv def is 

select count(*), o_orderstatus, o_comment
            from orders
            group by
            o_orderstatus, o_comment;

query is as following:

           select count(*), o_orderstatus, o_comment
            from orders
            where o_orderstatus = o_orderstatus
            group by
            o_orderstatus, o_comment;

after rewrite by materialized view, the result is wrong as following,
the row contains null should not appear

+----------+---------------+-----------+
| count(*) | o_orderstatus | o_comment |
+----------+---------------+-----------+
|        1 | NULL          | yy        |
|        1 | o             | mm        |
|        2 | o             | mi        |
|        4 | o             | yy        |
+----------+---------------+-----------+
  • Loading branch information
seawinde committed Aug 28, 2024
1 parent 42f30d4 commit fc02d8e
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,9 @@

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
* EquivalenceClass, this is used for equality propagation when predicate compensation
Expand All @@ -40,14 +37,19 @@ public class EquivalenceClass {
* a: [a, b],
* b: [a, b]
* }
* or column a = a,
* this would be
* {
* a: [a, a]
* }
*/
private Map<SlotReference, Set<SlotReference>> equivalenceSlotMap = new LinkedHashMap<>();
private List<Set<SlotReference>> equivalenceSlotList;
private Map<SlotReference, List<SlotReference>> equivalenceSlotMap = new LinkedHashMap<>();
private List<List<SlotReference>> equivalenceSlotList;

public EquivalenceClass() {
}

public EquivalenceClass(Map<SlotReference, Set<SlotReference>> equivalenceSlotMap) {
public EquivalenceClass(Map<SlotReference, List<SlotReference>> equivalenceSlotMap) {
this.equivalenceSlotMap = equivalenceSlotMap;
}

Expand All @@ -56,13 +58,13 @@ public EquivalenceClass(Map<SlotReference, Set<SlotReference>> equivalenceSlotMa
*/
public void addEquivalenceClass(SlotReference leftSlot, SlotReference rightSlot) {

Set<SlotReference> leftSlotSet = equivalenceSlotMap.get(leftSlot);
Set<SlotReference> rightSlotSet = equivalenceSlotMap.get(rightSlot);
List<SlotReference> leftSlotSet = equivalenceSlotMap.get(leftSlot);
List<SlotReference> rightSlotSet = equivalenceSlotMap.get(rightSlot);
if (leftSlotSet != null && rightSlotSet != null) {
// Both present, we need to merge
if (leftSlotSet.size() < rightSlotSet.size()) {
// We swap them to merge
Set<SlotReference> tmp = rightSlotSet;
List<SlotReference> tmp = rightSlotSet;
rightSlotSet = leftSlotSet;
leftSlotSet = tmp;
}
Expand All @@ -80,15 +82,15 @@ public void addEquivalenceClass(SlotReference leftSlot, SlotReference rightSlot)
equivalenceSlotMap.put(leftSlot, rightSlotSet);
} else {
// None are present, add to same equivalence class
Set<SlotReference> equivalenceClass = new LinkedHashSet<>();
List<SlotReference> equivalenceClass = new ArrayList<>();
equivalenceClass.add(leftSlot);
equivalenceClass.add(rightSlot);
equivalenceSlotMap.put(leftSlot, equivalenceClass);
equivalenceSlotMap.put(rightSlot, equivalenceClass);
}
}

public Map<SlotReference, Set<SlotReference>> getEquivalenceSlotMap() {
public Map<SlotReference, List<SlotReference>> getEquivalenceSlotMap() {
return equivalenceSlotMap;
}

Expand All @@ -101,15 +103,15 @@ public boolean isEmpty() {
*/
public EquivalenceClass permute(Map<SlotReference, SlotReference> mapping) {

Map<SlotReference, Set<SlotReference>> permutedEquivalenceSlotMap = new HashMap<>();
for (Map.Entry<SlotReference, Set<SlotReference>> slotReferenceSetEntry : equivalenceSlotMap.entrySet()) {
Map<SlotReference, List<SlotReference>> permutedEquivalenceSlotMap = new HashMap<>();
for (Map.Entry<SlotReference, List<SlotReference>> slotReferenceSetEntry : equivalenceSlotMap.entrySet()) {
SlotReference mappedSlotReferenceKey = mapping.get(slotReferenceSetEntry.getKey());
if (mappedSlotReferenceKey == null) {
// can not permute then need to return null
return null;
}
Set<SlotReference> equivalenceValueSet = slotReferenceSetEntry.getValue();
final Set<SlotReference> mappedSlotReferenceSet = new HashSet<>();
List<SlotReference> equivalenceValueSet = slotReferenceSetEntry.getValue();
final List<SlotReference> mappedSlotReferenceSet = new ArrayList<>();
for (SlotReference target : equivalenceValueSet) {
SlotReference mappedSlotReferenceValue = mapping.get(target);
if (mappedSlotReferenceValue == null) {
Expand All @@ -123,15 +125,14 @@ public EquivalenceClass permute(Map<SlotReference, SlotReference> mapping) {
}

/**
* Return the list of equivalence set, remove duplicate
* Return the list of equivalence list, remove duplicate
*/
public List<Set<SlotReference>> getEquivalenceSetList() {

public List<List<SlotReference>> getEquivalenceSetList() {
if (equivalenceSlotList != null) {
return equivalenceSlotList;
}
List<Set<SlotReference>> equivalenceSets = new ArrayList<>();
Set<Set<SlotReference>> visited = new HashSet<>();
List<List<SlotReference>> equivalenceSets = new ArrayList<>();
List<List<SlotReference>> visited = new ArrayList<>();
equivalenceSlotMap.values().forEach(slotSet -> {
if (!visited.contains(slotSet)) {
equivalenceSets.add(slotSet);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.doris.nereids.rules.exploration.mv;

import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.exploration.mv.mapping.EquivalenceClassSetMapping;
import org.apache.doris.nereids.rules.exploration.mv.mapping.EquivalenceClassMapping;
import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping;
import org.apache.doris.nereids.rules.expression.ExpressionNormalization;
import org.apache.doris.nereids.rules.expression.ExpressionOptimization;
Expand All @@ -33,6 +33,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
Expand Down Expand Up @@ -98,15 +99,15 @@ public static Set<Expression> compensateEquivalence(StructInfo queryStructInfo,
if (queryEquivalenceClass.isEmpty() && !viewEquivalenceClass.isEmpty()) {
return null;
}
EquivalenceClassSetMapping queryToViewEquivalenceMapping =
EquivalenceClassSetMapping.generate(queryEquivalenceClass, viewEquivalenceClassQueryBased);
EquivalenceClassMapping queryToViewEquivalenceMapping =
EquivalenceClassMapping.generate(queryEquivalenceClass, viewEquivalenceClassQueryBased);
// can not map all target equivalence class, can not compensate
if (queryToViewEquivalenceMapping.getEquivalenceClassSetMap().size()
< viewEquivalenceClass.getEquivalenceSetList().size()) {
return null;
}
// do equal compensate
Set<Set<SlotReference>> mappedQueryEquivalenceSet =
Set<List<SlotReference>> mappedQueryEquivalenceSet =
queryToViewEquivalenceMapping.getEquivalenceClassSetMap().keySet();
queryEquivalenceClass.getEquivalenceSetList().forEach(
queryEquivalenceSet -> {
Expand All @@ -120,9 +121,9 @@ public static Set<Expression> compensateEquivalence(StructInfo queryStructInfo,
}
} else {
// compensate the equivalence both in query and view, but query has more equivalence
Set<SlotReference> viewEquivalenceSet =
List<SlotReference> viewEquivalenceSet =
queryToViewEquivalenceMapping.getEquivalenceClassSetMap().get(queryEquivalenceSet);
Set<SlotReference> copiedQueryEquivalenceSet = new HashSet<>(queryEquivalenceSet);
List<SlotReference> copiedQueryEquivalenceSet = new ArrayList<>(queryEquivalenceSet);
copiedQueryEquivalenceSet.removeAll(viewEquivalenceSet);
SlotReference first = viewEquivalenceSet.iterator().next();
for (SlotReference slotReference : copiedQueryEquivalenceSet) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.doris.nereids.trees.expressions.SlotReference;

import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand All @@ -30,39 +31,41 @@
* This will extract the equivalence class set in EquivalenceClass and mapping set in
* two different EquivalenceClass.
*/
public class EquivalenceClassSetMapping extends Mapping {
public class EquivalenceClassMapping extends Mapping {

private final Map<Set<SlotReference>, Set<SlotReference>> equivalenceClassSetMap;
private final Map<List<SlotReference>, List<SlotReference>> equivalenceClassSetMap;

public EquivalenceClassSetMapping(Map<Set<SlotReference>,
Set<SlotReference>> equivalenceClassSetMap) {
public EquivalenceClassMapping(Map<List<SlotReference>,
List<SlotReference>> equivalenceClassSetMap) {
this.equivalenceClassSetMap = equivalenceClassSetMap;
}

public static EquivalenceClassSetMapping of(Map<Set<SlotReference>, Set<SlotReference>> equivalenceClassSetMap) {
return new EquivalenceClassSetMapping(equivalenceClassSetMap);
public static EquivalenceClassMapping of(Map<List<SlotReference>, List<SlotReference>> equivalenceClassSetMap) {
return new EquivalenceClassMapping(equivalenceClassSetMap);
}

/**
* Generate source equivalence set map to target equivalence set
*/
public static EquivalenceClassSetMapping generate(EquivalenceClass source, EquivalenceClass target) {
public static EquivalenceClassMapping generate(EquivalenceClass source, EquivalenceClass target) {

Map<Set<SlotReference>, Set<SlotReference>> equivalenceClassSetMap = new HashMap<>();
List<Set<SlotReference>> sourceSets = source.getEquivalenceSetList();
List<Set<SlotReference>> targetSets = target.getEquivalenceSetList();
Map<List<SlotReference>, List<SlotReference>> equivalenceClassSetMap = new HashMap<>();
List<List<SlotReference>> sourceSets = source.getEquivalenceSetList();
List<List<SlotReference>> targetSets = target.getEquivalenceSetList();

for (Set<SlotReference> sourceSet : sourceSets) {
for (Set<SlotReference> targetSet : targetSets) {
for (List<SlotReference> sourceList : sourceSets) {
Set<SlotReference> sourceSet = new HashSet<>(sourceList);
for (List<SlotReference> targetList : targetSets) {
Set<SlotReference> targetSet = new HashSet<>(targetList);
if (sourceSet.containsAll(targetSet)) {
equivalenceClassSetMap.put(sourceSet, targetSet);
equivalenceClassSetMap.put(sourceList, targetList);
}
}
}
return EquivalenceClassSetMapping.of(equivalenceClassSetMap);
return EquivalenceClassMapping.of(equivalenceClassSetMap);
}

public Map<Set<SlotReference>, Set<SlotReference>> getEquivalenceClassSetMap() {
public Map<List<SlotReference>, List<SlotReference>> getEquivalenceClassSetMap() {
return equivalenceClassSetMap;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !query1_0_before --
1 o mm
2 o mi
4 o yy

-- !query1_0_after --
1 o mm
2 o mi
4 o yy

Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package mv.unsafe_equals
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

suite("null_unsafe_equals") {
String db = context.config.getDbNameByFile(context.file)
sql "use ${db}"
sql "set runtime_filter_mode=OFF";
sql "SET ignore_shape_nodes='PhysicalDistribute,PhysicalProject'"

sql """
drop table if exists orders
"""

sql """
CREATE TABLE IF NOT EXISTS orders (
o_orderkey INTEGER NULL,
o_custkey INTEGER NULL,
o_orderstatus CHAR(1) NULL,
o_totalprice DECIMALV3(15,2) NULL,
o_orderdate DATE NULL,
o_orderpriority CHAR(15) NULL,
o_clerk CHAR(15) NULL,
o_shippriority INTEGER NULL,
O_COMMENT VARCHAR(79) NULL
)
DUPLICATE KEY(o_orderkey, o_custkey)
PARTITION BY RANGE(o_orderdate) (
PARTITION `day_2` VALUES LESS THAN ('2023-12-9'),
PARTITION `day_3` VALUES LESS THAN ("2023-12-11"),
PARTITION `day_4` VALUES LESS THAN ("2023-12-30")
)
DISTRIBUTED BY HASH(o_orderkey) BUCKETS 3
PROPERTIES (
"replication_num" = "1"
);
"""

sql """
insert into orders values
(null, 1, 'o', 9.5, '2023-12-08', 'a', 'b', 1, 'yy'),
(1, null, 'o', 10.5, '2023-12-08', 'a', 'b', 1, 'yy'),
(2, 1, null, 11.5, '2023-12-09', 'a', 'b', 1, 'yy'),
(3, 1, 'o', null, '2023-12-10', 'a', 'b', 1, 'yy'),
(3, 1, 'o', 33.5, null, 'a', 'b', 1, 'yy'),
(4, 2, 'o', 43.2, '2023-12-11', null,'d',2, 'mm'),
(5, 2, 'o', 56.2, '2023-12-12', 'c',null, 2, 'mi'),
(5, 2, 'o', 1.2, '2023-12-12', 'c','d', null, 'mi');
"""

def mv1_0 =
"""
select count(*), o_orderstatus, o_comment
from orders
group by
o_orderstatus, o_comment;
"""
// query contains the filter which is 'o_orderstatus = o_orderstatus' should reject null
def query1_0 =
"""
select count(*), o_orderstatus, o_comment
from orders
where o_orderstatus = o_orderstatus
group by
o_orderstatus, o_comment;
"""
order_qt_query1_0_before "${query1_0}"
async_mv_rewrite_success(db, mv1_0, query1_0, "mv1_0")
order_qt_query1_0_after "${query1_0}"
sql """ DROP MATERIALIZED VIEW IF EXISTS mv1_0"""
}

0 comments on commit fc02d8e

Please sign in to comment.