Skip to content

Commit

Permalink
Merge pull request #2307 from kuzudb/list-fix-null
Browse files Browse the repository at this point in the history
Fix list extract with null
  • Loading branch information
acquamarin committed Oct 31, 2023
2 parents 5db9186 + 6c8662b commit 2021979
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 60 deletions.
76 changes: 42 additions & 34 deletions src/function/vector_list_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,78 +139,86 @@ function_set SizeFunction::getFunctionSet() {
return result;
}

template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename FUNC>
static void BinaryExecListExtractFunction(
const std::vector<std::shared_ptr<common::ValueVector>>& params, common::ValueVector& result) {
assert(params.size() == 2);
BinaryFunctionExecutor::executeListExtract<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC>(
*params[0], *params[1], result);
}

std::unique_ptr<FunctionBindData> ListExtractFunction::bindFunc(
const binder::expression_vector& arguments, Function* function) {
auto resultType = VarListType::getChildType(&arguments[0]->dataType);
auto scalarFunction = reinterpret_cast<ScalarFunction*>(function);
switch (resultType->getPhysicalType()) {
case PhysicalTypeID::BOOL: {
scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t,
int64_t, uint8_t, ListExtract>;
scalarFunction->execFunc =
BinaryExecListExtractFunction<list_entry_t, int64_t, uint8_t, ListExtract>;
} break;
case PhysicalTypeID::INT64: {
scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t,
int64_t, int64_t, ListExtract>;
scalarFunction->execFunc =
BinaryExecListExtractFunction<list_entry_t, int64_t, int64_t, ListExtract>;
} break;
case PhysicalTypeID::INT32: {
scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t,
int64_t, int32_t, ListExtract>;
scalarFunction->execFunc =
BinaryExecListExtractFunction<list_entry_t, int64_t, int32_t, ListExtract>;
} break;
case PhysicalTypeID::INT16: {
scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t,
int64_t, int16_t, ListExtract>;
scalarFunction->execFunc =
BinaryExecListExtractFunction<list_entry_t, int64_t, int16_t, ListExtract>;
} break;
case PhysicalTypeID::INT8: {
scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t,
int64_t, int8_t, ListExtract>;
scalarFunction->execFunc =
BinaryExecListExtractFunction<list_entry_t, int64_t, int8_t, ListExtract>;
} break;
case PhysicalTypeID::UINT64: {
scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t,
int64_t, uint64_t, ListExtract>;
scalarFunction->execFunc =
BinaryExecListExtractFunction<list_entry_t, int64_t, uint64_t, ListExtract>;
} break;
case PhysicalTypeID::UINT32: {
scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t,
int64_t, uint32_t, ListExtract>;
scalarFunction->execFunc =
BinaryExecListExtractFunction<list_entry_t, int64_t, uint32_t, ListExtract>;
} break;
case PhysicalTypeID::UINT16: {
scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t,
int64_t, uint16_t, ListExtract>;
scalarFunction->execFunc =
BinaryExecListExtractFunction<list_entry_t, int64_t, uint16_t, ListExtract>;
} break;
case PhysicalTypeID::UINT8: {
scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t,
int64_t, uint8_t, ListExtract>;
scalarFunction->execFunc =
BinaryExecListExtractFunction<list_entry_t, int64_t, uint8_t, ListExtract>;
} break;
case PhysicalTypeID::INT128: {
scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t,
int64_t, int128_t, ListExtract>;
scalarFunction->execFunc =
BinaryExecListExtractFunction<list_entry_t, int64_t, int128_t, ListExtract>;
} break;
case PhysicalTypeID::DOUBLE: {
scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t,
int64_t, double_t, ListExtract>;
scalarFunction->execFunc =
BinaryExecListExtractFunction<list_entry_t, int64_t, double_t, ListExtract>;
} break;
case PhysicalTypeID::FLOAT: {
scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t,
int64_t, float_t, ListExtract>;
scalarFunction->execFunc =
BinaryExecListExtractFunction<list_entry_t, int64_t, float_t, ListExtract>;
} break;
case PhysicalTypeID::INTERVAL: {
scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t,
int64_t, interval_t, ListExtract>;
scalarFunction->execFunc =
BinaryExecListExtractFunction<list_entry_t, int64_t, interval_t, ListExtract>;
} break;
case PhysicalTypeID::STRING: {
scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t,
int64_t, ku_string_t, ListExtract>;
scalarFunction->execFunc =
BinaryExecListExtractFunction<list_entry_t, int64_t, ku_string_t, ListExtract>;
} break;
case PhysicalTypeID::VAR_LIST: {
scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t,
int64_t, list_entry_t, ListExtract>;
scalarFunction->execFunc =
BinaryExecListExtractFunction<list_entry_t, int64_t, list_entry_t, ListExtract>;
} break;
case PhysicalTypeID::STRUCT: {
scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t,
int64_t, struct_entry_t, ListExtract>;
scalarFunction->execFunc =
BinaryExecListExtractFunction<list_entry_t, int64_t, struct_entry_t, ListExtract>;
} break;
case PhysicalTypeID::INTERNAL_ID: {
scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t,
int64_t, internalID_t, ListExtract>;
scalarFunction->execFunc =
BinaryExecListExtractFunction<list_entry_t, int64_t, internalID_t, ListExtract>;
} break;
default: {
throw NotImplementedException("ListExtractFunction::bindFunc");
Expand Down
30 changes: 23 additions & 7 deletions src/include/function/binary_function_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ struct BinaryFunctionWrapper {
template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename OP>
static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result,
common::ValueVector* /*leftValueVector*/, common::ValueVector* /*rightValueVector*/,
common::ValueVector* /*resultValueVector*/, void* /*dataPtr*/) {

common::ValueVector* /*resultValueVector*/, uint64_t /*resultPos*/, void* /*dataPtr*/) {
OP::operation(left, right, result);
}
};
Expand All @@ -28,16 +27,26 @@ struct BinaryListStructFunctionWrapper {
template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename OP>
static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result,
common::ValueVector* leftValueVector, common::ValueVector* rightValueVector,
common::ValueVector* resultValueVector, void* /*dataPtr*/) {
common::ValueVector* resultValueVector, uint64_t /*resultPos*/, void* /*dataPtr*/) {
OP::operation(left, right, result, *leftValueVector, *rightValueVector, *resultValueVector);
}
};

struct BinaryListExtractFunctionWrapper {
template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename OP>
static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result,
common::ValueVector* leftValueVector, common::ValueVector* rightValueVector,
common::ValueVector* resultValueVector, uint64_t resultPos, void* /*dataPtr*/) {
OP::operation(left, right, result, *leftValueVector, *rightValueVector, *resultValueVector,
resultPos);
}
};

struct BinaryStringFunctionWrapper {
template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename OP>
static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result,
common::ValueVector* /*leftValueVector*/, common::ValueVector* /*rightValueVector*/,
common::ValueVector* resultValueVector, void* /*dataPtr*/) {
common::ValueVector* resultValueVector, uint64_t /*resultPos*/, void* /*dataPtr*/) {
OP::operation(left, right, result, *resultValueVector);
}
};
Expand All @@ -46,7 +55,7 @@ struct BinaryComparisonFunctionWrapper {
template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename OP>
static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result,
common::ValueVector* leftValueVector, common::ValueVector* rightValueVector,
common::ValueVector* /*resultValueVector*/, void* /*dataPtr*/) {
common::ValueVector* /*resultValueVector*/, uint64_t /*resultPos*/, void* /*dataPtr*/) {
OP::operation(left, right, result, leftValueVector, rightValueVector);
}
};
Expand All @@ -55,7 +64,7 @@ struct BinaryUDFFunctionWrapper {
template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename OP>
static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result,
common::ValueVector* /*leftValueVector*/, common::ValueVector* /*rightValueVector*/,
common::ValueVector* /*resultValueVector*/, void* dataPtr) {
common::ValueVector* /*resultValueVector*/, uint64_t /*resultPos*/, void* dataPtr) {
OP::operation(left, right, result, dataPtr);
}
};
Expand All @@ -69,7 +78,7 @@ struct BinaryFunctionExecutor {
OP_WRAPPER::template operation<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC>(
((LEFT_TYPE*)left.getData())[lPos], ((RIGHT_TYPE*)right.getData())[rPos],
((RESULT_TYPE*)resultValueVector.getData())[resPos], &left, &right, &resultValueVector,
dataPtr);
resPos, dataPtr);
}

template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename FUNC,
Expand Down Expand Up @@ -253,6 +262,13 @@ struct BinaryFunctionExecutor {
left, right, result, nullptr /* dataPtr */);
}

template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename FUNC>
static void executeListExtract(
common::ValueVector& left, common::ValueVector& right, common::ValueVector& result) {
executeSwitch<LEFT_TYPE, RIGHT_TYPE, RESULT_TYPE, FUNC, BinaryListExtractFunctionWrapper>(
left, right, result, nullptr /* dataPtr */);
}

template<typename LEFT_TYPE, typename RIGHT_TYPE, typename RESULT_TYPE, typename FUNC>
static void executeComparison(
common::ValueVector& left, common::ValueVector& right, common::ValueVector& result) {
Expand Down
24 changes: 8 additions & 16 deletions src/include/function/list/functions/list_extract_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,12 @@ namespace function {

struct ListExtract {
public:
template<typename T>
static inline void setValue(T& src, T& dest, common::ValueVector& /*resultValueVector*/) {
dest = src;
}

// Note: this function takes in a 1-based position (The index of the first value in the list
// is 1).
template<typename T>
static inline void operation(common::list_entry_t& listEntry, int64_t pos, T& result,
common::ValueVector& listVector, common::ValueVector& /*posVector*/,
common::ValueVector& resultVector) {
common::ValueVector& resultVector, uint64_t resPos) {
if (pos == 0) {
throw common::RuntimeException("List extract takes 1-based position.");
}
Expand All @@ -35,10 +30,13 @@ struct ListExtract {
return; // TODO(Xiyang/Ziyi): we should fix when extracting last element of list.
}
auto listDataVector = common::ListVector::getDataVector(&listVector);
auto listValues =
common::ListVector::getListValuesWithOffset(&listVector, listEntry, upos - 1);
resultVector.copyFromVectorData(
reinterpret_cast<uint8_t*>(&result), listDataVector, listValues);
resultVector.setNull(resPos, listDataVector->isNull(listEntry.offset + upos - 1));
if (!resultVector.isNull(resPos)) {
auto listValues =
common::ListVector::getListValuesWithOffset(&listVector, listEntry, upos - 1);
resultVector.copyFromVectorData(
reinterpret_cast<uint8_t*>(&result), listDataVector, listValues);
}
}

static inline void operation(
Expand All @@ -51,11 +49,5 @@ struct ListExtract {
}
};

template<>
inline void ListExtract::setValue(
common::ku_string_t& src, common::ku_string_t& dest, common::ValueVector& resultValueVector) {
common::StringVector::addString(&resultValueVector, dest, src);
}

} // namespace function
} // namespace kuzu
4 changes: 2 additions & 2 deletions src/include/function/schema/label_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ namespace function {
struct Label {
static inline void operation(common::internalID_t& left, common::list_entry_t& right,
common::ku_string_t& result, common::ValueVector& leftVector,
common::ValueVector& rightVector, common::ValueVector& resultVector) {
common::ValueVector& rightVector, common::ValueVector& resultVector, uint64_t resPos) {
assert(left.tableID < right.size);
ListExtract::operation(right, left.tableID + 1 /* listExtract requires 1-based index */,
result, rightVector, leftVector, resultVector);
result, rightVector, leftVector, resultVector, resPos);
}
};

Expand Down
2 changes: 1 addition & 1 deletion src/include/function/schema/vector_label_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ struct LabelFunction {
static void execFunction(const std::vector<std::shared_ptr<common::ValueVector>>& params,
common::ValueVector& result) {
assert(params.size() == 2);
BinaryFunctionExecutor::executeListStruct<common::internalID_t, common::list_entry_t,
BinaryFunctionExecutor::executeListExtract<common::internalID_t, common::list_entry_t,
common::ku_string_t, Label>(*params[0], *params[1], result);
}
};
Expand Down
14 changes: 14 additions & 0 deletions test/test_files/tinysnb/function/list.test
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,20 @@ n
sdwe
ad

-LOG ListExtractWithNull
-STATEMENT RETURN list_extract([1,3,null,null,2],3)
---- 1


-LOG ExtractNullList
-STATEMENT RETURN list_extract(null,1)
---- 1


-LOG ListExtractNullPos
-STATEMENT RETURN LIST_EXTRACT([3,4,5],NULL)
---- 1


-LOG SliceUTF8String
-STATEMENT Return '这是一个中文句子'[2:5]
Expand Down

0 comments on commit 2021979

Please sign in to comment.