Skip to content

Commit

Permalink
Merge pull request #3057 from kuzudb/fixed-list-rework
Browse files Browse the repository at this point in the history
Rework FIXED_LIST
  • Loading branch information
manh9203 authored Mar 18, 2024
2 parents bd963c1 + 35b9438 commit e7c6d73
Show file tree
Hide file tree
Showing 68 changed files with 576 additions and 1,273 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
24 changes: 5 additions & 19 deletions src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include "common/string_utils.h"
#include "function/table_functions.h"
#include "main/client_context.h"
#include "storage/storage_utils.h"

using namespace kuzu::catalog;
using namespace kuzu::common;
Expand Down Expand Up @@ -113,27 +112,14 @@ std::shared_ptr<Expression> Binder::createVariable(

std::unique_ptr<LogicalType> Binder::bindDataType(const std::string& dataType) {
auto boundType = LogicalTypeUtils::dataTypeFromString(dataType);
if (boundType.getLogicalTypeID() == LogicalTypeID::FIXED_LIST) {
auto validNumericTypes = LogicalTypeUtils::getNumericalLogicalTypeIDs();
auto childType = FixedListType::getChildType(&boundType);
auto numElementsInList = FixedListType::getNumValuesInList(&boundType);
if (find(validNumericTypes.begin(), validNumericTypes.end(),
childType->getLogicalTypeID()) == validNumericTypes.end()) {
throw BinderException("The child type of a fixed list must be a numeric type. Given: " +
childType->toString() + ".");
}
if (numElementsInList == 0) {
if (boundType.getLogicalTypeID() == LogicalTypeID::ARRAY) {
auto numElementsInArray = ArrayType::getNumElements(&boundType);
if (numElementsInArray == 0) {
// Note: the parser already guarantees that the number of elements is a non-negative
// number. However, we still need to check whether the number of elements is 0.
throw BinderException(
"The number of elements in a fixed list must be greater than 0. Given: " +
std::to_string(numElementsInList) + ".");
}
auto numElementsPerPage = storage::PageUtils::getNumElementsInAPage(
storage::StorageUtils::getDataTypeSize(boundType), true /* hasNull */);
if (numElementsPerPage == 0) {
throw BinderException(stringFormat("Cannot store a fixed list of size {} in a page.",
storage::StorageUtils::getDataTypeSize(boundType)));
"The number of elements in an array must be greater than 0. Given: " +
std::to_string(numElementsInArray) + ".");
}
}
return std::make_unique<LogicalType>(boundType);
Expand Down
16 changes: 8 additions & 8 deletions src/c_api/data_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ struct CAPIHelper {
} // namespace kuzu::common

kuzu_logical_type* kuzu_data_type_create(
kuzu_data_type_id id, kuzu_logical_type* child_type, uint64_t fixed_num_elements_in_list) {
kuzu_data_type_id id, kuzu_logical_type* child_type, uint64_t num_elements_in_array) {
auto* c_data_type = (kuzu_logical_type*)malloc(sizeof(kuzu_logical_type));
uint8_t data_type_id_u8 = id;
LogicalType* data_type;
Expand All @@ -23,10 +23,10 @@ kuzu_logical_type* kuzu_data_type_create(
} else {
auto child_type_pty =
std::make_unique<LogicalType>(*static_cast<LogicalType*>(child_type->_data_type));
auto extraTypeInfo = fixed_num_elements_in_list > 0 ?
std::make_unique<FixedListTypeInfo>(
std::move(child_type_pty), fixed_num_elements_in_list) :
std::make_unique<VarListTypeInfo>(std::move(child_type_pty));
auto extraTypeInfo =
num_elements_in_array > 0 ?
std::make_unique<ArrayTypeInfo>(std::move(child_type_pty), num_elements_in_array) :
std::make_unique<VarListTypeInfo>(std::move(child_type_pty));
data_type = CAPIHelper::createLogicalType(logicalTypeID, std::move(extraTypeInfo));
}
c_data_type->_data_type = data_type;
Expand Down Expand Up @@ -60,10 +60,10 @@ kuzu_data_type_id kuzu_data_type_get_id(kuzu_logical_type* data_type) {
return static_cast<kuzu_data_type_id>(data_type_id_u8);
}

uint64_t kuzu_data_type_get_fixed_num_elements_in_list(kuzu_logical_type* data_type) {
uint64_t kuzu_data_type_get_num_elements_in_array(kuzu_logical_type* data_type) {
auto parent_type = static_cast<LogicalType*>(data_type->_data_type);
if (parent_type->getLogicalTypeID() != LogicalTypeID::FIXED_LIST) {
if (parent_type->getLogicalTypeID() != LogicalTypeID::ARRAY) {
return 0;
}
return FixedListType::getNumValuesInList(static_cast<LogicalType*>(data_type->_data_type));
return ArrayType::getNumElements(static_cast<LogicalType*>(data_type->_data_type));
}
6 changes: 3 additions & 3 deletions src/common/arrow/arrow_converter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ void ArrowConverter::setArrowFormat(
child.children[0]->name = "l";
setArrowFormat(rootHolder, **child.children, *typeInfo.childrenTypesInfo[0]);
} break;
case LogicalTypeID::FIXED_LIST: {
auto numValuesPerList = "+w:" + std::to_string(typeInfo.numValuesPerList);
child.format = copyName(rootHolder, numValuesPerList);
case LogicalTypeID::ARRAY: {
auto numValuesPerArray = "+w:" + std::to_string(typeInfo.fixedNumValues);
child.format = copyName(rootHolder, numValuesPerArray);
child.n_children = 1;
rootHolder.nestedChildren.emplace_back();
rootHolder.nestedChildren.back().resize(1);
Expand Down
51 changes: 33 additions & 18 deletions src/common/arrow/arrow_row_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ void ArrowRowBatch::templateInitializeVector<LogicalTypeID::VAR_LIST>(
}

template<>
void ArrowRowBatch::templateInitializeVector<LogicalTypeID::FIXED_LIST>(
void ArrowRowBatch::templateInitializeVector<LogicalTypeID::ARRAY>(
ArrowVector* vector, const main::DataTypeInfo& typeInfo, std::int64_t capacity) {
initializeNullBits(vector->validity, capacity);
KU_ASSERT(typeInfo.childrenTypesInfo.size() == 1);
Expand Down Expand Up @@ -171,8 +171,8 @@ std::unique_ptr<ArrowVector> ArrowRowBatch::createVector(
case LogicalTypeID::VAR_LIST: {
templateInitializeVector<LogicalTypeID::VAR_LIST>(result.get(), typeInfo, capacity);
} break;
case LogicalTypeID::FIXED_LIST: {
templateInitializeVector<LogicalTypeID::FIXED_LIST>(result.get(), typeInfo, capacity);
case LogicalTypeID::ARRAY: {
templateInitializeVector<LogicalTypeID::ARRAY>(result.get(), typeInfo, capacity);
} break;
case LogicalTypeID::STRUCT: {
templateInitializeVector<LogicalTypeID::STRUCT>(result.get(), typeInfo, capacity);
Expand Down Expand Up @@ -277,7 +277,12 @@ void ArrowRowBatch::templateCopyNonNullValue<LogicalTypeID::VAR_LIST>(
for (auto i = currentNumBytesForChildValidity; i < numBytesForChildValidity; i++) {
vector->childData[0]->validity.data()[i] = 0xFF; // Init each value to be valid (as 1).
}
if (typeInfo.childrenTypesInfo[0]->typeID != LogicalTypeID::VAR_LIST) {
// If vector->childData[0] is a VAR_LIST, its data buffer will be resized when we add a new
// value into it
// If vector->childData[0] is an ARRAY, its data buffer is supposed to be empty,
// so we don't resize it here
if (typeInfo.childrenTypesInfo[0]->typeID != LogicalTypeID::VAR_LIST &&
typeInfo.childrenTypesInfo[0]->typeID != LogicalTypeID::ARRAY) {
vector->childData[0]->data.resize(
numChildElements * storage::StorageUtils::getDataTypeSize(
LogicalType{typeInfo.childrenTypesInfo[0]->typeID}));
Expand All @@ -289,21 +294,28 @@ void ArrowRowBatch::templateCopyNonNullValue<LogicalTypeID::VAR_LIST>(
}

template<>
void ArrowRowBatch::templateCopyNonNullValue<LogicalTypeID::FIXED_LIST>(
void ArrowRowBatch::templateCopyNonNullValue<LogicalTypeID::ARRAY>(
ArrowVector* vector, const main::DataTypeInfo& typeInfo, Value* value, std::int64_t pos) {
auto numValuesPerList = value->childrenSize;
auto numValuesInChild = numValuesPerList * (pos + 1);
auto numElements = value->childrenSize;
auto numChildElements = numElements * (pos + 1);
auto currentNumBytesForChildValidity = vector->childData[0]->validity.size();
auto numBytesForChildValidity = getNumBytesForBits(numValuesInChild);
auto numBytesForChildValidity = getNumBytesForBits(numChildElements);
vector->childData[0]->validity.resize(numBytesForChildValidity);
// Initialize validity mask which is used to mark each value is valid (non-null) or not (null).
for (auto i = currentNumBytesForChildValidity; i < numBytesForChildValidity; i++) {
vector->childData[0]->validity.data()[i] = 0xFF; // Init each value to be valid (as 1).
}
vector->childData[0]->data.resize(
numValuesInChild *
storage::StorageUtils::getDataTypeSize(LogicalType{typeInfo.childrenTypesInfo[0]->typeID}));
for (auto i = 0u; i < numValuesPerList; i++) {
// If vector->childData[0] is a VAR_LIST, its data buffer will be resized when we add a new
// value into it
// If vector->childData[0] is an ARRAY, its data buffer is supposed to be empty,
// so we don't resize it here
if (typeInfo.childrenTypesInfo[0]->typeID != LogicalTypeID::VAR_LIST &&
typeInfo.childrenTypesInfo[0]->typeID != LogicalTypeID::ARRAY) {
vector->childData[0]->data.resize(
numChildElements * storage::StorageUtils::getDataTypeSize(
LogicalType{typeInfo.childrenTypesInfo[0]->typeID}));
}
for (auto i = 0u; i < numElements; i++) {
appendValue(
vector->childData[0].get(), *typeInfo.childrenTypesInfo[0], value->children[i].get());
}
Expand Down Expand Up @@ -433,8 +445,8 @@ void ArrowRowBatch::copyNonNullValue(
case LogicalTypeID::VAR_LIST: {
templateCopyNonNullValue<LogicalTypeID::VAR_LIST>(vector, typeInfo, value, pos);
} break;
case LogicalTypeID::FIXED_LIST: {
templateCopyNonNullValue<LogicalTypeID::FIXED_LIST>(vector, typeInfo, value, pos);
case LogicalTypeID::ARRAY: {
templateCopyNonNullValue<LogicalTypeID::ARRAY>(vector, typeInfo, value, pos);
} break;
case LogicalTypeID::STRUCT: {
templateCopyNonNullValue<LogicalTypeID::STRUCT>(vector, typeInfo, value, pos);
Expand Down Expand Up @@ -480,7 +492,7 @@ void ArrowRowBatch::templateCopyNullValue<LogicalTypeID::VAR_LIST>(
}

template<>
void ArrowRowBatch::templateCopyNullValue<LogicalTypeID::FIXED_LIST>(
void ArrowRowBatch::templateCopyNullValue<LogicalTypeID::ARRAY>(
ArrowVector* vector, std::int64_t pos) {
setBitToZero(vector->validity.data(), pos);
vector->numNulls++;
Expand Down Expand Up @@ -559,6 +571,9 @@ void ArrowRowBatch::copyNullValue(ArrowVector* vector, Value* value, std::int64_
case LogicalTypeID::VAR_LIST: {
templateCopyNullValue<LogicalTypeID::VAR_LIST>(vector, pos);
} break;
case LogicalTypeID::ARRAY: {
templateCopyNullValue<LogicalTypeID::ARRAY>(vector, pos);
} break;
case LogicalTypeID::INTERNAL_ID: {
templateCopyNullValue<LogicalTypeID::INTERNAL_ID>(vector, pos);
} break;
Expand Down Expand Up @@ -637,7 +652,7 @@ ArrowArray* ArrowRowBatch::templateCreateArray<LogicalTypeID::VAR_LIST>(
}

template<>
ArrowArray* ArrowRowBatch::templateCreateArray<LogicalTypeID::FIXED_LIST>(
ArrowArray* ArrowRowBatch::templateCreateArray<LogicalTypeID::ARRAY>(
ArrowVector& vector, const main::DataTypeInfo& typeInfo) {
auto result = createArrayFromVector(vector);
vector.childPointers.resize(1);
Expand Down Expand Up @@ -756,8 +771,8 @@ ArrowArray* ArrowRowBatch::convertVectorToArray(
case LogicalTypeID::VAR_LIST: {
return templateCreateArray<LogicalTypeID::VAR_LIST>(vector, typeInfo);
}
case LogicalTypeID::FIXED_LIST: {
return templateCreateArray<LogicalTypeID::FIXED_LIST>(vector, typeInfo);
case LogicalTypeID::ARRAY: {
return templateCreateArray<LogicalTypeID::ARRAY>(vector, typeInfo);
}
case LogicalTypeID::STRUCT: {
return templateCreateArray<LogicalTypeID::STRUCT>(vector, typeInfo);
Expand Down
19 changes: 1 addition & 18 deletions src/common/type_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ static std::string entryToString(
return TypeUtils::toString(*reinterpret_cast<const ku_string_t*>(value));
case LogicalTypeID::INTERNAL_ID:
return TypeUtils::toString(*reinterpret_cast<const internalID_t*>(value));
case LogicalTypeID::FIXED_LIST:
return TypeUtils::fixedListToString(value, dataType, valueVector);
case LogicalTypeID::ARRAY:
case LogicalTypeID::VAR_LIST:
return TypeUtils::toString(*reinterpret_cast<const list_entry_t*>(value), valueVector);
case LogicalTypeID::MAP:
Expand Down Expand Up @@ -81,22 +80,6 @@ static std::string entryToString(sel_t pos, ValueVector* vector) {
vector->dataType, vector->getData() + vector->getNumBytesPerValue() * pos, vector);
}

std::string TypeUtils::fixedListToString(
const uint8_t* val, const LogicalType& type, ValueVector* dummyVector) {
std::string result = "[";
auto numValuesPerList = FixedListType::getNumValuesInList(&type);
auto childType = FixedListType::getChildType(&type);
for (auto i = 0u; i < numValuesPerList - 1; ++i) {
// Note: FixedList can only store numeric types and doesn't allow nulls.
result += entryToString(*childType, val, dummyVector);
result += ",";
val += PhysicalTypeUtils::getFixedTypeSize(childType->getPhysicalType());
}
result += entryToString(*childType, val, dummyVector);
result += "]";
return result;
}

template<>
std::string TypeUtils::toString(const int128_t& val, void* /*valueVector*/) {
return Int128_t::ToString(val);
Expand Down
Loading

0 comments on commit e7c6d73

Please sign in to comment.