diff --git a/dashboard/proto/gen/catalog.ts b/dashboard/proto/gen/catalog.ts index 0590f0a959ce..1f6feadf4579 100644 --- a/dashboard/proto/gen/catalog.ts +++ b/dashboard/proto/gen/catalog.ts @@ -204,10 +204,24 @@ export interface Function { name: string; owner: number; argTypes: DataType[]; - returnType: DataType | undefined; language: string; link: string; identifier: string; + kind?: { $case: "scalar"; scalar: Function_ScalarFunction } | { $case: "table"; table: Function_TableFunction } | { + $case: "aggregate"; + aggregate: Function_AggregateFunction; + }; +} + +export interface Function_ScalarFunction { + returnType: DataType | undefined; +} + +export interface Function_TableFunction { + returnTypes: DataType[]; +} + +export interface Function_AggregateFunction { } /** See `TableCatalog` struct in frontend crate for more information. */ @@ -810,10 +824,10 @@ function createBaseFunction(): Function { name: "", owner: 0, argTypes: [], - returnType: undefined, language: "", link: "", identifier: "", + kind: undefined, }; } @@ -828,10 +842,16 @@ export const Function = { argTypes: Array.isArray(object?.argTypes) ? object.argTypes.map((e: any) => DataType.fromJSON(e)) : [], - returnType: isSet(object.returnType) ? DataType.fromJSON(object.returnType) : undefined, language: isSet(object.language) ? String(object.language) : "", link: isSet(object.link) ? String(object.link) : "", identifier: isSet(object.identifier) ? String(object.identifier) : "", + kind: isSet(object.scalar) + ? { $case: "scalar", scalar: Function_ScalarFunction.fromJSON(object.scalar) } + : isSet(object.table) + ? { $case: "table", table: Function_TableFunction.fromJSON(object.table) } + : isSet(object.aggregate) + ? { $case: "aggregate", aggregate: Function_AggregateFunction.fromJSON(object.aggregate) } + : undefined, }; }, @@ -847,11 +867,17 @@ export const Function = { } else { obj.argTypes = []; } - message.returnType !== undefined && - (obj.returnType = message.returnType ? DataType.toJSON(message.returnType) : undefined); message.language !== undefined && (obj.language = message.language); message.link !== undefined && (obj.link = message.link); message.identifier !== undefined && (obj.identifier = message.identifier); + message.kind?.$case === "scalar" && + (obj.scalar = message.kind?.scalar ? Function_ScalarFunction.toJSON(message.kind?.scalar) : undefined); + message.kind?.$case === "table" && + (obj.table = message.kind?.table ? Function_TableFunction.toJSON(message.kind?.table) : undefined); + message.kind?.$case === "aggregate" && + (obj.aggregate = message.kind?.aggregate + ? Function_AggregateFunction.toJSON(message.kind?.aggregate) + : undefined); return obj; }, @@ -863,12 +889,91 @@ export const Function = { message.name = object.name ?? ""; message.owner = object.owner ?? 0; message.argTypes = object.argTypes?.map((e) => DataType.fromPartial(e)) || []; - message.returnType = (object.returnType !== undefined && object.returnType !== null) - ? DataType.fromPartial(object.returnType) - : undefined; message.language = object.language ?? ""; message.link = object.link ?? ""; message.identifier = object.identifier ?? ""; + if (object.kind?.$case === "scalar" && object.kind?.scalar !== undefined && object.kind?.scalar !== null) { + message.kind = { $case: "scalar", scalar: Function_ScalarFunction.fromPartial(object.kind.scalar) }; + } + if (object.kind?.$case === "table" && object.kind?.table !== undefined && object.kind?.table !== null) { + message.kind = { $case: "table", table: Function_TableFunction.fromPartial(object.kind.table) }; + } + if (object.kind?.$case === "aggregate" && object.kind?.aggregate !== undefined && object.kind?.aggregate !== null) { + message.kind = { $case: "aggregate", aggregate: Function_AggregateFunction.fromPartial(object.kind.aggregate) }; + } + return message; + }, +}; + +function createBaseFunction_ScalarFunction(): Function_ScalarFunction { + return { returnType: undefined }; +} + +export const Function_ScalarFunction = { + fromJSON(object: any): Function_ScalarFunction { + return { returnType: isSet(object.returnType) ? DataType.fromJSON(object.returnType) : undefined }; + }, + + toJSON(message: Function_ScalarFunction): unknown { + const obj: any = {}; + message.returnType !== undefined && + (obj.returnType = message.returnType ? DataType.toJSON(message.returnType) : undefined); + return obj; + }, + + fromPartial, I>>(object: I): Function_ScalarFunction { + const message = createBaseFunction_ScalarFunction(); + message.returnType = (object.returnType !== undefined && object.returnType !== null) + ? DataType.fromPartial(object.returnType) + : undefined; + return message; + }, +}; + +function createBaseFunction_TableFunction(): Function_TableFunction { + return { returnTypes: [] }; +} + +export const Function_TableFunction = { + fromJSON(object: any): Function_TableFunction { + return { + returnTypes: Array.isArray(object?.returnTypes) ? object.returnTypes.map((e: any) => DataType.fromJSON(e)) : [], + }; + }, + + toJSON(message: Function_TableFunction): unknown { + const obj: any = {}; + if (message.returnTypes) { + obj.returnTypes = message.returnTypes.map((e) => e ? DataType.toJSON(e) : undefined); + } else { + obj.returnTypes = []; + } + return obj; + }, + + fromPartial, I>>(object: I): Function_TableFunction { + const message = createBaseFunction_TableFunction(); + message.returnTypes = object.returnTypes?.map((e) => DataType.fromPartial(e)) || []; + return message; + }, +}; + +function createBaseFunction_AggregateFunction(): Function_AggregateFunction { + return {}; +} + +export const Function_AggregateFunction = { + fromJSON(_: any): Function_AggregateFunction { + return {}; + }, + + toJSON(_: Function_AggregateFunction): unknown { + const obj: any = {}; + return obj; + }, + + fromPartial, I>>(_: I): Function_AggregateFunction { + const message = createBaseFunction_AggregateFunction(); return message; }, }; diff --git a/dashboard/proto/gen/expr.ts b/dashboard/proto/gen/expr.ts index 0bebf5e1d928..1b610347b05f 100644 --- a/dashboard/proto/gen/expr.ts +++ b/dashboard/proto/gen/expr.ts @@ -644,7 +644,9 @@ export function exprNode_TypeToJSON(object: ExprNode_Type): string { export interface TableFunction { functionType: TableFunction_Type; args: ExprNode[]; - returnType: DataType | undefined; + returnTypes: DataType[]; + /** optional. only used when the type is UDTF. */ + udtf: UserDefinedTableFunction | undefined; } export const TableFunction_Type = { @@ -653,6 +655,8 @@ export const TableFunction_Type = { UNNEST: "UNNEST", REGEXP_MATCHES: "REGEXP_MATCHES", RANGE: "RANGE", + /** UDTF - User defined table function */ + UDTF: "UDTF", UNRECOGNIZED: "UNRECOGNIZED", } as const; @@ -675,6 +679,9 @@ export function tableFunction_TypeFromJSON(object: any): TableFunction_Type { case 4: case "RANGE": return TableFunction_Type.RANGE; + case 100: + case "UDTF": + return TableFunction_Type.UDTF; case -1: case "UNRECOGNIZED": default: @@ -694,6 +701,8 @@ export function tableFunction_TypeToJSON(object: TableFunction_Type): string { return "REGEXP_MATCHES"; case TableFunction_Type.RANGE: return "RANGE"; + case TableFunction_Type.UDTF: + return "UDTF"; case TableFunction_Type.UNRECOGNIZED: default: return "UNRECOGNIZED"; @@ -874,6 +883,13 @@ export interface UserDefinedFunction { identifier: string; } +export interface UserDefinedTableFunction { + argTypes: DataType[]; + language: string; + link: string; + identifier: string; +} + function createBaseExprNode(): ExprNode { return { exprType: ExprNode_Type.UNSPECIFIED, returnType: undefined, rexNode: undefined }; } @@ -945,7 +961,7 @@ export const ExprNode = { }; function createBaseTableFunction(): TableFunction { - return { functionType: TableFunction_Type.UNSPECIFIED, args: [], returnType: undefined }; + return { functionType: TableFunction_Type.UNSPECIFIED, args: [], returnTypes: [], udtf: undefined }; } export const TableFunction = { @@ -957,7 +973,8 @@ export const TableFunction = { args: Array.isArray(object?.args) ? object.args.map((e: any) => ExprNode.fromJSON(e)) : [], - returnType: isSet(object.returnType) ? DataType.fromJSON(object.returnType) : undefined, + returnTypes: Array.isArray(object?.returnTypes) ? object.returnTypes.map((e: any) => DataType.fromJSON(e)) : [], + udtf: isSet(object.udtf) ? UserDefinedTableFunction.fromJSON(object.udtf) : undefined, }; }, @@ -969,8 +986,12 @@ export const TableFunction = { } else { obj.args = []; } - message.returnType !== undefined && - (obj.returnType = message.returnType ? DataType.toJSON(message.returnType) : undefined); + if (message.returnTypes) { + obj.returnTypes = message.returnTypes.map((e) => e ? DataType.toJSON(e) : undefined); + } else { + obj.returnTypes = []; + } + message.udtf !== undefined && (obj.udtf = message.udtf ? UserDefinedTableFunction.toJSON(message.udtf) : undefined); return obj; }, @@ -978,8 +999,9 @@ export const TableFunction = { const message = createBaseTableFunction(); message.functionType = object.functionType ?? TableFunction_Type.UNSPECIFIED; message.args = object.args?.map((e) => ExprNode.fromPartial(e)) || []; - message.returnType = (object.returnType !== undefined && object.returnType !== null) - ? DataType.fromPartial(object.returnType) + message.returnTypes = object.returnTypes?.map((e) => DataType.fromPartial(e)) || []; + message.udtf = (object.udtf !== undefined && object.udtf !== null) + ? UserDefinedTableFunction.fromPartial(object.udtf) : undefined; return message; }, @@ -1190,6 +1212,43 @@ export const UserDefinedFunction = { }, }; +function createBaseUserDefinedTableFunction(): UserDefinedTableFunction { + return { argTypes: [], language: "", link: "", identifier: "" }; +} + +export const UserDefinedTableFunction = { + fromJSON(object: any): UserDefinedTableFunction { + return { + argTypes: Array.isArray(object?.argTypes) ? object.argTypes.map((e: any) => DataType.fromJSON(e)) : [], + language: isSet(object.language) ? String(object.language) : "", + link: isSet(object.link) ? String(object.link) : "", + identifier: isSet(object.identifier) ? String(object.identifier) : "", + }; + }, + + toJSON(message: UserDefinedTableFunction): unknown { + const obj: any = {}; + if (message.argTypes) { + obj.argTypes = message.argTypes.map((e) => e ? DataType.toJSON(e) : undefined); + } else { + obj.argTypes = []; + } + message.language !== undefined && (obj.language = message.language); + message.link !== undefined && (obj.link = message.link); + message.identifier !== undefined && (obj.identifier = message.identifier); + return obj; + }, + + fromPartial, I>>(object: I): UserDefinedTableFunction { + const message = createBaseUserDefinedTableFunction(); + message.argTypes = object.argTypes?.map((e) => DataType.fromPartial(e)) || []; + message.language = object.language ?? ""; + message.link = object.link ?? ""; + message.identifier = object.identifier ?? ""; + return message; + }, +}; + type Builtin = Date | Function | Uint8Array | string | number | boolean | undefined; export type DeepPartial = T extends Builtin ? T diff --git a/e2e_test/udf/python.slt b/e2e_test/udf/python.slt index 56683b6b6d9e..560116e2e05f 100644 --- a/e2e_test/udf/python.slt +++ b/e2e_test/udf/python.slt @@ -24,6 +24,13 @@ create function gcd(int, int, int) returns int language python as gcd3 using lin statement error exists create function gcd(int, int) returns int language python as gcd using link 'http://localhost:8815'; +# Create a table function. +statement ok +create function series(int) returns table (x int) language python as series using link 'http://localhost:8815'; + +statement ok +create function series2(int) returns table (x int, s varchar) language python as series2 using link 'http://localhost:8815'; + query I select int_42(); ---- @@ -39,6 +46,26 @@ select gcd(25, 15, 3); ---- 1 +query I +select series(5); +---- +0 +1 +2 +3 +4 + +# FIXME: support table function with multiple columns +# query IT +# select series2(5); +# ---- +# (0,0) +# (1,1) +# (2,2) +# (3,3) +# (4,4) + + # TODO: drop function without arguments # # Drop a function but ambiguous. diff --git a/e2e_test/udf/test.py b/e2e_test/udf/test.py index 4ef01fd9700d..eabc10ad9290 100644 --- a/e2e_test/udf/test.py +++ b/e2e_test/udf/test.py @@ -1,7 +1,8 @@ import sys +from typing import Iterator sys.path.append('src/udf/python') # noqa -from risingwave.udf import udf, UdfServer +from risingwave.udf import udf, udtf, UdfServer @udf(input_types=[], result_type='INT') @@ -21,9 +22,23 @@ def gcd3(x: int, y: int, z: int) -> int: return gcd(gcd(x, y), z) +@udtf(input_types='INT', result_types='INT') +def series(n: int) -> Iterator[int]: + for i in range(n): + yield i + + +@udtf(input_types=['INT'], result_types=['INT', 'VARCHAR']) +def series2(n: int) -> Iterator[tuple[int, str]]: + for i in range(n): + yield i, str(i) + + if __name__ == '__main__': server = UdfServer() server.add_function(int_42) server.add_function(gcd) server.add_function(gcd3) + server.add_function(series) + server.add_function(series2) server.serve() diff --git a/proto/catalog.proto b/proto/catalog.proto index fbd40076fd67..0da76e7a4a2e 100644 --- a/proto/catalog.proto +++ b/proto/catalog.proto @@ -109,6 +109,15 @@ message Function { string language = 7; string link = 8; string identifier = 10; + + oneof kind { + ScalarFunction scalar = 11; + TableFunction table = 12; + AggregateFunction aggregate = 13; + } + message ScalarFunction {} + message TableFunction {} + message AggregateFunction {} } // See `TableCatalog` struct in frontend crate for more information. diff --git a/proto/expr.proto b/proto/expr.proto index 4464023b8611..cc47f2242a14 100644 --- a/proto/expr.proto +++ b/proto/expr.proto @@ -159,10 +159,14 @@ message TableFunction { UNNEST = 2; REGEXP_MATCHES = 3; RANGE = 4; + // User defined table function + UDTF = 100; } Type function_type = 1; repeated expr.ExprNode args = 2; data.DataType return_type = 3; + // optional. only used when the type is UDTF. + UserDefinedTableFunction udtf = 4; } // Reference to an upstream column, containing its index and data type. @@ -243,3 +247,10 @@ message UserDefinedFunction { string link = 5; string identifier = 6; } + +message UserDefinedTableFunction { + repeated data.DataType arg_types = 3; + string language = 4; + string link = 5; + string identifier = 6; +} diff --git a/src/batch/src/task/task_manager.rs b/src/batch/src/task/task_manager.rs index 37ffd22ec0bb..b9c89c794bff 100644 --- a/src/batch/src/task/task_manager.rs +++ b/src/batch/src/task/task_manager.rs @@ -362,9 +362,8 @@ mod tests { make_i32_literal(i32::MAX), make_i32_literal(1), ], - // This is a bit hacky as we want to make sure the task lasts long enough - // for us to abort it. return_type: Some(DataType::Int32.to_protobuf()), + udtf: None, }), })), }), diff --git a/src/common/src/types/mod.rs b/src/common/src/types/mod.rs index 06aee7955bb3..0d8c904d45eb 100644 --- a/src/common/src/types/mod.rs +++ b/src/common/src/types/mod.rs @@ -355,6 +355,13 @@ impl DataType { ) } + pub fn as_struct(&self) -> &StructType { + match self { + DataType::Struct(t) => t, + _ => panic!("expect struct type"), + } + } + /// WARNING: Currently this should only be used in `WatermarkFilterExecutor`. Please be careful /// if you want to use this. pub fn min(&self) -> ScalarImpl { diff --git a/src/expr/src/table_function/mod.rs b/src/expr/src/table_function/mod.rs index 77905964e64b..169bcdd74b61 100644 --- a/src/expr/src/table_function/mod.rs +++ b/src/expr/src/table_function/mod.rs @@ -27,11 +27,14 @@ use super::Result; use crate::expr::{build_from_prost as expr_build_from_prost, BoxedExpression}; mod generate_series; -use generate_series::*; -mod unnest; -use unnest::*; mod regexp_matches; -use regexp_matches::*; +mod unnest; +mod user_defined; + +use self::generate_series::*; +use self::regexp_matches::*; +use self::unnest::*; +use self::user_defined::*; /// Instance of a table function. /// @@ -63,6 +66,7 @@ pub fn build_from_prost( Unnest => new_unnest(prost, chunk_size), RegexpMatches => new_regexp_matches(prost, chunk_size), Range => new_generate_series::(prost, chunk_size), + Udtf => new_user_defined(prost, chunk_size), Unspecified => unreachable!(), } } diff --git a/src/expr/src/table_function/user_defined.rs b/src/expr/src/table_function/user_defined.rs new file mode 100644 index 000000000000..26b0c090f19b --- /dev/null +++ b/src/expr/src/table_function/user_defined.rs @@ -0,0 +1,112 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use arrow_schema::{Field, Schema, SchemaRef}; +use risingwave_common::array::{ArrayImpl, ArrayRef, DataChunk}; +use risingwave_common::bail; +use risingwave_udf::ArrowFlightUdfClient; + +use super::*; + +#[derive(Debug)] +pub struct UserDefinedTableFunction { + children: Vec, + arg_schema: SchemaRef, + return_type: DataType, + client: ArrowFlightUdfClient, + identifier: String, + #[allow(dead_code)] + chunk_size: usize, +} + +#[cfg(not(madsim))] +impl TableFunction for UserDefinedTableFunction { + fn return_type(&self) -> DataType { + self.return_type.clone() + } + + fn eval(&self, input: &DataChunk) -> Result> { + let columns: Vec<_> = self + .children + .iter() + .map(|c| c.eval_checked(input).map(|a| a.as_ref().into())) + .try_collect()?; + let opts = + arrow_array::RecordBatchOptions::default().with_row_count(Some(input.cardinality())); + let input = + arrow_array::RecordBatch::try_new_with_options(self.arg_schema.clone(), columns, &opts) + .expect("failed to build record batch"); + let output = tokio::task::block_in_place(|| { + tokio::runtime::Handle::current().block_on(self.client.call(&self.identifier, input)) + })?; + // TODO: split by chunk_size + Ok(output + .columns() + .iter() + .map(|a| Arc::new(ArrayImpl::from(a))) + .collect()) + } +} + +#[cfg(not(madsim))] +pub fn new_user_defined( + prost: &TableFunctionProst, + chunk_size: usize, +) -> Result { + let Some(udtf) = &prost.udtf else { + bail!("expect UDTF"); + }; + + // connect to UDF service + let arg_schema = Arc::new(Schema::new( + udtf.arg_types + .iter() + .map(|t| Field::new("", DataType::from(t).into(), true)) + .collect(), + )); + let client = tokio::task::block_in_place(|| { + tokio::runtime::Handle::current().block_on(ArrowFlightUdfClient::connect(&udtf.link)) + })?; + + Ok(UserDefinedTableFunction { + children: prost.args.iter().map(expr_build_from_prost).try_collect()?, + return_type: prost.return_type.as_ref().expect("no return type").into(), + arg_schema, + client, + identifier: udtf.identifier.clone(), + chunk_size, + } + .boxed()) +} + +#[cfg(madsim)] +impl TableFunction for UserDefinedTableFunction { + fn return_type(&self) -> DataType { + panic!("UDF is not supported in simulation yet"); + } + + fn eval(&self, _input: &DataChunk) -> Result> { + panic!("UDF is not supported in simulation yet"); + } +} + +#[cfg(madsim)] +pub fn new_user_defined( + _prost: &TableFunctionProst, + _chunk_size: usize, +) -> Result { + panic!("UDF is not supported in simulation yet"); +} diff --git a/src/frontend/src/binder/expr/function.rs b/src/frontend/src/binder/expr/function.rs index c3194bed36bd..95c7cda62e15 100644 --- a/src/frontend/src/binder/expr/function.rs +++ b/src/frontend/src/binder/expr/function.rs @@ -114,7 +114,15 @@ impl Binder { &inputs.iter().map(|arg| arg.return_type()).collect_vec(), ) { - return Ok(UserDefinedFunction::new(func.clone(), inputs).into()); + use crate::catalog::function_catalog::FunctionKind::*; + match &func.kind { + Scalar { .. } => return Ok(UserDefinedFunction::new(func.clone(), inputs).into()), + Table { .. } => { + self.ensure_table_function_allowed()?; + return Ok(TableFunction::new_user_defined(func.clone(), inputs).into()); + } + Aggregate => todo!("support UDAF"), + } } self.bind_builtin_scalar_function(function_name.as_str(), inputs) diff --git a/src/frontend/src/catalog/function_catalog.rs b/src/frontend/src/catalog/function_catalog.rs index c4c8967f16d9..2d2de2fd7c22 100644 --- a/src/frontend/src/catalog/function_catalog.rs +++ b/src/frontend/src/catalog/function_catalog.rs @@ -14,6 +14,7 @@ use risingwave_common::catalog::FunctionId; use risingwave_common::types::DataType; +use risingwave_pb::catalog::function::Kind as ProstKind; use risingwave_pb::catalog::Function as ProstFunction; #[derive(Clone, PartialEq, Eq, Hash, Debug)] @@ -21,6 +22,7 @@ pub struct FunctionCatalog { pub id: FunctionId, pub name: String, pub owner: u32, + pub kind: FunctionKind, pub arg_types: Vec, pub return_type: DataType, pub language: String, @@ -28,12 +30,31 @@ pub struct FunctionCatalog { pub link: String, } +#[derive(Clone, PartialEq, Eq, Hash, Debug)] +pub enum FunctionKind { + Scalar, + Table, + Aggregate, +} + +impl From<&ProstKind> for FunctionKind { + fn from(prost: &ProstKind) -> Self { + use risingwave_pb::catalog::function::*; + match prost { + Kind::Scalar(ScalarFunction {}) => Self::Scalar, + Kind::Table(TableFunction {}) => Self::Table, + Kind::Aggregate(AggregateFunction {}) => Self::Aggregate, + } + } +} + impl From<&ProstFunction> for FunctionCatalog { fn from(prost: &ProstFunction) -> Self { FunctionCatalog { id: prost.id.into(), name: prost.name.clone(), owner: prost.owner, + kind: prost.kind.as_ref().unwrap().into(), arg_types: prost.arg_types.iter().map(|arg| arg.into()).collect(), return_type: prost.return_type.as_ref().expect("no return type").into(), language: prost.language.clone(), diff --git a/src/frontend/src/expr/expr_rewriter.rs b/src/frontend/src/expr/expr_rewriter.rs index 721960e82bd2..809b9b82bb5c 100644 --- a/src/frontend/src/expr/expr_rewriter.rs +++ b/src/frontend/src/expr/expr_rewriter.rs @@ -75,6 +75,7 @@ pub trait ExprRewriter { args, return_type, function_type, + udtf_catalog, } = table_func; let args = args .into_iter() @@ -84,6 +85,7 @@ pub trait ExprRewriter { args, return_type, function_type, + udtf_catalog, } .into() } diff --git a/src/frontend/src/expr/table_function.rs b/src/frontend/src/expr/table_function.rs index d84cb5e6087b..b3fe68404c36 100644 --- a/src/frontend/src/expr/table_function.rs +++ b/src/frontend/src/expr/table_function.rs @@ -13,14 +13,18 @@ // limitations under the License. use std::str::FromStr; +use std::sync::Arc; use itertools::Itertools; use risingwave_common::error::ErrorCode; use risingwave_common::types::{unnested_list_type, DataType, ScalarImpl}; use risingwave_pb::expr::table_function::Type; -use risingwave_pb::expr::TableFunction as TableFunctionProst; +use risingwave_pb::expr::{ + TableFunction as TableFunctionProst, UserDefinedTableFunction as UserDefinedTableFunctionProst, +}; use super::{Expr, ExprImpl, ExprRewriter, RwResult}; +use crate::catalog::function_catalog::{FunctionCatalog, FunctionKind}; /// A table function takes a row as input and returns a table. It is also known as Set-Returning /// Function. @@ -32,14 +36,17 @@ pub struct TableFunction { pub args: Vec, pub return_type: DataType, pub function_type: TableFunctionType, + /// Catalog of user defined table function. + pub udtf_catalog: Option>, } -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum TableFunctionType { Generate, Range, Unnest, RegexpMatches, + Udtf, } impl TableFunctionType { @@ -49,6 +56,7 @@ impl TableFunctionType { TableFunctionType::Range => Type::Range, TableFunctionType::Unnest => Type::Unnest, TableFunctionType::RegexpMatches => Type::RegexpMatches, + TableFunctionType::Udtf => Type::Udtf, } } } @@ -60,6 +68,7 @@ impl TableFunctionType { TableFunctionType::Range => "range", TableFunctionType::Unnest => "unnest", TableFunctionType::RegexpMatches => "regexp_matches", + TableFunctionType::Udtf => "udtf", } } } @@ -122,6 +131,7 @@ impl TableFunction { args, return_type: data_type, function_type, + udtf_catalog: None, }) } TableFunctionType::Unnest => { @@ -140,6 +150,7 @@ impl TableFunction { args: vec![expr], return_type: data_type, function_type: TableFunctionType::Unnest, + udtf_catalog: None, }) } else { Err(ErrorCode::BindError( @@ -197,8 +208,24 @@ impl TableFunction { datatype: Box::new(DataType::Varchar), }, function_type: TableFunctionType::RegexpMatches, + udtf_catalog: None, }) } + // not in this path + TableFunctionType::Udtf => unreachable!(), + } + } + + /// Create a user-defined `TableFunction`. + pub fn new_user_defined(catalog: Arc, args: Vec) -> Self { + let FunctionKind::Table = &catalog.kind else { + panic!("not a table function"); + }; + TableFunction { + args, + return_type: catalog.return_type.clone(), + function_type: TableFunctionType::Udtf, + udtf_catalog: Some(catalog), } } @@ -207,6 +234,15 @@ impl TableFunction { function_type: self.function_type.to_protobuf() as i32, args: self.args.iter().map(|c| c.to_expr_proto()).collect_vec(), return_type: Some(self.return_type.to_protobuf()), + udtf: self + .udtf_catalog + .as_ref() + .map(|c| UserDefinedTableFunctionProst { + arg_types: c.arg_types.iter().map(|t| t.to_protobuf()).collect(), + language: c.language.clone(), + link: c.link.clone(), + identifier: c.identifier.clone(), + }), } } diff --git a/src/frontend/src/expr/user_defined_function.rs b/src/frontend/src/expr/user_defined_function.rs index b2a925a13d78..907eac88b73e 100644 --- a/src/frontend/src/expr/user_defined_function.rs +++ b/src/frontend/src/expr/user_defined_function.rs @@ -19,7 +19,7 @@ use risingwave_common::catalog::FunctionId; use risingwave_common::types::DataType; use super::{Expr, ExprImpl}; -use crate::catalog::function_catalog::FunctionCatalog; +use crate::catalog::function_catalog::{FunctionCatalog, FunctionKind}; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct UserDefinedFunction { @@ -34,7 +34,7 @@ impl UserDefinedFunction { pub(super) fn from_expr_proto( udf: &risingwave_pb::expr::UserDefinedFunction, - ret_type: DataType, + return_type: DataType, ) -> risingwave_common::error::Result { let args: Vec<_> = udf .get_children() @@ -50,8 +50,9 @@ impl UserDefinedFunction { name: udf.get_name().clone(), // FIXME(yuhao): owner is not in udf proto. owner: u32::MAX - 1, + kind: FunctionKind::Scalar, arg_types, - return_type: ret_type, + return_type, language: udf.get_language().clone(), identifier: udf.get_identifier().clone(), link: udf.get_link().clone(), diff --git a/src/frontend/src/handler/create_function.rs b/src/frontend/src/handler/create_function.rs index 92ebf2cffdf0..f9606b23cb48 100644 --- a/src/frontend/src/handler/create_function.rs +++ b/src/frontend/src/handler/create_function.rs @@ -16,9 +16,11 @@ use anyhow::anyhow; use itertools::Itertools; use pgwire::pg_response::StatementType; use risingwave_common::catalog::FunctionId; +use risingwave_common::types::DataType; +use risingwave_pb::catalog::function::{Kind, ScalarFunction, TableFunction}; use risingwave_pb::catalog::Function; use risingwave_sqlparser::ast::{ - CreateFunctionBody, DataType, FunctionDefinition, ObjectName, OperateFunctionArg, + CreateFunctionBody, FunctionDefinition, ObjectName, OperateFunctionArg, }; use risingwave_udf::ArrowFlightUdfClient; @@ -32,7 +34,7 @@ pub async fn handle_create_function( temporary: bool, name: ObjectName, args: Option>, - return_type: Option, + returns: Option, params: CreateFunctionBody, ) -> Result { if or_replace { @@ -75,12 +77,37 @@ pub async fn handle_create_function( ) .into()); }; - let Some(return_type) = return_type else { - return Err( - ErrorCode::InvalidParameterValue("return type must be specified".to_string()).into(), - ) + let return_type; + let kind = match returns { + Some(CreateFunctionReturns::Value(data_type)) => { + return_type = bind_data_type(&data_type)?; + Kind::Scalar(ScalarFunction {}) + } + Some(CreateFunctionReturns::Table(columns)) => { + if columns.len() == 1 { + // return type is the original type for single column + return_type = bind_data_type(&columns[0].data_type)?; + } else { + // return type is a struct for multiple columns + let datatypes = columns + .iter() + .map(|c| bind_data_type(&c.data_type)) + .collect::>>()?; + let names = columns + .iter() + .map(|c| c.name.real_value()) + .collect::>(); + return_type = DataType::new_struct(datatypes, names); + } + Kind::Table(TableFunction {}) + } + None => { + return Err(ErrorCode::InvalidParameterValue( + "return type must be specified".to_string(), + ) + .into()) + } }; - let return_type = bind_data_type(&return_type)?; let mut arg_types = vec![]; for arg in args.unwrap_or_default() { @@ -116,11 +143,24 @@ pub async fn handle_create_function( .map(|t| arrow_schema::Field::new("", t.into(), true)) .collect(), ); - let returns = arrow_schema::Schema::new(vec![arrow_schema::Field::new( - "", - return_type.clone().into(), - true, - )]); + let returns = match kind { + Kind::Scalar(_) => arrow_schema::Schema::new(vec![arrow_schema::Field::new( + "", + return_type.clone().into(), + true, + )]), + Kind::Table(_) => arrow_schema::Schema::new(match &return_type { + DataType::Struct(s) => (s.fields.iter()) + .map(|t| arrow_schema::Field::new("", t.clone().into(), true)) + .collect(), + _ => vec![arrow_schema::Field::new( + "", + return_type.clone().into(), + true, + )], + }), + _ => unreachable!(), + }; client .check(&identifier, &args, &returns) .await @@ -131,6 +171,7 @@ pub async fn handle_create_function( schema_id, database_id, name: function_name, + kind: Some(kind), arg_types: arg_types.into_iter().map(|t| t.into()).collect(), return_type: Some(return_type.into()), language, diff --git a/src/frontend/src/handler/mod.rs b/src/frontend/src/handler/mod.rs index bffe3e8818d7..4301a78c5129 100644 --- a/src/frontend/src/handler/mod.rs +++ b/src/frontend/src/handler/mod.rs @@ -173,7 +173,7 @@ pub async fn handle( temporary, name, args, - return_type, + returns, params, } => { create_function::handle_create_function( @@ -182,7 +182,7 @@ pub async fn handle( temporary, name, args, - return_type, + returns, params, ) .await diff --git a/src/frontend/src/optimizer/plan_node/logical_project_set.rs b/src/frontend/src/optimizer/plan_node/logical_project_set.rs index 52cb7bda9e3e..44a86c36759f 100644 --- a/src/frontend/src/optimizer/plan_node/logical_project_set.rs +++ b/src/frontend/src/optimizer/plan_node/logical_project_set.rs @@ -87,6 +87,7 @@ impl LogicalProjectSet { args, return_type, function_type, + udtf_catalog, } = table_func; let args = args .into_iter() @@ -98,6 +99,7 @@ impl LogicalProjectSet { args, return_type, function_type, + udtf_catalog, } .into() } else { diff --git a/src/sqlparser/src/ast/mod.rs b/src/sqlparser/src/ast/mod.rs index 39cd6764198b..d6330e02f285 100644 --- a/src/sqlparser/src/ast/mod.rs +++ b/src/sqlparser/src/ast/mod.rs @@ -981,7 +981,7 @@ pub enum Statement { temporary: bool, name: ObjectName, args: Option>, - return_type: Option, + returns: Option, /// Optional parameters. params: CreateFunctionBody, }, @@ -1250,7 +1250,7 @@ impl fmt::Display for Statement { temporary, name, args, - return_type, + returns, params, } => { write!( @@ -1262,8 +1262,8 @@ impl fmt::Display for Statement { if let Some(args) = args { write!(f, "({})", display_comma_separated(args))?; } - if let Some(return_type) = return_type { - write!(f, " RETURNS {}", return_type)?; + if let Some(return_type) = returns { + write!(f, " {}", return_type)?; } write!(f, "{params}")?; Ok(()) @@ -2188,6 +2188,41 @@ impl fmt::Display for FunctionDefinition { } } +/// Return types of a function. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum CreateFunctionReturns { + /// RETURNS rettype + Value(DataType), + /// RETURNS TABLE ( column_name column_type [, ...] ) + Table(Vec), +} + +impl fmt::Display for CreateFunctionReturns { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Value(data_type) => write!(f, "RETURNS {}", data_type), + Self::Table(columns) => { + write!(f, "RETURNS TABLE ({})", display_comma_separated(columns)) + } + } + } +} + +/// Table column definition +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct TableColumnDef { + pub name: Ident, + pub data_type: DataType, +} + +impl fmt::Display for TableColumnDef { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{} {}", self.name, self.data_type) + } +} + /// Postgres specific feature. /// /// See [Postgresdocs](https://www.postgresql.org/docs/15/sql-createfunction.html) diff --git a/src/sqlparser/src/parser.rs b/src/sqlparser/src/parser.rs index ed9893166585..d41d31b85226 100644 --- a/src/sqlparser/src/parser.rs +++ b/src/sqlparser/src/parser.rs @@ -1785,7 +1785,23 @@ impl Parser { self.expect_token(&Token::RParen)?; let return_type = if self.parse_keyword(Keyword::RETURNS) { - Some(self.parse_data_type()?) + if self.parse_keyword(Keyword::TABLE) { + self.expect_token(&Token::LParen)?; + let mut values = vec![]; + loop { + values.push(self.parse_table_column_def()?); + let comma = self.consume_token(&Token::Comma); + if self.consume_token(&Token::RParen) { + // allow a trailing comma, even though it's not in standard + break; + } else if !comma { + return self.expected("',' or ')'", self.peek_token()); + } + } + Some(CreateFunctionReturns::Table(values)) + } else { + Some(CreateFunctionReturns::Value(self.parse_data_type()?)) + } } else { None }; @@ -1797,11 +1813,18 @@ impl Parser { temporary, name, args, - return_type, + returns: return_type, params, }) } + fn parse_table_column_def(&mut self) -> Result { + Ok(TableColumnDef { + name: self.parse_identifier_non_reserved()?, + data_type: self.parse_data_type()?, + }) + } + fn parse_function_arg(&mut self) -> Result { let mode = if self.parse_keyword(Keyword::IN) { Some(ArgMode::In) diff --git a/src/sqlparser/tests/sqlparser_postgres.rs b/src/sqlparser/tests/sqlparser_postgres.rs index db6f06c9cb8a..e13ae03cc04a 100644 --- a/src/sqlparser/tests/sqlparser_postgres.rs +++ b/src/sqlparser/tests/sqlparser_postgres.rs @@ -754,7 +754,7 @@ fn parse_create_function() { OperateFunctionArg::unnamed(DataType::Int), OperateFunctionArg::unnamed(DataType::Int), ]), - return_type: Some(DataType::Int), + returns: Some(CreateFunctionReturns::Value(DataType::Int)), params: CreateFunctionBody { language: Some("SQL".into()), behavior: Some(FunctionBehavior::Immutable), @@ -782,7 +782,7 @@ fn parse_create_function() { default_expr: Some(Expr::Value(Value::Number("1".into()))), } ]), - return_type: Some(DataType::Int), + returns: Some(CreateFunctionReturns::Value(DataType::Int)), params: CreateFunctionBody { language: Some("SQL".into()), behavior: Some(FunctionBehavior::Immutable), @@ -795,6 +795,29 @@ fn parse_create_function() { }, } ); + + let sql = "CREATE FUNCTION unnest(a INT[]) RETURNS TABLE (x INT) LANGUAGE SQL RETURN a"; + assert_eq!( + verified_stmt(sql), + Statement::CreateFunction { + or_replace: false, + temporary: false, + name: ObjectName(vec![Ident::new_unchecked("unnest")]), + args: Some(vec![OperateFunctionArg::with_name( + "a", + DataType::Array(Box::new(DataType::Int)) + ),]), + returns: Some(CreateFunctionReturns::Table(vec![TableColumnDef { + name: Ident::new_unchecked("x"), + data_type: DataType::Int, + }])), + params: CreateFunctionBody { + language: Some("SQL".into()), + return_: Some(Expr::Identifier("a".into())), + ..Default::default() + }, + } + ); } #[test] diff --git a/src/udf/python/example.py b/src/udf/python/example.py index b7e6ded7df21..366ecda593e2 100644 --- a/src/udf/python/example.py +++ b/src/udf/python/example.py @@ -1,4 +1,5 @@ -from risingwave.udf import udf, UdfServer +from typing import Iterator +from risingwave.udf import udf, udtf, UdfServer import random @@ -19,9 +20,23 @@ def gcd3(x: int, y: int, z: int) -> int: return gcd(gcd(x, y), z) +@udtf(input_types='INT', result_types='INT') +def series(n: int) -> Iterator[int]: + for i in range(n): + yield i + + +@udtf(input_types=['INT'], result_types=['INT', 'VARCHAR']) +def series2(n: int) -> Iterator[tuple[int, str]]: + for i in range(n): + yield i, str(i) + + if __name__ == '__main__': server = UdfServer() server.add_function(random_int) server.add_function(gcd) server.add_function(gcd3) + server.add_function(series) + server.add_function(series2) server.serve() diff --git a/src/udf/python/risingwave/udf.py b/src/udf/python/risingwave/udf.py index 93d47f29df4d..7134bc17014d 100644 --- a/src/udf/python/risingwave/udf.py +++ b/src/udf/python/risingwave/udf.py @@ -39,9 +39,30 @@ def eval_batch(self, batch: pa.RecordBatch) -> pa.RecordBatch: return pa.RecordBatch.from_arrays([result], schema=self._result_schema) -class UserDefinedFunctionWrapper(ScalarFunction): +class TableFunction(UserDefinedFunction): """ - Base Wrapper for Python user-defined function. + Base interface for user-defined table function. A user-defined table functions maps zero, one, + or multiple table values to a new table value. + """ + + def eval(self, *args): + """ + Method which defines the logic of the table function. + """ + pass + + def eval_batch(self, batch: pa.RecordBatch) -> pa.RecordBatch: + # only the first row from batch is used + res = self.eval(*[col[0].as_py() for col in batch]) + columns = zip(*res) if len(self._result_schema) > 1 else [res] + arrays = [pa.array(col, type) + for col, type in zip(columns, self._result_schema.types)] + return pa.RecordBatch.from_arrays(arrays, schema=self._result_schema) + + +class UserDefinedScalarFunctionWrapper(ScalarFunction): + """ + Base Wrapper for Python user-defined scalar function. """ _func: Callable @@ -49,7 +70,7 @@ def __init__(self, func, input_types, result_type, name=None): self._func = func self._input_schema = pa.schema(zip( inspect.getfullargspec(func)[0], - [_to_data_type(t) for t in input_types] + [_to_data_type(t) for t in _to_list(input_types)] )) self._result_schema = pa.schema( [('output', _to_data_type(result_type))]) @@ -63,9 +84,35 @@ def eval(self, *args): return self._func(*args) -def _create_udf(f, input_types, result_type, name): - return UserDefinedFunctionWrapper( - f, input_types, result_type, name) +class UserDefinedTableFunctionWrapper(TableFunction): + """ + Base Wrapper for Python user-defined table function. + """ + _func: Callable + + def __init__(self, func, input_types, result_types, name=None): + self._func = func + self._input_schema = pa.schema(zip( + inspect.getfullargspec(func)[0], + [_to_data_type(t) for t in _to_list(input_types)] + )) + self._result_schema = pa.schema( + [('', _to_data_type(t)) for t in _to_list(result_types)]) + self._name = name or ( + func.__name__ if hasattr(func, '__name__') else func.__class__.__name__) + + def __call__(self, *args): + return self._func(*args) + + def eval(self, *args): + return self._func(*args) + + +def _to_list(x): + if isinstance(x, list): + return x + else: + return [x] def udf(input_types: Union[List[Union[str, pa.DataType]], Union[str, pa.DataType]], @@ -75,7 +122,17 @@ def udf(input_types: Union[List[Union[str, pa.DataType]], Union[str, pa.DataType Annotation for creating a user-defined function. """ - return lambda f: _create_udf(f, input_types, result_type, name) + return lambda f: UserDefinedScalarFunctionWrapper(f, input_types, result_type, name) + + +def udtf(input_types: Union[List[Union[str, pa.DataType]], Union[str, pa.DataType]], + result_types: Union[List[Union[str, pa.DataType]], Union[str, pa.DataType]], + name: Optional[str] = None,) -> Union[Callable, UserDefinedFunction]: + """ + Annotation for creating a user-defined table function. + """ + + return lambda f: UserDefinedTableFunctionWrapper(f, input_types, result_types, name) class UdfServer(pa.flight.FlightServerBase): @@ -93,8 +150,10 @@ def get_flight_info(self, context, descriptor): """Return the result schema of a function.""" udf = self._functions[descriptor.path[0].decode('utf-8')] # return the concatenation of input and output schema - full_schema = udf._input_schema.append(udf._result_schema.field(0)) - return pa.flight.FlightInfo(schema=full_schema, descriptor=descriptor, endpoints=[], total_records=0, total_bytes=0) + full_schema = pa.schema( + list(udf._input_schema) + list(udf._result_schema)) + # we use `total_records` to indicate the number of input arguments + return pa.flight.FlightInfo(schema=full_schema, descriptor=descriptor, endpoints=[], total_records=len(udf._input_schema), total_bytes=0) def add_function(self, udf: UserDefinedFunction): """Add a function to the server.""" diff --git a/src/udf/src/lib.rs b/src/udf/src/lib.rs index 748a0c3224f6..00bc7073d9b9 100644 --- a/src/udf/src/lib.rs +++ b/src/udf/src/lib.rs @@ -44,11 +44,10 @@ impl ArrowFlightUdfClient { // check schema let info = response.into_inner(); + let input_num = info.total_records as usize; let full_schema = Schema::try_from(info) .map_err(|e| FlightError::DecodeError(format!("Error decoding schema: {e}")))?; - // TODO: only support one return value for now - let (input_fields, return_fields) = - full_schema.fields.split_at(full_schema.fields.len() - 1); + let (input_fields, return_fields) = full_schema.fields.split_at(input_num); let actual_input_types: Vec<_> = input_fields.iter().map(|f| f.data_type()).collect(); let actual_result_types: Vec<_> = return_fields.iter().map(|f| f.data_type()).collect(); let expect_input_types: Vec<_> = args.fields.iter().map(|f| f.data_type()).collect();