Skip to content

Commit

Permalink
finish implemention of cast VarList to VarList
Browse files Browse the repository at this point in the history
  • Loading branch information
AEsir777 committed Nov 10, 2023
1 parent 4dff299 commit 9a73d5e
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 62 deletions.
6 changes: 2 additions & 4 deletions src/expression_evaluator/function_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@ void FunctionExpressionEvaluator::evaluate() {
child->evaluate();
}
auto expr = reinterpret_cast<binder::ScalarFunctionExpression*>(expression.get());
if (expr->getFunctionName() == CAST_FUNC_NAME &&
parameters[0]->dataType.getLogicalTypeID() == LogicalTypeID::STRING) {
execFunc(parameters, *resultVector,
reinterpret_cast<function::StringCastFunctionBindData*>(expr->getBindData()));
if (expr->getFunctionName() == CAST_FUNC_NAME) {
execFunc(parameters, *resultVector, expr->getBindData());
return;
}
if (execFunc != nullptr) {
Expand Down
129 changes: 90 additions & 39 deletions src/function/vector_cast_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ static void castFixedListToString(
resultVector.setValue(resultPos, result);
}

template<typename /*EXECUTOR*/ = UnaryFunctionExecutor>
static void fixedListCastExecFunction(const std::vector<std::shared_ptr<ValueVector>>& params,
ValueVector& result, void* /*dataPtr*/ = nullptr) {
assert(params.size() == 1);
KU_ASSERT(params.size() == 1);
auto param = params[0];
if (param->state->isFlat()) {
castFixedListToString(*param, param->state->selVector->selectedPositions[0], result,
Expand All @@ -53,10 +54,25 @@ static void fixedListCastExecFunction(const std::vector<std::shared_ptr<ValueVec
}
}

template<typename T>
template<>
void fixedListCastExecFunction<CastChildFunctionExecutor>(
const std::vector<std::shared_ptr<ValueVector>>& params, ValueVector& result,
void* /*dataPtr*/) {
KU_ASSERT(params.size() == 1);

auto inputVector = params[0].get();
auto childNum = common::ListVector::getDataVectorSize(inputVector);
auto inputChildVector = (common::ListVector::getDataVector(inputVector));
auto resultChildVector = (common::ListVector::getDataVector(&result));
for (auto i = 0u; i < childNum; i++) {
castFixedListToString(*inputChildVector, i, *resultChildVector, i);
}
}

template<typename /*EXECUTOR*/ = UnaryFunctionExecutor>
static void StringtoFixedListCastExecFunction(
const std::vector<std::shared_ptr<ValueVector>>& params, ValueVector& result, void* dataPtr) {
assert(params.size() == 1);
KU_ASSERT(params.size() == 1);
auto param = params[0];
auto csvReaderConfig = &reinterpret_cast<StringCastFunctionBindData*>(dataPtr)->csvConfig;
if (param->state->isFlat()) {
Expand Down Expand Up @@ -90,16 +106,33 @@ static void StringtoFixedListCastExecFunction(
template<>
void StringtoFixedListCastExecFunction<CastChildFunctionExecutor>(
const std::vector<std::shared_ptr<ValueVector>>& params, ValueVector& result, void* dataPtr) {
KU_ASSERT(params.size() == 1 &&
params[0]->dataType.getLogicalTypeID() == common::LogicalTypeID::VAR_LIST &&
result.dataType.getLogicalTypeID() == common::LogicalTypeID::VAR_LIST);
auto csvReaderConfig = &reinterpret_cast<StringCastFunctionBindData*>(dataPtr)->csvConfig;

auto inputVector = params[0].get();
auto childNum = common::ListVector::getDataVectorSize(inputVector);
auto inputChildVector = (common::ListVector::getDataVector(inputVector));
auto resultChildVector = (common::ListVector::getDataVector(&result));
for (auto i = 0u; i < childNum; i++) {
resultChildVector->setNull(i, inputChildVector->isNull(i));
if (!resultChildVector->isNull(i)) {
CastString::castToFixedList(
inputChildVector->getValue<ku_string_t>(i), resultChildVector, i, csvReaderConfig);
}
}
}

template<typename EXECUTOR = UnaryFunctionExecutor>
static void varListCastExecFunction(const std::vector<std::shared_ptr<ValueVector>>& params,
ValueVector& result, void* dataPtr) {
assert(params.size() == 1);
template<typename /*EXECUTOR*/ = UnaryFunctionExecutor>
static void varListCastExecFunction(
const std::vector<std::shared_ptr<ValueVector>>& params, ValueVector& result, void* dataPtr) {
KU_ASSERT(params.size() == 1);
result.resetAuxiliaryBuffer();
auto inputVector = params[0];
scalar_exec_func func = CastFunction::bindCastFunction<CastChildFunctionExecutor>("CAST",
inputVector->dataType.getLogicalTypeID(), result.dataType.getLogicalTypeID())
VarListType::getChildType(&inputVector->dataType)->getLogicalTypeID(),
VarListType::getChildType(&result.dataType)->getLogicalTypeID())
->execFunc;
for (auto i = 0u; i < inputVector->state->selVector->selectedSize; i++) {
auto pos = inputVector->state->selVector->selectedPositions[i];
Expand All @@ -116,10 +149,30 @@ static void varListCastExecFunction(const std::vector<std::shared_ptr<ValueVecto
template<>
void varListCastExecFunction<CastChildFunctionExecutor>(
const std::vector<std::shared_ptr<ValueVector>>& params, ValueVector& result, void* dataPtr) {
KU_ASSERT(params.size() == 1);
auto inputChildVector = params[0];

// TODO: Kebing finish this one
KU_ASSERT(params.size() == 1 &&
params[0]->dataType.getLogicalTypeID() == common::LogicalTypeID::VAR_LIST &&
result.dataType.getLogicalTypeID() == common::LogicalTypeID::VAR_LIST);
result.resetAuxiliaryBuffer();

auto inputVector = params[0].get();
auto childNum = common::ListVector::getDataVectorSize(inputVector);
auto inputChildVector = (common::ListVector::getSharedDataVector(inputVector));
auto resultChildVector = (common::ListVector::getDataVector(&result));
scalar_exec_func func = CastFunction::bindCastFunction<CastChildFunctionExecutor>("CAST",
VarListType::getChildType(&inputChildVector->dataType)->getLogicalTypeID(),
VarListType::getChildType(&resultChildVector->dataType)->getLogicalTypeID())
->execFunc;
for (auto i = 0u; i < childNum; i++) {
resultChildVector->setNull(i, inputChildVector->isNull(i));
if (!resultChildVector->isNull(i)) {
// cast position i in child data vector
auto input_list_entry = inputChildVector->getValue<list_entry_t>(i);
auto result_list_entry = ListVector::addList(resultChildVector, input_list_entry.size);
resultChildVector->setValue<list_entry_t>(i, result_list_entry);
}
}
std::vector<std::shared_ptr<ValueVector>> childParams{inputChildVector};
func(childParams, *resultChildVector, dataPtr);
}

bool CastFunction::hasImplicitCast(const LogicalType& srcType, const LogicalType& dstType) {
Expand Down Expand Up @@ -165,12 +218,12 @@ static std::unique_ptr<ScalarFunction> bindCastFromStringFunction(
ScalarFunction::UnaryCastStringExecFunction<ku_string_t, date_t, CastString, EXECUTOR>;
} break;
case LogicalTypeID::TIMESTAMP: {
execFunc = ScalarFunction::UnaryCastStringExecFunction<
ku_string_t, timestamp_t, CastString, EXECUTOR>;
execFunc = ScalarFunction::UnaryCastStringExecFunction<ku_string_t, timestamp_t, CastString,
EXECUTOR>;
} break;
case LogicalTypeID::INTERVAL: {
execFunc = ScalarFunction::UnaryCastStringExecFunction<
ku_string_t, interval_t, CastString, EXECUTOR>;
execFunc = ScalarFunction::UnaryCastStringExecFunction<ku_string_t, interval_t, CastString,
EXECUTOR>;
} break;
case LogicalTypeID::BLOB: {
execFunc =
Expand All @@ -185,16 +238,16 @@ static std::unique_ptr<ScalarFunction> bindCastFromStringFunction(
ScalarFunction::UnaryCastStringExecFunction<ku_string_t, bool, CastString, EXECUTOR>;
} break;
case LogicalTypeID::DOUBLE: {
execFunc = ScalarFunction::UnaryCastStringExecFunction<
ku_string_t, double_t, CastString, EXECUTOR>;
execFunc = ScalarFunction::UnaryCastStringExecFunction<ku_string_t, double_t, CastString,
EXECUTOR>;
} break;
case LogicalTypeID::FLOAT: {
execFunc =
ScalarFunction::UnaryCastStringExecFunction<ku_string_t, float, CastString, EXECUTOR>;
} break;
case LogicalTypeID::INT128: {
execFunc = ScalarFunction::UnaryCastStringExecFunction<
ku_string_t, int128_t, CastString, EXECUTOR>;
execFunc = ScalarFunction::UnaryCastStringExecFunction<ku_string_t, int128_t, CastString,
EXECUTOR>;
} break;
case LogicalTypeID::SERIAL:
case LogicalTypeID::INT64: {
Expand All @@ -214,35 +267,35 @@ static std::unique_ptr<ScalarFunction> bindCastFromStringFunction(
ScalarFunction::UnaryCastStringExecFunction<ku_string_t, int8_t, CastString, EXECUTOR>;
} break;
case LogicalTypeID::UINT64: {
execFunc = ScalarFunction::UnaryCastStringExecFunction<
ku_string_t, uint64_t, CastString, EXECUTOR>;
execFunc = ScalarFunction::UnaryCastStringExecFunction<ku_string_t, uint64_t, CastString,
EXECUTOR>;
} break;
case LogicalTypeID::UINT32: {
execFunc = ScalarFunction::UnaryCastStringExecFunction<
ku_string_t, uint32_t, CastString, EXECUTOR>;
execFunc = ScalarFunction::UnaryCastStringExecFunction<ku_string_t, uint32_t, CastString,
EXECUTOR>;
} break;
case LogicalTypeID::UINT16: {
execFunc = ScalarFunction::UnaryCastStringExecFunction<
ku_string_t, uint16_t, CastString, EXECUTOR>;
execFunc = ScalarFunction::UnaryCastStringExecFunction<ku_string_t, uint16_t, CastString,
EXECUTOR>;
} break;
case LogicalTypeID::UINT8: {
execFunc =
ScalarFunction::UnaryCastStringExecFunction<ku_string_t, uint8_t, CastString, EXECUTOR>;
} break;
case LogicalTypeID::VAR_LIST: {
execFunc = ScalarFunction::UnaryCastStringExecFunction<
ku_string_t, list_entry_t, CastString, EXECUTOR>;
execFunc = ScalarFunction::UnaryCastStringExecFunction<ku_string_t, list_entry_t,
CastString, EXECUTOR>;
} break;
case LogicalTypeID::FIXED_LIST: {
execFunc = StringtoFixedListCastExecFunction<EXECUTOR>;
} break;
case LogicalTypeID::MAP: {
execFunc = ScalarFunction::UnaryCastStringExecFunction<
ku_string_t, map_entry_t, CastString, EXECUTOR>;
execFunc = ScalarFunction::UnaryCastStringExecFunction<ku_string_t, map_entry_t, CastString,
EXECUTOR>;
} break;
case LogicalTypeID::STRUCT: {
execFunc = ScalarFunction::UnaryCastStringExecFunction<
ku_string_t, struct_entry_t, CastString, EXECUTOR>;
execFunc = ScalarFunction::UnaryCastStringExecFunction<ku_string_t, struct_entry_t,
CastString, EXECUTOR>;
} break;
case LogicalTypeID::UNION: {
execFunc = ScalarFunction::UnaryCastStringExecFunction<ku_string_t, union_entry_t,
Expand Down Expand Up @@ -402,7 +455,7 @@ static std::unique_ptr<ScalarFunction> bindCastToStringFunction(
EXECUTOR>;
} break;
case LogicalTypeID::FIXED_LIST: {
func = fixedListCastExecFunction;
func = fixedListCastExecFunction<EXECUTOR>;
} break;
case LogicalTypeID::MAP: {
func =
Expand Down Expand Up @@ -473,7 +526,7 @@ static std::unique_ptr<ScalarFunction> bindCastToNumericFunction(
functionName, std::vector<LogicalTypeID>{sourceTypeID}, targetTypeID, func);
}

template<typename DST_TYPE, typename EXECUTOR = UnaryFunctionExecutor>
template<typename /*DST_TYPE*/, typename EXECUTOR = UnaryFunctionExecutor>
static std::unique_ptr<ScalarFunction> bindCastBetweenNested(
const std::string& functionName, LogicalTypeID sourceTypeID, LogicalTypeID targetTypeID) {
scalar_exec_func func;
Expand All @@ -495,7 +548,8 @@ static std::unique_ptr<ScalarFunction> bindCastToTimestampFunction(
scalar_exec_func func;
switch (sourceTypeID) {
case LogicalTypeID::DATE: {
func = ScalarFunction::UnaryExecFunction<date_t, timestamp_t, CastDateToTimestamp, EXECUTOR>;
func =
ScalarFunction::UnaryExecFunction<date_t, timestamp_t, CastDateToTimestamp, EXECUTOR>;
} break;
default:
throw ConversionException{stringFormat("Unsupported casting function from {} to TIMESTAMP.",
Expand Down Expand Up @@ -810,10 +864,7 @@ std::unique_ptr<FunctionBindData> CastAnyFunction::bindFunc(
func->execFunc =
CastFunction::bindCastFunction(func->name, inputTypeID, outputType->getLogicalTypeID())
->execFunc;
if (inputTypeID == LogicalTypeID::STRING) {
return std::make_unique<function::StringCastFunctionBindData>(*outputType);
}
return std::make_unique<function::FunctionBindData>(*outputType);
return std::make_unique<function::StringCastFunctionBindData>(*outputType);
}

function_set CastAnyFunction::getFunctionSet() {
Expand Down
1 change: 1 addition & 0 deletions src/include/common/vector/auxiliary_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class ListAuxiliaryBuffer : public AuxiliaryBuffer {
dataVector = std::move(vector);
}
inline ValueVector* getDataVector() const { return dataVector.get(); }
inline std::shared_ptr<ValueVector> getSharedDataVector() const { return dataVector; }

list_entry_t addList(uint64_t listSize);

Expand Down
5 changes: 5 additions & 0 deletions src/include/common/vector/value_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ class ListVector {
return reinterpret_cast<ListAuxiliaryBuffer*>(vector->auxiliaryBuffer.get())
->getDataVector();
}
static inline std::shared_ptr<ValueVector> getSharedDataVector(const ValueVector* vector) {
KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::VAR_LIST);
return reinterpret_cast<ListAuxiliaryBuffer*>(vector->auxiliaryBuffer.get())
->getSharedDataVector();
}
static inline uint64_t getDataVectorSize(const ValueVector* vector) {
KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::VAR_LIST);
return reinterpret_cast<ListAuxiliaryBuffer*>(vector->auxiliaryBuffer.get())->getSize();
Expand Down
9 changes: 4 additions & 5 deletions src/include/function/scalar_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,8 @@ struct ScalarFunction : public BaseScalarFunction {
const std::vector<std::shared_ptr<common::ValueVector>>& params,
common::ValueVector& result, void* dataPtr) {
KU_ASSERT(params.size() == 1);
EXECUTOR::template executeSwitch<
OPERAND_TYPE, RESULT_TYPE, FUNC, UnaryCastStringFunctionWrapper>
(*params[0], result, dataPtr);
EXECUTOR::template executeSwitch<OPERAND_TYPE, RESULT_TYPE, FUNC,
UnaryCastStringFunctionWrapper>(*params[0], result, dataPtr);
}

template<typename OPERAND_TYPE, typename RESULT_TYPE, typename FUNC,
Expand All @@ -127,8 +126,8 @@ struct ScalarFunction : public BaseScalarFunction {
const std::vector<std::shared_ptr<common::ValueVector>>& params,
common::ValueVector& result, void* /*dataPtr*/ = nullptr) {
KU_ASSERT(params.size() == 1);
EXECUTOR::template executeSwitch<OPERAND_TYPE, RESULT_TYPE, FUNC,
UnaryCastFunctionWrapper>(*params[0], result, nullptr /* dataPtr */);
EXECUTOR::template executeSwitch<OPERAND_TYPE, RESULT_TYPE, FUNC, UnaryCastFunctionWrapper>(
*params[0], result, nullptr /* dataPtr */);
}

template<typename OPERAND_TYPE, typename RESULT_TYPE, typename FUNC,
Expand Down
11 changes: 6 additions & 5 deletions src/include/function/unary_function_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,16 @@ struct UnaryUDFFunctionWrapper {

struct CastChildFunctionExecutor {
template<typename OPERAND_TYPE, typename RESULT_TYPE, typename FUNC, typename OP_WRAPPER>
static void executeSwitch(common::ValueVector& operand, common::ValueVector& result,
void* dataPtr) {
static void executeSwitch(
common::ValueVector& operand, common::ValueVector& result, void* dataPtr) {
// this vector is of var list type and the child vector is of non-nested types then cast
KU_ASSERT(operand.dataType.getLogicalTypeID() == common::LogicalTypeID::VAR_LIST);
KU_ASSERT(operand.dataType.getLogicalTypeID() == common::LogicalTypeID::VAR_LIST &&
result.dataType.getLogicalTypeID() == common::LogicalTypeID::VAR_LIST);
result.resetAuxiliaryBuffer();
auto childNum = common::ListVector::getDataVectorSize(&operand);
auto inputChildVector = (common::ListVector::getDataVector(&operand));
auto resultChildVector = (common::ListVector::getDataVector(&result));
for (auto i = 0u; i < childNum; i++) {
auto inputChildVector = (common::ListVector::getDataVector(&operand));
auto resultChildVector = (common::ListVector::getDataVector(&operand));
resultChildVector->setNull(i, inputChildVector->isNull(i));
if (!resultChildVector->isNull(i)) {
// cast position i in child data vector
Expand Down
23 changes: 23 additions & 0 deletions test/test_files/tinysnb/cast/cast_to_nested_types.test
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@ False|-4325||18446744073709551616.000000| dfsa
---- 2
[3324.123047,342423.437500,432.122986]
[1.000000,4231.000000,432.122986]
-STATEMENT RETURN cast("[423, 321, 423]", "INT64[3]"), cast(null, "INT64[5]"), cast("[432.43214]", "FLOAT[1]"), cast("[4, -5]", "double[2]"), cast("[4234, 42312, 432, 1321]", "INT32[4]"), cast("[-32768]", "INT16[1]")
---- 1
[423,321,423]||[432.432129]|[4.000000,-5.000000]|[4234,42312,432,1321]|[-32768]
-STATEMENT Return cast(cast(-4324324, "int128"), "int64")
---- 1
-4324324

-LOG CastToMap
-STATEMENT RETURN cast(" { c= {a = 3423 }, b = { g = 3421 } } ", "MAP(STRING, MAP(STRING, INT16))"), cast("{}", "MAP(STRING, MAP(STRING, INT16))"), cast("{d = {}}", "MAP(STRING, MAP(STRING, INT16))");
Expand Down Expand Up @@ -168,3 +174,20 @@ True|2019-03-19|-12.343200|32768|
1|-2147483648|1970-01-01 10:00:00.004666|-32769.000000|fsdxcv
0|0|2014-05-12 12:11:59|4324254534123134124032.000000|fsaf
False|-4325|14|18446744073709551616.000000| dfsa

-LOG CastVarListToVarList
-STATEMENT RETURN cast([321, 432], "DOUBLE[]"), cast([321, 432], "FLOAT[]"), cast([321, 432], "INT128[]"), cast([321, 432], "INT64[]"), cast([321, 432], "INT32[]"), cast([321, 432], "INT16[]"), cast([-1, -43], "INT8[]"), cast([0, 23], "UINT8[]"), cast([0, 23], "UINT16[]"), cast([0, 23], "UINT32[]"), cast([0, 23], "UINT64[]"), cast([5435234412435123, -432425341231], "STRING[]");
---- 1
[321.000000,432.000000]|[321.000000,432.000000]|[321,432]|[321,432]|[321,432]|[321,432]|[-1,-43]|[0,23]|[0,23]|[0,23]|[0,23]|[5435234412435123,-432425341231]
-STATEMENT RETURN cast([], "UINT64[]"), cast([NULL,], "UINT64[]"), cast(NULL, "UINT64[]"), cast([NULL, 432124, 0, NULL], "UINT64[]");
---- 1
[]|[,]||[,432124,0,]

-LOG CastNestedVarListToNestedVarList
-STATEMENT RETURN cast([[4324.2312, 432.321, 43242.543], [31214.59,4132.72], NULL, [NULL,,4324.32]], "INT64[][]");
---- 1
[[4324,432,43243],[31215,4133],,[,,4324]]
-STATEMENT RETURN cast(["[123, 3234]", "[124, 3241]", NULL, "[0, -4324234]"], "INT64[2][]"), cast(cast(["[123, 3234]", "[124, 3241]", NULL, "[0, -4324234]"], "DOUBLE[2][]"), "STRING[]");
---- 1
[[123,3234],[124,3241],,[0,-4324234]]|[[123.000000,3234.000000],[124.000000,3241.000000],,[0.000000,-4324234.000000]]

Loading

0 comments on commit 9a73d5e

Please sign in to comment.