Skip to content

Commit

Permalink
List ValueVector refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin committed May 2, 2023
1 parent ecb5d46 commit 63c3cc0
Show file tree
Hide file tree
Showing 41 changed files with 937 additions and 481 deletions.
11 changes: 10 additions & 1 deletion src/common/null_mask.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#include "common/null_mask.h"

namespace kuzu {
#include <cstring>

namespace kuzu {
namespace common {

void NullMask::setNull(uint32_t pos, bool isNull) {
Expand Down Expand Up @@ -78,5 +79,13 @@ bool NullMask::copyNullMask(const uint64_t* srcNullEntries, uint64_t srcOffset,
return hasNullInSrcNullMask;
}

void NullMask::resize(uint64_t capacity) {
auto resizedBuffer = std::make_unique<uint64_t[]>(capacity);
memcpy(resizedBuffer.get(), buffer.get(), numNullEntries);
buffer = std::move(resizedBuffer);
data = buffer.get();
numNullEntries = capacity;
}

} // namespace common
} // namespace kuzu
123 changes: 109 additions & 14 deletions src/common/type_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "common/exception.h"
#include "common/string_utils.h"
#include "common/vector/value_vector.h"

namespace kuzu {
namespace common {
Expand Down Expand Up @@ -30,25 +31,102 @@ bool TypeUtils::convertToBoolean(const char* data) {
". Input is not equal to True or False (in a case-insensitive manner)");
}

bool TypeUtils::isListValueEqual(
common::list_entry_t& leftEntry, void* left, common::list_entry_t& rightEntry, void* right) {
auto leftVector = (ValueVector*)left;
auto rightVector = (ValueVector*)right;
if (leftVector->dataType != rightVector->dataType || leftEntry.size != rightEntry.size) {
return false;
}
auto leftElements = leftVector->getListElements(leftEntry);
auto rightElements = rightVector->getListElements(rightEntry);
for (auto i = 0u; i < leftEntry.size; i++) {
switch (leftVector->dataType.getChildType()->typeID) {
case BOOL: {
if (!isValueEqual(reinterpret_cast<uint8_t*>(leftElements)[i],
reinterpret_cast<uint8_t*>(rightElements)[i],
*leftVector->dataType.getChildType(), *rightVector->dataType.getChildType())) {
return false;
}
} break;
case INT64: {
if (!isValueEqual(reinterpret_cast<int64_t*>(leftElements)[i],
reinterpret_cast<int64_t*>(rightElements)[i],
*leftVector->dataType.getChildType(), *rightVector->dataType.getChildType())) {
return false;
}
} break;
case DOUBLE: {
if (!isValueEqual(reinterpret_cast<double_t*>(leftElements)[i],
reinterpret_cast<double_t*>(rightElements)[i],
*leftVector->dataType.getChildType(), *rightVector->dataType.getChildType())) {
return false;
}
} break;
case STRING: {
if (!isValueEqual(reinterpret_cast<ku_string_t*>(leftElements)[i],
reinterpret_cast<ku_string_t*>(rightElements)[i],
*leftVector->dataType.getChildType(), *rightVector->dataType.getChildType())) {
return false;
}
} break;
case DATE: {
if (!isValueEqual(reinterpret_cast<date_t*>(leftElements)[i],
reinterpret_cast<date_t*>(rightElements)[i],
*leftVector->dataType.getChildType(), *rightVector->dataType.getChildType())) {
return false;
}
} break;
case TIMESTAMP: {
if (!isValueEqual(reinterpret_cast<timestamp_t*>(leftElements)[i],
reinterpret_cast<timestamp_t*>(rightElements)[i],
*leftVector->dataType.getChildType(), *rightVector->dataType.getChildType())) {
return false;
}
} break;
case INTERVAL: {
if (!isValueEqual(reinterpret_cast<interval_t*>(leftElements)[i],
reinterpret_cast<interval_t*>(rightElements)[i],
*leftVector->dataType.getChildType(), *rightVector->dataType.getChildType())) {
return false;
}
} break;
case VAR_LIST: {
if (!isListValueEqual(reinterpret_cast<list_entry_t*>(leftElements)[i],
leftVector->getDataVector(), reinterpret_cast<list_entry_t*>(rightElements)[i],
rightVector->getDataVector())) {
return false;
}
} break;
default: {
throw RuntimeException("Unsupported data type " +
Types::dataTypeToString(leftVector->dataType) +
" for TypeUtils::isValueEqual.");
}
}
}
return true;
}

std::string TypeUtils::elementToString(
const DataType& dataType, uint8_t* overflowPtr, uint64_t pos) {
const DataType& dataType, uint8_t* listElements, uint64_t pos) {
switch (dataType.typeID) {
case BOOL:
return TypeUtils::toString(((bool*)overflowPtr)[pos]);
return TypeUtils::toString(((bool*)listElements)[pos]);
case INT64:
return TypeUtils::toString(((int64_t*)overflowPtr)[pos]);
return TypeUtils::toString(((int64_t*)listElements)[pos]);
case DOUBLE:
return TypeUtils::toString(((double_t*)overflowPtr)[pos]);
return TypeUtils::toString(((double_t*)listElements)[pos]);
case DATE:
return TypeUtils::toString(((date_t*)overflowPtr)[pos]);
return TypeUtils::toString(((date_t*)listElements)[pos]);
case TIMESTAMP:
return TypeUtils::toString(((timestamp_t*)overflowPtr)[pos]);
return TypeUtils::toString(((timestamp_t*)listElements)[pos]);
case INTERVAL:
return TypeUtils::toString(((interval_t*)overflowPtr)[pos]);
return TypeUtils::toString(((interval_t*)listElements)[pos]);
case STRING:
return TypeUtils::toString(((ku_string_t*)overflowPtr)[pos]);
return TypeUtils::toString(((ku_string_t*)listElements)[pos]);
case VAR_LIST:
return TypeUtils::toString(((ku_list_t*)overflowPtr)[pos], dataType);
return TypeUtils::toString(((ku_list_t*)listElements)[pos], dataType);
default:
throw RuntimeException("Invalid data type " + Types::dataTypeToString(dataType) +
" for TypeUtils::elementToString.");
Expand All @@ -57,14 +135,31 @@ std::string TypeUtils::elementToString(

std::string TypeUtils::toString(const ku_list_t& val, const DataType& dataType) {
std::string result = "[";
for (auto i = 0u; i < val.size - 1; ++i) {
for (auto i = 0u; i < val.size; ++i) {
result += elementToString(
*dataType.getChildType(), reinterpret_cast<uint8_t*>(val.overflowPtr), i) +
*dataType.getChildType(), reinterpret_cast<uint8_t*>(val.overflowPtr), i);
result += (i == val.size - 1 ? "]" : ",");
}
return result;
}

std::string TypeUtils::toString(const list_entry_t& val, void* valVector) {
auto listVector = (common::ValueVector*)valVector;
std::string result = "[";
auto elements = listVector->getListElements(val);
for (auto i = 0u; i < val.size - 1; ++i) {
result += (listVector->dataType.getChildType()->typeID == VAR_LIST ?
toString(reinterpret_cast<common::list_entry_t*>(elements)[i],
listVector->getDataVector()) :
elementToString(*listVector->dataType.getChildType(), elements, i)) +
",";
}
result += elementToString(*dataType.getChildType(), reinterpret_cast<uint8_t*>(val.overflowPtr),
val.size - 1) +
"]";
result +=
(listVector->dataType.getChildType()->typeID == VAR_LIST ?
toString(reinterpret_cast<common::list_entry_t*>(elements)[val.size - 1],
listVector->getDataVector()) :
elementToString(*listVector->dataType.getChildType(), elements, val.size - 1)) +
"]";
return result;
}

Expand Down
37 changes: 34 additions & 3 deletions src/common/vector/value_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,40 @@ namespace common {

ValueVector::ValueVector(DataType dataType, storage::MemoryManager* memoryManager)
: dataType{std::move(dataType)} {
valueBuffer = std::make_unique<uint8_t[]>(
Types::getDataTypeSize(this->dataType) * DEFAULT_VECTOR_CAPACITY);
numBytesPerValue = this->dataType.typeID == VAR_LIST ? sizeof(common::list_entry_t) :
Types::getDataTypeSize(this->dataType);
valueBuffer = std::make_unique<uint8_t[]>(numBytesPerValue * DEFAULT_VECTOR_CAPACITY);
if (needOverflowBuffer()) {
assert(memoryManager != nullptr);
inMemOverflowBuffer = std::make_unique<InMemOverflowBuffer>(memoryManager);
}
nullMask = std::make_unique<NullMask>();
numBytesPerValue = Types::getDataTypeSize(this->dataType);
if (this->dataType.typeID == VAR_LIST) {
dataVector = std::make_unique<ValueVector>(*this->dataType.getChildType(), memoryManager);
}
}

void ValueVector::resizeValueBuffer(uint64_t oldSize, uint64_t newSize) {
auto buffer = std::make_unique<uint8_t[]>(newSize * numBytesPerValue);
memcpy(valueBuffer.get(), buffer.get(), oldSize * numBytesPerValue);
valueBuffer = std::move(buffer);
nullMask->resize(newSize);
}

list_entry_t ValueVector::createNewListEntry(uint64_t listSize) {
assert(dataType.typeID == VAR_LIST);
list_entry_t listEntry;
listEntry.offset = size;
listEntry.size = listSize;
bool needResizeDataVector = size + listSize > capacity;
while (size + listSize > capacity) {
capacity *= 2;
}
if (needResizeDataVector) {
dataVector->resizeValueBuffer(size, capacity);
}
size += listSize;
return listEntry;
}

void ValueVector::addString(uint32_t pos, char* value, uint64_t len) const {
Expand Down Expand Up @@ -51,6 +77,11 @@ void ValueVector::setValue(uint32_t pos, T val) {
((T*)valueBuffer.get())[pos] = val;
}

template<>
void ValueVector::setValue(uint32_t pos, list_entry_t val) {
((list_entry_t*)valueBuffer.get())[pos] = val;
}

template<>
void ValueVector::setValue(uint32_t pos, std::string val) {
addString(pos, val.data(), val.length());
Expand Down
127 changes: 124 additions & 3 deletions src/common/vector/value_vector_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ void ValueVectorUtils::copyNonNullDataWithSameTypeIntoPos(
copyNonNullDataWithSameTypeIntoPos(*childVector, pos, srcData);
srcData += childVector->getNumBytesPerValue();
}
} else if (resultVector.dataType.typeID == VAR_LIST) {
copyKuListToVector(resultVector, pos, *reinterpret_cast<const ku_list_t*>(srcData));
} else {
copyNonNullDataWithSameType(resultVector.dataType, srcData,
resultVector.getData() + pos * resultVector.getNumBytesPerValue(),
Expand All @@ -26,23 +28,142 @@ void ValueVectorUtils::copyNonNullDataWithSameTypeOutFromPos(const ValueVector&
copyNonNullDataWithSameTypeOutFromPos(*childVector, pos, dstData, dstOverflowBuffer);
dstData += childVector->getNumBytesPerValue();
}
} else if (srcVector.dataType.typeID == VAR_LIST) {
auto kuList = ValueVectorUtils::convertListEntryToKuList(srcVector, pos, dstOverflowBuffer);
memcpy(dstData, &kuList, sizeof(kuList));
} else {
copyNonNullDataWithSameType(srcVector.dataType,
srcVector.getData() + pos * srcVector.getNumBytesPerValue(), dstData,
dstOverflowBuffer);
}
}

void ValueVectorUtils::copyListRecursively(common::list_entry_t& dstList,
common::ValueVector& dstListVector, common::offset_t dstListOffsetToCopy,
const common::list_entry_t& srcList, const common::ValueVector& srcListVector) {
auto srcElements = srcListVector.getListElements(srcList);
auto dstElements = dstListVector.getListElementsWithOffset(dstList, dstListOffsetToCopy);
auto numBytesPerElement = srcListVector.getDataVector()->getNumBytesPerValue();
for (auto i = 0u; i < srcList.size; i++) {
copyListElement(dstElements, dstListVector, srcElements, srcListVector);
srcElements += numBytesPerElement;
dstElements += numBytesPerElement;
}
}

void ValueVectorUtils::copyElementToListWithOffset(common::list_entry_t& listEntry,
ValueVector& listVector, offset_t elementOffsetInList, uint8_t* element,
ValueVector& elementVector) {
auto listDataVector = listVector.getDataVector();
auto dstElement = listVector.getListElementsWithOffset(listEntry, elementOffsetInList);
assert(listDataVector->dataType == elementVector.dataType);
if (elementVector.dataType.typeID == VAR_LIST) {
auto srcList = reinterpret_cast<list_entry_t*>(element);
auto dstList = reinterpret_cast<list_entry_t*>(dstElement);
*dstList = listDataVector->createNewListEntry(srcList->size);
copyListRecursively(*dstList, *listDataVector, *srcList, elementVector);
} else {
copyNonNullDataWithSameType(
listDataVector->dataType, element, dstElement, listDataVector->getOverflowBuffer());
}
}

void ValueVectorUtils::copyListElement(uint8_t* dstElement, common::ValueVector& dstListVector,
const uint8_t* srcElement, const common::ValueVector& srcListVector) {
auto srcDataVector = srcListVector.getDataVector();
auto dstDataVector = dstListVector.getDataVector();
if (srcDataVector->dataType.typeID == common::VAR_LIST) {
auto srcList = reinterpret_cast<const common::list_entry_t*>(srcElement);
auto dstList = reinterpret_cast<common::list_entry_t*>(dstElement);
*dstList = dstDataVector->createNewListEntry(srcList->size);
copyListRecursively(*dstList, *dstDataVector, *srcList, *srcDataVector);
} else if (srcDataVector->dataType.typeID == common::STRING) {
common::InMemOverflowBufferUtils::copyString(*(common::ku_string_t*)srcElement,
*(common::ku_string_t*)dstElement, dstDataVector->getOverflowBuffer());
} else {
memcpy(dstElement, srcElement, srcDataVector->getNumBytesPerValue());
}
}

void ValueVectorUtils::copyElementOutFromListWithOffset(const list_entry_t& listEntry,
const ValueVector& listVector, offset_t elementOffsetInList, uint8_t* element,
ValueVector& elementVector) {
auto listDataVector = listVector.getDataVector();
auto srcElement = listVector.getListElementsWithOffset(listEntry, elementOffsetInList);
if (elementVector.dataType.typeID == VAR_LIST) {
auto srcList = reinterpret_cast<list_entry_t*>(srcElement);
auto dstList = reinterpret_cast<list_entry_t*>(element);
*dstList = listDataVector->createNewListEntry(srcList->size);
copyListRecursively(*dstList, elementVector, *srcList, *listVector.getDataVector());
} else {
copyNonNullDataWithSameType(
listDataVector->dataType, srcElement, element, elementVector.getOverflowBuffer());
}
}

void ValueVectorUtils::copyNonNullDataWithSameType(const DataType& dataType, const uint8_t* srcData,
uint8_t* dstData, InMemOverflowBuffer& inMemOverflowBuffer) {
assert(dataType.typeID != STRUCT);
if (dataType.typeID == STRING) {
InMemOverflowBufferUtils::copyString(
*(ku_string_t*)srcData, *(ku_string_t*)dstData, inMemOverflowBuffer);
} else if (dataType.typeID == VAR_LIST) {
InMemOverflowBufferUtils::copyListRecursiveIfNested(
*(ku_list_t*)srcData, *(ku_list_t*)dstData, dataType, inMemOverflowBuffer);
} else {
memcpy(dstData, srcData, Types::getDataTypeSize(dataType));
}
}

ku_list_t ValueVectorUtils::convertListEntryToKuList(
const ValueVector& srcVector, uint64_t pos, InMemOverflowBuffer& dstOverflowBuffer) {
auto listEntry = srcVector.getValue<list_entry_t>(pos);
auto listElements = srcVector.getListElements(listEntry);
ku_list_t dstList;
dstList.size = listEntry.size;
InMemOverflowBufferUtils::allocateSpaceForList(dstList,
Types::getDataTypeSize(*srcVector.dataType.getChildType()) * dstList.size,
dstOverflowBuffer);
auto srcDataVector = srcVector.getDataVector();
if (srcDataVector->dataType.typeID == VAR_LIST) {
for (auto i = 0u; i < dstList.size; i++) {
auto kuList =
convertListEntryToKuList(*srcDataVector, listEntry.offset + i, dstOverflowBuffer);
(reinterpret_cast<ku_list_t*>(dstList.overflowPtr))[i] = kuList;
}
} else {
memcpy(reinterpret_cast<uint8_t*>(dstList.overflowPtr), listElements,
srcDataVector->getNumBytesPerValue() * listEntry.size);
if (srcDataVector->dataType.typeID == STRING) {
for (auto i = 0u; i < dstList.size; i++) {
InMemOverflowBufferUtils::copyString(
(reinterpret_cast<ku_string_t*>(listElements))[i],
(reinterpret_cast<ku_string_t*>(dstList.overflowPtr))[i], dstOverflowBuffer);
}
}
}
return dstList;
}

void ValueVectorUtils::copyKuListToVector(
ValueVector& dstVector, uint64_t pos, const ku_list_t& srcList) {
auto srcListElements = reinterpret_cast<uint8_t*>(srcList.overflowPtr);
auto dstListEntry = dstVector.createNewListEntry(srcList.size);
dstVector.setValue<list_entry_t>(pos, dstListEntry);
if (dstVector.dataType.getChildType()->typeID == VAR_LIST) {
for (auto i = 0u; i < srcList.size; i++) {
ValueVectorUtils::copyKuListToVector(*dstVector.getDataVector(),
dstListEntry.offset + i, reinterpret_cast<ku_list_t*>(srcList.overflowPtr)[i]);
}
} else {
auto dstDataVector = dstVector.getDataVector();
auto dstListElements = dstVector.getListElements(dstListEntry);
memcpy(
dstListElements, srcListElements, srcList.size * dstDataVector->getNumBytesPerValue());
if (dstDataVector->dataType.getTypeID() == STRING) {
for (auto i = 0u; i < srcList.size; i++) {
InMemOverflowBufferUtils::copyString(
(reinterpret_cast<ku_string_t*>(srcListElements))[i],
(reinterpret_cast<ku_string_t*>(dstListElements))[i],
dstDataVector->getOverflowBuffer());
}
}
}
}
Loading

0 comments on commit 63c3cc0

Please sign in to comment.