From 65550e6939b895aed7cf7afced83548bf7fcc297 Mon Sep 17 00:00:00 2001 From: Richard Chien Date: Mon, 23 Sep 2024 17:49:13 +0800 Subject: [PATCH] introduce proto `AggType` message Signed-off-by: Richard Chien --- proto/expr.proto | 17 +++++++- src/expr/core/src/aggregate/def.rs | 39 +++++++++++++++++-- src/expr/core/src/window_function/kind.rs | 7 +++- src/frontend/src/binder/expr/function/mod.rs | 4 +- .../plan_node/generic/over_window.rs | 2 +- 5 files changed, 59 insertions(+), 10 deletions(-) diff --git a/proto/expr.proto b/proto/expr.proto index c5c66733d0d9..4cbe7d2abb1c 100644 --- a/proto/expr.proto +++ b/proto/expr.proto @@ -471,6 +471,20 @@ message AggCall { ExprNode scalar = 9; } +// The aggregation type. +// +// Ideally this should be used to encode the Rust `AggCall::agg_type` field, but historically we +// flattened it into multiple fields in proto `AggCall` - `kind` + `udf` + `scalar`. So this +// `AggType` proto type is only used by `WindowFunction` currently. +message AggType { + AggCall.Kind kind = 1; + + // UDF metadata. Only present when the kind is `USER_DEFINED`. + optional UserDefinedFunctionMetadata udf_meta = 8; + // Wrapped scalar expression. Only present when the kind is `WRAP_SCALAR`. + optional ExprNode scalar_expr = 9; +} + message WindowFrame { enum Type { TYPE_UNSPECIFIED = 0; @@ -562,7 +576,8 @@ message WindowFunction { oneof type { GeneralType general = 1; - AggCall.Kind aggregate = 2; + AggCall.Kind aggregate_simple = 2 [deprecated = true]; // Deprecated since we have a new `aggregate` variant. + AggType aggregate = 103; } repeated InputRef args = 3; data.DataType return_type = 4; diff --git a/src/expr/core/src/aggregate/def.rs b/src/expr/core/src/aggregate/def.rs index 3abe80dcd4d3..8554affb23af 100644 --- a/src/expr/core/src/aggregate/def.rs +++ b/src/expr/core/src/aggregate/def.rs @@ -27,7 +27,9 @@ use risingwave_common::types::{DataType, Datum}; use risingwave_common::util::sort_util::{ColumnOrder, OrderType}; use risingwave_common::util::value_encoding::DatumFromProtoExt; pub use risingwave_pb::expr::agg_call::PbKind as PbAggKind; -use risingwave_pb::expr::{PbAggCall, PbExprNode, PbInputRef, PbUserDefinedFunctionMetadata}; +use risingwave_pb::expr::{ + PbAggCall, PbAggType, PbExprNode, PbInputRef, PbUserDefinedFunctionMetadata, +}; use crate::expr::{ build_from_prost, BoxedExpression, ExpectExt, Expression, LiteralExpression, Token, @@ -65,7 +67,7 @@ pub struct AggCall { impl AggCall { pub fn from_protobuf(agg_call: &PbAggCall) -> Result { - let agg_type = AggType::from_protobuf( + let agg_type = AggType::from_protobuf_flatten( agg_call.get_kind()?, agg_call.udf.as_ref(), agg_call.scalar.as_ref(), @@ -160,7 +162,7 @@ impl> Parser { self.tokens.next(); // Consume the RParen AggCall { - agg_type: AggType::from_protobuf(func, None, None).unwrap(), + agg_type: AggType::from_protobuf_flatten(func, None, None).unwrap(), args: AggArgs { data_types: children.iter().map(|(_, ty)| ty.clone()).collect(), val_indices: children.iter().map(|(idx, _)| *idx).collect(), @@ -260,7 +262,7 @@ impl From for AggType { } impl AggType { - pub fn from_protobuf( + pub fn from_protobuf_flatten( pb_kind: PbAggKind, user_defined: Option<&PbUserDefinedFunctionMetadata>, scalar: Option<&PbExprNode>, @@ -286,6 +288,35 @@ impl AggType { Self::WrapScalar(_) => PbAggKind::WrapScalar, } } + + pub fn from_protobuf(pb_type: &PbAggType) -> Result { + match pb_type.kind() { + PbAggKind::Unspecified => bail!("Unrecognized agg."), + PbAggKind::UserDefined => Ok(AggType::UserDefined(pb_type.get_udf_meta()?.clone())), + PbAggKind::WrapScalar => Ok(AggType::WrapScalar(pb_type.get_scalar_expr()?.clone())), + kind => Ok(AggType::Builtin(kind)), + } + } + + pub fn to_protobuf(&self) -> PbAggType { + match self { + Self::Builtin(kind) => PbAggType { + kind: *kind as _, + udf_meta: None, + scalar_expr: None, + }, + Self::UserDefined(udf_meta) => PbAggType { + kind: PbAggKind::UserDefined as _, + udf_meta: Some(udf_meta.clone()), + scalar_expr: None, + }, + Self::WrapScalar(scalar_expr) => PbAggType { + kind: PbAggKind::WrapScalar as _, + udf_meta: None, + scalar_expr: Some(scalar_expr.clone()), + }, + } + } } /// Macros to generate match arms for `AggType`. diff --git a/src/expr/core/src/window_function/kind.rs b/src/expr/core/src/window_function/kind.rs index 32c5f746020d..4d2f1ef00b8f 100644 --- a/src/expr/core/src/window_function/kind.rs +++ b/src/expr/core/src/window_function/kind.rs @@ -51,11 +51,14 @@ impl WindowFuncKind { Ok(PbGeneralType::Lead) => Self::Lead, Err(_) => bail!("no such window function type"), }, - PbType::Aggregate(agg_type) => match PbAggKind::try_from(*agg_type) { + PbType::AggregateSimple(agg_type) => match PbAggKind::try_from(*agg_type) { // TODO(runji): support UDAF and wrapped scalar functions - Ok(agg_type) => Self::Aggregate(AggType::from_protobuf(agg_type, None, None)?), + Ok(agg_type) => { + Self::Aggregate(AggType::from_protobuf_flatten(agg_type, None, None)?) + } Err(_) => bail!("no such aggregate function type"), }, + PbType::Aggregate(agg_type) => Self::Aggregate(AggType::from_protobuf(agg_type)?), }; Ok(kind) } diff --git a/src/frontend/src/binder/expr/function/mod.rs b/src/frontend/src/binder/expr/function/mod.rs index 5d3dfb79300d..00f2438cb35a 100644 --- a/src/frontend/src/binder/expr/function/mod.rs +++ b/src/frontend/src/binder/expr/function/mod.rs @@ -227,8 +227,8 @@ impl Binder { None }; - let agg_type = if let Some(wrapped_agg_type) = wrapped_agg_type { - Some(wrapped_agg_type) + let agg_type = if wrapped_agg_type.is_some() { + wrapped_agg_type } else if let Some(ref udf) = udf && udf.kind.is_aggregate() { diff --git a/src/frontend/src/optimizer/plan_node/generic/over_window.rs b/src/frontend/src/optimizer/plan_node/generic/over_window.rs index 5622d1e8952c..100586506d9b 100644 --- a/src/frontend/src/optimizer/plan_node/generic/over_window.rs +++ b/src/frontend/src/optimizer/plan_node/generic/over_window.rs @@ -121,7 +121,7 @@ impl PlanWindowFunction { DenseRank => PbType::General(PbGeneralType::DenseRank as _), Lag => PbType::General(PbGeneralType::Lag as _), Lead => PbType::General(PbGeneralType::Lead as _), - Aggregate(agg_type) => PbType::Aggregate(agg_type.to_protobuf_simple() as _), + Aggregate(agg_type) => PbType::Aggregate(agg_type.to_protobuf()), }; PbWindowFunction {