diff --git a/dashboard/proto/gen/common.ts b/dashboard/proto/gen/common.ts index f3cefe80d671..4ebe859a7f48 100644 --- a/dashboard/proto/gen/common.ts +++ b/dashboard/proto/gen/common.ts @@ -102,6 +102,47 @@ export function directionToJSON(object: Direction): string { } } +export const NullsAre = { + NULLS_ARE_UNSPECIFIED: "NULLS_ARE_UNSPECIFIED", + NULLS_ARE_LARGEST: "NULLS_ARE_LARGEST", + NULLS_ARE_SMALLEST: "NULLS_ARE_SMALLEST", + UNRECOGNIZED: "UNRECOGNIZED", +} as const; + +export type NullsAre = typeof NullsAre[keyof typeof NullsAre]; + +export function nullsAreFromJSON(object: any): NullsAre { + switch (object) { + case 0: + case "NULLS_ARE_UNSPECIFIED": + return NullsAre.NULLS_ARE_UNSPECIFIED; + case 1: + case "NULLS_ARE_LARGEST": + return NullsAre.NULLS_ARE_LARGEST; + case 2: + case "NULLS_ARE_SMALLEST": + return NullsAre.NULLS_ARE_SMALLEST; + case -1: + case "UNRECOGNIZED": + default: + return NullsAre.UNRECOGNIZED; + } +} + +export function nullsAreToJSON(object: NullsAre): string { + switch (object) { + case NullsAre.NULLS_ARE_UNSPECIFIED: + return "NULLS_ARE_UNSPECIFIED"; + case NullsAre.NULLS_ARE_LARGEST: + return "NULLS_ARE_LARGEST"; + case NullsAre.NULLS_ARE_SMALLEST: + return "NULLS_ARE_SMALLEST"; + case NullsAre.UNRECOGNIZED: + default: + return "UNRECOGNIZED"; + } +} + export interface Status { code: Status_Code; message: string; @@ -267,11 +308,8 @@ export interface BatchQueryEpoch { } export interface OrderType { - /** - * TODO(rc): enable `NULLS FIRST | LAST` - * NullsAre nulls_are = 2; - */ direction: Direction; + nullsAre: NullsAre; } /** Column index with an order type (ASC or DESC). Used to represent a sort key (`repeated ColumnOrder`). */ @@ -545,25 +583,28 @@ export const BatchQueryEpoch = { }; function createBaseOrderType(): OrderType { - return { direction: Direction.DIRECTION_UNSPECIFIED }; + return { direction: Direction.DIRECTION_UNSPECIFIED, nullsAre: NullsAre.NULLS_ARE_UNSPECIFIED }; } export const OrderType = { fromJSON(object: any): OrderType { return { direction: isSet(object.direction) ? directionFromJSON(object.direction) : Direction.DIRECTION_UNSPECIFIED, + nullsAre: isSet(object.nullsAre) ? nullsAreFromJSON(object.nullsAre) : NullsAre.NULLS_ARE_UNSPECIFIED, }; }, toJSON(message: OrderType): unknown { const obj: any = {}; message.direction !== undefined && (obj.direction = directionToJSON(message.direction)); + message.nullsAre !== undefined && (obj.nullsAre = nullsAreToJSON(message.nullsAre)); return obj; }, fromPartial, I>>(object: I): OrderType { const message = createBaseOrderType(); message.direction = object.direction ?? Direction.DIRECTION_UNSPECIFIED; + message.nullsAre = object.nullsAre ?? NullsAre.NULLS_ARE_UNSPECIFIED; return message; }, }; diff --git a/dashboard/proto/gen/order.ts b/dashboard/proto/gen/order.ts deleted file mode 100644 index 6037394eadce..000000000000 --- a/dashboard/proto/gen/order.ts +++ /dev/null @@ -1,128 +0,0 @@ -/* eslint-disable */ - -export const protobufPackage = "order"; - -export const PbDirection = { - PbDirection_UNSPECIFIED: "PbDirection_UNSPECIFIED", - PbDirection_ASCENDING: "PbDirection_ASCENDING", - PbDirection_DESCENDING: "PbDirection_DESCENDING", - UNRECOGNIZED: "UNRECOGNIZED", -} as const; - -export type PbDirection = typeof PbDirection[keyof typeof PbDirection]; - -export function pbDirectionFromJSON(object: any): PbDirection { - switch (object) { - case 0: - case "PbDirection_UNSPECIFIED": - return PbDirection.PbDirection_UNSPECIFIED; - case 1: - case "PbDirection_ASCENDING": - return PbDirection.PbDirection_ASCENDING; - case 2: - case "PbDirection_DESCENDING": - return PbDirection.PbDirection_DESCENDING; - case -1: - case "UNRECOGNIZED": - default: - return PbDirection.UNRECOGNIZED; - } -} - -export function pbDirectionToJSON(object: PbDirection): string { - switch (object) { - case PbDirection.PbDirection_UNSPECIFIED: - return "PbDirection_UNSPECIFIED"; - case PbDirection.PbDirection_ASCENDING: - return "PbDirection_ASCENDING"; - case PbDirection.PbDirection_DESCENDING: - return "PbDirection_DESCENDING"; - case PbDirection.UNRECOGNIZED: - default: - return "UNRECOGNIZED"; - } -} - -export interface PbOrderType { - /** - * TODO(rc): enable `NULLS FIRST | LAST` - * PbNullsAre nulls_are = 2; - */ - direction: PbDirection; -} - -/** Column index with an order type (ASC or DESC). Used to represent a sort key (`repeated PbColumnOrder`). */ -export interface PbColumnOrder { - columnIndex: number; - orderType: PbOrderType | undefined; -} - -function createBasePbOrderType(): PbOrderType { - return { direction: PbDirection.PbDirection_UNSPECIFIED }; -} - -export const PbOrderType = { - fromJSON(object: any): PbOrderType { - return { - direction: isSet(object.direction) ? pbDirectionFromJSON(object.direction) : PbDirection.PbDirection_UNSPECIFIED, - }; - }, - - toJSON(message: PbOrderType): unknown { - const obj: any = {}; - message.direction !== undefined && (obj.direction = pbDirectionToJSON(message.direction)); - return obj; - }, - - fromPartial, I>>(object: I): PbOrderType { - const message = createBasePbOrderType(); - message.direction = object.direction ?? PbDirection.PbDirection_UNSPECIFIED; - return message; - }, -}; - -function createBasePbColumnOrder(): PbColumnOrder { - return { columnIndex: 0, orderType: undefined }; -} - -export const PbColumnOrder = { - fromJSON(object: any): PbColumnOrder { - return { - columnIndex: isSet(object.columnIndex) ? Number(object.columnIndex) : 0, - orderType: isSet(object.orderType) ? PbOrderType.fromJSON(object.orderType) : undefined, - }; - }, - - toJSON(message: PbColumnOrder): unknown { - const obj: any = {}; - message.columnIndex !== undefined && (obj.columnIndex = Math.round(message.columnIndex)); - message.orderType !== undefined && - (obj.orderType = message.orderType ? PbOrderType.toJSON(message.orderType) : undefined); - return obj; - }, - - fromPartial, I>>(object: I): PbColumnOrder { - const message = createBasePbColumnOrder(); - message.columnIndex = object.columnIndex ?? 0; - message.orderType = (object.orderType !== undefined && object.orderType !== null) - ? PbOrderType.fromPartial(object.orderType) - : undefined; - return message; - }, -}; - -type Builtin = Date | Function | Uint8Array | string | number | boolean | undefined; - -export type DeepPartial = T extends Builtin ? T - : T extends Array ? Array> : T extends ReadonlyArray ? ReadonlyArray> - : T extends { $case: string } ? { [K in keyof Omit]?: DeepPartial } & { $case: T["$case"] } - : T extends {} ? { [K in keyof T]?: DeepPartial } - : Partial; - -type KeysOfUnion = T extends T ? keyof T : never; -export type Exact = P extends Builtin ? P - : P & { [K in keyof P]: Exact } & { [K in Exclude>]: never }; - -function isSet(value: any): boolean { - return value !== null && value !== undefined; -} diff --git a/e2e_test/batch/aggregate/array_agg.slt.part b/e2e_test/batch/aggregate/array_agg.slt.part index 4403d8089aa5..537af2eab3dc 100644 --- a/e2e_test/batch/aggregate/array_agg.slt.part +++ b/e2e_test/batch/aggregate/array_agg.slt.part @@ -2,7 +2,7 @@ statement ok SET RW_IMPLICIT_FLUSH TO true; statement ok -create table t(v1 varchar, v2 int, v3 int) +create table t(v1 varchar, v2 int, v3 int); query T select array_agg(v1) from t; @@ -18,10 +18,7 @@ select array_agg(v1) from t; {NULL} statement ok -delete from t; - -statement ok -insert into t values ('aaa', 1, 1), ('bbb', 0, 2), ('ccc', 0, 5), ('ddd', 1, 4) +insert into t values ('aaa', 1, 1), ('bbb', 0, 2), ('ccc', 0, 5), ('ddd', 1, 4); query T select b from (select unnest(a) from (select array_agg(v3) as v3_arr from t) g(a)) p(b) order by b; @@ -30,16 +27,22 @@ select b from (select unnest(a) from (select array_agg(v3) as v3_arr from t) g(a 2 4 5 +NULL + +query T +select array_agg(v1 order by v3 asc nulls first) from t; +---- +{NULL,aaa,bbb,ddd,ccc} query T -select array_agg(v1 order by v3 desc) from t +select array_agg(v1 order by v3 desc) from t; ---- -{ccc,ddd,bbb,aaa} +{NULL,ccc,ddd,bbb,aaa} query T -select array_agg(v1 order by v2 asc, v3 desc) from t +select array_agg(v1 order by v2 asc nulls last, v3 desc) from t; ---- -{ccc,bbb,ddd,aaa} +{ccc,bbb,ddd,aaa,NULL} statement ok -drop table t +drop table t; diff --git a/e2e_test/batch/basic/index.slt.part b/e2e_test/batch/basic/index.slt.part index 007786702371..a45cd724bb45 100644 --- a/e2e_test/batch/basic/index.slt.part +++ b/e2e_test/batch/basic/index.slt.part @@ -11,23 +11,23 @@ statement ok create index idx2 on t1(v2); statement ok -insert into t1 values(1, 2),(3,4),(5,6); +insert into t1 values (1, 2), (3, 4), (5, 6); statement ok -explain select v1,v2 from t1 where v1 = 1; +explain select v1, v2 from t1 where v1 = 1; query II -select v1,v2 from t1 where v1 = 1; +select v1, v2 from t1 where v1 = 1; ---- 1 2 query II -select v1,v2 from t1 where v2 = 4; +select v1, v2 from t1 where v2 = 4; ---- 3 4 query II -select v1,v2 from t1 where v1 = 1 or v2 = 4 order by v1, v2; +select v1, v2 from t1 where v1 = 1 or v2 = 4 order by v1, v2; ---- 1 2 3 4 @@ -36,10 +36,57 @@ statement ok delete from t1 where v1 = 1; query II -select v1,v2 from t1 order by v1, v2; +select v1, v2 from t1 order by v1, v2; ---- 3 4 5 6 +statement ok +insert into t1 values (NULL, 5); + +statement ok +create index idx3 on t1(v1 desc); + +statement ok +create index idx4 on t1(v1 nulls first); + +statement ok +create index idx5 on t1(v1 desc nulls last); + +query II +select v1, v2 from t1 order by v1; +---- +3 4 +5 6 +NULL 5 + +query II +select v1, v2 from t1 order by v1 desc; +---- +NULL 5 +5 6 +3 4 + +query II +select v1, v2 from t1 order by v1 asc nulls first; +---- +NULL 5 +3 4 +5 6 + +query II +select v1, v2 from t1 order by v1 desc nulls last; +---- +5 6 +3 4 +NULL 5 + +query II +select v1, v2 from t1 order by v1 desc nulls first; +---- +NULL 5 +5 6 +3 4 + statement ok drop table t1; diff --git a/e2e_test/batch/basic/order_by.slt.part b/e2e_test/batch/basic/order_by.slt.part index 670280e8f5ff..3fae22b6386d 100644 --- a/e2e_test/batch/basic/order_by.slt.part +++ b/e2e_test/batch/basic/order_by.slt.part @@ -5,10 +5,10 @@ statement ok create table t (v1 int, v2 int, v3 int); statement ok -insert into t values (1,4,2), (2,3,3), (3,4,4), (4,3,5) +insert into t values (1,4,2), (2,3,3), (3,4,4), (4,3,5); query III rowsort -select * from t +select * from t; ---- 1 4 2 2 3 3 @@ -16,7 +16,7 @@ select * from t 4 3 5 query III -select * from t order by v1 desc +select * from t order by v1 desc; ---- 4 3 5 3 4 4 @@ -48,17 +48,17 @@ select * from t order by v1 + v2, v1; 4 3 5 query III -select * from t order by v1 desc limit 1 +select * from t order by v1 desc limit 1; ---- 4 3 5 query III -select * from t order by v1 desc limit 1 offset 1 +select * from t order by v1 desc limit 1 offset 1; ---- 3 4 4 query III -select * from t order by v2, v1 +select * from t order by v2, v1; ---- 2 3 3 4 3 5 @@ -66,13 +66,13 @@ select * from t order by v2, v1 3 4 4 query III -select * from t order by v2, v1 limit 2 +select * from t order by v2, v1 limit 2; ---- 2 3 3 4 3 5 query III -select * from t order by v2, v1 limit 10 +select * from t order by v2, v1 limit 10; ---- 2 3 3 4 3 5 @@ -80,7 +80,7 @@ select * from t order by v2, v1 limit 10 3 4 4 query III -select * from t order by v2 desc, v1 limit 2 +select * from t order by v2 desc, v1 limit 2; ---- 1 4 2 3 4 4 @@ -94,6 +94,24 @@ select * from t order by v1 limit 2; 1 4 2 2 3 3 +query III +select * from t order by v1 asc limit 2; +---- +1 4 2 +2 3 3 + +query III +select * from t order by v1 nulls first limit 2; +---- +NULL 7 NULL +1 4 2 + +query III +select * from t order by v1 asc nulls last limit 2; +---- +1 4 2 +2 3 3 + query III select * from t order by v1 desc limit 7; ---- @@ -103,5 +121,23 @@ NULL 7 NULL 2 3 3 1 4 2 +query III +select * from t order by v1 desc nulls first limit 7; +---- +NULL 7 NULL +4 3 5 +3 4 4 +2 3 3 +1 4 2 + +query III +select * from t order by v1 desc nulls last limit 7; +---- +4 3 5 +3 4 4 +2 3 3 +1 4 2 +NULL 7 NULL + statement ok drop table t; diff --git a/e2e_test/ddl/show.slt b/e2e_test/ddl/show.slt index 5c1a69d2c993..f70dfb853c5f 100644 --- a/e2e_test/ddl/show.slt +++ b/e2e_test/ddl/show.slt @@ -32,7 +32,7 @@ v1 Int32 v2 Int32 v3 Int32 primary key _row_id -idx1 index(v1, v2) include(v3) distributed by(v1, v2) +idx1 index(v1 ASC, v2 ASC) include(v3) distributed by(v1, v2) statement ok drop index idx1; diff --git a/e2e_test/streaming/array_agg.slt b/e2e_test/streaming/array_agg.slt index 42c02a05c47a..cf76386e3866 100644 --- a/e2e_test/streaming/array_agg.slt +++ b/e2e_test/streaming/array_agg.slt @@ -13,9 +13,6 @@ create materialized view mv1 as select array_agg(c) as res from t; statement ok create materialized view mv2 as select array_agg(a order by b asc, a desc) as res from t; -statement ok -flush; - query T select u from (select unnest(res) from mv1) p(u) order by u; ---- @@ -47,11 +44,25 @@ select * from mv2; ---- {ccc,bbb,x,ddd,aaa,y} +statement ok +create materialized view mv3 as select array_agg(a order by b nulls first, a nulls last) as res from t; + +statement ok +insert into t values (NULL, NULL, 2), ('z', NULL, 6); + +query T +select * from mv3; +---- +{z,NULL,bbb,ccc,aaa,ddd,x,y} + statement ok drop materialized view mv1; statement ok drop materialized view mv2; +statement ok +drop materialized view mv3; + statement ok drop table t; diff --git a/e2e_test/streaming/order_by.slt b/e2e_test/streaming/order_by.slt index 4a6e44c9a819..c87b937b3d71 100644 --- a/e2e_test/streaming/order_by.slt +++ b/e2e_test/streaming/order_by.slt @@ -16,9 +16,6 @@ create materialized view mv2 as select * from t1 order by v1 limit 3; statement ok create materialized view mv3 as select * from t1 order by v1 limit 3 offset 1; -statement ok -flush; - query III rowsort select v1, v2, v3 from mv1; ---- @@ -43,13 +40,60 @@ select v1, v2, v3 from mv3; 5 1 4 statement ok -drop materialized view mv1 +insert into t1 values (NULL,0,0); + +statement ok +create materialized view mv4 as select * from t1 order by v1 desc limit 1; + +statement ok +create materialized view mv5 as select * from t1 order by v1 nulls first limit 1; + +statement ok +create materialized view mv6 as select * from t1 order by v1 nulls last limit 1; + +statement ok +create materialized view mv7 as select * from t1 order by v1 desc nulls last limit 1; + +query III +select v1, v2, v3 from mv4; +---- +NULL 0 0 + +query III +select v1, v2, v3 from mv5; +---- +NULL 0 0 + +query III +select v1, v2, v3 from mv6; +---- +0 2 3 + +query III +select v1, v2, v3 from mv7; +---- +9 8 1 + +statement ok +drop materialized view mv1; + +statement ok +drop materialized view mv2; + +statement ok +drop materialized view mv3; + +statement ok +drop materialized view mv4; + +statement ok +drop materialized view mv5; statement ok -drop materialized view mv2 +drop materialized view mv6; statement ok -drop materialized view mv3 +drop materialized view mv7; statement ok -drop table t1 +drop table t1; diff --git a/proto/common.proto b/proto/common.proto index 546232538b84..c7787072df08 100644 --- a/proto/common.proto +++ b/proto/common.proto @@ -82,17 +82,15 @@ enum Direction { DIRECTION_DESCENDING = 2; } -// TODO(rc): enable `NULLS FIRST | LAST` -// enum NullsAre { -// NULLS_ARE_UNSPECIFIED = 0; -// NULLS_ARE_SMALLEST = 1; -// NULLS_ARE_LARGEST = 2; -// } +enum NullsAre { + NULLS_ARE_UNSPECIFIED = 0; + NULLS_ARE_LARGEST = 1; + NULLS_ARE_SMALLEST = 2; +} message OrderType { Direction direction = 1; - // TODO(rc): enable `NULLS FIRST | LAST` - // NullsAre nulls_are = 2; + NullsAre nulls_are = 2; } // Column index with an order type (ASC or DESC). Used to represent a sort key (`repeated ColumnOrder`). diff --git a/src/batch/src/executor/group_top_n.rs b/src/batch/src/executor/group_top_n.rs index 20024c232758..8d0da917b0b3 100644 --- a/src/batch/src/executor/group_top_n.rs +++ b/src/batch/src/executor/group_top_n.rs @@ -25,8 +25,8 @@ use risingwave_common::error::{Result, RwError}; use risingwave_common::hash::{HashKey, HashKeyDispatcher}; use risingwave_common::types::DataType; use risingwave_common::util::chunk_coalesce::DataChunkBuilder; -use risingwave_common::util::encoding_for_comparison::encode_chunk; use risingwave_common::util::iter_util::ZipEqFast; +use risingwave_common::util::memcmp_encoding::encode_chunk; use risingwave_common::util::sort_util::ColumnOrder; use risingwave_pb::batch_plan::plan_node::NodeBody; diff --git a/src/batch/src/executor/order_by.rs b/src/batch/src/executor/order_by.rs index 3d02af44e7cb..88438f8b63c6 100644 --- a/src/batch/src/executor/order_by.rs +++ b/src/batch/src/executor/order_by.rs @@ -17,7 +17,7 @@ use risingwave_common::array::DataChunk; use risingwave_common::catalog::Schema; use risingwave_common::error::{Result, RwError}; use risingwave_common::util::chunk_coalesce::DataChunkBuilder; -use risingwave_common::util::encoding_for_comparison::encode_chunk; +use risingwave_common::util::memcmp_encoding::encode_chunk; use risingwave_common::util::sort_util::ColumnOrder; use risingwave_pb::batch_plan::plan_node::NodeBody; diff --git a/src/batch/src/executor/row_seq_scan.rs b/src/batch/src/executor/row_seq_scan.rs index a4eab294934d..5b8b97cf8bd5 100644 --- a/src/batch/src/executor/row_seq_scan.rs +++ b/src/batch/src/executor/row_seq_scan.rs @@ -26,7 +26,7 @@ use risingwave_common::row::{OwnedRow, Row}; use risingwave_common::types::{DataType, Datum}; use risingwave_common::util::chunk_coalesce::DataChunkBuilder; use risingwave_common::util::select_all; -use risingwave_common::util::sort_util::{Direction, OrderType}; +use risingwave_common::util::sort_util::OrderType; use risingwave_common::util::value_encoding::deserialize_datum; use risingwave_pb::batch_plan::plan_node::NodeBody; use risingwave_pb::batch_plan::{scan_range, PbScanRange}; @@ -406,9 +406,10 @@ impl RowSeqScanExecutor { } = scan_range; let (start_bound, end_bound) = - match table.pk_serializer().get_order_types()[pk_prefix.len()].direction() { - Direction::Ascending => (next_col_bounds.0, next_col_bounds.1), - Direction::Descending => (next_col_bounds.1, next_col_bounds.0), + if table.pk_serializer().get_order_types()[pk_prefix.len()].is_ascending() { + (next_col_bounds.0, next_col_bounds.1) + } else { + (next_col_bounds.1, next_col_bounds.0) }; // Range Scan. diff --git a/src/batch/src/executor/top_n.rs b/src/batch/src/executor/top_n.rs index 37aaba8f8b95..6ae829bdf522 100644 --- a/src/batch/src/executor/top_n.rs +++ b/src/batch/src/executor/top_n.rs @@ -23,7 +23,7 @@ use risingwave_common::array::DataChunk; use risingwave_common::catalog::Schema; use risingwave_common::error::{Result, RwError}; use risingwave_common::util::chunk_coalesce::DataChunkBuilder; -use risingwave_common::util::encoding_for_comparison::encode_chunk; +use risingwave_common::util::memcmp_encoding::encode_chunk; use risingwave_common::util::sort_util::ColumnOrder; use risingwave_pb::batch_plan::plan_node::NodeBody; diff --git a/src/common/benches/bench_encoding.rs b/src/common/benches/bench_encoding.rs index 24eecabd407f..b47904375f51 100644 --- a/src/common/benches/bench_encoding.rs +++ b/src/common/benches/bench_encoding.rs @@ -19,10 +19,11 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion}; use risingwave_common::array::{ListValue, StructValue}; use risingwave_common::types::struct_type::StructType; use risingwave_common::types::{ - memcmp_deserialize_datum_from, memcmp_serialize_datum_into, DataType, Datum, IntervalUnit, - NaiveDateTimeWrapper, NaiveDateWrapper, NaiveTimeWrapper, ScalarImpl, + DataType, Datum, IntervalUnit, NaiveDateTimeWrapper, NaiveDateWrapper, NaiveTimeWrapper, + ScalarImpl, }; -use risingwave_common::util::value_encoding; +use risingwave_common::util::sort_util::OrderType; +use risingwave_common::util::{memcmp_encoding, value_encoding}; const ENV_BENCH_SER: &str = "BENCH_SER"; const ENV_BENCH_DE: &str = "BENCH_DE"; @@ -45,9 +46,12 @@ impl Case { } fn key_serialization(datum: &Datum) -> Vec { - let mut serializer = memcomparable::Serializer::new(vec![]); - memcmp_serialize_datum_into(datum, &mut serializer).unwrap(); - black_box(serializer.into_inner()) + let result = memcmp_encoding::encode_value( + datum.as_ref().map(ScalarImpl::as_scalar_ref_impl), + OrderType::default(), + ) + .unwrap(); + black_box(result) } fn value_serialization(datum: &Datum) -> Vec { @@ -55,8 +59,7 @@ fn value_serialization(datum: &Datum) -> Vec { } fn key_deserialization(ty: &DataType, datum: &[u8]) { - let mut deserializer = memcomparable::Deserializer::new(datum); - let result = memcmp_deserialize_datum_from(ty, &mut deserializer); + let result = memcmp_encoding::decode_value(ty, datum, OrderType::default()).unwrap(); let _ = black_box(result); } diff --git a/src/common/src/array/list_array.rs b/src/common/src/array/list_array.rs index 3d5b1e0c3e06..a90d3fb75067 100644 --- a/src/common/src/array/list_array.rs +++ b/src/common/src/array/list_array.rs @@ -27,10 +27,8 @@ use super::{Array, ArrayBuilder, ArrayBuilderImpl, ArrayImpl, ArrayMeta, ArrayRe use crate::buffer::{Bitmap, BitmapBuilder}; use crate::row::Row; use crate::types::to_text::ToText; -use crate::types::{ - hash_datum, memcmp_deserialize_datum_from, memcmp_serialize_datum_into, DataType, Datum, - DatumRef, Scalar, ScalarRefImpl, ToDatumRef, -}; +use crate::types::{hash_datum, DataType, Datum, DatumRef, Scalar, ScalarRefImpl, ToDatumRef}; +use crate::util::memcmp_encoding; #[derive(Debug)] pub struct ListArrayBuilder { @@ -359,7 +357,7 @@ impl ListValue { let mut inner_deserializer = memcomparable::Deserializer::new(bytes.as_slice()); let mut values = Vec::new(); while inner_deserializer.has_remaining() { - values.push(memcmp_deserialize_datum_from( + values.push(memcmp_encoding::deserialize_datum_in_composite( datatype, &mut inner_deserializer, )?) @@ -434,7 +432,7 @@ impl<'a> ListRef<'a> { let mut inner_serializer = memcomparable::Serializer::new(vec![]); iter_elems_ref!(self, it, { for datum_ref in it { - memcmp_serialize_datum_into(datum_ref, &mut inner_serializer)? + memcmp_encoding::serialize_datum_in_composite(datum_ref, &mut inner_serializer)? } }); serializer.serialize_bytes(&inner_serializer.into_inner()) diff --git a/src/common/src/array/struct_array.rs b/src/common/src/array/struct_array.rs index 2e688396623b..964a49273a2e 100644 --- a/src/common/src/array/struct_array.rs +++ b/src/common/src/array/struct_array.rs @@ -26,11 +26,9 @@ use super::{Array, ArrayBuilder, ArrayBuilderImpl, ArrayImpl, ArrayMeta, ArrayRe use crate::array::ArrayRef; use crate::buffer::{Bitmap, BitmapBuilder}; use crate::types::to_text::ToText; -use crate::types::{ - hash_datum, memcmp_deserialize_datum_from, memcmp_serialize_datum_into, DataType, Datum, - DatumRef, Scalar, ScalarRefImpl, ToDatumRef, -}; +use crate::types::{hash_datum, DataType, Datum, DatumRef, Scalar, ScalarRefImpl, ToDatumRef}; use crate::util::iter_util::ZipEqFast; +use crate::util::memcmp_encoding; #[derive(Debug)] pub struct StructArrayBuilder { @@ -345,7 +343,7 @@ impl StructValue { ) -> memcomparable::Result { fields .iter() - .map(|field| memcmp_deserialize_datum_from(field, deserializer)) + .map(|field| memcmp_encoding::deserialize_datum_in_composite(field, deserializer)) .try_collect() .map(Self::new) } @@ -384,7 +382,7 @@ impl<'a> StructRef<'a> { ) -> memcomparable::Result<()> { iter_fields_ref!(self, it, { for datum_ref in it { - memcmp_serialize_datum_into(datum_ref, serializer)? + memcmp_encoding::serialize_datum_in_composite(datum_ref, serializer)? } Ok(()) }) diff --git a/src/common/src/types/mod.rs b/src/common/src/types/mod.rs index c653e297a517..7ab0108dd5cb 100644 --- a/src/common/src/types/mod.rs +++ b/src/common/src/types/mod.rs @@ -555,57 +555,6 @@ for_all_scalar_variants! { scalar_impl_partial_ord } pub type Datum = Option; pub type DatumRef<'a> = Option>; -// TODO(MrCroxx): turn Datum into a struct, and impl ser/de as its member functions. (#477) -// TODO: specify `NULL FIRST` or `NULL LAST`. -pub fn memcmp_serialize_datum_into( - datum: impl ToDatumRef, - serializer: &mut memcomparable::Serializer, -) -> memcomparable::Result<()> { - // By default, `null` is treated as largest in PostgreSQL. - if let Some(datum) = datum.to_datum_ref() { - 0u8.serialize(&mut *serializer)?; - datum.serialize(serializer)?; - } else { - 1u8.serialize(serializer)?; - } - Ok(()) -} - -// TODO(MrCroxx): turn Datum into a struct, and impl ser/de as its member functions. (#477) -#[cfg_attr(not(test), expect(dead_code))] -fn memcmp_serialize_datum_not_null_into( - datum: impl ToDatumRef, - serializer: &mut memcomparable::Serializer, -) -> memcomparable::Result<()> { - datum - .to_datum_ref() - .as_ref() - .expect("datum cannot be null") - .serialize(serializer) -} - -// TODO(MrCroxx): turn Datum into a struct, and impl ser/de as its member functions. (#477) -pub fn memcmp_deserialize_datum_from( - ty: &DataType, - deserializer: &mut memcomparable::Deserializer, -) -> memcomparable::Result { - let null_tag = u8::deserialize(&mut *deserializer)?; - match null_tag { - 1 => Ok(None), - 0 => Ok(Some(ScalarImpl::deserialize(ty, deserializer)?)), - _ => Err(memcomparable::Error::InvalidTagEncoding(null_tag as _)), - } -} - -// TODO(MrCroxx): turn Datum into a struct, and impl ser/de as its member functions. (#477) -#[cfg_attr(not(test), expect(dead_code))] -fn memcmp_deserialize_datum_not_null_from( - ty: &DataType, - deserializer: &mut memcomparable::Deserializer, -) -> memcomparable::Result { - Ok(Some(ScalarImpl::deserialize(ty, deserializer)?)) -} - /// This trait is to implement `to_owned_datum` for `Option` pub trait ToOwnedDatum { /// Convert the datum to an owned [`Datum`]. @@ -1108,64 +1057,6 @@ impl ScalarImpl { }) } - /// Deserialize the `data_size` of `input_data_type` in `storage_encoding`. This function will - /// consume the offset of deserializer then return the length (without memcopy, only length - /// calculation). The difference between `encoding_data_size` and `ScalarImpl::data_size` is - /// that `ScalarImpl::data_size` calculates the `memory_length` of type instead of - /// `storage_encoding` - pub fn encoding_data_size( - data_type: &DataType, - deserializer: &mut memcomparable::Deserializer, - ) -> memcomparable::Result { - let base_position = deserializer.position(); - let null_tag = u8::deserialize(&mut *deserializer)?; - match null_tag { - 1 => {} - 0 => { - use std::mem::size_of; - let len = match data_type { - DataType::Int16 => size_of::(), - DataType::Int32 => size_of::(), - DataType::Int64 => size_of::(), - DataType::Serial => size_of::(), - DataType::Float32 => size_of::(), - DataType::Float64 => size_of::(), - DataType::Date => size_of::(), - DataType::Time => size_of::(), - DataType::Timestamp => size_of::(), - DataType::Timestamptz => size_of::(), - DataType::Boolean => size_of::(), - // IntervalUnit is serialized as (i32, i32, i64) - DataType::Interval => size_of::<(i32, i32, i64)>(), - DataType::Decimal => { - deserializer.deserialize_decimal()?; - 0 // the len is not used since decimal is not a fixed length type - } - // these two types is var-length and should only be determine at runtime. - // TODO: need some test for this case (e.g. e2e test) - DataType::List { .. } => deserializer.skip_bytes()?, - DataType::Struct(t) => t - .fields - .iter() - .map(|field| Self::encoding_data_size(field, deserializer)) - .try_fold(0, |a, b| b.map(|b| a + b))?, - DataType::Jsonb => deserializer.skip_bytes()?, - DataType::Varchar => deserializer.skip_bytes()?, - DataType::Bytea => deserializer.skip_bytes()?, - }; - - // consume offset of fixed_type - if deserializer.position() == base_position + 1 { - // fixed type - deserializer.advance(len); - } - } - _ => return Err(memcomparable::Error::InvalidTagEncoding(null_tag as _)), - } - - Ok(deserializer.position() - base_position) - } - pub fn as_integral(&self) -> i64 { match self { Self::Int16(v) => *v as i64, @@ -1232,89 +1123,12 @@ pub fn literal_type_match(data_type: &DataType, literal: Option<&ScalarImpl>) -> #[cfg(test)] mod tests { use std::hash::{BuildHasher, Hasher}; - use std::ops::Neg; - use itertools::Itertools; - use rand::thread_rng; use strum::IntoEnumIterator; use super::*; use crate::util::hash_util::Crc32FastBuilder; - fn serialize_datum_not_null_into_vec(data: i64) -> Vec { - let mut serializer = memcomparable::Serializer::new(vec![]); - memcmp_serialize_datum_not_null_into(&Some(ScalarImpl::Int64(data)), &mut serializer) - .unwrap(); - serializer.into_inner() - } - - #[test] - fn test_memcomparable() { - let memcmp_minus_1 = serialize_datum_not_null_into_vec(-1); - let memcmp_3874 = serialize_datum_not_null_into_vec(3874); - let memcmp_45745 = serialize_datum_not_null_into_vec(45745); - let memcmp_21234 = serialize_datum_not_null_into_vec(21234); - assert!(memcmp_3874 < memcmp_45745); - assert!(memcmp_3874 < memcmp_21234); - assert!(memcmp_21234 < memcmp_45745); - - assert!(memcmp_minus_1 < memcmp_3874); - assert!(memcmp_minus_1 < memcmp_21234); - assert!(memcmp_minus_1 < memcmp_45745); - } - - #[test] - fn test_issue_legacy_2057_ordered_float_memcomparable() { - use num_traits::*; - use rand::seq::SliceRandom; - - fn serialize(f: OrderedF32) -> Vec { - let mut serializer = memcomparable::Serializer::new(vec![]); - memcmp_serialize_datum_not_null_into(&Some(f.into()), &mut serializer).unwrap(); - serializer.into_inner() - } - - fn deserialize(data: Vec) -> OrderedF32 { - let mut deserializer = memcomparable::Deserializer::new(data.as_slice()); - let datum = - memcmp_deserialize_datum_not_null_from(&DataType::Float32, &mut deserializer) - .unwrap(); - datum.unwrap().try_into().unwrap() - } - - let floats = vec![ - // -inf - OrderedF32::neg_infinity(), - // -1 - OrderedF32::one().neg(), - // 0, -0 should be treated the same - OrderedF32::zero(), - OrderedF32::neg_zero(), - OrderedF32::zero(), - // 1 - OrderedF32::one(), - // inf - OrderedF32::infinity(), - // nan, -nan should be treated the same - OrderedF32::nan(), - OrderedF32::nan().neg(), - OrderedF32::nan(), - ]; - assert!(floats.is_sorted()); - - let mut floats_clone = floats.clone(); - floats_clone.shuffle(&mut thread_rng()); - floats_clone.sort(); - assert_eq!(floats, floats_clone); - - let memcomparables = floats.clone().into_iter().map(serialize).collect_vec(); - assert!(memcomparables.is_sorted()); - - let decoded_floats = memcomparables.into_iter().map(deserialize).collect_vec(); - assert!(decoded_floats.is_sorted()); - assert_eq!(floats, decoded_floats); - } - #[test] fn test_size() { use static_assertions::const_assert_eq; diff --git a/src/common/src/util/encoding_for_comparison.rs b/src/common/src/util/encoding_for_comparison.rs deleted file mode 100644 index 409a3b2f4252..000000000000 --- a/src/common/src/util/encoding_for_comparison.rs +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright 2023 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use itertools::Itertools; - -use super::iter_util::ZipEqFast; -use super::sort_util::Direction; -use crate::array::{ArrayImpl, DataChunk}; -use crate::error::Result; -use crate::row::OwnedRow; -use crate::types::{memcmp_serialize_datum_into, ScalarRefImpl}; -use crate::util::sort_util::{ColumnOrder, OrderType}; - -fn encode_value(value: Option>, order: &OrderType) -> Result> { - let mut serializer = memcomparable::Serializer::new(vec![]); - serializer.set_reverse(order.direction() == Direction::Descending); - memcmp_serialize_datum_into(value, &mut serializer)?; - Ok(serializer.into_inner()) -} - -fn encode_array(array: &ArrayImpl, order: &OrderType) -> Result>> { - let mut data = Vec::with_capacity(array.len()); - for datum in array.iter() { - data.push(encode_value(datum, order)?); - } - Ok(data) -} - -/// This function is used to accelerate the comparison of tuples. It takes datachunk and -/// user-defined order as input, yield encoded binary string with order preserved for each tuple in -/// the datachunk. -/// -/// TODO: specify the order for `NULL`. -pub fn encode_chunk(chunk: &DataChunk, column_orders: &[ColumnOrder]) -> Vec> { - let encoded_columns = column_orders - .iter() - .map(|o| encode_array(chunk.column_at(o.column_index).array_ref(), &o.order_type).unwrap()) - .collect_vec(); - - let mut encoded_chunk = vec![vec![]; chunk.capacity()]; - for encoded_column in encoded_columns { - for (encoded_row, data) in encoded_chunk.iter_mut().zip_eq_fast(encoded_column) { - encoded_row.extend(data); - } - } - - encoded_chunk -} - -pub fn encode_row(row: &OwnedRow, column_orders: &[ColumnOrder]) -> Vec { - let mut encoded_row = vec![]; - column_orders.iter().for_each(|o| { - let value = row[o.column_index].as_ref(); - encoded_row - .extend(encode_value(value.map(|x| x.as_scalar_ref_impl()), &o.order_type).unwrap()); - }); - encoded_row -} - -#[cfg(test)] -mod tests { - use itertools::Itertools; - - use super::{encode_chunk, encode_row, encode_value}; - use crate::array::DataChunk; - use crate::row::OwnedRow; - use crate::types::{DataType, ScalarImpl}; - use crate::util::sort_util::{ColumnOrder, OrderType}; - - #[test] - fn test_encode_row() { - let v10 = Some(ScalarImpl::Int32(42)); - let v10_cloned = v10.clone(); - let v11 = Some(ScalarImpl::Utf8("hello".into())); - let v11_cloned = v11.clone(); - let v12 = Some(ScalarImpl::Float32(4.0.into())); - let v20 = Some(ScalarImpl::Int32(42)); - let v21 = Some(ScalarImpl::Utf8("hell".into())); - let v22 = Some(ScalarImpl::Float32(3.0.into())); - - let row1 = OwnedRow::new(vec![v10, v11, v12]); - let row2 = OwnedRow::new(vec![v20, v21, v22]); - let column_orders = vec![ - ColumnOrder::new(0, OrderType::ascending()), - ColumnOrder::new(1, OrderType::descending()), - ]; - - let encoded_row1 = encode_row(&row1, &column_orders); - let encoded_v10 = encode_value( - v10_cloned.as_ref().map(|x| x.as_scalar_ref_impl()), - &OrderType::ascending(), - ) - .unwrap(); - let encoded_v11 = encode_value( - v11_cloned.as_ref().map(|x| x.as_scalar_ref_impl()), - &OrderType::descending(), - ) - .unwrap(); - let concated_encoded_row1 = encoded_v10 - .into_iter() - .chain(encoded_v11.into_iter()) - .collect_vec(); - assert_eq!(encoded_row1, concated_encoded_row1); - - let encoded_row2 = encode_row(&row2, &column_orders); - assert!(encoded_row1 < encoded_row2); - } - - #[test] - fn test_encode_chunk() { - let v10 = Some(ScalarImpl::Int32(42)); - let v11 = Some(ScalarImpl::Utf8("hello".into())); - let v12 = Some(ScalarImpl::Float32(4.0.into())); - let v20 = Some(ScalarImpl::Int32(42)); - let v21 = Some(ScalarImpl::Utf8("hell".into())); - let v22 = Some(ScalarImpl::Float32(3.0.into())); - - let row1 = OwnedRow::new(vec![v10, v11, v12]); - let row2 = OwnedRow::new(vec![v20, v21, v22]); - let chunk = DataChunk::from_rows( - &[row1.clone(), row2.clone()], - &[DataType::Int32, DataType::Varchar, DataType::Float32], - ); - let column_orders = vec![ - ColumnOrder::new(0, OrderType::ascending()), - ColumnOrder::new(1, OrderType::descending()), - ]; - - let encoded_row1 = encode_row(&row1, &column_orders); - let encoded_row2 = encode_row(&row2, &column_orders); - let encoded_chunk = encode_chunk(&chunk, &column_orders); - assert_eq!(&encoded_chunk, &[encoded_row1, encoded_row2]); - } -} diff --git a/src/common/src/util/memcmp_encoding.rs b/src/common/src/util/memcmp_encoding.rs new file mode 100644 index 000000000000..ea8fec2cab5a --- /dev/null +++ b/src/common/src/util/memcmp_encoding.rs @@ -0,0 +1,558 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use bytes::{Buf, BufMut}; +use itertools::Itertools; +use serde::{Deserialize, Serialize}; + +use super::iter_util::ZipEqFast; +use crate::array::serial_array::Serial; +use crate::array::{ArrayImpl, DataChunk}; +use crate::error::Result; +use crate::row::Row; +use crate::types::{ + DataType, Datum, NaiveDateTimeWrapper, NaiveDateWrapper, NaiveTimeWrapper, OrderedF32, + OrderedF64, ScalarImpl, ToDatumRef, +}; +use crate::util::sort_util::{ColumnOrder, OrderType}; + +// NULL > any non-NULL value by default +const DEFAULT_NULL_TAG_NONE: u8 = 1; +const DEFAULT_NULL_TAG_SOME: u8 = 0; + +pub(crate) fn serialize_datum( + datum: impl ToDatumRef, + order: OrderType, + serializer: &mut memcomparable::Serializer, +) -> memcomparable::Result<()> { + serializer.set_reverse(order.is_descending()); + let (null_tag_none, null_tag_some) = if order.nulls_are_largest() { + (1u8, 0u8) // None > Some + } else { + (0u8, 1u8) // None < Some + }; + if let Some(scalar) = datum.to_datum_ref() { + null_tag_some.serialize(&mut *serializer)?; + scalar.serialize(serializer)?; + } else { + null_tag_none.serialize(serializer)?; + } + Ok(()) +} + +pub(crate) fn serialize_datum_in_composite( + datum: impl ToDatumRef, + serializer: &mut memcomparable::Serializer, +) -> memcomparable::Result<()> { + // NOTE: No need to call `serializer.set_reverse` because we are inside a + // composite type value, we should follow the outside order, except for `NULL`s. + if let Some(scalar) = datum.to_datum_ref() { + DEFAULT_NULL_TAG_SOME.serialize(&mut *serializer)?; + scalar.serialize(serializer)?; + } else { + DEFAULT_NULL_TAG_NONE.serialize(serializer)?; + } + Ok(()) +} + +pub(crate) fn deserialize_datum( + ty: &DataType, + order: OrderType, + deserializer: &mut memcomparable::Deserializer, +) -> memcomparable::Result { + deserializer.set_reverse(order.is_descending()); + let null_tag = u8::deserialize(&mut *deserializer)?; + let (null_tag_none, null_tag_some) = if order.nulls_are_largest() { + (1u8, 0u8) // None > Some + } else { + (0u8, 1u8) // None < Some + }; + if null_tag == null_tag_none { + Ok(None) + } else if null_tag == null_tag_some { + Ok(Some(ScalarImpl::deserialize(ty, deserializer)?)) + } else { + Err(memcomparable::Error::InvalidTagEncoding(null_tag as _)) + } +} + +pub(crate) fn deserialize_datum_in_composite( + ty: &DataType, + deserializer: &mut memcomparable::Deserializer, +) -> memcomparable::Result { + // NOTE: Similar to serialization, we should follow the outside order, except for `NULL`s. + let null_tag = u8::deserialize(&mut *deserializer)?; + if null_tag == DEFAULT_NULL_TAG_NONE { + Ok(None) + } else if null_tag == DEFAULT_NULL_TAG_SOME { + Ok(Some(ScalarImpl::deserialize(ty, deserializer)?)) + } else { + Err(memcomparable::Error::InvalidTagEncoding(null_tag as _)) + } +} + +/// Deserialize the `data_size` of `input_data_type` in `memcmp_encoding`. This function will +/// consume the offset of deserializer then return the length (without memcopy, only length +/// calculation). +pub(crate) fn calculate_encoded_size( + ty: &DataType, + order: OrderType, + encoded_data: &[u8], +) -> memcomparable::Result { + let mut deserializer = memcomparable::Deserializer::new(encoded_data); + let (null_tag_none, null_tag_some) = if order.nulls_are_largest() { + (1u8, 0u8) // None > Some + } else { + (0u8, 1u8) // None < Some + }; + deserializer.set_reverse(order.is_descending()); + calculate_encoded_size_inner(ty, null_tag_none, null_tag_some, &mut deserializer) +} + +fn calculate_encoded_size_inner( + ty: &DataType, + null_tag_none: u8, + null_tag_some: u8, + deserializer: &mut memcomparable::Deserializer, +) -> memcomparable::Result { + let base_position = deserializer.position(); + let null_tag = u8::deserialize(&mut *deserializer)?; + if null_tag == null_tag_none { + // deserialize nothing more + } else if null_tag == null_tag_some { + use std::mem::size_of; + let len = match ty { + DataType::Int16 => size_of::(), + DataType::Int32 => size_of::(), + DataType::Int64 => size_of::(), + DataType::Serial => size_of::(), + DataType::Float32 => size_of::(), + DataType::Float64 => size_of::(), + DataType::Date => size_of::(), + DataType::Time => size_of::(), + DataType::Timestamp => size_of::(), + DataType::Timestamptz => size_of::(), + DataType::Boolean => size_of::(), + // IntervalUnit is serialized as (i32, i32, i64) + DataType::Interval => size_of::<(i32, i32, i64)>(), + DataType::Decimal => { + deserializer.deserialize_decimal()?; + 0 // the len is not used since decimal is not a fixed length type + } + // these two types is var-length and should only be determine at runtime. + // TODO: need some test for this case (e.g. e2e test) + DataType::List { .. } => deserializer.skip_bytes()?, + DataType::Struct(t) => t + .fields + .iter() + .map(|field| { + // use default null tags inside composite type + calculate_encoded_size_inner( + field, + DEFAULT_NULL_TAG_NONE, + DEFAULT_NULL_TAG_SOME, + deserializer, + ) + }) + .try_fold(0, |a, b| b.map(|b| a + b))?, + DataType::Jsonb => deserializer.skip_bytes()?, + DataType::Varchar => deserializer.skip_bytes()?, + DataType::Bytea => deserializer.skip_bytes()?, + }; + + // consume offset of fixed_type + if deserializer.position() == base_position + 1 { + // fixed type + deserializer.advance(len); + } + } else { + return Err(memcomparable::Error::InvalidTagEncoding(null_tag as _)); + } + + Ok(deserializer.position() - base_position) +} + +pub fn encode_value(value: impl ToDatumRef, order: OrderType) -> Result> { + let mut serializer = memcomparable::Serializer::new(vec![]); + serialize_datum(value, order, &mut serializer)?; + Ok(serializer.into_inner()) +} + +pub fn decode_value(ty: &DataType, encoded_value: &[u8], order: OrderType) -> Result { + let mut deserializer = memcomparable::Deserializer::new(encoded_value); + Ok(deserialize_datum(ty, order, &mut deserializer)?) +} + +pub fn encode_array(array: &ArrayImpl, order: OrderType) -> Result>> { + let mut data = Vec::with_capacity(array.len()); + for datum in array.iter() { + data.push(encode_value(datum, order)?); + } + Ok(data) +} + +/// This function is used to accelerate the comparison of tuples. It takes datachunk and +/// user-defined order as input, yield encoded binary string with order preserved for each tuple in +/// the datachunk. +pub fn encode_chunk(chunk: &DataChunk, column_orders: &[ColumnOrder]) -> Vec> { + let encoded_columns = column_orders + .iter() + .map(|o| encode_array(chunk.column_at(o.column_index).array_ref(), o.order_type).unwrap()) + .collect_vec(); + + let mut encoded_chunk = vec![vec![]; chunk.capacity()]; + for encoded_column in encoded_columns { + for (encoded_row, data) in encoded_chunk.iter_mut().zip_eq_fast(encoded_column) { + encoded_row.extend(data); + } + } + + encoded_chunk +} + +/// Encode a row into memcomparable format. +pub fn encode_row(row: impl Row, column_orders: &[ColumnOrder]) -> Vec { + let mut encoded_row = vec![]; + column_orders.iter().for_each(|o| { + encoded_row.extend(encode_value(row.datum_at(o.column_index), o.order_type).unwrap()); + }); + encoded_row +} + +#[cfg(test)] +mod tests { + use std::ops::Neg; + + use itertools::Itertools; + use rand::thread_rng; + + use super::*; + use crate::array::{DataChunk, ListValue, StructValue}; + use crate::row::OwnedRow; + use crate::types::{DataType, OrderedF32, ScalarImpl}; + use crate::util::sort_util::{ColumnOrder, OrderType}; + + #[test] + fn test_memcomparable() { + fn encode_num(num: Option, order_type: OrderType) -> Vec { + encode_value(num.map(ScalarImpl::from), order_type).unwrap() + } + + { + // default ascending + let order_type = OrderType::ascending(); + let memcmp_minus_1 = encode_num(Some(-1), order_type); + let memcmp_3874 = encode_num(Some(3874), order_type); + let memcmp_45745 = encode_num(Some(45745), order_type); + let memcmp_i32_min = encode_num(Some(i32::MIN), order_type); + let memcmp_i32_max = encode_num(Some(i32::MAX), order_type); + let memcmp_none = encode_num(None, order_type); + + assert!(memcmp_3874 < memcmp_45745); + assert!(memcmp_3874 < memcmp_i32_max); + assert!(memcmp_45745 < memcmp_i32_max); + + assert!(memcmp_i32_min < memcmp_i32_max); + assert!(memcmp_i32_min < memcmp_3874); + assert!(memcmp_i32_min < memcmp_45745); + + assert!(memcmp_minus_1 < memcmp_3874); + assert!(memcmp_minus_1 < memcmp_45745); + assert!(memcmp_minus_1 < memcmp_i32_max); + assert!(memcmp_minus_1 > memcmp_i32_min); + + assert!(memcmp_none > memcmp_minus_1); + assert!(memcmp_none > memcmp_3874); + assert!(memcmp_none > memcmp_i32_min); + assert!(memcmp_none > memcmp_i32_max); + } + { + // default descending + let order_type = OrderType::descending(); + let memcmp_minus_1 = encode_num(Some(-1), order_type); + let memcmp_3874 = encode_num(Some(3874), order_type); + let memcmp_none = encode_num(None, order_type); + + assert!(memcmp_none < memcmp_minus_1); + assert!(memcmp_none < memcmp_3874); + assert!(memcmp_3874 < memcmp_minus_1); + } + { + // ASC NULLS FIRST (NULLS SMALLEST) + let order_type = OrderType::ascending_nulls_first(); + let memcmp_minus_1 = encode_num(Some(-1), order_type); + let memcmp_3874 = encode_num(Some(3874), order_type); + let memcmp_none = encode_num(None, order_type); + assert!(memcmp_none < memcmp_minus_1); + assert!(memcmp_none < memcmp_3874); + } + { + // ASC NULLS LAST (NULLS LARGEST) + let order_type = OrderType::ascending_nulls_last(); + let memcmp_minus_1 = encode_num(Some(-1), order_type); + let memcmp_3874 = encode_num(Some(3874), order_type); + let memcmp_none = encode_num(None, order_type); + assert!(memcmp_none > memcmp_minus_1); + assert!(memcmp_none > memcmp_3874); + } + { + // DESC NULLS FIRST (NULLS LARGEST) + let order_type = OrderType::descending_nulls_first(); + let memcmp_minus_1 = encode_num(Some(-1), order_type); + let memcmp_3874 = encode_num(Some(3874), order_type); + let memcmp_none = encode_num(None, order_type); + assert!(memcmp_none < memcmp_minus_1); + assert!(memcmp_none < memcmp_3874); + } + { + // DESC NULLS LAST (NULLS SMALLEST) + let order_type = OrderType::descending_nulls_last(); + let memcmp_minus_1 = encode_num(Some(-1), order_type); + let memcmp_3874 = encode_num(Some(3874), order_type); + let memcmp_none = encode_num(None, order_type); + assert!(memcmp_none > memcmp_minus_1); + assert!(memcmp_none > memcmp_3874); + } + } + + #[test] + fn test_memcomparable_structs() { + // NOTE: `NULL`s inside composite type values are always the largest. + + let struct_none = None; + let struct_1 = Some( + StructValue::new(vec![Some(ScalarImpl::from(1)), Some(ScalarImpl::from(2))]).into(), + ); + let struct_2 = Some( + StructValue::new(vec![Some(ScalarImpl::from(1)), Some(ScalarImpl::from(3))]).into(), + ); + let struct_3 = Some(StructValue::new(vec![Some(ScalarImpl::from(1)), None]).into()); + + { + // ASC NULLS FIRST (NULLS SMALLEST) + let order_type = OrderType::ascending_nulls_first(); + let memcmp_struct_none = encode_value(&struct_none, order_type).unwrap(); + let memcmp_struct_1 = encode_value(&struct_1, order_type).unwrap(); + let memcmp_struct_2 = encode_value(&struct_2, order_type).unwrap(); + let memcmp_struct_3 = encode_value(&struct_3, order_type).unwrap(); + assert!(memcmp_struct_none < memcmp_struct_1); + assert!(memcmp_struct_1 < memcmp_struct_2); + assert!(memcmp_struct_2 < memcmp_struct_3); + } + { + // ASC NULLS LAST (NULLS LARGEST) + let order_type = OrderType::ascending_nulls_last(); + let memcmp_struct_none = encode_value(&struct_none, order_type).unwrap(); + let memcmp_struct_1 = encode_value(&struct_1, order_type).unwrap(); + let memcmp_struct_2 = encode_value(&struct_2, order_type).unwrap(); + let memcmp_struct_3 = encode_value(&struct_3, order_type).unwrap(); + assert!(memcmp_struct_1 < memcmp_struct_2); + assert!(memcmp_struct_2 < memcmp_struct_3); + assert!(memcmp_struct_3 < memcmp_struct_none); + } + { + // DESC NULLS FIRST (NULLS LARGEST) + let order_type = OrderType::descending_nulls_first(); + let memcmp_struct_none = encode_value(&struct_none, order_type).unwrap(); + let memcmp_struct_1 = encode_value(&struct_1, order_type).unwrap(); + let memcmp_struct_2 = encode_value(&struct_2, order_type).unwrap(); + let memcmp_struct_3 = encode_value(&struct_3, order_type).unwrap(); + assert!(memcmp_struct_none < memcmp_struct_3); + assert!(memcmp_struct_3 < memcmp_struct_2); + assert!(memcmp_struct_2 < memcmp_struct_1); + } + { + // DESC NULLS LAST (NULLS SMALLEST) + let order_type = OrderType::descending_nulls_last(); + let memcmp_struct_none = encode_value(&struct_none, order_type).unwrap(); + let memcmp_struct_1 = encode_value(&struct_1, order_type).unwrap(); + let memcmp_struct_2 = encode_value(&struct_2, order_type).unwrap(); + let memcmp_struct_3 = encode_value(&struct_3, order_type).unwrap(); + assert!(memcmp_struct_3 < memcmp_struct_2); + assert!(memcmp_struct_2 < memcmp_struct_1); + assert!(memcmp_struct_1 < memcmp_struct_none); + } + } + + #[test] + fn test_memcomparable_lists() { + // NOTE: `NULL`s inside composite type values are always the largest. + + let list_none = None; + let list_1 = + Some(ListValue::new(vec![Some(ScalarImpl::from(1)), Some(ScalarImpl::from(2))]).into()); + let list_2 = + Some(ListValue::new(vec![Some(ScalarImpl::from(1)), Some(ScalarImpl::from(3))]).into()); + let list_3 = Some(ListValue::new(vec![Some(ScalarImpl::from(1)), None]).into()); + + { + // ASC NULLS FIRST (NULLS SMALLEST) + let order_type = OrderType::ascending_nulls_first(); + let memcmp_list_none = encode_value(&list_none, order_type).unwrap(); + let memcmp_list_1 = encode_value(&list_1, order_type).unwrap(); + let memcmp_list_2 = encode_value(&list_2, order_type).unwrap(); + let memcmp_list_3 = encode_value(&list_3, order_type).unwrap(); + assert!(memcmp_list_none < memcmp_list_1); + assert!(memcmp_list_1 < memcmp_list_2); + assert!(memcmp_list_2 < memcmp_list_3); + } + { + // ASC NULLS LAST (NULLS LARGEST) + let order_type = OrderType::ascending_nulls_last(); + let memcmp_list_none = encode_value(&list_none, order_type).unwrap(); + let memcmp_list_1 = encode_value(&list_1, order_type).unwrap(); + let memcmp_list_2 = encode_value(&list_2, order_type).unwrap(); + let memcmp_list_3 = encode_value(&list_3, order_type).unwrap(); + assert!(memcmp_list_1 < memcmp_list_2); + assert!(memcmp_list_2 < memcmp_list_3); + assert!(memcmp_list_3 < memcmp_list_none); + } + { + // DESC NULLS FIRST (NULLS LARGEST) + let order_type = OrderType::descending_nulls_first(); + let memcmp_list_none = encode_value(&list_none, order_type).unwrap(); + let memcmp_list_1 = encode_value(&list_1, order_type).unwrap(); + let memcmp_list_2 = encode_value(&list_2, order_type).unwrap(); + let memcmp_list_3 = encode_value(&list_3, order_type).unwrap(); + assert!(memcmp_list_none < memcmp_list_3); + assert!(memcmp_list_3 < memcmp_list_2); + assert!(memcmp_list_2 < memcmp_list_1); + } + { + // DESC NULLS LAST (NULLS SMALLEST) + let order_type = OrderType::descending_nulls_last(); + let memcmp_list_none = encode_value(&list_none, order_type).unwrap(); + let memcmp_list_1 = encode_value(&list_1, order_type).unwrap(); + let memcmp_list_2 = encode_value(&list_2, order_type).unwrap(); + let memcmp_list_3 = encode_value(&list_3, order_type).unwrap(); + assert!(memcmp_list_3 < memcmp_list_2); + assert!(memcmp_list_2 < memcmp_list_1); + assert!(memcmp_list_1 < memcmp_list_none); + } + } + + #[test] + fn test_issue_legacy_2057_ordered_float_memcomparable() { + use num_traits::*; + use rand::seq::SliceRandom; + + fn serialize(f: OrderedF32) -> Vec { + encode_value(&Some(ScalarImpl::from(f)), OrderType::default()).unwrap() + } + + fn deserialize(data: Vec) -> OrderedF32 { + decode_value(&DataType::Float32, &data, OrderType::default()) + .unwrap() + .unwrap() + .into_float32() + } + + let floats = vec![ + // -inf + OrderedF32::neg_infinity(), + // -1 + OrderedF32::one().neg(), + // 0, -0 should be treated the same + OrderedF32::zero(), + OrderedF32::neg_zero(), + OrderedF32::zero(), + // 1 + OrderedF32::one(), + // inf + OrderedF32::infinity(), + // nan, -nan should be treated the same + OrderedF32::nan(), + OrderedF32::nan().neg(), + OrderedF32::nan(), + ]; + assert!(floats.is_sorted()); + + let mut floats_clone = floats.clone(); + floats_clone.shuffle(&mut thread_rng()); + floats_clone.sort(); + assert_eq!(floats, floats_clone); + + let memcomparables = floats.clone().into_iter().map(serialize).collect_vec(); + assert!(memcomparables.is_sorted()); + + let decoded_floats = memcomparables.into_iter().map(deserialize).collect_vec(); + assert!(decoded_floats.is_sorted()); + assert_eq!(floats, decoded_floats); + } + + #[test] + fn test_encode_row() { + let v10 = Some(ScalarImpl::Int32(42)); + let v10_cloned = v10.clone(); + let v11 = Some(ScalarImpl::Utf8("hello".into())); + let v11_cloned = v11.clone(); + let v12 = Some(ScalarImpl::Float32(4.0.into())); + let v20 = Some(ScalarImpl::Int32(42)); + let v21 = Some(ScalarImpl::Utf8("hell".into())); + let v22 = Some(ScalarImpl::Float32(3.0.into())); + + let row1 = OwnedRow::new(vec![v10, v11, v12]); + let row2 = OwnedRow::new(vec![v20, v21, v22]); + let column_orders = vec![ + ColumnOrder::new(0, OrderType::ascending()), + ColumnOrder::new(1, OrderType::descending()), + ]; + + let encoded_row1 = encode_row(&row1, &column_orders); + let encoded_v10 = encode_value( + v10_cloned.as_ref().map(|x| x.as_scalar_ref_impl()), + OrderType::ascending(), + ) + .unwrap(); + let encoded_v11 = encode_value( + v11_cloned.as_ref().map(|x| x.as_scalar_ref_impl()), + OrderType::descending(), + ) + .unwrap(); + let concated_encoded_row1 = encoded_v10 + .into_iter() + .chain(encoded_v11.into_iter()) + .collect_vec(); + assert_eq!(encoded_row1, concated_encoded_row1); + + let encoded_row2 = encode_row(&row2, &column_orders); + assert!(encoded_row1 < encoded_row2); + } + + #[test] + fn test_encode_chunk() { + let v10 = Some(ScalarImpl::Int32(42)); + let v11 = Some(ScalarImpl::Utf8("hello".into())); + let v12 = Some(ScalarImpl::Float32(4.0.into())); + let v20 = Some(ScalarImpl::Int32(42)); + let v21 = Some(ScalarImpl::Utf8("hell".into())); + let v22 = Some(ScalarImpl::Float32(3.0.into())); + + let row1 = OwnedRow::new(vec![v10, v11, v12]); + let row2 = OwnedRow::new(vec![v20, v21, v22]); + let chunk = DataChunk::from_rows( + &[row1.clone(), row2.clone()], + &[DataType::Int32, DataType::Varchar, DataType::Float32], + ); + let column_orders = vec![ + ColumnOrder::new(0, OrderType::ascending()), + ColumnOrder::new(1, OrderType::descending()), + ]; + + let encoded_row1 = encode_row(&row1, &column_orders); + let encoded_row2 = encode_row(&row2, &column_orders); + let encoded_chunk = encode_chunk(&chunk, &column_orders); + assert_eq!(&encoded_chunk, &[encoded_row1, encoded_row2]); + } +} diff --git a/src/common/src/util/mod.rs b/src/common/src/util/mod.rs index 0420b0c59690..9c7d214c660c 100644 --- a/src/common/src/util/mod.rs +++ b/src/common/src/util/mod.rs @@ -23,12 +23,12 @@ pub mod addr; pub mod chunk_coalesce; pub mod column_index_mapping; pub mod compress; -pub mod encoding_for_comparison; pub mod env_var; pub mod epoch; mod future_utils; pub mod hash_util; pub mod iter_util; +pub mod memcmp_encoding; pub mod ordered; pub mod prost; pub mod resource_util; diff --git a/src/common/src/util/ordered/mod.rs b/src/common/src/util/ordered/mod.rs index 08e10965fda5..394ceee2fe31 100644 --- a/src/common/src/util/ordered/mod.rs +++ b/src/common/src/util/ordered/mod.rs @@ -14,159 +14,4 @@ mod serde; -use std::cmp::Reverse; - -use OrderedDatum::{NormalOrder, ReversedOrder}; - pub use self::serde::*; -use super::iter_util::ZipEqFast; -use super::sort_util::Direction; -use crate::row::OwnedRow; -use crate::types::{memcmp_serialize_datum_into, Datum}; -use crate::util::sort_util::OrderType; - -// TODO(rc): support `NULLS FIRST | LAST` -#[derive(Clone, Eq, PartialEq, Ord, PartialOrd)] -pub enum OrderedDatum { - NormalOrder(Datum), - ReversedOrder(Reverse), -} - -impl std::fmt::Debug for OrderedDatum { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - NormalOrder(d) => match d { - Some(s) => write!(f, "{:?}", s), - None => write!(f, "NULL"), - }, - ReversedOrder(d) => match &d.0 { - Some(s) => write!(f, "{:?}", s), - None => write!(f, "NULL"), - }, - } - } -} - -/// `OrderedRow` is used for the pk in those states whose primary key contains several columns and -/// requires comparison. -#[derive(Clone, Eq, PartialEq, Ord, PartialOrd)] -pub struct OrderedRow(Vec); - -impl std::fmt::Debug for OrderedRow { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.0.fmt(f) - } -} - -impl OrderedRow { - pub fn new(row: OwnedRow, order_types: &[OrderType]) -> Self { - OrderedRow( - row.into_inner() - .into_iter() - .zip_eq_fast(order_types.iter()) - .map(|(datum, order_type)| match order_type.direction() { - Direction::Ascending => NormalOrder(datum), - Direction::Descending => ReversedOrder(Reverse(datum)), - }) - .collect::>(), - ) - } - - pub fn into_vec(self) -> Vec { - self.0 - .into_iter() - .map(|ordered_datum| match ordered_datum { - NormalOrder(datum) => datum, - ReversedOrder(datum) => datum.0, - }) - .collect::>() - } - - pub fn into_row(self) -> OwnedRow { - OwnedRow::new(self.into_vec()) - } - - /// Serialize the row into a memcomparable bytes. - /// - /// All values are nullable. Each value will have 1 extra byte to indicate whether it is null. - pub fn serialize(&self) -> Result, memcomparable::Error> { - let mut serializer = memcomparable::Serializer::new(vec![]); - for v in &self.0 { - let datum = match v { - NormalOrder(datum) => { - serializer.set_reverse(false); - datum - } - ReversedOrder(datum) => { - serializer.set_reverse(true); - &datum.0 - } - }; - memcmp_serialize_datum_into(datum, &mut serializer)?; - } - Ok(serializer.into_inner()) - } - - pub fn reverse_serialize(&self) -> Result, memcomparable::Error> { - let mut res = self.serialize()?; - res.iter_mut().for_each(|byte| *byte = !*byte); - Ok(res) - } - - pub fn prefix(&self, n: usize) -> Self { - assert!(n <= self.0.len()); - OrderedRow(self.0[..n].to_vec()) - } - - pub fn starts_with(&self, other: &Self) -> bool { - self.0.starts_with(&other.0) - } - - pub fn skip(&self, n: usize) -> Self { - assert!(n < self.0.len()); - OrderedRow(self.0[n..].to_vec()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::types::ScalarImpl; - - fn make_row(values: Vec) -> OrderedRow { - let row = OwnedRow::new( - values - .into_iter() - .map(|v| Some(ScalarImpl::Int64(v))) - .collect(), - ); - OrderedRow::new(row, ORDER_TYPES) - } - - const ORDER_TYPES: &[OrderType] = &[ - OrderType::ascending(), - OrderType::descending(), - OrderType::ascending(), - ]; - - #[test] - fn test_prefix() { - let row = make_row(vec![1, 2, 3]); - - assert!(row.prefix(0) < row.prefix(1)); - assert!(row.prefix(1) < row.prefix(2)); - assert!(row.prefix(2) < row.prefix(3)); - assert_eq!(row.prefix(3), row); - - let row2 = make_row(vec![1, 3, 3]); - assert!(row.prefix(1) < row2); - assert!(row.prefix(2) > row2); - } - - #[should_panic] - #[test] - fn test_prefix_panic() { - let row = make_row(vec![1, 2, 3]); - row.prefix(4); - } -} diff --git a/src/common/src/util/ordered/serde.rs b/src/common/src/util/ordered/serde.rs index 2cff721756f4..6cafbf6b782b 100644 --- a/src/common/src/util/ordered/serde.rs +++ b/src/common/src/util/ordered/serde.rs @@ -17,11 +17,10 @@ use std::borrow::Cow; use bytes::BufMut; use crate::row::{OwnedRow, Row}; -use crate::types::{ - memcmp_deserialize_datum_from, memcmp_serialize_datum_into, DataType, ToDatumRef, -}; +use crate::types::{DataType, ToDatumRef}; use crate::util::iter_util::{ZipEqDebug, ZipEqFast}; -use crate::util::sort_util::{Direction, OrderType}; +use crate::util::memcmp_encoding; +use crate::util::sort_util::OrderType; /// `OrderedRowSerde` is responsible for serializing and deserializing Ordered Row. #[derive(Clone)] @@ -64,19 +63,21 @@ impl OrderedRowSerde { datum_refs: impl Iterator, mut append_to: impl BufMut, ) { - for (datum, order_type) in datum_refs.zip_eq_debug(self.order_types.iter()) { - let mut serializer = memcomparable::Serializer::new(&mut append_to); - serializer.set_reverse(order_type.direction() == Direction::Descending); - memcmp_serialize_datum_into(datum, &mut serializer).unwrap(); + let mut serializer = memcomparable::Serializer::new(&mut append_to); + for (datum, order) in datum_refs.zip_eq_debug(self.order_types.iter().copied()) { + memcmp_encoding::serialize_datum(datum, order, &mut serializer).unwrap(); } } pub fn deserialize(&self, data: &[u8]) -> memcomparable::Result { let mut values = Vec::with_capacity(self.schema.len()); let mut deserializer = memcomparable::Deserializer::new(data); - for (data_type, order_type) in self.schema.iter().zip_eq_fast(self.order_types.iter()) { - deserializer.set_reverse(order_type.direction() == Direction::Descending); - let datum = memcmp_deserialize_datum_from(data_type, &mut deserializer)?; + for (data_type, order) in self + .schema + .iter() + .zip_eq_fast(self.order_types.iter().copied()) + { + let datum = memcmp_encoding::deserialize_datum(data_type, order, &mut deserializer)?; values.push(datum); } Ok(OwnedRow::new(values)) @@ -95,18 +96,13 @@ impl OrderedRowSerde { key: &[u8], prefix_len: usize, ) -> memcomparable::Result { - use crate::types::ScalarImpl; let mut len: usize = 0; for index in 0..prefix_len { let data_type = &self.schema[index]; - let order_type = &self.order_types[index]; let data = &key[len..]; - let mut deserializer = memcomparable::Deserializer::new(data); - deserializer.set_reverse(order_type.direction() == Direction::Descending); - - len += ScalarImpl::encoding_data_size(data_type, &mut deserializer)?; + len += + memcmp_encoding::calculate_encoded_size(data_type, self.order_types[index], data)?; } - Ok(len) } } @@ -247,7 +243,7 @@ mod tests { let order_types = vec![OrderType::ascending()]; let schema = vec![DataType::Int16]; - let serde = OrderedRowSerde::new(schema, order_types); + let serde = OrderedRowSerde::new(schema, order_types.clone()); // test fixed_size { @@ -256,9 +252,12 @@ mod tests { let row = OwnedRow::new(vec![None]); let mut row_bytes = vec![]; serde.serialize(&row, &mut row_bytes); - let mut deserializer = memcomparable::Deserializer::new(&row_bytes[..]); - let encoding_data_size = - ScalarImpl::encoding_data_size(&DataType::Int16, &mut deserializer).unwrap(); + let encoding_data_size = memcmp_encoding::calculate_encoded_size( + &DataType::Int16, + order_types[0], + &row_bytes[..], + ) + .unwrap(); assert_eq!(1, encoding_data_size); } @@ -267,9 +266,12 @@ mod tests { let row = OwnedRow::new(vec![Some(ScalarImpl::Float64(6.4.into()))]); let mut row_bytes = vec![]; serde.serialize(&row, &mut row_bytes); - let mut deserializer = memcomparable::Deserializer::new(&row_bytes[..]); - let encoding_data_size = - ScalarImpl::encoding_data_size(&DataType::Float64, &mut deserializer).unwrap(); + let encoding_data_size = memcmp_encoding::calculate_encoded_size( + &DataType::Float64, + order_types[0], + &row_bytes[..], + ) + .unwrap(); let data_size = size_of::(); assert_eq!(8, data_size); assert_eq!(1 + data_size, encoding_data_size); @@ -280,9 +282,12 @@ mod tests { let row = OwnedRow::new(vec![Some(ScalarImpl::Bool(false))]); let mut row_bytes = vec![]; serde.serialize(&row, &mut row_bytes); - let mut deserializer = memcomparable::Deserializer::new(&row_bytes[..]); - let encoding_data_size = - ScalarImpl::encoding_data_size(&DataType::Boolean, &mut deserializer).unwrap(); + let encoding_data_size = memcmp_encoding::calculate_encoded_size( + &DataType::Boolean, + order_types[0], + &row_bytes[..], + ) + .unwrap(); let data_size = size_of::(); assert_eq!(1, data_size); @@ -296,10 +301,12 @@ mod tests { ))]); let mut row_bytes = vec![]; serde.serialize(&row, &mut row_bytes); - let mut deserializer = memcomparable::Deserializer::new(&row_bytes[..]); - let encoding_data_size = - ScalarImpl::encoding_data_size(&DataType::Timestamp, &mut deserializer) - .unwrap(); + let encoding_data_size = memcmp_encoding::calculate_encoded_size( + &DataType::Timestamp, + order_types[0], + &row_bytes[..], + ) + .unwrap(); let data_size = size_of::(); assert_eq!(12, data_size); assert_eq!(1 + data_size, encoding_data_size); @@ -310,10 +317,12 @@ mod tests { let row = OwnedRow::new(vec![Some(ScalarImpl::Int64(1111111111))]); let mut row_bytes = vec![]; serde.serialize(&row, &mut row_bytes); - let mut deserializer = memcomparable::Deserializer::new(&row_bytes[..]); - let encoding_data_size = - ScalarImpl::encoding_data_size(&DataType::Timestamptz, &mut deserializer) - .unwrap(); + let encoding_data_size = memcmp_encoding::calculate_encoded_size( + &DataType::Timestamptz, + order_types[0], + &row_bytes[..], + ) + .unwrap(); let data_size = size_of::(); assert_eq!(8, data_size); assert_eq!(1 + data_size, encoding_data_size); @@ -326,9 +335,12 @@ mod tests { ))]); let mut row_bytes = vec![]; serde.serialize(&row, &mut row_bytes); - let mut deserializer = memcomparable::Deserializer::new(&row_bytes[..]); - let encoding_data_size = - ScalarImpl::encoding_data_size(&DataType::Interval, &mut deserializer).unwrap(); + let encoding_data_size = memcmp_encoding::calculate_encoded_size( + &DataType::Interval, + order_types[0], + &row_bytes[..], + ) + .unwrap(); let data_size = size_of::(); assert_eq!(16, data_size); assert_eq!(1 + data_size, encoding_data_size); @@ -346,10 +358,12 @@ mod tests { let row = OwnedRow::new(vec![Some(ScalarImpl::Decimal(d))]); let mut row_bytes = vec![]; serde.serialize(&row, &mut row_bytes); - let mut deserializer = memcomparable::Deserializer::new(&row_bytes[..]); - let encoding_data_size = - ScalarImpl::encoding_data_size(&DataType::Decimal, &mut deserializer) - .unwrap(); + let encoding_data_size = memcmp_encoding::calculate_encoded_size( + &DataType::Decimal, + order_types[0], + &row_bytes[..], + ) + .unwrap(); // [nulltag, flag, decimal_chunk] assert_eq!(17, encoding_data_size); } @@ -359,10 +373,12 @@ mod tests { let row = OwnedRow::new(vec![Some(ScalarImpl::Decimal(d))]); let mut row_bytes = vec![]; serde.serialize(&row, &mut row_bytes); - let mut deserializer = memcomparable::Deserializer::new(&row_bytes[..]); - let encoding_data_size = - ScalarImpl::encoding_data_size(&DataType::Decimal, &mut deserializer) - .unwrap(); + let encoding_data_size = memcmp_encoding::calculate_encoded_size( + &DataType::Decimal, + order_types[0], + &row_bytes[..], + ) + .unwrap(); // [nulltag, flag, decimal_chunk] assert_eq!(3, encoding_data_size); } @@ -372,10 +388,12 @@ mod tests { let row = OwnedRow::new(vec![Some(ScalarImpl::Decimal(d))]); let mut row_bytes = vec![]; serde.serialize(&row, &mut row_bytes); - let mut deserializer = memcomparable::Deserializer::new(&row_bytes[..]); - let encoding_data_size = - ScalarImpl::encoding_data_size(&DataType::Decimal, &mut deserializer) - .unwrap(); + let encoding_data_size = memcmp_encoding::calculate_encoded_size( + &DataType::Decimal, + order_types[0], + &row_bytes[..], + ) + .unwrap(); assert_eq!(2, encoding_data_size); // [1, 35] } @@ -385,10 +403,12 @@ mod tests { let row = OwnedRow::new(vec![Some(ScalarImpl::Decimal(d))]); let mut row_bytes = vec![]; serde.serialize(&row, &mut row_bytes); - let mut deserializer = memcomparable::Deserializer::new(&row_bytes[..]); - let encoding_data_size = - ScalarImpl::encoding_data_size(&DataType::Decimal, &mut deserializer) - .unwrap(); + let encoding_data_size = memcmp_encoding::calculate_encoded_size( + &DataType::Decimal, + order_types[0], + &row_bytes[..], + ) + .unwrap(); assert_eq!(2, encoding_data_size); // [1, 6] } @@ -402,10 +422,12 @@ mod tests { let row = OwnedRow::new(vec![Some(Utf8(varchar.into()))]); let mut row_bytes = vec![]; serde.serialize(&row, &mut row_bytes); - let mut deserializer = memcomparable::Deserializer::new(&row_bytes[..]); - let encoding_data_size = - ScalarImpl::encoding_data_size(&DataType::Varchar, &mut deserializer) - .unwrap(); + let encoding_data_size = memcmp_encoding::calculate_encoded_size( + &DataType::Varchar, + order_types[0], + &row_bytes[..], + ) + .unwrap(); // [1, 1, 97, 98, 99, 100, 101, 102, 103, 104, 9, 105, 106, 107, 108, 109, 110, // 0, 0, 6] assert_eq!(6 + varchar.len(), encoding_data_size); @@ -416,16 +438,17 @@ mod tests { // test varchar Descending let order_types = vec![OrderType::descending()]; let schema = vec![DataType::Varchar]; - let serde = OrderedRowSerde::new(schema, order_types); + let serde = OrderedRowSerde::new(schema, order_types.clone()); let varchar = "abcdefghijklmnopq"; let row = OwnedRow::new(vec![Some(Utf8(varchar.into()))]); let mut row_bytes = vec![]; serde.serialize(&row, &mut row_bytes); - let mut deserializer = memcomparable::Deserializer::new(&row_bytes[..]); - deserializer.set_reverse(true); - let encoding_data_size = - ScalarImpl::encoding_data_size(&DataType::Varchar, &mut deserializer) - .unwrap(); + let encoding_data_size = memcmp_encoding::calculate_encoded_size( + &DataType::Varchar, + order_types[0], + &row_bytes[..], + ) + .unwrap(); // [254, 254, 158, 157, 156, 155, 154, 153, 152, 151, 246, 150, 149, 148, // 147, 146, 145, 144, 143, 246, 142, 255, 255, 255, 255, 255, 255, 255, diff --git a/src/common/src/util/sort_util.rs b/src/common/src/util/sort_util.rs index 65bae034a26b..f795382abf2d 100644 --- a/src/common/src/util/sort_util.rs +++ b/src/common/src/util/sort_util.rs @@ -17,17 +17,17 @@ use std::fmt; use std::sync::Arc; use parse_display::Display; -use risingwave_pb::common::{PbColumnOrder, PbDirection, PbOrderType}; +use risingwave_pb::common::{PbColumnOrder, PbDirection, PbNullsAre, PbOrderType}; use crate::array::{Array, ArrayImpl, DataChunk}; use crate::catalog::{FieldDisplay, Schema}; use crate::error::ErrorCode::InternalError; use crate::error::Result; +use crate::types::ToDatumRef; -// TODO(rc): to support `NULLS FIRST | LAST`, we may need to hide this enum, forcing developers use -// `OrderType` instead. +/// Sort direction, ascending/descending. #[derive(PartialEq, Eq, Hash, Copy, Clone, Debug, Display, Default)] -pub enum Direction { +enum Direction { #[default] #[display("ASC")] Ascending, @@ -36,25 +36,25 @@ pub enum Direction { } impl Direction { - pub fn from_protobuf(order_type: &PbDirection) -> Direction { - match order_type { - PbDirection::Ascending => Direction::Ascending, - PbDirection::Descending => Direction::Descending, + pub fn from_protobuf(direction: &PbDirection) -> Self { + match direction { + PbDirection::Ascending => Self::Ascending, + PbDirection::Descending => Self::Descending, PbDirection::Unspecified => unreachable!(), } } pub fn to_protobuf(self) -> PbDirection { match self { - Direction::Ascending => PbDirection::Ascending, - Direction::Descending => PbDirection::Descending, + Self::Ascending => PbDirection::Ascending, + Self::Descending => PbDirection::Descending, } } } -#[allow(dead_code)] +/// Nulls are largest/smallest. #[derive(PartialEq, Eq, Hash, Copy, Clone, Debug, Display, Default)] -pub enum NullsAre { +enum NullsAre { #[default] #[display("LARGEST")] Largest, @@ -62,55 +62,159 @@ pub enum NullsAre { Smallest, } +impl NullsAre { + pub fn from_protobuf(nulls_are: &PbNullsAre) -> Self { + match nulls_are { + PbNullsAre::Largest => Self::Largest, + PbNullsAre::Smallest => Self::Smallest, + PbNullsAre::Unspecified => unreachable!(), + } + } + + pub fn to_protobuf(self) -> PbNullsAre { + match self { + Self::Largest => PbNullsAre::Largest, + Self::Smallest => PbNullsAre::Smallest, + } + } +} + +/// Order type of a column. #[derive(PartialEq, Eq, Hash, Copy, Clone, Debug, Default)] pub struct OrderType { direction: Direction, - // TODO(rc): enable `NULLS FIRST | LAST` - // nulls_are: NullsAre, + nulls_are: NullsAre, } impl OrderType { pub fn from_protobuf(order_type: &PbOrderType) -> OrderType { OrderType { direction: Direction::from_protobuf(&order_type.direction()), + nulls_are: NullsAre::from_protobuf(&order_type.nulls_are()), } } pub fn to_protobuf(self) -> PbOrderType { PbOrderType { direction: self.direction.to_protobuf() as _, + nulls_are: self.nulls_are.to_protobuf() as _, } } } impl OrderType { - pub const fn new(direction: Direction) -> Self { - Self { direction } + fn new(direction: Direction, nulls_are: NullsAre) -> Self { + Self { + direction, + nulls_are, + } + } + + fn nulls_first(direction: Direction) -> Self { + match direction { + Direction::Ascending => Self::new(direction, NullsAre::Smallest), + Direction::Descending => Self::new(direction, NullsAre::Largest), + } + } + + fn nulls_last(direction: Direction) -> Self { + match direction { + Direction::Ascending => Self::new(direction, NullsAre::Largest), + Direction::Descending => Self::new(direction, NullsAre::Smallest), + } } - /// Create an ascending order type, with other options set to default. - pub const fn ascending() -> Self { + pub fn from_bools(asc: Option, nulls_first: Option) -> Self { + let direction = match asc { + None => Direction::default(), + Some(true) => Direction::Ascending, + Some(false) => Direction::Descending, + }; + match nulls_first { + None => Self::new(direction, NullsAre::default()), + Some(true) => Self::nulls_first(direction), + Some(false) => Self::nulls_last(direction), + } + } + + // TODO(rc): Many places that call `ascending` should've call `default`. + /// Create an `ASC` order type. + pub fn ascending() -> Self { Self { direction: Direction::Ascending, + nulls_are: NullsAre::default(), } } - /// Create an descending order type, with other options set to default. - pub const fn descending() -> Self { + /// Create a `DESC` order type. + pub fn descending() -> Self { Self { direction: Direction::Descending, + nulls_are: NullsAre::default(), } } - /// Get the order direction. - pub fn direction(&self) -> Direction { - self.direction + /// Create an `ASC NULLS FIRST` order type. + pub fn ascending_nulls_first() -> Self { + Self::nulls_first(Direction::Ascending) + } + + /// Create an `ASC NULLS LAST` order type. + pub fn ascending_nulls_last() -> Self { + Self::nulls_last(Direction::Ascending) + } + + /// Create a `DESC NULLS FIRST` order type. + pub fn descending_nulls_first() -> Self { + Self::nulls_first(Direction::Descending) + } + + /// Create a `DESC NULLS LAST` order type. + pub fn descending_nulls_last() -> Self { + Self::nulls_last(Direction::Descending) + } + + pub fn is_ascending(&self) -> bool { + self.direction == Direction::Ascending + } + + pub fn is_descending(&self) -> bool { + self.direction == Direction::Descending + } + + pub fn nulls_are_largest(&self) -> bool { + self.nulls_are == NullsAre::Largest + } + + pub fn nulls_are_smallest(&self) -> bool { + self.nulls_are == NullsAre::Smallest + } + + pub fn nulls_are_first(&self) -> bool { + self.is_ascending() && self.nulls_are_smallest() + || self.is_descending() && self.nulls_are_largest() + } + + pub fn nulls_are_last(&self) -> bool { + !self.nulls_are_first() } } impl fmt::Display for OrderType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.direction) + write!(f, "{}", self.direction)?; + if self.nulls_are != NullsAre::default() { + write!( + f, + " NULLS {}", + if self.nulls_are_first() { + "FIRST" + } else { + "LAST" + } + )?; + } + Ok(()) } } @@ -232,18 +336,19 @@ impl PartialEq for HeapElem { impl Eq for HeapElem {} -fn compare_values(lhs: Option<&T>, rhs: Option<&T>, order_type: &OrderType) -> Ordering +fn compare_values(lhs: Option<&T>, rhs: Option<&T>, order_type: OrderType) -> Ordering where T: Ord, { - let ord = match (lhs, rhs) { - (Some(l), Some(r)) => l.cmp(r), - (None, None) => Ordering::Equal, - // TODO(yuchao): `null first` / `null last` is not supported yet. - (Some(_), None) => Ordering::Less, - (None, Some(_)) => Ordering::Greater, + let ord = match (lhs, rhs, order_type.nulls_are) { + (Some(l), Some(r), _) => l.cmp(r), + (None, None, _) => Ordering::Equal, + (Some(_), None, NullsAre::Largest) => Ordering::Less, + (Some(_), None, NullsAre::Smallest) => Ordering::Greater, + (None, Some(_), NullsAre::Largest) => Ordering::Greater, + (None, Some(_), NullsAre::Smallest) => Ordering::Less, }; - if order_type.direction == Direction::Descending { + if order_type.is_descending() { ord.reverse() } else { ord @@ -255,7 +360,7 @@ fn compare_values_in_array<'a, T>( lhs_idx: usize, rhs_array: &'a T, rhs_idx: usize, - order_type: &'a OrderType, + order_type: OrderType, ) -> Ordering where T: Array, @@ -281,7 +386,7 @@ pub fn compare_rows_in_chunk( macro_rules! gen_match { ( $( { $variant_name:ident, $suffix_name:ident, $array:ty, $builder:ty } ),*) => { match (lhs_array.as_ref(), rhs_array.as_ref()) { - $((ArrayImpl::$variant_name(lhs_inner), ArrayImpl::$variant_name(rhs_inner)) => Ok(compare_values_in_array(lhs_inner, lhs_idx, rhs_inner, rhs_idx, &column_order.order_type)),)* + $((ArrayImpl::$variant_name(lhs_inner), ArrayImpl::$variant_name(rhs_inner)) => Ok(compare_values_in_array(lhs_inner, lhs_idx, rhs_inner, rhs_idx, column_order.order_type)),)* (l_arr, r_arr) => Err(InternalError(format!("Unmatched array types, lhs array is: {}, rhs array is: {}", l_arr.get_ident(), r_arr.get_ident()))), }? } @@ -294,17 +399,67 @@ pub fn compare_rows_in_chunk( Ok(Ordering::Equal) } +/// Compare two `Datum`s with specified order type. +pub fn compare_datum( + lhs: impl ToDatumRef, + rhs: impl ToDatumRef, + order_type: OrderType, +) -> Ordering { + compare_values( + lhs.to_datum_ref().as_ref(), + rhs.to_datum_ref().as_ref(), + order_type, + ) +} + #[cfg(test)] mod tests { use std::cmp::Ordering; use itertools::Itertools; - use super::{ColumnOrder, OrderType}; + use super::*; use crate::array::{DataChunk, ListValue, StructValue}; use crate::row::{OwnedRow, Row}; - use crate::types::{DataType, ScalarImpl}; - use crate::util::sort_util::compare_rows_in_chunk; + use crate::types::{DataType, Datum, ScalarImpl}; + + #[test] + fn test_order_type() { + assert_eq!(OrderType::default(), OrderType::ascending()); + assert_eq!( + OrderType::default(), + OrderType::new(Direction::Ascending, NullsAre::Largest) + ); + assert_eq!( + OrderType::default(), + OrderType::from_bools(Some(true), Some(false)) + ); + assert_eq!(OrderType::default(), OrderType::from_bools(None, None)); + + assert!(OrderType::ascending().is_ascending()); + assert!(OrderType::ascending().nulls_are_largest()); + assert!(OrderType::ascending().nulls_are_last()); + + assert!(OrderType::descending().is_descending()); + assert!(OrderType::descending().nulls_are_largest()); + assert!(OrderType::descending().nulls_are_first()); + + assert!(OrderType::ascending_nulls_first().is_ascending()); + assert!(OrderType::ascending_nulls_first().nulls_are_smallest()); + assert!(OrderType::ascending_nulls_first().nulls_are_first()); + + assert!(OrderType::ascending_nulls_last().is_ascending()); + assert!(OrderType::ascending_nulls_last().nulls_are_largest()); + assert!(OrderType::ascending_nulls_last().nulls_are_last()); + + assert!(OrderType::descending_nulls_first().is_descending()); + assert!(OrderType::descending_nulls_first().nulls_are_largest()); + assert!(OrderType::descending_nulls_first().nulls_are_first()); + + assert!(OrderType::descending_nulls_last().is_descending()); + assert!(OrderType::descending_nulls_last().nulls_are_smallest()); + assert!(OrderType::descending_nulls_last().nulls_are_last()); + } #[test] fn test_compare_rows_in_chunk() { @@ -417,4 +572,60 @@ mod tests { compare_rows_in_chunk(&chunk, 0, &chunk, 1, &column_orders).unwrap() ); } + + #[test] + fn test_compare_datum() { + assert_eq!( + Ordering::Equal, + compare_datum( + Some(ScalarImpl::from(42)), + Some(ScalarImpl::from(42)), + OrderType::default(), + ) + ); + assert_eq!( + Ordering::Equal, + compare_datum(None as Datum, None as Datum, OrderType::default(),) + ); + assert_eq!( + Ordering::Less, + compare_datum( + Some(ScalarImpl::from(42)), + Some(ScalarImpl::from(100)), + OrderType::ascending(), + ) + ); + assert_eq!( + Ordering::Greater, + compare_datum( + Some(ScalarImpl::from(42)), + None as Datum, + OrderType::ascending_nulls_first(), + ) + ); + assert_eq!( + Ordering::Less, + compare_datum( + Some(ScalarImpl::from(42)), + None as Datum, + OrderType::ascending_nulls_last(), + ) + ); + assert_eq!( + Ordering::Greater, + compare_datum( + Some(ScalarImpl::from(42)), + None as Datum, + OrderType::descending_nulls_first(), + ) + ); + assert_eq!( + Ordering::Less, + compare_datum( + Some(ScalarImpl::from(42)), + None as Datum, + OrderType::descending_nulls_last(), + ) + ); + } } diff --git a/src/expr/src/vector_op/agg/aggregator.rs b/src/expr/src/vector_op/agg/aggregator.rs index aff68f1e930c..75b5eee93e55 100644 --- a/src/expr/src/vector_op/agg/aggregator.rs +++ b/src/expr/src/vector_op/agg/aggregator.rs @@ -77,8 +77,6 @@ impl AggStateFactory { .map(|col_order| { let col_idx = col_order.get_column_index() as usize; let order_type = OrderType::from_protobuf(col_order.get_order_type().unwrap()); - // TODO(yuchao): `nulls first/last` is not supported yet, so it's ignore here, - // see also `risingwave_common::util::sort_util::compare_values` ColumnOrder::new(col_idx, order_type) }) .collect(); diff --git a/src/expr/src/vector_op/agg/array_agg.rs b/src/expr/src/vector_op/agg/array_agg.rs index f810149b74ab..b3d475b0602c 100644 --- a/src/expr/src/vector_op/agg/array_agg.rs +++ b/src/expr/src/vector_op/agg/array_agg.rs @@ -14,10 +14,10 @@ use risingwave_common::array::{ArrayBuilder, ArrayBuilderImpl, DataChunk, ListValue, RowRef}; use risingwave_common::bail; -use risingwave_common::row::{Row, RowExt}; +use risingwave_common::row::Row; use risingwave_common::types::{DataType, Datum, Scalar, ToOwnedDatum}; -use risingwave_common::util::ordered::OrderedRow; -use risingwave_common::util::sort_util::{ColumnOrder, OrderType}; +use risingwave_common::util::memcmp_encoding; +use risingwave_common::util::sort_util::ColumnOrder; use crate::vector_op::agg::aggregator::Aggregator; use crate::Result; @@ -91,43 +91,36 @@ impl Aggregator for ArrayAggUnordered { } } +type OrderKey = Vec; + #[derive(Clone)] struct ArrayAggOrdered { return_type: DataType, agg_col_idx: usize, - order_col_indices: Vec, - order_types: Vec, - unordered_values: Vec<(OrderedRow, Datum)>, + column_orders: Vec, + unordered_values: Vec<(OrderKey, Datum)>, } impl ArrayAggOrdered { fn new(return_type: DataType, agg_col_idx: usize, column_orders: Vec) -> Self { - debug_assert!(matches!(return_type, DataType::List { datatype: _ })); - let (order_col_indices, order_types) = column_orders - .into_iter() - .map(|p| (p.column_index, p.order_type)) - .unzip(); + assert!(matches!(return_type, DataType::List { datatype: _ })); ArrayAggOrdered { return_type, agg_col_idx, - order_col_indices, - order_types, + column_orders, unordered_values: vec![], } } fn push_row(&mut self, row: RowRef<'_>) { - let key = OrderedRow::new( - row.project(&self.order_col_indices).into_owned_row(), - &self.order_types, - ); + let key = memcmp_encoding::encode_row(row, &self.column_orders); let datum = row.datum_at(self.agg_col_idx).to_owned_datum(); self.unordered_values.push((key, datum)); } fn get_result_and_reset(&mut self) -> ListValue { let mut rows = std::mem::take(&mut self.unordered_values); - rows.sort_unstable_by(|a, b| a.0.cmp(&b.0)); + rows.sort_unstable_by(|(key_a, _), (key_b, _)| key_a.cmp(key_b)); ListValue::new(rows.into_iter().map(|(_, datum)| datum).collect()) } } @@ -190,6 +183,7 @@ mod tests { use risingwave_common::array::Array; use risingwave_common::test_prelude::DataChunkTestExt; use risingwave_common::types::ScalarRef; + use risingwave_common::util::sort_util::OrderType; use super::*; diff --git a/src/expr/src/vector_op/agg/string_agg.rs b/src/expr/src/vector_op/agg/string_agg.rs index bf829f34dfea..32ef9330d1d7 100644 --- a/src/expr/src/vector_op/agg/string_agg.rs +++ b/src/expr/src/vector_op/agg/string_agg.rs @@ -16,11 +16,10 @@ use risingwave_common::array::{ Array, ArrayBuilder, ArrayBuilderImpl, ArrayImpl, DataChunk, RowRef, }; use risingwave_common::bail; -use risingwave_common::row::{Row, RowExt}; use risingwave_common::types::DataType; use risingwave_common::util::iter_util::ZipEqFast; -use risingwave_common::util::ordered::OrderedRow; -use risingwave_common::util::sort_util::{ColumnOrder, OrderType}; +use risingwave_common::util::memcmp_encoding; +use risingwave_common::util::sort_util::ColumnOrder; use crate::vector_op::agg::aggregator::Aggregator; use crate::Result; @@ -113,6 +112,8 @@ impl Aggregator for StringAggUnordered { } } +type OrderKey = Vec; + #[derive(Clone)] struct StringAggData { value: String, @@ -123,31 +124,22 @@ struct StringAggData { struct StringAggOrdered { agg_col_idx: usize, delim_col_idx: usize, - order_col_indices: Vec, - order_types: Vec, - unordered_values: Vec<(OrderedRow, StringAggData)>, + column_orders: Vec, + unordered_values: Vec<(OrderKey, StringAggData)>, } impl StringAggOrdered { fn new(agg_col_idx: usize, delim_col_idx: usize, column_orders: Vec) -> Self { - let (order_col_indices, order_types) = column_orders - .into_iter() - .map(|p| (p.column_index, p.order_type)) - .unzip(); Self { agg_col_idx, delim_col_idx, - order_col_indices, - order_types, + column_orders, unordered_values: vec![], } } fn push_row(&mut self, value: &str, delim: &str, row: RowRef<'_>) { - let key = OrderedRow::new( - row.project(&self.order_col_indices).into_owned_row(), - &self.order_types, - ); + let key = memcmp_encoding::encode_row(row, &self.column_orders); self.unordered_values.push(( key, StringAggData { @@ -162,7 +154,7 @@ impl StringAggOrdered { if rows.is_empty() { return None; } - rows.sort_unstable_by(|a, b| a.0.cmp(&b.0)); + rows.sort_unstable_by(|(key_a, _), (key_b, _)| key_a.cmp(key_b)); let mut rows_iter = rows.into_iter(); let mut result = rows_iter.next().unwrap().1.value; for (_, data) in rows_iter { diff --git a/src/frontend/src/binder/expr/order_by.rs b/src/frontend/src/binder/expr/order_by.rs index cdf062cae76c..c6d494ea0b15 100644 --- a/src/frontend/src/binder/expr/order_by.rs +++ b/src/frontend/src/binder/expr/order_by.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use risingwave_common::error::{ErrorCode, Result}; +use risingwave_common::error::Result; use risingwave_common::util::sort_util::OrderType; use risingwave_sqlparser::ast::OrderByExpr; @@ -34,19 +34,7 @@ impl Binder { nulls_first, }: OrderByExpr, ) -> Result { - // TODO(rc): support `NULLS FIRST | LAST` - if nulls_first.is_some() { - return Err(ErrorCode::NotImplemented( - "NULLS FIRST or NULLS LAST".to_string(), - 4743.into(), - ) - .into()); - } - let order_type = match asc { - None => OrderType::default(), - Some(true) => OrderType::ascending(), - Some(false) => OrderType::descending(), - }; + let order_type = OrderType::from_bools(asc, nulls_first); let expr = self.bind_expr(expr)?; Ok(BoundOrderByExpr { expr, order_type }) } diff --git a/src/frontend/src/binder/mod.rs b/src/frontend/src/binder/mod.rs index 962f37cce54d..c758440bcc5b 100644 --- a/src/frontend/src/binder/mod.rs +++ b/src/frontend/src/binder/mod.rs @@ -109,7 +109,7 @@ pub struct Binder { /// 4. After bind finished: /// (a) parameter not in `ParameterTypes` means that the user didn't specify it and it didn't /// occur in the query. `export` will return error if there is a kind of -/// parameter. This rule is compatible with PostgreSQL +/// parameter. This rule is compatible with PostgreSQL /// (b) parameter is None means that it's a unknown type. The user didn't specify it /// and we can't infer it in the query. We will treat it as VARCHAR type finally. This rule is /// compatible with PostgreSQL. diff --git a/src/frontend/src/binder/query.rs b/src/frontend/src/binder/query.rs index 74b5c591fcf3..3cac14efd1cd 100644 --- a/src/frontend/src/binder/query.rs +++ b/src/frontend/src/binder/query.rs @@ -221,19 +221,7 @@ impl Binder { extra_order_exprs: &mut Vec, visible_output_num: usize, ) -> Result { - // TODO(rc): support `NULLS FIRST | LAST` - if nulls_first.is_some() { - return Err(ErrorCode::NotImplemented( - "NULLS FIRST or NULLS LAST".to_string(), - 4743.into(), - ) - .into()); - } - let order_type = match asc { - None => OrderType::default(), - Some(true) => OrderType::ascending(), - Some(false) => OrderType::descending(), - }; + let order_type = OrderType::from_bools(asc, nulls_first); let column_index = match expr { Expr::Identifier(name) if let Some(index) = name_to_index.get(&name.real_value()) => match *index != usize::MAX { true => *index, diff --git a/src/frontend/src/handler/create_index.rs b/src/frontend/src/handler/create_index.rs index df9af08fb413..0025eb234d3d 100644 --- a/src/frontend/src/handler/create_index.rs +++ b/src/frontend/src/handler/create_index.rs @@ -337,28 +337,12 @@ fn check_columns(columns: Vec) -> Result> { columns .into_iter() .map(|column| { - // TODO(rc): support `NULLS FIRST | LAST` - if column.nulls_first.is_some() { - return Err(ErrorCode::NotImplemented( - "nulls_first not supported".into(), - None.into(), - ) - .into()); - } + let order_type = OrderType::from_bools(column.asc, column.nulls_first); use risingwave_sqlparser::ast::Expr; if let Expr::Identifier(ident) = column.expr { - Ok::<(_, _), RwError>(( - ident, - column.asc.map_or(OrderType::ascending(), |x| { - if x { - OrderType::ascending() - } else { - OrderType::descending() - } - }), - )) + Ok::<(_, _), RwError>((ident, order_type)) } else { Err(ErrorCode::NotImplemented( "only identifier is supported for create index".into(), diff --git a/src/frontend/src/handler/describe.rs b/src/frontend/src/handler/describe.rs index b14159e3a816..c7195cb0f391 100644 --- a/src/frontend/src/handler/describe.rs +++ b/src/frontend/src/handler/describe.rs @@ -22,7 +22,6 @@ use pgwire::types::Row; use risingwave_common::catalog::ColumnDesc; use risingwave_common::error::Result; use risingwave_common::types::DataType; -use risingwave_common::util::sort_util::Direction; use risingwave_sqlparser::ast::{display_comma_separated, ObjectName}; use super::RwPgResponse; @@ -120,11 +119,7 @@ pub fn handle_describe(handler_args: HandlerArgs, table_name: ObjectName) -> Res .filter(|x| !index_table.columns[x.column_index].is_hidden) .map(|x| { let index_column_name = index_table.columns[x.column_index].name().to_string(); - if Direction::Descending == x.order_type.direction() { - index_column_name + " DESC" - } else { - index_column_name - } + format!("{} {}", index_column_name, x.order_type) }) .collect_vec(); @@ -241,7 +236,7 @@ mod tests { "v3".into() => "Int32".into(), "v4".into() => "Int32".into(), "primary key".into() => "v3".into(), - "idx1".into() => "index(v1 DESC, v2, v3) include(v4) distributed by(v1, v2)".into(), + "idx1".into() => "index(v1 DESC, v2 ASC, v3 ASC) include(v4) distributed by(v1, v2)".into(), }; assert_eq!(columns, expected_columns); diff --git a/src/frontend/src/optimizer/plan_node/logical_agg.rs b/src/frontend/src/optimizer/plan_node/logical_agg.rs index 03ac112e9178..be36bd7b288d 100644 --- a/src/frontend/src/optimizer/plan_node/logical_agg.rs +++ b/src/frontend/src/optimizer/plan_node/logical_agg.rs @@ -333,17 +333,12 @@ impl LogicalAgg { .group_key() .iter() .map(|group_by_idx| { - let order_type = if required_order + required_order .column_orders - .contains(&ColumnOrder::new(*group_by_idx, OrderType::descending())) - { - // If output requires descending order, use descending order - OrderType::descending() - } else { - // In all other cases use ascending order - OrderType::ascending() - }; - ColumnOrder::new(*group_by_idx, order_type) + .iter() + .find(|o| o.column_index == *group_by_idx) + .cloned() + .unwrap_or_else(|| ColumnOrder::new(*group_by_idx, OrderType::ascending())) }) .collect(), }; diff --git a/src/meta/src/stream/test_fragmenter.rs b/src/meta/src/stream/test_fragmenter.rs index 91ea7647fcd7..616f6c38d7a4 100644 --- a/src/meta/src/stream/test_fragmenter.rs +++ b/src/meta/src/stream/test_fragmenter.rs @@ -19,7 +19,9 @@ use std::vec; use itertools::Itertools; use risingwave_common::catalog::{DatabaseId, SchemaId, TableId}; use risingwave_pb::catalog::PbTable; -use risingwave_pb::common::{ParallelUnit, PbColumnOrder, PbDirection, PbOrderType, WorkerNode}; +use risingwave_pb::common::{ + ParallelUnit, PbColumnOrder, PbDirection, PbNullsAre, PbOrderType, WorkerNode, +}; use risingwave_pb::data::data_type::TypeName; use risingwave_pb::data::DataType; use risingwave_pb::expr::agg_call::Type; @@ -96,6 +98,7 @@ fn make_column_order(column_index: u32) -> PbColumnOrder { column_index, order_type: Some(PbOrderType { direction: PbDirection::Ascending as _, + nulls_are: PbNullsAre::Largest as _, }), } } @@ -129,6 +132,7 @@ fn make_source_internal_table(id: u32) -> PbTable { column_index: 0, order_type: Some(PbOrderType { direction: PbDirection::Descending as _, + nulls_are: PbNullsAre::Largest as _, }), }], ..Default::default() @@ -150,6 +154,7 @@ fn make_internal_table(id: u32, is_agg_value: bool) -> PbTable { column_index: 0, order_type: Some(PbOrderType { direction: PbDirection::Descending as _, + nulls_are: PbNullsAre::Largest as _, }), }], stream_key: vec![2], diff --git a/src/storage/hummock_sdk/src/filter_key_extractor.rs b/src/storage/hummock_sdk/src/filter_key_extractor.rs index a576de5b647e..dd9537f40b8e 100644 --- a/src/storage/hummock_sdk/src/filter_key_extractor.rs +++ b/src/storage/hummock_sdk/src/filter_key_extractor.rs @@ -348,7 +348,7 @@ mod tests { use risingwave_common::util::sort_util::OrderType; use risingwave_pb::catalog::table::TableType; use risingwave_pb::catalog::PbTable; - use risingwave_pb::common::{PbColumnOrder, PbDirection, PbOrderType}; + use risingwave_pb::common::{PbColumnOrder, PbDirection, PbNullsAre, PbOrderType}; use risingwave_pb::plan_common::PbColumnCatalog; use tokio::task; @@ -443,12 +443,14 @@ mod tests { column_index: 1, order_type: Some(PbOrderType { direction: PbDirection::Ascending as _, + nulls_are: PbNullsAre::Largest as _, }), }, PbColumnOrder { column_index: 3, order_type: Some(PbOrderType { direction: PbDirection::Ascending as _, + nulls_are: PbNullsAre::Largest as _, }), }, ], diff --git a/src/stream/src/executor/backfill.rs b/src/stream/src/executor/backfill.rs index 1bbf6a5d2bb3..a8677a85719d 100644 --- a/src/stream/src/executor/backfill.rs +++ b/src/stream/src/executor/backfill.rs @@ -25,7 +25,7 @@ use risingwave_common::buffer::BitmapBuilder; use risingwave_common::catalog::Schema; use risingwave_common::row::{self, OwnedRow, Row, RowExt}; use risingwave_common::util::iter_util::ZipEqFast; -use risingwave_common::util::sort_util::{Direction, OrderType}; +use risingwave_common::util::sort_util::{compare_datum, OrderType}; use risingwave_hummock_sdk::HummockReadEpoch; use risingwave_storage::store::PrefetchOptions; use risingwave_storage::table::batch_table::storage_table::StorageTable; @@ -373,12 +373,9 @@ where match row .project(pk_in_output_indices) .iter() - .zip_eq_fast(pk_order.iter()) + .zip_eq_fast(pk_order.iter().copied()) .cmp_by(current_pos.iter(), |(x, order), y| { - match order.direction() { - Direction::Ascending => x.cmp(&y), - Direction::Descending => y.cmp(&x), - } + compare_datum(x, y, order) }) { Ordering::Less | Ordering::Equal => true, Ordering::Greater => false, diff --git a/src/stream/src/executor/source/state_table_handler.rs b/src/stream/src/executor/source/state_table_handler.rs index 0b7a15bcb632..f578db63221e 100644 --- a/src/stream/src/executor/source/state_table_handler.rs +++ b/src/stream/src/executor/source/state_table_handler.rs @@ -28,7 +28,7 @@ use risingwave_connector::source::{SplitId, SplitImpl, SplitMetaData}; use risingwave_hummock_sdk::key::next_key; use risingwave_pb::catalog::table::TableType; use risingwave_pb::catalog::PbTable; -use risingwave_pb::common::{PbColumnOrder, PbDirection, PbOrderType}; +use risingwave_pb::common::{PbColumnOrder, PbDirection, PbNullsAre, PbOrderType}; use risingwave_pb::data::data_type::TypeName; use risingwave_pb::data::DataType; use risingwave_pb::plan_common::{ColumnCatalog, ColumnDesc}; @@ -230,6 +230,7 @@ pub fn default_source_internal_table(id: u32) -> PbTable { column_index: 0, order_type: Some(PbOrderType { direction: PbDirection::Ascending as _, + nulls_are: PbNullsAre::Largest as _, }), }], ..Default::default() diff --git a/src/stream/src/from_proto/agg_common.rs b/src/stream/src/from_proto/agg_common.rs index 29f4001837b6..9a527df24513 100644 --- a/src/stream/src/from_proto/agg_common.rs +++ b/src/stream/src/from_proto/agg_common.rs @@ -52,8 +52,6 @@ pub fn build_agg_call_from_prost( .map(|col_order| { let col_idx = col_order.get_column_index() as usize; let order_type = OrderType::from_protobuf(col_order.get_order_type().unwrap()); - // TODO(yuchao): `nulls first/last` is not supported yet, so it's ignore here, - // see also `risingwave_common::util::sort_util::compare_values` ColumnOrder::new(col_idx, order_type) }) .collect();