Skip to content

Commit

Permalink
finish implemention of cast VarList to VarList
Browse files Browse the repository at this point in the history
  • Loading branch information
AEsir777 committed Nov 13, 2023
1 parent 63a2932 commit c143192
Show file tree
Hide file tree
Showing 10 changed files with 203 additions and 81 deletions.
6 changes: 2 additions & 4 deletions src/expression_evaluator/function_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@ void FunctionExpressionEvaluator::evaluate() {
child->evaluate();
}
auto expr = reinterpret_cast<binder::ScalarFunctionExpression*>(expression.get());
if (expr->getFunctionName() == CAST_FUNC_NAME &&
parameters[0]->dataType.getLogicalTypeID() == LogicalTypeID::STRING) {
execFunc(parameters, *resultVector,
reinterpret_cast<function::StringCastFunctionBindData*>(expr->getBindData()));
if (expr->getFunctionName() == CAST_FUNC_NAME) {
execFunc(parameters, *resultVector, expr->getBindData());
return;
}
if (execFunc != nullptr) {
Expand Down
153 changes: 97 additions & 56 deletions src/function/vector_cast_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ static void castFixedListToString(
resultVector.setValue(resultPos, result);
}

template<typename /*EXECUTOR*/ = UnaryFunctionExecutor>
static void fixedListCastExecFunction(const std::vector<std::shared_ptr<ValueVector>>& params,
ValueVector& result, void* /*dataPtr*/ = nullptr) {
assert(params.size() == 1);
KU_ASSERT(params.size() == 1);
auto param = params[0];
if (param->state->isFlat()) {
castFixedListToString(*param, param->state->selVector->selectedPositions[0], result,
Expand All @@ -53,10 +54,25 @@ static void fixedListCastExecFunction(const std::vector<std::shared_ptr<ValueVec
}
}

template<typename T>
template<>
void fixedListCastExecFunction<CastChildFunctionExecutor>(
const std::vector<std::shared_ptr<ValueVector>>& params, ValueVector& result,
void* /*dataPtr*/) {
KU_ASSERT(params.size() == 1);

auto inputVector = params[0].get();
auto childNum = ListVector::getDataVectorSize(inputVector);
auto inputChildVector = (ListVector::getDataVector(inputVector));
auto resultChildVector = (ListVector::getDataVector(&result));
for (auto i = 0u; i < childNum; i++) {
castFixedListToString(*inputChildVector, i, *resultChildVector, i);
}
}

template<typename /*EXECUTOR*/ = UnaryFunctionExecutor>
static void StringtoFixedListCastExecFunction(
const std::vector<std::shared_ptr<ValueVector>>& params, ValueVector& result, void* dataPtr) {
assert(params.size() == 1);
KU_ASSERT(params.size() == 1);
auto param = params[0];
auto csvReaderConfig = &reinterpret_cast<StringCastFunctionBindData*>(dataPtr)->csvConfig;
if (param->state->isFlat()) {
Expand Down Expand Up @@ -90,36 +106,61 @@ static void StringtoFixedListCastExecFunction(
template<>
void StringtoFixedListCastExecFunction<CastChildFunctionExecutor>(
const std::vector<std::shared_ptr<ValueVector>>& params, ValueVector& result, void* dataPtr) {
KU_ASSERT(params.size() == 1 &&
params[0]->dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST &&
result.dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST);
auto csvReaderConfig = &reinterpret_cast<StringCastFunctionBindData*>(dataPtr)->csvConfig;

}

template<typename EXECUTOR = UnaryFunctionExecutor>
static void varListCastExecFunction(const std::vector<std::shared_ptr<ValueVector>>& params,
ValueVector& result, void* dataPtr) {
assert(params.size() == 1);
auto inputVector = params[0];
scalar_exec_func func = CastFunction::bindCastFunction<CastChildFunctionExecutor>("CAST",
inputVector->dataType.getLogicalTypeID(), result.dataType.getLogicalTypeID())
->execFunc;
for (auto i = 0u; i < inputVector->state->selVector->selectedSize; i++) {
auto pos = inputVector->state->selVector->selectedPositions[i];
result.setNull(pos, inputVector->isNull(pos));
if (!result.isNull(pos)) {
auto input_list_entry = inputVector->getValue<list_entry_t>(pos);
auto result_list_entry = ListVector::addList(&result, input_list_entry.size);
result.setValue<list_entry_t>(pos, result_list_entry);
auto inputVector = params[0].get();
auto childNum = ListVector::getDataVectorSize(inputVector);
auto inputChildVector = (ListVector::getDataVector(inputVector));
auto resultChildVector = (ListVector::getDataVector(&result));
for (auto i = 0u; i < childNum; i++) {
resultChildVector->setNull(i, inputChildVector->isNull(i));
if (!resultChildVector->isNull(i)) {
CastString::castToFixedList(
inputChildVector->getValue<ku_string_t>(i), resultChildVector, i, csvReaderConfig);
}
}
func(params, result, dataPtr);
}

template<>
void varListCastExecFunction<CastChildFunctionExecutor>(
static void varListCastExecFunction(
const std::vector<std::shared_ptr<ValueVector>>& params, ValueVector& result, void* dataPtr) {
KU_ASSERT(params.size() == 1);
auto inputChildVector = params[0];

// TODO: Kebing finish this one
result.resetAuxiliaryBuffer();
auto inputVector = params[0];
auto resultVector = &result;

auto childNum = ListVector::getDataVectorSize(inputVector.get());
ListVector::addList(&result, childNum);
auto listEntrySize = inputVector->state->selVector
->selectedPositions[inputVector->state->selVector->selectedSize - 1] +
1;
memcpy(resultVector->getData(), inputVector->getData(),
listEntrySize * resultVector->getNumBytesPerValue());
resultVector->setNullFromBits(inputVector->getNullMaskData(), 0, 0, listEntrySize);

// resolve to the lowest level dataVector
auto inputChildTypeID = VarListType::getChildType(&inputVector->dataType)->getLogicalTypeID();
auto resultChildTypeID = VarListType::getChildType(&resultVector->dataType)->getLogicalTypeID();
while (inputChildTypeID == LogicalTypeID::VAR_LIST &&
resultChildTypeID == LogicalTypeID::VAR_LIST) {
inputVector = ListVector::getSharedDataVector(inputVector.get());
resultVector = ListVector::getDataVector(resultVector);
inputChildTypeID = VarListType::getChildType(&inputVector->dataType)->getLogicalTypeID();
resultChildTypeID = VarListType::getChildType(&resultVector->dataType)->getLogicalTypeID();

// copy NULL musk and list entry
memcpy(resultVector->getData(), inputVector->getData(), childNum * resultVector->getNumBytesPerValue());
resultVector->setNullFromBits(inputVector->getNullMaskData(), 0, 0, childNum);
childNum = ListVector::getDataVectorSize(inputVector.get());
ListVector::addList(resultVector, childNum);
}
scalar_exec_func func = CastFunction::bindCastFunction<CastChildFunctionExecutor>(
"CAST", inputChildTypeID, resultChildTypeID)
->execFunc;
std::vector<std::shared_ptr<ValueVector>> childParams{inputVector};
func(childParams, *resultVector, dataPtr);
}

bool CastFunction::hasImplicitCast(const LogicalType& srcType, const LogicalType& dstType) {
Expand Down Expand Up @@ -165,12 +206,12 @@ static std::unique_ptr<ScalarFunction> bindCastFromStringFunction(
ScalarFunction::UnaryCastStringExecFunction<ku_string_t, date_t, CastString, EXECUTOR>;
} break;
case LogicalTypeID::TIMESTAMP: {
execFunc = ScalarFunction::UnaryCastStringExecFunction<
ku_string_t, timestamp_t, CastString, EXECUTOR>;
execFunc = ScalarFunction::UnaryCastStringExecFunction<ku_string_t, timestamp_t, CastString,
EXECUTOR>;
} break;
case LogicalTypeID::INTERVAL: {
execFunc = ScalarFunction::UnaryCastStringExecFunction<
ku_string_t, interval_t, CastString, EXECUTOR>;
execFunc = ScalarFunction::UnaryCastStringExecFunction<ku_string_t, interval_t, CastString,
EXECUTOR>;
} break;
case LogicalTypeID::BLOB: {
execFunc =
Expand All @@ -185,16 +226,16 @@ static std::unique_ptr<ScalarFunction> bindCastFromStringFunction(
ScalarFunction::UnaryCastStringExecFunction<ku_string_t, bool, CastString, EXECUTOR>;
} break;
case LogicalTypeID::DOUBLE: {
execFunc = ScalarFunction::UnaryCastStringExecFunction<
ku_string_t, double_t, CastString, EXECUTOR>;
execFunc = ScalarFunction::UnaryCastStringExecFunction<ku_string_t, double_t, CastString,
EXECUTOR>;
} break;
case LogicalTypeID::FLOAT: {
execFunc =
ScalarFunction::UnaryCastStringExecFunction<ku_string_t, float, CastString, EXECUTOR>;
} break;
case LogicalTypeID::INT128: {
execFunc = ScalarFunction::UnaryCastStringExecFunction<
ku_string_t, int128_t, CastString, EXECUTOR>;
execFunc = ScalarFunction::UnaryCastStringExecFunction<ku_string_t, int128_t, CastString,
EXECUTOR>;
} break;
case LogicalTypeID::SERIAL:
case LogicalTypeID::INT64: {
Expand All @@ -214,35 +255,35 @@ static std::unique_ptr<ScalarFunction> bindCastFromStringFunction(
ScalarFunction::UnaryCastStringExecFunction<ku_string_t, int8_t, CastString, EXECUTOR>;
} break;
case LogicalTypeID::UINT64: {
execFunc = ScalarFunction::UnaryCastStringExecFunction<
ku_string_t, uint64_t, CastString, EXECUTOR>;
execFunc = ScalarFunction::UnaryCastStringExecFunction<ku_string_t, uint64_t, CastString,
EXECUTOR>;
} break;
case LogicalTypeID::UINT32: {
execFunc = ScalarFunction::UnaryCastStringExecFunction<
ku_string_t, uint32_t, CastString, EXECUTOR>;
execFunc = ScalarFunction::UnaryCastStringExecFunction<ku_string_t, uint32_t, CastString,
EXECUTOR>;
} break;
case LogicalTypeID::UINT16: {
execFunc = ScalarFunction::UnaryCastStringExecFunction<
ku_string_t, uint16_t, CastString, EXECUTOR>;
execFunc = ScalarFunction::UnaryCastStringExecFunction<ku_string_t, uint16_t, CastString,
EXECUTOR>;
} break;
case LogicalTypeID::UINT8: {
execFunc =
ScalarFunction::UnaryCastStringExecFunction<ku_string_t, uint8_t, CastString, EXECUTOR>;
} break;
case LogicalTypeID::VAR_LIST: {
execFunc = ScalarFunction::UnaryCastStringExecFunction<
ku_string_t, list_entry_t, CastString, EXECUTOR>;
execFunc = ScalarFunction::UnaryCastStringExecFunction<ku_string_t, list_entry_t,
CastString, EXECUTOR>;
} break;
case LogicalTypeID::FIXED_LIST: {
execFunc = StringtoFixedListCastExecFunction<EXECUTOR>;
} break;
case LogicalTypeID::MAP: {
execFunc = ScalarFunction::UnaryCastStringExecFunction<
ku_string_t, map_entry_t, CastString, EXECUTOR>;
execFunc = ScalarFunction::UnaryCastStringExecFunction<ku_string_t, map_entry_t, CastString,
EXECUTOR>;
} break;
case LogicalTypeID::STRUCT: {
execFunc = ScalarFunction::UnaryCastStringExecFunction<
ku_string_t, struct_entry_t, CastString, EXECUTOR>;
execFunc = ScalarFunction::UnaryCastStringExecFunction<ku_string_t, struct_entry_t,
CastString, EXECUTOR>;
} break;
case LogicalTypeID::UNION: {
execFunc = ScalarFunction::UnaryCastStringExecFunction<ku_string_t, union_entry_t,
Expand Down Expand Up @@ -402,7 +443,7 @@ static std::unique_ptr<ScalarFunction> bindCastToStringFunction(
EXECUTOR>;
} break;
case LogicalTypeID::FIXED_LIST: {
func = fixedListCastExecFunction;
func = fixedListCastExecFunction<EXECUTOR>;
} break;
case LogicalTypeID::MAP: {
func =
Expand Down Expand Up @@ -473,17 +514,20 @@ static std::unique_ptr<ScalarFunction> bindCastToNumericFunction(
functionName, std::vector<LogicalTypeID>{sourceTypeID}, targetTypeID, func);
}

template<typename DST_TYPE, typename EXECUTOR = UnaryFunctionExecutor>
template<typename /*DST_TYPE*/>
static std::unique_ptr<ScalarFunction> bindCastBetweenNested(
const std::string& functionName, LogicalTypeID sourceTypeID, LogicalTypeID targetTypeID) {
scalar_exec_func func;
switch (sourceTypeID) {
case LogicalTypeID::VAR_LIST: {
func = varListCastExecFunction<EXECUTOR>;
func = varListCastExecFunction;
} break;
default:
// lcov_excl_start
// TODO(kebing): implement more
throw ConversionException{stringFormat("Unsupported casting function from {} to {}.",
LogicalTypeUtils::toString(sourceTypeID), LogicalTypeUtils::toString(targetTypeID))};
// lcov_excl_end
}
return std::make_unique<ScalarFunction>(
functionName, std::vector<LogicalTypeID>{sourceTypeID}, targetTypeID, func);
Expand All @@ -495,7 +539,8 @@ static std::unique_ptr<ScalarFunction> bindCastToTimestampFunction(
scalar_exec_func func;
switch (sourceTypeID) {
case LogicalTypeID::DATE: {
func = ScalarFunction::UnaryExecFunction<date_t, timestamp_t, CastDateToTimestamp, EXECUTOR>;
func =
ScalarFunction::UnaryExecFunction<date_t, timestamp_t, CastDateToTimestamp, EXECUTOR>;
} break;
default:
throw ConversionException{stringFormat("Unsupported casting function from {} to TIMESTAMP.",
Expand Down Expand Up @@ -570,8 +615,7 @@ std::unique_ptr<ScalarFunction> CastFunction::bindCastFunction(
return bindCastToTimestampFunction<EXECUTOR>(functionName, sourceTypeID);
}
case LogicalTypeID::VAR_LIST: {
return bindCastBetweenNested<list_entry_t, EXECUTOR>(
functionName, sourceTypeID, targetTypeID);
return bindCastBetweenNested<list_entry_t>(functionName, sourceTypeID, targetTypeID);
}
default: {
throw ConversionException{stringFormat("Unsupported casting function from {} to {}.",
Expand Down Expand Up @@ -810,10 +854,7 @@ std::unique_ptr<FunctionBindData> CastAnyFunction::bindFunc(
func->execFunc =
CastFunction::bindCastFunction(func->name, inputTypeID, outputType->getLogicalTypeID())
->execFunc;
if (inputTypeID == LogicalTypeID::STRING) {
return std::make_unique<function::StringCastFunctionBindData>(*outputType);
}
return std::make_unique<function::FunctionBindData>(*outputType);
return std::make_unique<function::StringCastFunctionBindData>(*outputType);
}

function_set CastAnyFunction::getFunctionSet() {
Expand Down
2 changes: 1 addition & 1 deletion src/include/common/null_mask.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class NullMask {

// const because updates to the data must set mayContainNulls if any value
// becomes non-null
// Modifying the underlying data shuld be done with setNull or copyFromNullData
// Modifying the underlying data should be done with setNull or copyFromNullData
inline const uint64_t* getData() { return data; }

static inline uint64_t getNumNullEntries(uint64_t numNullBits) {
Expand Down
1 change: 1 addition & 0 deletions src/include/common/vector/auxiliary_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class ListAuxiliaryBuffer : public AuxiliaryBuffer {
dataVector = std::move(vector);
}
inline ValueVector* getDataVector() const { return dataVector.get(); }
inline std::shared_ptr<ValueVector> getSharedDataVector() const { return dataVector; }

list_entry_t addList(uint64_t listSize);

Expand Down
5 changes: 5 additions & 0 deletions src/include/common/vector/value_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ class ListVector {
return reinterpret_cast<ListAuxiliaryBuffer*>(vector->auxiliaryBuffer.get())
->getDataVector();
}
static inline std::shared_ptr<ValueVector> getSharedDataVector(const ValueVector* vector) {
KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::VAR_LIST);
return reinterpret_cast<ListAuxiliaryBuffer*>(vector->auxiliaryBuffer.get())
->getSharedDataVector();
}
static inline uint64_t getDataVectorSize(const ValueVector* vector) {
KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::VAR_LIST);
return reinterpret_cast<ListAuxiliaryBuffer*>(vector->auxiliaryBuffer.get())->getSize();
Expand Down
9 changes: 4 additions & 5 deletions src/include/function/scalar_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,8 @@ struct ScalarFunction : public BaseScalarFunction {
const std::vector<std::shared_ptr<common::ValueVector>>& params,
common::ValueVector& result, void* dataPtr) {
KU_ASSERT(params.size() == 1);
EXECUTOR::template executeSwitch<
OPERAND_TYPE, RESULT_TYPE, FUNC, UnaryCastStringFunctionWrapper>
(*params[0], result, dataPtr);
EXECUTOR::template executeSwitch<OPERAND_TYPE, RESULT_TYPE, FUNC,
UnaryCastStringFunctionWrapper>(*params[0], result, dataPtr);
}

template<typename OPERAND_TYPE, typename RESULT_TYPE, typename FUNC,
Expand All @@ -127,8 +126,8 @@ struct ScalarFunction : public BaseScalarFunction {
const std::vector<std::shared_ptr<common::ValueVector>>& params,
common::ValueVector& result, void* /*dataPtr*/ = nullptr) {
KU_ASSERT(params.size() == 1);
EXECUTOR::template executeSwitch<OPERAND_TYPE, RESULT_TYPE, FUNC,
UnaryCastFunctionWrapper>(*params[0], result, nullptr /* dataPtr */);
EXECUTOR::template executeSwitch<OPERAND_TYPE, RESULT_TYPE, FUNC, UnaryCastFunctionWrapper>(
*params[0], result, nullptr /* dataPtr */);
}

template<typename OPERAND_TYPE, typename RESULT_TYPE, typename FUNC,
Expand Down
12 changes: 6 additions & 6 deletions src/include/function/unary_function_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,15 @@ struct UnaryUDFFunctionWrapper {

struct CastChildFunctionExecutor {
template<typename OPERAND_TYPE, typename RESULT_TYPE, typename FUNC, typename OP_WRAPPER>
static void executeSwitch(common::ValueVector& operand, common::ValueVector& result,
void* dataPtr) {
static void executeSwitch(
common::ValueVector& operand, common::ValueVector& result, void* dataPtr) {
// this vector is of var list type and the child vector is of non-nested types then cast
KU_ASSERT(operand.dataType.getLogicalTypeID() == common::LogicalTypeID::VAR_LIST);
result.resetAuxiliaryBuffer();
KU_ASSERT(operand.dataType.getLogicalTypeID() == common::LogicalTypeID::VAR_LIST &&
result.dataType.getLogicalTypeID() == common::LogicalTypeID::VAR_LIST);
auto childNum = common::ListVector::getDataVectorSize(&operand);
auto inputChildVector = common::ListVector::getDataVector(&operand);
auto resultChildVector = (common::ListVector::getDataVector(&result));
for (auto i = 0u; i < childNum; i++) {
auto inputChildVector = (common::ListVector::getDataVector(&operand));
auto resultChildVector = (common::ListVector::getDataVector(&operand));
resultChildVector->setNull(i, inputChildVector->isNull(i));
if (!resultChildVector->isNull(i)) {
// cast position i in child data vector
Expand Down
11 changes: 11 additions & 0 deletions test/test_files/tinysnb/cast/cast_error.test
Original file line number Diff line number Diff line change
Expand Up @@ -702,3 +702,14 @@ Conversion exception: Unsupported casting function from REL to TIMESTAMP.
-STATEMENT MATCH (:person)-[e:studyAt*1..3]->(:organisation) return cast(e, "INT64");
---- error
Conversion exception: Unsupported casting function from RECURSIVE_REL to INT64.

-LOG InvalidVarListToVarList
-STATEMENT RETURN cast([31231], "INT64[][]");
---- error
Conversion exception: Unsupported casting function from INT64 to VAR_LIST.
-STATEMENT RETURN cast([-1], "UINT8[]");
---- error
Overflow exception: Value -1 is not within UINT8 range
-STATEMENT RETURN cast([[1, 1]], "UINT8[]");
---- error
Conversion exception: Unsupported casting function from VAR_LIST to UINT8.
Loading

0 comments on commit c143192

Please sign in to comment.