Skip to content

Commit

Permalink
Add null to list values
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin committed May 9, 2023
1 parent 6f72c7a commit 057c219
Show file tree
Hide file tree
Showing 10 changed files with 147 additions and 65 deletions.
14 changes: 11 additions & 3 deletions src/common/types/value.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "common/types/value.h"

#include "common/null_bytes.h"
#include "common/string_utils.h"

namespace kuzu {
Expand Down Expand Up @@ -327,12 +328,19 @@ Value::Value(DataType dataType) : dataType{std::move(dataType)}, isNull_{true} {
std::vector<std::unique_ptr<Value>> Value::convertKUVarListToVector(ku_list_t& list) const {
std::vector<std::unique_ptr<Value>> listResultValue;
auto numBytesPerElement = Types::getDataTypeSize(*dataType.getChildType());
auto listNullBytes = reinterpret_cast<uint8_t*>(list.overflowPtr);
auto numBytesForNullValues = NullBuffer::getNumBytesForNullValues(list.size);
auto listValues = listNullBytes + numBytesForNullValues;
for (auto i = 0; i < list.size; i++) {
auto childValue =
std::make_unique<Value>(Value::createDefaultValue(*dataType.getChildType()));
childValue->copyValueFrom(
reinterpret_cast<uint8_t*>(list.overflowPtr + i * numBytesPerElement));
listResultValue.emplace_back(std::move(childValue));
if (NullBuffer::isNull(listNullBytes, i)) {
childValue->setNull();
} else {
childValue->copyValueFrom(listValues);
}
listResultValue.push_back(std::move(childValue));
listValues += numBytesPerElement;
}
return listResultValue;
}
Expand Down
42 changes: 31 additions & 11 deletions src/common/vector/value_vector_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "common/vector/value_vector_utils.h"

#include "common/in_mem_overflow_buffer_utils.h"
#include "common/null_bytes.h"

using namespace kuzu;
using namespace common;
Expand All @@ -16,13 +17,19 @@ void ValueVectorUtils::copyNonNullDataWithSameTypeIntoPos(
} break;
case VAR_LIST: {
auto srcKuList = *(ku_list_t*)srcData;
auto srcListValues = reinterpret_cast<uint8_t*>(srcKuList.overflowPtr);
auto srcNullBytes = reinterpret_cast<uint8_t*>(srcKuList.overflowPtr);
auto srcListValues = srcNullBytes + NullBuffer::getNumBytesForNullValues(srcKuList.size);
auto dstListEntry = ListVector::addList(&resultVector, srcKuList.size);
resultVector.setValue<list_entry_t>(pos, dstListEntry);
auto resultDataVector = common::ListVector::getDataVector(&resultVector);
for (auto i = 0u; i < srcKuList.size; i++) {
copyNonNullDataWithSameTypeIntoPos(
*resultDataVector, dstListEntry.offset + i, srcListValues);
auto dstListValuePos = dstListEntry.offset + i;
if (NullBuffer::isNull(srcNullBytes, i)) {
resultDataVector->setNull(dstListValuePos, true);
} else {
copyNonNullDataWithSameTypeIntoPos(
*resultDataVector, dstListValuePos, srcListValues);
}
srcListValues += Types::getDataTypeSize(resultDataVector->dataType);
}
} break;
Expand All @@ -49,12 +56,20 @@ void ValueVectorUtils::copyNonNullDataWithSameTypeOutFromPos(const ValueVector&
ku_list_t dstList;
dstList.size = srcListEntry.size;
InMemOverflowBufferUtils::allocateSpaceForList(dstList,
Types::getDataTypeSize(srcListDataVector->dataType) * dstList.size, dstOverflowBuffer);
Types::getDataTypeSize(srcListDataVector->dataType) * dstList.size +
NullBuffer::getNumBytesForNullValues(dstList.size),
dstOverflowBuffer);
auto dstListNullBytes = reinterpret_cast<uint8_t*>(dstList.overflowPtr);
NullBuffer::initNullBytes(dstListNullBytes, dstList.size);
auto dstListValues = dstListNullBytes + NullBuffer::getNumBytesForNullValues(dstList.size);
for (auto i = 0u; i < srcListEntry.size; i++) {
copyNonNullDataWithSameTypeOutFromPos(*srcListDataVector, srcListEntry.offset + i,
reinterpret_cast<uint8_t*>(dstList.overflowPtr) +
i * Types::getDataTypeSize(srcListDataVector->dataType),
dstOverflowBuffer);
if (srcListDataVector->isNull(srcListEntry.offset + i)) {
NullBuffer::setNull(dstListNullBytes, i);
} else {
copyNonNullDataWithSameTypeOutFromPos(
*srcListDataVector, srcListEntry.offset + i, dstListValues, dstOverflowBuffer);
}
dstListValues += Types::getDataTypeSize(srcListDataVector->dataType);
}
memcpy(dstData, &dstList, sizeof(dstList));
} break;
Expand All @@ -74,11 +89,16 @@ void ValueVectorUtils::copyValue(uint8_t* dstValue, common::ValueVector& dstVect
auto dstList = reinterpret_cast<common::list_entry_t*>(dstValue);
*dstList = ListVector::addList(&dstVector, srcList->size);
auto srcValues = ListVector::getListValues(&srcVector, *srcList);
auto srcDataVector = ListVector::getDataVector(&srcVector);
auto dstValues = ListVector::getListValues(&dstVector, *dstList);
auto numBytesPerValue = ListVector::getDataVector(&srcVector)->getNumBytesPerValue();
auto dstDataVector = ListVector::getDataVector(&dstVector);
auto numBytesPerValue = srcDataVector->getNumBytesPerValue();
for (auto i = 0u; i < srcList->size; i++) {
copyValue(dstValues, *ListVector::getDataVector(&dstVector), srcValues,
*ListVector::getDataVector(&srcVector));
if (srcDataVector->isNull(srcList->offset + i)) {
dstDataVector->setNull(dstList->offset + i, true);
} else {
copyValue(dstValues, *dstDataVector, srcValues, *srcDataVector);
}
srcValues += numBytesPerValue;
dstValues += numBytesPerValue;
}
Expand Down
43 changes: 31 additions & 12 deletions src/function/vector_list_operation.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "binder/expression_binder.h"
#include "common/types/ku_list.h"
#include "common/vector/value_vector_utils.h"
#include "function/list/operations/list_append_operation.h"
Expand All @@ -24,7 +25,7 @@ static std::string getListFunctionIncompatibleChildrenTypeErrorMsg(

void ListCreationVectorOperation::execFunc(
const std::vector<std::shared_ptr<ValueVector>>& parameters, ValueVector& result) {
assert(!parameters.empty() && result.dataType.typeID == VAR_LIST);
assert(result.dataType.typeID == VAR_LIST);
common::StringVector::resetOverflowBuffer(&result);
for (auto selectedPos = 0u; selectedPos < result.state->selVector->selectedSize;
++selectedPos) {
Expand All @@ -34,30 +35,48 @@ void ListCreationVectorOperation::execFunc(
auto resultValues = common::ListVector::getListValues(&result, resultEntry);
auto resultDataVector = common::ListVector::getDataVector(&result);
auto numBytesPerValue = resultDataVector->getNumBytesPerValue();
for (auto& parameter : parameters) {
for (auto i = 0u; i < parameters.size(); i++) {
auto parameter = parameters[i];
auto paramPos = parameter->state->isFlat() ?
parameter->state->selVector->selectedPositions[0] :
pos;
common::ValueVectorUtils::copyValue(resultValues, *resultDataVector,
parameter->getData() + parameter->getNumBytesPerValue() * paramPos, *parameter);
if (parameter->isNull(paramPos)) {
resultDataVector->setNull(resultEntry.offset + i, true);
} else {
common::ValueVectorUtils::copyValue(resultValues, *resultDataVector,
parameter->getData() + parameter->getNumBytesPerValue() * paramPos, *parameter);
}
resultValues += numBytesPerValue;
}
}
}

std::unique_ptr<FunctionBindData> ListCreationVectorOperation::bindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition) {
if (arguments.empty()) {
throw BinderException(
"Cannot resolve child data type for " + LIST_CREATION_FUNC_NAME + ".");
// ListCreation requires all parameters to have the same type or be ANY type. The result type of
// listCreation can be determined by the first non-ANY type parameter. If all parameters have
// dataType ANY, then the resultType will be INT64[] (default type).
auto resultType = DataType{std::make_unique<DataType>(INT64)};
for (auto i = 0u; i < arguments.size(); i++) {
if (arguments[i]->getDataType().typeID != common::ANY) {
resultType = DataType{std::make_unique<DataType>(arguments[i]->getDataType())};
break;
}
}
for (auto i = 1u; i < arguments.size(); i++) {
if (arguments[i]->getDataType() != arguments[0]->getDataType()) {
throw BinderException(getListFunctionIncompatibleChildrenTypeErrorMsg(
LIST_CREATION_FUNC_NAME, arguments[0]->getDataType(), arguments[i]->getDataType()));
// Cast parameters with ANY dataType to resultChildType.
for (auto i = 0u; i < arguments.size(); i++) {
auto parameterType = arguments[i]->getDataType();
if (parameterType != *resultType.getChildType()) {
if (parameterType.typeID == common::ANY) {
binder::ExpressionBinder::resolveAnyDataType(
*arguments[i], *resultType.getChildType());
} else {
throw BinderException(
getListFunctionIncompatibleChildrenTypeErrorMsg(LIST_CREATION_FUNC_NAME,
arguments[0]->getDataType(), arguments[i]->getDataType()));
}
}
}
auto resultType = DataType(std::make_unique<DataType>(arguments[0]->getDataType()));
return std::make_unique<FunctionBindData>(resultType);
}

Expand Down
3 changes: 2 additions & 1 deletion src/include/binder/expression_binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class ExpressionBinder {

std::shared_ptr<Expression> bindExpression(const parser::ParsedExpression& parsedExpression);

static void resolveAnyDataType(Expression& expression, const common::DataType& targetType);

private:
std::shared_ptr<Expression> bindBooleanExpression(
const parser::ParsedExpression& parsedExpression);
Expand Down Expand Up @@ -91,7 +93,6 @@ class ExpressionBinder {
const std::shared_ptr<Expression>& expression, const common::DataType& targetType);
static std::shared_ptr<Expression> implicitCastIfNecessary(
const std::shared_ptr<Expression>& expression, common::DataTypeID targetTypeID);
static void resolveAnyDataType(Expression& expression, const common::DataType& targetType);
static std::shared_ptr<Expression> implicitCast(
const std::shared_ptr<Expression>& expression, const common::DataType& targetType);

Expand Down
33 changes: 33 additions & 0 deletions src/include/common/null_bytes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#pragma once

#include <cstdint>

namespace kuzu {
namespace common {

class NullBuffer {

public:
constexpr static const uint64_t NUM_NULL_MASKS_PER_BYTE = 8;

static inline bool isNull(const uint8_t* nullBytes, uint64_t valueIdx) {
return nullBytes[valueIdx / NUM_NULL_MASKS_PER_BYTE] &
(1 << (valueIdx % NUM_NULL_MASKS_PER_BYTE));
}

static inline void setNull(uint8_t* nullBytes, uint64_t valueIdx) {
nullBytes[valueIdx / NUM_NULL_MASKS_PER_BYTE] |=
(1 << (valueIdx % NUM_NULL_MASKS_PER_BYTE));
}

static inline uint64_t getNumBytesForNullValues(uint64_t numValues) {
return (numValues + NUM_NULL_MASKS_PER_BYTE - 1) / NUM_NULL_MASKS_PER_BYTE;
}

static inline void initNullBytes(uint8_t* nullBytes, uint64_t numValues) {
memset(nullBytes, 0 /* value */, getNumBytesForNullValues(numValues));
}
};

} // namespace common
} // namespace kuzu
6 changes: 0 additions & 6 deletions src/include/processor/result/factorized_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,6 @@ class FactorizedTableSchema {
columns[idx]->setMayContainsNullsToTrue();
}

static inline uint32_t getNumBytesForNullBuffer(uint32_t numColumns) {
return (numColumns >> 3) + ((numColumns & 7) != 0); // &7 is the same as %8;
}

inline bool isEmpty() const { return columns.empty(); }

bool operator==(const FactorizedTableSchema& other) const;
Expand Down Expand Up @@ -264,8 +260,6 @@ class FactorizedTable {
int64_t findValueInFlatColumn(ft_col_idx_t colIdx, int64_t value) const;

private:
static bool isNull(const uint8_t* nullMapBuffer, ft_col_idx_t idx);
void setNull(uint8_t* nullBuffer, ft_col_idx_t idx);
void setOverflowColNull(uint8_t* nullBuffer, ft_col_idx_t colIdx, ft_tuple_idx_t tupleIdx);

uint64_t computeNumTuplesToAppend(
Expand Down
25 changes: 7 additions & 18 deletions src/processor/result/factorized_table.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "processor/result/factorized_table.h"

#include "common/exception.h"
#include "common/null_bytes.h"
#include "common/vector/value_vector_utils.h"

using namespace kuzu::common;
Expand All @@ -27,7 +28,7 @@ void FactorizedTableSchema::appendColumn(std::unique_ptr<ColumnSchema> column) {
columns.push_back(std::move(column));
colOffsets.push_back(
colOffsets.empty() ? 0 : colOffsets.back() + getColumn(columns.size() - 2)->getNumBytes());
numBytesForNullMapPerTuple = getNumBytesForNullBuffer(getNumColumns());
numBytesForNullMapPerTuple = NullBuffer::getNumBytesForNullValues(getNumColumns());
numBytesPerTuple = numBytesForDataPerTuple + numBytesForNullMapPerTuple;
}

Expand Down Expand Up @@ -262,19 +263,19 @@ bool FactorizedTable::isOverflowColNull(
if (tableSchema->getColumn(colIdx)->hasNoNullGuarantee()) {
return false;
}
return isNull(nullBuffer, tupleIdx);
return NullBuffer::isNull(nullBuffer, tupleIdx);
}

bool FactorizedTable::isNonOverflowColNull(const uint8_t* nullBuffer, ft_col_idx_t colIdx) const {
assert(colIdx < tableSchema->getNumColumns());
if (tableSchema->getColumn(colIdx)->hasNoNullGuarantee()) {
return false;
}
return isNull(nullBuffer, colIdx);
return NullBuffer::isNull(nullBuffer, colIdx);
}

void FactorizedTable::setNonOverflowColNull(uint8_t* nullBuffer, ft_col_idx_t colIdx) {
setNull(nullBuffer, colIdx);
NullBuffer::setNull(nullBuffer, colIdx);
tableSchema->setMayContainsNullsToTrue(colIdx);
}

Expand Down Expand Up @@ -337,21 +338,9 @@ void FactorizedTable::clear() {
inMemOverflowBuffer->resetBuffer();
}

bool FactorizedTable::isNull(const uint8_t* nullMapBuffer, ft_col_idx_t idx) {
uint32_t nullMapIdx = idx >> 3;
uint8_t nullMapMask = 0x1 << (idx & 7); // note: &7 is the same as %8
return nullMapBuffer[nullMapIdx] & nullMapMask;
}

void FactorizedTable::setNull(uint8_t* nullBuffer, ft_col_idx_t idx) {
uint64_t nullMapIdx = idx >> 3;
uint8_t nullMapMask = 0x1 << (idx & 7); // note: &7 is the same as %8
nullBuffer[nullMapIdx] |= nullMapMask;
}

void FactorizedTable::setOverflowColNull(
uint8_t* nullBuffer, ft_col_idx_t colIdx, ft_tuple_idx_t tupleIdx) {
setNull(nullBuffer, tupleIdx);
NullBuffer::setNull(nullBuffer, tupleIdx);
tableSchema->setMayContainsNullsToTrue(colIdx);
}

Expand Down Expand Up @@ -511,7 +500,7 @@ overflow_value_t FactorizedTable::appendVectorToUnflatTupleBlocks(
auto numBytesPerValue = Types::getDataTypeSize(vector.dataType);
auto numBytesForData = numBytesPerValue * numFlatTuplesInVector;
auto overflowBlockBuffer = allocateUnflatTupleBlock(
numBytesForData + FactorizedTableSchema::getNumBytesForNullBuffer(numFlatTuplesInVector));
numBytesForData + NullBuffer::getNumBytesForNullValues(numFlatTuplesInVector));
if (vector.state->selVector->isUnfiltered()) {
if (vector.hasNoNullsGuarantee()) {
auto dstDataBuffer = overflowBlockBuffer;
Expand Down
19 changes: 12 additions & 7 deletions src/storage/storage_structure/disk_overflow_file.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "storage/storage_structure/disk_overflow_file.h"

#include "common/in_mem_overflow_buffer_utils.h"
#include "common/null_bytes.h"
#include "common/string_utils.h"
#include "common/type_utils.h"

Expand Down Expand Up @@ -240,28 +241,32 @@ void DiskOverflowFile::setListRecursiveIfNestedWithoutLock(
nextBytePosToWriteTo, BufferPoolConstants::PAGE_4KB_SIZE);
diskDstList.size = inMemSrcList.size;
// Copy non-overflow part for elements in the list.
// TODO(Ziyi): Current storage design doesn't support nulls within a list, so we can't read the
// nullBits from factorizedTable to InMemLists.
auto listValues = reinterpret_cast<uint8_t*>(inMemSrcList.overflowPtr) +
NullBuffer::getNumBytesForNullValues(inMemSrcList.size);
memcpy(updatedPageInfoAndWALPageFrame.frame + updatedPageInfoAndWALPageFrame.posInPage,
(uint8_t*)inMemSrcList.overflowPtr, inMemSrcList.size * elementSize);
listValues, inMemSrcList.size * elementSize);
nextBytePosToWriteTo += inMemSrcList.size * elementSize;
TypeUtils::encodeOverflowPtr(diskDstList.overflowPtr,
updatedPageInfoAndWALPageFrame.originalPageIdx, updatedPageInfoAndWALPageFrame.posInPage);
StorageStructureUtils::unpinWALPageAndReleaseOriginalPageLock(
updatedPageInfoAndWALPageFrame, *fileHandle, *bufferManager, *wal);
if (dataType.getChildType()->typeID == STRING) {
// Copy overflow for string elements in the list.
auto dstListElements = (ku_string_t*)(updatedPageInfoAndWALPageFrame.frame +
updatedPageInfoAndWALPageFrame.posInPage);
auto dstListElements = reinterpret_cast<ku_string_t*>(
updatedPageInfoAndWALPageFrame.frame + updatedPageInfoAndWALPageFrame.posInPage);
for (auto i = 0u; i < diskDstList.size; i++) {
auto kuString = ((ku_string_t*)inMemSrcList.overflowPtr)[i];
auto kuString = ((ku_string_t*)listValues)[i];
setStringOverflowWithoutLock(
(const char*)kuString.overflowPtr, kuString.len, dstListElements[i]);
}
} else if (dataType.getChildType()->typeID == VAR_LIST) {
// Recursively copy overflow for list elements in the list.
auto dstListElements = (ku_list_t*)(updatedPageInfoAndWALPageFrame.frame +
updatedPageInfoAndWALPageFrame.posInPage);
auto dstListElements = reinterpret_cast<ku_list_t*>(
updatedPageInfoAndWALPageFrame.frame + updatedPageInfoAndWALPageFrame.posInPage);
for (auto i = 0u; i < diskDstList.size; i++) {
setListRecursiveIfNestedWithoutLock(((ku_list_t*)inMemSrcList.overflowPtr)[i],
setListRecursiveIfNestedWithoutLock((reinterpret_cast<ku_list_t*>(listValues))[i],
dstListElements[i], *dataType.getChildType());
}
}
Expand Down
7 changes: 0 additions & 7 deletions test/binder/binder_error_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,13 +468,6 @@ TEST_F(BinderErrorTest, RenamePropertyDuplicateName) {
ASSERT_STREQ(expectedException.c_str(), getBindingError(input).c_str());
}

TEST_F(BinderErrorTest, EmptyList) {
std::string expectedException =
"Binder exception: Cannot resolve child data type for LIST_CREATION.";
auto input = "RETURN []";
ASSERT_STREQ(expectedException.c_str(), getBindingError(input).c_str());
}

TEST_F(BinderErrorTest, InvalidFixedListChildType) {
std::string expectedException =
"Binder exception: The child type of a fixed list must be a numeric type. Given: STRING.";
Expand Down
Loading

0 comments on commit 057c219

Please sign in to comment.