Skip to content

Commit

Permalink
finish cast between fixed list and refactor code;
Browse files Browse the repository at this point in the history
  • Loading branch information
AEsir777 committed Nov 20, 2023
1 parent b7874af commit 0ecc9e3
Show file tree
Hide file tree
Showing 7 changed files with 517 additions and 316 deletions.
3 changes: 2 additions & 1 deletion src/function/cast/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_library(kuzu_function_cast
OBJECT
cast_rdf_variant.cpp)
cast_rdf_variant.cpp
cast_fixed_list.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_function_cast>
Expand Down
389 changes: 389 additions & 0 deletions src/function/cast/cast_fixed_list.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,389 @@
#include "function/cast/functions/cast_fixed_list.h"

#include "common/exception/conversion.h"
#include "common/type_utils.h"
#include "function/cast/functions/cast_from_string_functions.h"
#include "function/cast/functions/cast_functions.h"

namespace kuzu {
namespace function {

bool CastFixedListHelper::containsListToFixedList(
const LogicalType* srcType, const LogicalType* dstType) {
if (srcType->getLogicalTypeID() == LogicalTypeID::VAR_LIST &&
dstType->getLogicalTypeID() == LogicalTypeID::FIXED_LIST) {
return true;
}

while (srcType->getLogicalTypeID() == dstType->getLogicalTypeID()) {
switch (srcType->getPhysicalType()) {
case PhysicalTypeID::VAR_LIST: {
return containsListToFixedList(
VarListType::getChildType(srcType), VarListType::getChildType(dstType));
}
case PhysicalTypeID::STRUCT: {
auto srcFieldTypes = StructType::getFieldTypes(srcType);
auto dstFieldTypes = StructType::getFieldTypes(dstType);
if (srcFieldTypes.size() != dstFieldTypes.size()) {
throw ConversionException{
stringFormat("Unsupported casting function from {} to {}.", srcType->toString(),
dstType->toString())};
}

auto result = false;
std::vector<struct_field_idx_t> fields;
for (auto i = 0u; i < srcFieldTypes.size(); i++) {
if (containsListToFixedList(srcFieldTypes[i], dstFieldTypes[i])) {
return true;
}
}
}
default:
return false;
}
}
return false;
}

void CastFixedListHelper::validateListEntry(
ValueVector* inputVector, LogicalType* resultType, uint64_t pos) {
if (inputVector->isNull(pos)) {
return;
}
auto inputTypeID = inputVector->dataType.getPhysicalType();

switch (resultType->getPhysicalType()) {
case PhysicalTypeID::FIXED_LIST: {
if (inputTypeID == PhysicalTypeID::VAR_LIST) {
auto listEntry = inputVector->getValue<list_entry_t>(pos);
if (listEntry.size != FixedListType::getNumValuesInList(resultType)) {
throw ConversionException{stringFormat(
"Unsupported casting VAR_LIST with incorrect list entry to FIXED_LIST. "
"Expected: {}, Actual: {}.",
FixedListType::getNumValuesInList(resultType),
inputVector->getValue<list_entry_t>(pos).size)};
}

auto inputChildVector = ListVector::getDataVector(inputVector);
for (auto i = listEntry.offset; i < listEntry.offset + listEntry.size; i++) {
if (inputChildVector->isNull(i)) {
throw ConversionException("Cast failed. NULL is not allowed for FIXED_LIST.");
}
}
}
} break;
case PhysicalTypeID::VAR_LIST: {
if (inputTypeID == PhysicalTypeID::VAR_LIST) {
auto listEntry = inputVector->getValue<list_entry_t>(pos);
for (auto i = listEntry.offset; i < listEntry.offset + listEntry.size; i++) {
validateListEntry(ListVector::getDataVector(inputVector),
VarListType::getChildType(resultType), i);
}
}
} break;
case PhysicalTypeID::STRUCT: {
if (inputTypeID == PhysicalTypeID::STRUCT) {
auto fieldVectors = StructVector::getFieldVectors(inputVector);
auto fieldTypes = StructType::getFieldTypes(resultType);

auto structEntry = inputVector->getValue<struct_entry_t>(pos);
for (auto i = 0u; i < fieldVectors.size(); i++) {
validateListEntry(fieldVectors[i].get(), fieldTypes[i], structEntry.pos);
}
}
} break;
default: {
return;
}
}
}

static void CastFixedListToString(
ValueVector& param, uint64_t pos, ValueVector& resultVector, uint64_t resultPos) {
resultVector.setNull(resultPos, param.isNull(pos));
if (param.isNull(pos)) {
return;
}
std::string result = "[";
auto numValuesPerList = FixedListType::getNumValuesInList(&param.dataType);
auto childType = FixedListType::getChildType(&param.dataType);
auto values = param.getData() + pos * param.getNumBytesPerValue();
for (auto i = 0u; i < numValuesPerList - 1; ++i) {
// Note: FixedList can only store numeric types and doesn't allow nulls.
result += TypeUtils::castValueToString(*childType, values, nullptr /* vector */);
result += ",";
values += PhysicalTypeUtils::getFixedTypeSize(childType->getPhysicalType());
}
result += TypeUtils::castValueToString(*childType, values, nullptr /* vector */);
result += "]";
resultVector.setValue(resultPos, result);
}

template<>
void CastFixedList::fixedListToStringCastExecFunction<UnaryFunctionExecutor>(
const std::vector<std::shared_ptr<ValueVector>>& params, ValueVector& result,
void* /*dataPtr*/) {
KU_ASSERT(params.size() == 1);
auto param = params[0];
if (param->state->isFlat()) {
CastFixedListToString(*param, param->state->selVector->selectedPositions[0], result,
result.state->selVector->selectedPositions[0]);
} else if (param->state->selVector->isUnfiltered()) {
for (auto i = 0u; i < param->state->selVector->selectedSize; i++) {
CastFixedListToString(*param, i, result, i);
}
} else {
for (auto i = 0u; i < param->state->selVector->selectedSize; i++) {
CastFixedListToString(*param, param->state->selVector->selectedPositions[i], result,
result.state->selVector->selectedPositions[i]);
}
}
}

template<>
void CastFixedList::fixedListToStringCastExecFunction<CastFixedListToListFunctionExecutor>(

Check warning on line 144 in src/function/cast/cast_fixed_list.cpp

View check run for this annotation

Codecov / codecov/patch

src/function/cast/cast_fixed_list.cpp#L144

Added line #L144 was not covered by tests
const std::vector<std::shared_ptr<ValueVector>>& params, ValueVector& result, void* dataPtr) {
KU_UNREACHABLE;
}

template<>
void CastFixedList::fixedListToStringCastExecFunction<CastChildFunctionExecutor>(
const std::vector<std::shared_ptr<ValueVector>>& params, ValueVector& result, void* dataPtr) {
KU_ASSERT(params.size() == 1);

auto inputVector = params[0].get();
auto numOfEntries = reinterpret_cast<CastFunctionBindData*>(dataPtr)->numOfEntries;
for (auto i = 0u; i < numOfEntries; i++) {
CastFixedListToString(*inputVector, i, result, i);
}
}

template<>
void CastFixedList::stringtoFixedListCastExecFunction<UnaryFunctionExecutor>(
const std::vector<std::shared_ptr<ValueVector>>& params, ValueVector& result, void* dataPtr) {
KU_ASSERT(params.size() == 1);
auto param = params[0];
auto csvReaderConfig = &reinterpret_cast<CastFunctionBindData*>(dataPtr)->csvConfig;
if (param->state->isFlat()) {
auto inputPos = param->state->selVector->selectedPositions[0];
auto resultPos = result.state->selVector->selectedPositions[0];
result.setNull(resultPos, param->isNull(inputPos));
if (!result.isNull(inputPos)) {
CastString::castToFixedList(
param->getValue<ku_string_t>(inputPos), &result, resultPos, csvReaderConfig);
}
} else if (param->state->selVector->isUnfiltered()) {
for (auto i = 0u; i < param->state->selVector->selectedSize; i++) {
result.setNull(i, param->isNull(i));
if (!result.isNull(i)) {
CastString::castToFixedList(
param->getValue<ku_string_t>(i), &result, i, csvReaderConfig);
}
}
} else {
for (auto i = 0u; i < param->state->selVector->selectedSize; i++) {
auto pos = param->state->selVector->selectedPositions[i];
result.setNull(pos, param->isNull(pos));
if (!result.isNull(pos)) {
CastString::castToFixedList(
param->getValue<ku_string_t>(pos), &result, pos, csvReaderConfig);
}
}
}
}

template<>
void CastFixedList::stringtoFixedListCastExecFunction<CastFixedListToListFunctionExecutor>(

Check warning on line 196 in src/function/cast/cast_fixed_list.cpp

View check run for this annotation

Codecov / codecov/patch

src/function/cast/cast_fixed_list.cpp#L196

Added line #L196 was not covered by tests
const std::vector<std::shared_ptr<ValueVector>>& params, ValueVector& result, void* dataPtr) {
KU_UNREACHABLE;
}

template<>
void CastFixedList::stringtoFixedListCastExecFunction<CastChildFunctionExecutor>(
const std::vector<std::shared_ptr<ValueVector>>& params, ValueVector& result, void* dataPtr) {
KU_ASSERT(params.size() == 1);
auto numOfEntries = reinterpret_cast<CastFunctionBindData*>(dataPtr)->numOfEntries;
auto csvReaderConfig = &reinterpret_cast<CastFunctionBindData*>(dataPtr)->csvConfig;

auto inputVector = params[0].get();
for (auto i = 0u; i < numOfEntries; i++) {
result.setNull(i, inputVector->isNull(i));
if (!result.isNull(i)) {
CastString::castToFixedList(
inputVector->getValue<ku_string_t>(i), &result, i, csvReaderConfig);
}
}
}

template<>
void CastFixedList::listToFixedListCastExecFunction<UnaryFunctionExecutor>(
const std::vector<std::shared_ptr<ValueVector>>& params, ValueVector& result, void* dataPtr) {
KU_ASSERT(params.size() == 1);
auto inputVector = params[0];

for (auto i = 0u; i < inputVector->state->selVector->selectedSize; i++) {
auto pos = inputVector->state->selVector->selectedPositions[i];
CastFixedListHelper::validateListEntry(inputVector.get(), &result.dataType, pos);
}

auto numOfEntries = inputVector->state->selVector
->selectedPositions[inputVector->state->selVector->selectedSize - 1] +
1;
reinterpret_cast<CastFunctionBindData*>(dataPtr)->numOfEntries = numOfEntries;
listToFixedListCastExecFunction<CastChildFunctionExecutor>(params, result, dataPtr);
}

template<>
void CastFixedList::listToFixedListCastExecFunction<CastFixedListToListFunctionExecutor>(

Check warning on line 237 in src/function/cast/cast_fixed_list.cpp

View check run for this annotation

Codecov / codecov/patch

src/function/cast/cast_fixed_list.cpp#L237

Added line #L237 was not covered by tests
const std::vector<std::shared_ptr<ValueVector>>& params, ValueVector& result, void* dataPtr) {
KU_UNREACHABLE;
}

using scalar_cast_func = std::function<void(void*, uint64_t, void*, uint64_t, void*)>;

template<typename DST_TYPE, typename OP>
static void getFixedListChildFuncHelper(scalar_cast_func& func, LogicalTypeID inputTypeID) {
switch (inputTypeID) {
case LogicalTypeID::STRING: {
func = UnaryCastStringFunctionWrapper::operation<ku_string_t, DST_TYPE, CastString>;
} break;
case LogicalTypeID::INT128: {
func = UnaryFunctionWrapper::operation<int128_t, DST_TYPE, OP>;
} break;
case LogicalTypeID::INT64: {
func = UnaryFunctionWrapper::operation<int64_t, DST_TYPE, OP>;
} break;
case LogicalTypeID::INT32: {
func = UnaryFunctionWrapper::operation<int32_t, DST_TYPE, OP>;
} break;
case LogicalTypeID::INT16: {
func = UnaryFunctionWrapper::operation<int16_t, DST_TYPE, OP>;
} break;
case LogicalTypeID::INT8: {
func = UnaryFunctionWrapper::operation<int8_t, DST_TYPE, OP>;
} break;
case LogicalTypeID::UINT8: {
func = UnaryFunctionWrapper::operation<uint8_t, DST_TYPE, OP>;
} break;
case LogicalTypeID::UINT16: {
func = UnaryFunctionWrapper::operation<uint16_t, DST_TYPE, OP>;
} break;
case LogicalTypeID::UINT32: {
func = UnaryFunctionWrapper::operation<uint32_t, DST_TYPE, OP>;
} break;
case LogicalTypeID::UINT64: {
func = UnaryFunctionWrapper::operation<uint64_t, DST_TYPE, OP>;
} break;
case LogicalTypeID::FLOAT: {
func = UnaryFunctionWrapper::operation<float_t, DST_TYPE, OP>;
} break;
case LogicalTypeID::DOUBLE: {
func = UnaryFunctionWrapper::operation<double_t, DST_TYPE, OP>;
} break;
default: {
throw ConversionException{
stringFormat("Unsupported casting function from {} to numerical type.",
LogicalTypeUtils::toString(inputTypeID))};
}
}
}

static void getFixedListChildCastFunc(
scalar_cast_func& func, LogicalTypeID inputType, LogicalTypeID resultType) {
// only support limited Fixed List Types
switch (resultType) {
case LogicalTypeID::INT64: {
return getFixedListChildFuncHelper<int64_t, CastToInt64>(func, inputType);
}
case LogicalTypeID::INT32: {
return getFixedListChildFuncHelper<int32_t, CastToInt32>(func, inputType);
}
case LogicalTypeID::INT16: {
return getFixedListChildFuncHelper<int16_t, CastToInt16>(func, inputType);
}
case LogicalTypeID::DOUBLE: {
return getFixedListChildFuncHelper<double_t, CastToDouble>(func, inputType);
}
case LogicalTypeID::FLOAT: {
return getFixedListChildFuncHelper<float_t, CastToFloat>(func, inputType);
}
default: {
throw RuntimeException("Unsupported FIXED_LIST type: Function::getFixedListChildCastFunc");
}
}
}

template<>
void CastFixedList::listToFixedListCastExecFunction<CastChildFunctionExecutor>(
const std::vector<std::shared_ptr<ValueVector>>& params, ValueVector& result, void* dataPtr) {
auto inputVector = params[0];
auto numOfEntries = reinterpret_cast<CastFunctionBindData*>(dataPtr)->numOfEntries;

auto inputChildId = VarListType::getChildType(&inputVector->dataType)->getLogicalTypeID();
auto outputChildId = FixedListType::getChildType(&result.dataType)->getLogicalTypeID();
auto numValuesPerList = FixedListType::getNumValuesInList(&result.dataType);
scalar_cast_func func;
getFixedListChildCastFunc(func, inputChildId, outputChildId);

result.setNullFromBits(inputVector->getNullMaskData(), 0, 0, numOfEntries);
auto inputChildVector = ListVector::getDataVector(inputVector.get());
for (auto i = 0u; i < numOfEntries; i++) {
if (!result.isNull(i)) {
auto listEntry = inputVector->getValue<list_entry_t>(i);
if (listEntry.size == numValuesPerList) {
for (auto j = 0u; j < listEntry.size; j++) {
func((void*)(inputChildVector), listEntry.offset + j, (void*)(&result),
i * numValuesPerList + j, nullptr);
}
}
}
}
}

template<>
void CastFixedList::castBetweenFixedListExecFunc<UnaryFunctionExecutor>(
const std::vector<std::shared_ptr<ValueVector>>& params, ValueVector& result, void* dataPtr) {
auto inputVector = params[0];
auto numOfEntries = inputVector->state->selVector
->selectedPositions[inputVector->state->selVector->selectedSize - 1] +
1;
reinterpret_cast<CastFunctionBindData*>(dataPtr)->numOfEntries = numOfEntries;
castBetweenFixedListExecFunc<CastChildFunctionExecutor>(params, result, dataPtr);
}

template<>
void CastFixedList::castBetweenFixedListExecFunc<CastFixedListToListFunctionExecutor>(

Check warning on line 355 in src/function/cast/cast_fixed_list.cpp

View check run for this annotation

Codecov / codecov/patch

src/function/cast/cast_fixed_list.cpp#L355

Added line #L355 was not covered by tests
const std::vector<std::shared_ptr<ValueVector>>& params, ValueVector& result, void* dataPtr) {
KU_UNREACHABLE;
}

template<>
void CastFixedList::castBetweenFixedListExecFunc<CastChildFunctionExecutor>(
const std::vector<std::shared_ptr<ValueVector>>& params, ValueVector& result, void* dataPtr) {
auto inputVector = params[0];
auto numOfEntries = reinterpret_cast<CastFunctionBindData*>(dataPtr)->numOfEntries;

auto inputChildId = FixedListType::getChildType(&inputVector->dataType)->getLogicalTypeID();
auto outputChildId = FixedListType::getChildType(&result.dataType)->getLogicalTypeID();
auto numValuesPerList = FixedListType::getNumValuesInList(&result.dataType);
if (FixedListType::getNumValuesInList(&inputVector->dataType) != numValuesPerList) {
throw ConversionException(stringFormat("Unsupported casting function from {} to {}.",
inputVector->dataType.toString(), result.dataType.toString()));
}

scalar_cast_func func;
getFixedListChildCastFunc(func, inputChildId, outputChildId);

result.setNullFromBits(inputVector->getNullMaskData(), 0, 0, numOfEntries);
for (auto i = 0u; i < numOfEntries; i++) {
if (!result.isNull(i)) {
for (auto j = 0u; j < numValuesPerList; j++) {
func((void*)(inputVector.get()), i * numValuesPerList + j, (void*)(&result),
i * numValuesPerList + j, nullptr);
}
}
}
}

} // namespace function
} // namespace kuzu
Loading

0 comments on commit 0ecc9e3

Please sign in to comment.