Skip to content

Commit

Permalink
List ValueVector refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin committed May 4, 2023
1 parent ecb5d46 commit 48e4fd5
Show file tree
Hide file tree
Showing 60 changed files with 1,177 additions and 721 deletions.
7 changes: 0 additions & 7 deletions src/common/in_mem_overflow_buffer_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,6 @@ void InMemOverflowBufferUtils::copyString(
dest.set(src);
}

void InMemOverflowBufferUtils::copyListNonRecursive(const uint8_t* srcValues, ku_list_t& dst,
const DataType& dataType, InMemOverflowBuffer& inMemOverflowBuffer) {
InMemOverflowBufferUtils::allocateSpaceForList(
dst, dst.size * Types::getDataTypeSize(*dataType.getChildType()), inMemOverflowBuffer);
dst.set(srcValues, dataType);
}

void InMemOverflowBufferUtils::copyListRecursiveIfNested(const ku_list_t& src, ku_list_t& dst,
const DataType& dataType, InMemOverflowBuffer& inMemOverflowBuffer, uint32_t srcStartIdx,
uint32_t srcEndIdx) {
Expand Down
11 changes: 10 additions & 1 deletion src/common/null_mask.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#include "common/null_mask.h"

namespace kuzu {
#include <cstring>

namespace kuzu {
namespace common {

void NullMask::setNull(uint32_t pos, bool isNull) {
Expand Down Expand Up @@ -78,5 +79,13 @@ bool NullMask::copyNullMask(const uint64_t* srcNullEntries, uint64_t srcOffset,
return hasNullInSrcNullMask;
}

void NullMask::resize(uint64_t capacity) {
auto resizedBuffer = std::make_unique<uint64_t[]>(capacity);
memcpy(resizedBuffer.get(), buffer.get(), numNullEntries);
buffer = std::move(resizedBuffer);
data = buffer.get();
numNullEntries = capacity;
}

} // namespace common
} // namespace kuzu
131 changes: 116 additions & 15 deletions src/common/type_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "common/exception.h"
#include "common/string_utils.h"
#include "common/vector/value_vector.h"

namespace kuzu {
namespace common {
Expand Down Expand Up @@ -30,40 +31,140 @@ bool TypeUtils::convertToBoolean(const char* data) {
". Input is not equal to True or False (in a case-insensitive manner)");
}

std::string TypeUtils::elementToString(
const DataType& dataType, uint8_t* overflowPtr, uint64_t pos) {
bool TypeUtils::isListValueEqual(
common::list_entry_t& leftEntry, void* left, common::list_entry_t& rightEntry, void* right) {
auto leftVector = (ValueVector*)left;
auto rightVector = (ValueVector*)right;
if (leftVector->dataType != rightVector->dataType || leftEntry.size != rightEntry.size) {
return false;
}
auto leftValues = ListVector::getListValues(leftVector, leftEntry);
auto rightValues = ListVector::getListValues(rightVector, rightEntry);
switch (leftVector->dataType.getChildType()->typeID) {
case BOOL: {
for (auto i = 0u; i < leftEntry.size; i++) {
if (!isValueEqual(reinterpret_cast<uint8_t*>(leftValues)[i],
reinterpret_cast<uint8_t*>(rightValues)[i])) {
return false;
}
}
} break;
case INT64: {
for (auto i = 0u; i < leftEntry.size; i++) {
if (!isValueEqual(reinterpret_cast<int64_t*>(leftValues)[i],
reinterpret_cast<int64_t*>(rightValues)[i])) {
return false;
}
}
} break;
case DOUBLE: {
for (auto i = 0u; i < leftEntry.size; i++) {
if (!isValueEqual(reinterpret_cast<double_t*>(leftValues)[i],
reinterpret_cast<double_t*>(rightValues)[i])) {
return false;
}
}
} break;
case STRING: {
for (auto i = 0u; i < leftEntry.size; i++) {
if (!isValueEqual(reinterpret_cast<ku_string_t*>(leftValues)[i],
reinterpret_cast<ku_string_t*>(rightValues)[i])) {
return false;
}
}
} break;
case DATE: {
for (auto i = 0u; i < leftEntry.size; i++) {
if (!isValueEqual(reinterpret_cast<date_t*>(leftValues)[i],
reinterpret_cast<date_t*>(rightValues)[i])) {
return false;
}
}
} break;
case TIMESTAMP: {
for (auto i = 0u; i < leftEntry.size; i++) {
if (!isValueEqual(reinterpret_cast<timestamp_t*>(leftValues)[i],
reinterpret_cast<timestamp_t*>(rightValues)[i])) {
return false;
}
}
} break;
case INTERVAL: {
for (auto i = 0u; i < leftEntry.size; i++) {
if (!isValueEqual(reinterpret_cast<interval_t*>(leftValues)[i],
reinterpret_cast<interval_t*>(rightValues)[i])) {
return false;
}
}
} break;
case VAR_LIST: {
for (auto i = 0u; i < leftEntry.size; i++) {
if (!isListValueEqual(reinterpret_cast<list_entry_t*>(leftValues)[i],
ListVector::getDataVector(leftVector),
reinterpret_cast<list_entry_t*>(rightValues)[i],
ListVector::getDataVector(rightVector))) {
return false;
}
}
} break;
default: {
throw RuntimeException("Unsupported data type " +
Types::dataTypeToString(leftVector->dataType) +
" for TypeUtils::isValueEqual.");
}
}
return true;
}

std::string TypeUtils::valueToString(const DataType& dataType, uint8_t* listValues, uint64_t pos) {
switch (dataType.typeID) {
case BOOL:
return TypeUtils::toString(((bool*)overflowPtr)[pos]);
return TypeUtils::toString(((bool*)listValues)[pos]);
case INT64:
return TypeUtils::toString(((int64_t*)overflowPtr)[pos]);
return TypeUtils::toString(((int64_t*)listValues)[pos]);
case DOUBLE:
return TypeUtils::toString(((double_t*)overflowPtr)[pos]);
return TypeUtils::toString(((double_t*)listValues)[pos]);
case DATE:
return TypeUtils::toString(((date_t*)overflowPtr)[pos]);
return TypeUtils::toString(((date_t*)listValues)[pos]);
case TIMESTAMP:
return TypeUtils::toString(((timestamp_t*)overflowPtr)[pos]);
return TypeUtils::toString(((timestamp_t*)listValues)[pos]);
case INTERVAL:
return TypeUtils::toString(((interval_t*)overflowPtr)[pos]);
return TypeUtils::toString(((interval_t*)listValues)[pos]);
case STRING:
return TypeUtils::toString(((ku_string_t*)overflowPtr)[pos]);
return TypeUtils::toString(((ku_string_t*)listValues)[pos]);
case VAR_LIST:
return TypeUtils::toString(((ku_list_t*)overflowPtr)[pos], dataType);
return TypeUtils::toString(((ku_list_t*)listValues)[pos], dataType);
default:
throw RuntimeException("Invalid data type " + Types::dataTypeToString(dataType) +
" for TypeUtils::elementToString.");
" for TypeUtils::valueToString.");
}
}

std::string TypeUtils::toString(const ku_list_t& val, const DataType& dataType) {
std::string result = "[";
for (auto i = 0u; i < val.size; ++i) {
result +=
valueToString(*dataType.getChildType(), reinterpret_cast<uint8_t*>(val.overflowPtr), i);
result += (i == val.size - 1 ? "]" : ",");
}
return result;
}

std::string TypeUtils::toString(const list_entry_t& val, void* valVector) {
auto listVector = (common::ValueVector*)valVector;
std::string result = "[";
auto values = ListVector::getListValues(listVector, val);
for (auto i = 0u; i < val.size - 1; ++i) {
result += elementToString(
*dataType.getChildType(), reinterpret_cast<uint8_t*>(val.overflowPtr), i) +
result += (listVector->dataType.getChildType()->typeID == VAR_LIST ?
toString(reinterpret_cast<common::list_entry_t*>(values)[i],
ListVector::getDataVector(listVector)) :
valueToString(*listVector->dataType.getChildType(), values, i)) +
",";
}
result += elementToString(*dataType.getChildType(), reinterpret_cast<uint8_t*>(val.overflowPtr),
val.size - 1) +
result += (listVector->dataType.getChildType()->typeID == VAR_LIST ?
toString(reinterpret_cast<common::list_entry_t*>(values)[val.size - 1],
ListVector::getDataVector(listVector)) :
valueToString(*listVector->dataType.getChildType(), values, val.size - 1)) +
"]";
return result;
}
Expand Down
3 changes: 2 additions & 1 deletion src/common/vector/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
add_library(kuzu_common_vector
OBJECT
value_vector.cpp
value_vector_utils.cpp)
value_vector_utils.cpp
auxiliary_buffer.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_common_vector>
Expand Down
53 changes: 53 additions & 0 deletions src/common/vector/auxiliary_buffer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#include "common/vector/auxiliary_buffer.h"

#include "common/in_mem_overflow_buffer_utils.h"
#include "common/vector/value_vector.h"

namespace kuzu {
namespace common {

void StringAuxiliaryBuffer::addString(
common::ValueVector* vector, uint32_t pos, char* value, uint64_t len) const {
assert(vector->dataType.typeID == STRING);
auto& entry = ((ku_string_t*)vector->getData())[pos];
InMemOverflowBufferUtils::copyString(value, len, entry, *inMemOverflowBuffer);
}

ListAuxiliaryBuffer::ListAuxiliaryBuffer(
kuzu::common::DataType& dataVectorType, storage::MemoryManager* memoryManager)
: capacity{common::DEFAULT_VECTOR_CAPACITY}, size{0}, dataVector{std::make_unique<ValueVector>(
dataVectorType, memoryManager)} {}

list_entry_t ListAuxiliaryBuffer::addList(uint64_t listSize) {
auto listEntry = list_entry_t{size, listSize};
bool needResizeDataVector = size + listSize > capacity;
while (size + listSize > capacity) {
capacity *= 2;
}
auto numBytesPerElement = dataVector->getNumBytesPerValue();
if (needResizeDataVector) {
auto buffer = std::make_unique<uint8_t[]>(capacity * numBytesPerElement);
memcpy(dataVector->valueBuffer.get(), buffer.get(), size * numBytesPerElement);
dataVector->valueBuffer = std::move(buffer);
dataVector->nullMask->resize(capacity);
}
size += listSize;
return listEntry;
}

std::unique_ptr<AuxiliaryBuffer> AuxiliaryBufferFactory::getAuxiliaryBuffer(
DataType& type, storage::MemoryManager* memoryManager) {
switch (type.typeID) {
case STRING:
return std::make_unique<StringAuxiliaryBuffer>(memoryManager);
case STRUCT:
return std::make_unique<StructAuxiliaryBuffer>();
case VAR_LIST:
return std::make_unique<ListAuxiliaryBuffer>(*type.getChildType(), memoryManager);
default:
return nullptr;
}
}

} // namespace common
} // namespace kuzu
41 changes: 20 additions & 21 deletions src/common/vector/value_vector.cpp
Original file line number Diff line number Diff line change
@@ -1,26 +1,20 @@
#include "common/vector/value_vector.h"

#include "common/in_mem_overflow_buffer_utils.h"
#include "common/vector/auxiliary_buffer.h"
#include "common/vector/value_vector_utils.h"

namespace kuzu {
namespace common {

ValueVector::ValueVector(DataType dataType, storage::MemoryManager* memoryManager)
: dataType{std::move(dataType)} {
valueBuffer = std::make_unique<uint8_t[]>(
Types::getDataTypeSize(this->dataType) * DEFAULT_VECTOR_CAPACITY);
if (needOverflowBuffer()) {
assert(memoryManager != nullptr);
inMemOverflowBuffer = std::make_unique<InMemOverflowBuffer>(memoryManager);
}
// TODO(Ziyi): remove this if/else statement once we removed the ku_list.
numBytesPerValue = this->dataType.typeID == VAR_LIST ? sizeof(common::list_entry_t) :
Types::getDataTypeSize(this->dataType);
valueBuffer = std::make_unique<uint8_t[]>(numBytesPerValue * DEFAULT_VECTOR_CAPACITY);
nullMask = std::make_unique<NullMask>();
numBytesPerValue = Types::getDataTypeSize(this->dataType);
}

void ValueVector::addString(uint32_t pos, char* value, uint64_t len) const {
assert(dataType.typeID == STRING);
auto& entry = ((ku_string_t*)getData())[pos];
InMemOverflowBufferUtils::copyString(value, len, entry, *inMemOverflowBuffer);
auxiliaryBuffer = AuxiliaryBufferFactory::getAuxiliaryBuffer(this->dataType, memoryManager);
}

bool NodeIDVector::discardNull(ValueVector& vector) {
Expand Down Expand Up @@ -51,9 +45,14 @@ void ValueVector::setValue(uint32_t pos, T val) {
((T*)valueBuffer.get())[pos] = val;
}

template<>
void ValueVector::setValue(uint32_t pos, common::list_entry_t val) {
((list_entry_t*)valueBuffer.get())[pos] = val;
}

template<>
void ValueVector::setValue(uint32_t pos, std::string val) {
addString(pos, val.data(), val.length());
StringVector::addString(this, pos, val.data(), val.length());
}

template void ValueVector::setValue<nodeID_t>(uint32_t pos, nodeID_t val);
Expand Down Expand Up @@ -108,18 +107,18 @@ void ValueVector::copyValue(uint8_t* dest, const Value& value) {
memcpy(dest, &value.val.intervalVal, size);
} break;
case STRING: {
InMemOverflowBufferUtils::copyString(
value.strVal.data(), value.strVal.length(), *(ku_string_t*)dest, getOverflowBuffer());
InMemOverflowBufferUtils::copyString(value.strVal.data(), value.strVal.length(),
*(ku_string_t*)dest,
*reinterpret_cast<StringAuxiliaryBuffer*>(auxiliaryBuffer.get())->getOverflowBuffer());
} break;
case VAR_LIST: {
auto& entry = *(ku_list_t*)dest;
auto listEntry = reinterpret_cast<list_entry_t*>(dest);
auto numElements = value.nestedTypeVal.size();
*listEntry = ListVector::addList(this, numElements);
auto elementSize = Types::getDataTypeSize(*dataType.getChildType());
InMemOverflowBufferUtils::allocateSpaceForList(
entry, numElements * elementSize, getOverflowBuffer());
entry.size = numElements;
auto dstElements = ListVector::getListValues(this, *listEntry);
for (auto i = 0u; i < numElements; ++i) {
copyValue((uint8_t*)entry.overflowPtr + i * elementSize, *value.nestedTypeVal[i]);
copyValue(dstElements + i * elementSize, *value.nestedTypeVal[i]);
}
} break;
default:
Expand Down
Loading

0 comments on commit 48e4fd5

Please sign in to comment.