Skip to content

Commit

Permalink
Add separate rule for dealing with nulls in aggregations
Browse files Browse the repository at this point in the history
Duplicate SubstituteSurrogate in "Operator Optimization" batch
Many more tests
Add tests for mad
Add mv handling to top function
  • Loading branch information
astefan committed Aug 30, 2024
1 parent f444ce6 commit 5159109
Show file tree
Hide file tree
Showing 17 changed files with 357 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ mv_median |Converts a multivalued field into a single valued field containin
mv_min |Converts a multivalued expression into a single valued column containing the minimum value.
mv_percentile |Converts a multivalued field into a single valued field containing the value at which a certain percentage of observed values occur.
mv_pseries_wei|Converts a multivalued expression into a single-valued column by multiplying every element on the input list by its corresponding term in P-Series and computing the sum.
mv_slice |Returns a subset of the multivalued field using the start and end index values.
mv_slice |Returns a subset of the multivalued field using the start and end index values. The function uses 0-based indexing.
mv_sort |Sorts a multivalued field in lexicographical order.
mv_sum |Converts a multivalued field into a single valued field containing the sum of all of the values.
mv_zip |Combines the values from two multivalued fields with a delimiter that joins them together.
Expand Down
38 changes: 26 additions & 12 deletions x-pack/plugin/esql/qa/testFixtures/src/main/resources/row.csv-spec
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ a:integer | b:integer | c:integer
;

evalRowWithNull
row a = 1, b = 2, c = null | eval z = c+b+a;
row a = 1, b = 2, c = null | eval z = c+b+a, t = a + null, u = c + null;

a:integer | b:integer | c:null | z:integer
1 | 2 | null | null
a:integer | b:integer | c:null | z:integer | t:integer | u:null
1 | 2 | null | null | null | null
;

evalRowWithNull2
Expand All @@ -96,10 +96,10 @@ a:integer | b:integer | c:null | null:null | z:integer
;

evalRowWithNull3
row a = 1, b = 2, x = round(null) | eval z = a+b+x;
row a = 1, b = 2, x = round(null) | eval z = a+b+x, t = a + round(null);

a:integer | b:integer | x:null | z:integer
1 | 2 | null | null
a:integer | b:integer | x:null | z:integer | t:integer
1 | 2 | null | null | null
;

evalRowWithRound
Expand Down Expand Up @@ -251,27 +251,41 @@ avg:double | min(x):integer | max(x):integer | count(x):long | avg(x):double | a
8.0 | 8 | 8 | 1 | 8.0 | 5.0 | 4.0
;

rowWithMultipleStats2
row a = 1+3 | eval a = 1 + a | stats avg = avg(1+3), min(5*2), max(3*3), count(123-123), avg(1-123), avg(a+123), count_distinct(5+6);

avg:double | min(5*2):integer | max(3*3):integer |count(123-123):long |avg(1-123):double |avg(a+123):double |count_distinct(5+6):long
4.0 |10 |9 |1 |-122.0 |128.0 |1
;

rowWithMultipleStatsOverNull
row x=1, y=2 | eval tot = null + y + x | stats c=count(tot), a=avg(tot), mi=min(tot), ma=max(tot), s=sum(tot);

c:long | a:double | mi:integer | ma:integer | s:long
0 | null | null | null | null
;

rowWithMultipleStatsOverNull2
row x=1, y=2 | stats c=count(null + 1), c_d=count_distinct(1 + null), a=avg(null + x), mi=min(y + null), ma=max(y + x * null), s=sum(null);

c:long | c_d:long | a:double | mi:integer | ma:integer | s:double
0 |0 |null |null |null |null
;


min
row l=1, d=1.0, ln=1 + null, dn=1.0 + null | stats min(l), min(d), min(ln), min(dn);
row l=1, d=1.0, ln=1 + null, dn=1.0 + null, n=null | stats min(l), min(d), min(ln), min(dn), n1=min(null), n2=min(null+123), n3=min(n);

min(l):integer | min(d):double | min(ln):integer | min(dn):double
1 | 1.0 | null | null
min(l):integer | min(d):double | min(ln):integer | min(dn):double | n1:null | n2:integer | n3:null
1 | 1.0 | null | null | null | null | null
;


sum
row l=1, d=1.0, ln=1 + null, dn=1.0 + null | stats sum(l), sum(d), sum(ln), sum(dn);
row l=1, d=1.0, ln=1 + null, dn=1.0 + null, n=null | stats sum(l), sum(d), sum(ln), sum(dn), s1=sum(null), s2=sum(123-null), s3=sum(n);

sum(l):long | sum(d):double | sum(ln):long | sum(dn):double
1 | 1.0 | null | null
sum(l):long | sum(d):double | sum(ln):long | sum(dn):double | s1:double | s2:long | s3:double
1 | 1.0 | null | null | null | null | null
;

boolean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ wkt:keyword
["POINT(42.97109630194 14.7552534413725)", "POINT(75.8092915005895 22.727749187571)"] |[POINT(42.97109630194 14.7552534413725), POINT(75.8092915005895 22.727749187571)]
;

centroidOfNull
FROM airports | eval z = TO_GEOPOINT(null) | STATS centroidNull = ST_CENTROID_AGG(null), centroidExpNull = ST_CENTROID_AGG(TO_GEOPOINT(null::string)), centroidEvalNull = ST_CENTROID_AGG(z);

centroidNull:null|centroidExpNull:geo_point|centroidEvalNull:geo_point
null |null |null
;

########### failing :-( with InvalidArgumentException: Does not support yet aggregations over constants
centroidFromStringNested
required_capability: st_centroid_agg

Expand Down Expand Up @@ -1304,6 +1312,7 @@ wkt:keyword |pt:cartesian_point
["POINT(4297.11 -1475.53)", "POINT(7580.93 2272.77)"] |[POINT(4297.11 -1475.53), POINT(7580.93 2272.77)]
;

########### failing :-( with InvalidArgumentException: Does not support yet aggregations over constants
centroidCartesianFromStringNested
required_capability: st_centroid_agg

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1991,6 +1991,67 @@ s2point1:l | s_mv:l | s_param:l | s_expr:l | s_expr_null:l | languages:i
1 | 4 | 4 | 1 | 0 | null
;

emptyStatsBy1
from employees | eval x = [1,2,3] | stats by x | eval z = case(null is null, null, 5);

x:integer |z:integer
1 |null
2 |null
3 |null
;

emptyStatsBy2
from employees | eval x = [1,2,3], y = null | stats max(y) by x;

max(y):null |x:integer
null |1
null |2
null |3
;

statsOfPropagateableConst
row foo="unused" | eval mv=[1,2,3] | stats avg = avg(mv), avg([5,6]), min(mv), min([5,6]), max(mv), max([5,6]), count(mv), count([5,6]), count_distinct(mv), count_distinct([5,5,6]);

avg:double |avg([5,6]):double|min(mv):integer|min([5,6]):integer|max(mv):integer|max([5,6]):integer|count(mv):long |count([5,6]):long|count_distinct(mv):long|count_distinct([5,5,6]):long
2.0 |5.5 |1 |5 |3 |6 |3 |2 |3 |2
;

statsOfPropagateableConstWithGrouping_Count
ROW a = [1,2,3], c = 5 | EVAL c = null + c | STATS COUNT(c), COUNT(null), COUNT(null - 1) BY a;

COUNT(c):long |COUNT(null):long|COUNT(null - 1):long|a:integer
0 |0 |0 |1
0 |0 |0 |2
0 |0 |0 |3
;

statsOfPropagateableConstWithGrouping_DistinctCount
ROW a = [1,2,3], c = 5 | EVAL c = null + c | STATS COUNT_DISTINCT(c), COUNT_DISTINCT(null), COUNT_DISTINCT(null - 1) BY a;

COUNT_DISTINCT(c):long|COUNT_DISTINCT(null):long|COUNT_DISTINCT(null - 1):long|a:integer
0 |0 |0 |1
0 |0 |0 |2
0 |0 |0 |3
;

countDistinctWithGrouping
from employees | EVAL c = 5 + 6, d = 6 + null, e = [1,2,2] | STATS `cd(5+6)`=COUNT_DISTINCT(c), `cd(null)`=COUNT_DISTINCT(null), `cd(null-1)`=COUNT_DISTINCT(null - 1), `cd(eval_null)`=COUNT_DISTINCT(d), `cd(mv)`=COUNT_DISTINCT(e) BY gender | sort gender;

cd(5+6):long |cd(null):long |cd(null-1):long|cd(eval_null):long|cd(mv):long |gender:keyword
1 |0 |0 |0 |2 |F
1 |0 |0 |0 |2 |M
1 |0 |0 |0 |2 |null
;

countWithGrouping
from employees | EVAL c = 5 + 6, d = 6 + null, e = [1,2,2] | STATS `cd(5+6)`=COUNT(c), `cd(null)`=COUNT(null), `cd(null-1)`=COUNT(null - 1), `cd(eval_null)`=COUNT(d), `cd(mv)`=COUNT(e) BY gender | sort gender;

cd(5+6):long |cd(null):long |cd(null-1):long|cd(eval_null):long|cd(mv):long |gender:keyword
33 |0 |0 |0 |99 |F
57 |0 |0 |0 |171 |M
10 |0 |0 |0 |30 |null
;

evalOverridingKey#[skip:-8.13.1,reason:fixed in 8.13.2]
FROM employees
| EVAL k = languages
Expand Down Expand Up @@ -2290,3 +2351,12 @@ from employees
m:integer |a:double |x:integer
74999 |48249.0 |0
;

valuesOfConstants
row mv_int = [1,2,3], string = "abc", mv_string = ["bar","foo"], null_exp = 234 + null, to_be_casted_ip = "127.0.0.1", row_null = null
| stats values(null), values(mv_int), values(string), values(mv_string), values(null_exp), values(123 + null), values(to_be_casted_ip::ip), values(row_null)
;

values(null):null|values(mv_int):integer|values(string):keyword|values(mv_string):keyword|values(null_exp):integer|values(123 + null):integer|values(to_be_casted_ip::ip):ip|values(row_null):null
null |[1, 2, 3] |abc |[bar, foo] |null |null |127.0.0.1 |null
;
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
madOfNulls
row x = null::integer, y = null, z = 123 - null, a = [1,2,3], b = 5 | stats x = MEDIAN_ABSOLUTE_DEVIATION(x), y = MEDIAN_ABSOLUTE_DEVIATION(y), z = MEDIAN_ABSOLUTE_DEVIATION(null), madnull = MEDIAN_ABSOLUTE_DEVIATION(2+null), madnullEvalNull = MEDIAN_ABSOLUTE_DEVIATION(z) by a, b;

x:double|y:double |z:double |madnull:double |madnullEvalNull:double|a:integer |b:integer
null |null |null |null |null |1 |5
null |null |null |null |null |2 |5
null |null |null |null |null |3 |5
;

Original file line number Diff line number Diff line change
@@ -1,3 +1,22 @@
percentileOfNull
row x = null::integer, y = null, z = 123 - null | stats percIntNull = percentile(y, 90), percNull = percentile(y, 90), percentile(null, 90), percentile(null+2, 90), percentile(z, 90);

percIntNull:double|percNull:double|percentile(null, 90):double|percentile(null+2, 90):double|percentile(z, 90):double
null |null |null |null |null
;

percentileOfNullsOnRealIndex
from employees | eval x = null::integer, y = null, z = 123 - null | stats percIntNull = percentile(y, 90), percNull = percentile(y, 90), percentile(null, 90), percentile(null+2, 90), percentile(z, 90) by languages | sort languages desc;

percIntNull:double|percNull:double|percentile(null, 90):double|percentile(null+2, 90):double|percentile(z, 90):double|languages:integer
null |null |null |null |null |null
null |null |null |null |null |5
null |null |null |null |null |4
null |null |null |null |null |3
null |null |null |null |null |2
null |null |null |null |null |1
;

percentileOfLong
from employees | stats p0 = percentile(salary_change.long, 0), p50 = percentile(salary_change.long, 50), p99 = percentile(salary_change.long, 99);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,55 @@ date:date | double:double | integer:integer | long:long
[1999-04-30T00:00:00.000Z,1997-05-19T00:00:00.000Z] | [14.74,14.68] | [74999,74970] | [14,14]
;

topRowWithEval
required_capability: agg_top
ROW d=-9.81, i=25324, l=TO_LONG(-9) | EVAL x=d+1, y=i/2, z=l+3 | STATS double = TOP(x, 2, "asc"), integer = TOP(y, 2, "asc"), long = TOP(z, 2, "asc") | keep double, integer, long
;

double:double |integer:integer| long:long
-8.81 |12662 |-6
;

topRowWithNull
required_capability: agg_top
ROW d=-9.81, i=25324, l=TO_LONG(-9), n=null | EVAL x=d+1+null, y=i/2+n, z=l+3, t=null | STATS double = TOP(x, 2, "asc"), integer = TOP(y, 5, "desc"), long = TOP(z, 8, "desc"), null1 = TOP(n, 1000, "asc"), null2 = TOP(t, 444, "desc") | keep double, long, integer, null1, null2
;

double:double |long:long |integer:integer| null1:null |null2:null
null |-6 |null |null |null
;

topRowWithNull_WithBY
required_capability: agg_top
ROW d=-9.81, i=25324, l=TO_LONG(-9), n=null, `by` = [1,2,3] | EVAL x=d+1+null, y=i/2+n, z=l+3, t=null | STATS double = TOP(x, 2, "asc"), integer = TOP(y, 5, "desc"), long = TOP(z, 8, "desc"), null1 = TOP(n, 1000, "asc"), null2 = TOP(t, 444, "desc") BY `by` | keep double, long, integer, null1, null2, `by`
;

double:double |long:long |integer:integer| null1:null |null2:null | by:integer
null |-6 |null |null |null |1
null |-6 |null |null |null |2
null |-6 |null |null |null |3
;

topRowWithMultiValues
required_capability: agg_top
row x = [1,2,3,4,5], z = [1,2,3,4,5] | eval y = mv_slice(x, 1, 3) | stats z2a=top(z, 2, "asc"), z3d=top(z, 3, "desc"), y2a=top(y, 2, "asc"), y2d=top(y, 2, "desc")
;

z2a:integer|z3d:integer|y2a:integer|y2d:integer
[1, 2] |[5, 4, 3] |[2, 3] |[4, 3]
;

topRowWithMultiValues_WithBY
required_capability: agg_top
row x = [1,2,3,4,5], z = [1,2,3,4,5] | eval y = mv_slice(x, 1, 3) | stats z2a=top(z, 2, "asc"), z3d=top(z, 3, "desc"), y2a=top(y, 2, "asc"), y2d=top(y, 2, "desc") by y
;

z2a:integer|z3d:integer|y2a:integer|y2d:integer| y:integer
[1, 2] |[5, 4, 3] |[2, 3] |[4, 3] | 2
[1, 2] |[5, 4, 3] |[2, 3] |[4, 3] | 3
[1, 2] |[5, 4, 3] |[2, 3] |[4, 3] | 4
;

topAllTypesRow
required_capability: agg_top
ROW
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.Nullability;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
Expand Down Expand Up @@ -151,6 +152,11 @@ public DataType dataType() {
return DataType.LONG;
}

@Override
public Nullability nullable() {
return Nullability.FALSE;
}

@Override
protected TypeResolution resolveType() {
if (childrenResolved() == false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,18 @@
import org.elasticsearch.compute.aggregation.TopLongAggregatorFunctionSupplier;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.SurrogateExpression;
import org.elasticsearch.xpack.esql.expression.function.Example;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMax;
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMin;
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvSlice;
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvSort;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.esql.planner.ToAggregator;

Expand Down Expand Up @@ -197,11 +202,18 @@ public AggregatorFunctionSupplier supplier(List<Integer> inputChannels) {
public Expression surrogate() {
var s = source();

if (limitValue() == 1) {
if (orderValue()) {
return new Min(s, field());
if (field().foldable()) {
if (limitValue() == 1) {
if (orderValue()) {
return new MvMin(s, field());
} else {
return new MvMax(s, field());
}
} else {
return new Max(s, field());
var start = new Literal(s, 0, DataType.INTEGER);
var end = new Literal(s, limitValue() - 1, DataType.INTEGER);
MvSort sort = new MvSort(s, field(), new Literal(s, orderValue() ? ORDER_ASC : ORDER_DESC, DataType.KEYWORD));
return new MvSlice(s, sort, start, end);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.SurrogateExpression;
import org.elasticsearch.xpack.esql.expression.function.Example;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.Param;
Expand All @@ -32,7 +33,7 @@
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT;
import static org.elasticsearch.xpack.esql.core.type.DataType.UNSIGNED_LONG;

public class Values extends AggregateFunction implements ToAggregator {
public class Values extends AggregateFunction implements ToAggregator, SurrogateExpression {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Values", Values::new);

@FunctionInfo(
Expand Down Expand Up @@ -115,4 +116,9 @@ public AggregatorFunctionSupplier supplier(List<Integer> inputChannels) {
// TODO cartesian_point, geo_point
throw EsqlIllegalArgumentException.illegalDataType(type);
}

@Override
public Expression surrogate() {
return field().foldable() ? field() : null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public class MvSlice extends EsqlScalarFunction implements OptionalArgument, Eva
"long",
"text",
"version" },
description = "Returns a subset of the multivalued field using the start and end index values.",
description = "Returns a subset of the multivalued field using the start and end index values. The function uses 0-based indexing.",
examples = { @Example(file = "ints", tag = "mv_slice_positive"), @Example(file = "ints", tag = "mv_slice_negative") }
)
public MvSlice(
Expand Down
Loading

0 comments on commit 5159109

Please sign in to comment.