Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support null values in list #1524

Merged
merged 1 commit into from
May 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/common/types/value.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "common/types/value.h"

#include "common/null_bytes.h"
#include "common/string_utils.h"

namespace kuzu {
Expand Down Expand Up @@ -327,12 +328,19 @@ Value::Value(DataType dataType) : dataType{std::move(dataType)}, isNull_{true} {
std::vector<std::unique_ptr<Value>> Value::convertKUVarListToVector(ku_list_t& list) const {
std::vector<std::unique_ptr<Value>> listResultValue;
auto numBytesPerElement = Types::getDataTypeSize(*dataType.getChildType());
auto listNullBytes = reinterpret_cast<uint8_t*>(list.overflowPtr);
auto numBytesForNullValues = NullBuffer::getNumBytesForNullValues(list.size);
auto listValues = listNullBytes + numBytesForNullValues;
for (auto i = 0; i < list.size; i++) {
auto childValue =
std::make_unique<Value>(Value::createDefaultValue(*dataType.getChildType()));
childValue->copyValueFrom(
reinterpret_cast<uint8_t*>(list.overflowPtr + i * numBytesPerElement));
listResultValue.emplace_back(std::move(childValue));
if (NullBuffer::isNull(listNullBytes, i)) {
childValue->setNull();
} else {
childValue->copyValueFrom(listValues);
}
listResultValue.push_back(std::move(childValue));
listValues += numBytesPerElement;
}
return listResultValue;
}
Expand Down
42 changes: 31 additions & 11 deletions src/common/vector/value_vector_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "common/vector/value_vector_utils.h"

#include "common/in_mem_overflow_buffer_utils.h"
#include "common/null_bytes.h"

using namespace kuzu;
using namespace common;
Expand All @@ -16,13 +17,19 @@ void ValueVectorUtils::copyNonNullDataWithSameTypeIntoPos(
} break;
case VAR_LIST: {
auto srcKuList = *(ku_list_t*)srcData;
auto srcListValues = reinterpret_cast<uint8_t*>(srcKuList.overflowPtr);
auto srcNullBytes = reinterpret_cast<uint8_t*>(srcKuList.overflowPtr);
auto srcListValues = srcNullBytes + NullBuffer::getNumBytesForNullValues(srcKuList.size);
auto dstListEntry = ListVector::addList(&resultVector, srcKuList.size);
resultVector.setValue<list_entry_t>(pos, dstListEntry);
auto resultDataVector = common::ListVector::getDataVector(&resultVector);
for (auto i = 0u; i < srcKuList.size; i++) {
copyNonNullDataWithSameTypeIntoPos(
*resultDataVector, dstListEntry.offset + i, srcListValues);
auto dstListValuePos = dstListEntry.offset + i;
if (NullBuffer::isNull(srcNullBytes, i)) {
resultDataVector->setNull(dstListValuePos, true);
} else {
copyNonNullDataWithSameTypeIntoPos(
*resultDataVector, dstListValuePos, srcListValues);
}
srcListValues += Types::getDataTypeSize(resultDataVector->dataType);
}
} break;
Expand All @@ -49,12 +56,20 @@ void ValueVectorUtils::copyNonNullDataWithSameTypeOutFromPos(const ValueVector&
ku_list_t dstList;
dstList.size = srcListEntry.size;
InMemOverflowBufferUtils::allocateSpaceForList(dstList,
Types::getDataTypeSize(srcListDataVector->dataType) * dstList.size, dstOverflowBuffer);
Types::getDataTypeSize(srcListDataVector->dataType) * dstList.size +
NullBuffer::getNumBytesForNullValues(dstList.size),
dstOverflowBuffer);
auto dstListNullBytes = reinterpret_cast<uint8_t*>(dstList.overflowPtr);
NullBuffer::initNullBytes(dstListNullBytes, dstList.size);
auto dstListValues = dstListNullBytes + NullBuffer::getNumBytesForNullValues(dstList.size);
for (auto i = 0u; i < srcListEntry.size; i++) {
copyNonNullDataWithSameTypeOutFromPos(*srcListDataVector, srcListEntry.offset + i,
reinterpret_cast<uint8_t*>(dstList.overflowPtr) +
i * Types::getDataTypeSize(srcListDataVector->dataType),
dstOverflowBuffer);
if (srcListDataVector->isNull(srcListEntry.offset + i)) {
NullBuffer::setNull(dstListNullBytes, i);
} else {
copyNonNullDataWithSameTypeOutFromPos(
*srcListDataVector, srcListEntry.offset + i, dstListValues, dstOverflowBuffer);
}
dstListValues += Types::getDataTypeSize(srcListDataVector->dataType);
}
memcpy(dstData, &dstList, sizeof(dstList));
} break;
Expand All @@ -74,11 +89,16 @@ void ValueVectorUtils::copyValue(uint8_t* dstValue, common::ValueVector& dstVect
auto dstList = reinterpret_cast<common::list_entry_t*>(dstValue);
*dstList = ListVector::addList(&dstVector, srcList->size);
auto srcValues = ListVector::getListValues(&srcVector, *srcList);
auto srcDataVector = ListVector::getDataVector(&srcVector);
auto dstValues = ListVector::getListValues(&dstVector, *dstList);
auto numBytesPerValue = ListVector::getDataVector(&srcVector)->getNumBytesPerValue();
auto dstDataVector = ListVector::getDataVector(&dstVector);
auto numBytesPerValue = srcDataVector->getNumBytesPerValue();
for (auto i = 0u; i < srcList->size; i++) {
copyValue(dstValues, *ListVector::getDataVector(&dstVector), srcValues,
*ListVector::getDataVector(&srcVector));
if (srcDataVector->isNull(srcList->offset + i)) {
dstDataVector->setNull(dstList->offset + i, true);
} else {
copyValue(dstValues, *dstDataVector, srcValues, *srcDataVector);
}
srcValues += numBytesPerValue;
dstValues += numBytesPerValue;
}
Expand Down
43 changes: 31 additions & 12 deletions src/function/vector_list_operation.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "binder/expression_binder.h"
#include "common/types/ku_list.h"
#include "common/vector/value_vector_utils.h"
#include "function/list/operations/list_append_operation.h"
Expand All @@ -24,7 +25,7 @@ static std::string getListFunctionIncompatibleChildrenTypeErrorMsg(

void ListCreationVectorOperation::execFunc(
const std::vector<std::shared_ptr<ValueVector>>& parameters, ValueVector& result) {
assert(!parameters.empty() && result.dataType.typeID == VAR_LIST);
assert(result.dataType.typeID == VAR_LIST);
common::StringVector::resetOverflowBuffer(&result);
for (auto selectedPos = 0u; selectedPos < result.state->selVector->selectedSize;
++selectedPos) {
Expand All @@ -34,30 +35,48 @@ void ListCreationVectorOperation::execFunc(
auto resultValues = common::ListVector::getListValues(&result, resultEntry);
auto resultDataVector = common::ListVector::getDataVector(&result);
auto numBytesPerValue = resultDataVector->getNumBytesPerValue();
for (auto& parameter : parameters) {
for (auto i = 0u; i < parameters.size(); i++) {
auto parameter = parameters[i];
auto paramPos = parameter->state->isFlat() ?
parameter->state->selVector->selectedPositions[0] :
pos;
common::ValueVectorUtils::copyValue(resultValues, *resultDataVector,
parameter->getData() + parameter->getNumBytesPerValue() * paramPos, *parameter);
if (parameter->isNull(paramPos)) {
resultDataVector->setNull(resultEntry.offset + i, true);
} else {
common::ValueVectorUtils::copyValue(resultValues, *resultDataVector,
parameter->getData() + parameter->getNumBytesPerValue() * paramPos, *parameter);
}
resultValues += numBytesPerValue;
}
}
}

std::unique_ptr<FunctionBindData> ListCreationVectorOperation::bindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition) {
if (arguments.empty()) {
throw BinderException(
"Cannot resolve child data type for " + LIST_CREATION_FUNC_NAME + ".");
// ListCreation requires all parameters to have the same type or be ANY type. The result type of
// listCreation can be determined by the first non-ANY type parameter. If all parameters have
// dataType ANY, then the resultType will be INT64[] (default type).
auto resultType = DataType{std::make_unique<DataType>(INT64)};
for (auto i = 0u; i < arguments.size(); i++) {
if (arguments[i]->getDataType().typeID != common::ANY) {
resultType = DataType{std::make_unique<DataType>(arguments[i]->getDataType())};
break;
}
}
for (auto i = 1u; i < arguments.size(); i++) {
if (arguments[i]->getDataType() != arguments[0]->getDataType()) {
throw BinderException(getListFunctionIncompatibleChildrenTypeErrorMsg(
LIST_CREATION_FUNC_NAME, arguments[0]->getDataType(), arguments[i]->getDataType()));
// Cast parameters with ANY dataType to resultChildType.
for (auto i = 0u; i < arguments.size(); i++) {
auto parameterType = arguments[i]->getDataType();
if (parameterType != *resultType.getChildType()) {
if (parameterType.typeID == common::ANY) {
binder::ExpressionBinder::resolveAnyDataType(
*arguments[i], *resultType.getChildType());
} else {
throw BinderException(
getListFunctionIncompatibleChildrenTypeErrorMsg(LIST_CREATION_FUNC_NAME,
arguments[0]->getDataType(), arguments[i]->getDataType()));
}
}
}
auto resultType = DataType(std::make_unique<DataType>(arguments[0]->getDataType()));
return std::make_unique<FunctionBindData>(resultType);
}

Expand Down
3 changes: 2 additions & 1 deletion src/include/binder/expression_binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class ExpressionBinder {

std::shared_ptr<Expression> bindExpression(const parser::ParsedExpression& parsedExpression);

static void resolveAnyDataType(Expression& expression, const common::DataType& targetType);

private:
std::shared_ptr<Expression> bindBooleanExpression(
const parser::ParsedExpression& parsedExpression);
Expand Down Expand Up @@ -91,7 +93,6 @@ class ExpressionBinder {
const std::shared_ptr<Expression>& expression, const common::DataType& targetType);
static std::shared_ptr<Expression> implicitCastIfNecessary(
const std::shared_ptr<Expression>& expression, common::DataTypeID targetTypeID);
static void resolveAnyDataType(Expression& expression, const common::DataType& targetType);
static std::shared_ptr<Expression> implicitCast(
const std::shared_ptr<Expression>& expression, const common::DataType& targetType);

Expand Down
33 changes: 33 additions & 0 deletions src/include/common/null_bytes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#pragma once

#include <cstdint>

namespace kuzu {
namespace common {

class NullBuffer {

public:
constexpr static const uint64_t NUM_NULL_MASKS_PER_BYTE = 8;

static inline bool isNull(const uint8_t* nullBytes, uint64_t valueIdx) {
return nullBytes[valueIdx / NUM_NULL_MASKS_PER_BYTE] &
(1 << (valueIdx % NUM_NULL_MASKS_PER_BYTE));
}

static inline void setNull(uint8_t* nullBytes, uint64_t valueIdx) {
nullBytes[valueIdx / NUM_NULL_MASKS_PER_BYTE] |=
(1 << (valueIdx % NUM_NULL_MASKS_PER_BYTE));
}

static inline uint64_t getNumBytesForNullValues(uint64_t numValues) {
return (numValues + NUM_NULL_MASKS_PER_BYTE - 1) / NUM_NULL_MASKS_PER_BYTE;
}

static inline void initNullBytes(uint8_t* nullBytes, uint64_t numValues) {
memset(nullBytes, 0 /* value */, getNumBytesForNullValues(numValues));
}
};

} // namespace common
} // namespace kuzu
6 changes: 0 additions & 6 deletions src/include/processor/result/factorized_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,6 @@ class FactorizedTableSchema {
columns[idx]->setMayContainsNullsToTrue();
}

static inline uint32_t getNumBytesForNullBuffer(uint32_t numColumns) {
return (numColumns >> 3) + ((numColumns & 7) != 0); // &7 is the same as %8;
}

inline bool isEmpty() const { return columns.empty(); }

bool operator==(const FactorizedTableSchema& other) const;
Expand Down Expand Up @@ -264,8 +260,6 @@ class FactorizedTable {
int64_t findValueInFlatColumn(ft_col_idx_t colIdx, int64_t value) const;

private:
static bool isNull(const uint8_t* nullMapBuffer, ft_col_idx_t idx);
void setNull(uint8_t* nullBuffer, ft_col_idx_t idx);
void setOverflowColNull(uint8_t* nullBuffer, ft_col_idx_t colIdx, ft_tuple_idx_t tupleIdx);

uint64_t computeNumTuplesToAppend(
Expand Down
25 changes: 7 additions & 18 deletions src/processor/result/factorized_table.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "processor/result/factorized_table.h"

#include "common/exception.h"
#include "common/null_bytes.h"
#include "common/vector/value_vector_utils.h"

using namespace kuzu::common;
Expand All @@ -27,7 +28,7 @@ void FactorizedTableSchema::appendColumn(std::unique_ptr<ColumnSchema> column) {
columns.push_back(std::move(column));
colOffsets.push_back(
colOffsets.empty() ? 0 : colOffsets.back() + getColumn(columns.size() - 2)->getNumBytes());
numBytesForNullMapPerTuple = getNumBytesForNullBuffer(getNumColumns());
numBytesForNullMapPerTuple = NullBuffer::getNumBytesForNullValues(getNumColumns());
numBytesPerTuple = numBytesForDataPerTuple + numBytesForNullMapPerTuple;
}

Expand Down Expand Up @@ -262,19 +263,19 @@ bool FactorizedTable::isOverflowColNull(
if (tableSchema->getColumn(colIdx)->hasNoNullGuarantee()) {
return false;
}
return isNull(nullBuffer, tupleIdx);
return NullBuffer::isNull(nullBuffer, tupleIdx);
}

bool FactorizedTable::isNonOverflowColNull(const uint8_t* nullBuffer, ft_col_idx_t colIdx) const {
assert(colIdx < tableSchema->getNumColumns());
if (tableSchema->getColumn(colIdx)->hasNoNullGuarantee()) {
return false;
}
return isNull(nullBuffer, colIdx);
return NullBuffer::isNull(nullBuffer, colIdx);
}

void FactorizedTable::setNonOverflowColNull(uint8_t* nullBuffer, ft_col_idx_t colIdx) {
setNull(nullBuffer, colIdx);
NullBuffer::setNull(nullBuffer, colIdx);
tableSchema->setMayContainsNullsToTrue(colIdx);
}

Expand Down Expand Up @@ -337,21 +338,9 @@ void FactorizedTable::clear() {
inMemOverflowBuffer->resetBuffer();
}

bool FactorizedTable::isNull(const uint8_t* nullMapBuffer, ft_col_idx_t idx) {
uint32_t nullMapIdx = idx >> 3;
uint8_t nullMapMask = 0x1 << (idx & 7); // note: &7 is the same as %8
return nullMapBuffer[nullMapIdx] & nullMapMask;
}

void FactorizedTable::setNull(uint8_t* nullBuffer, ft_col_idx_t idx) {
uint64_t nullMapIdx = idx >> 3;
uint8_t nullMapMask = 0x1 << (idx & 7); // note: &7 is the same as %8
nullBuffer[nullMapIdx] |= nullMapMask;
}

void FactorizedTable::setOverflowColNull(
uint8_t* nullBuffer, ft_col_idx_t colIdx, ft_tuple_idx_t tupleIdx) {
setNull(nullBuffer, tupleIdx);
NullBuffer::setNull(nullBuffer, tupleIdx);
tableSchema->setMayContainsNullsToTrue(colIdx);
}

Expand Down Expand Up @@ -511,7 +500,7 @@ overflow_value_t FactorizedTable::appendVectorToUnflatTupleBlocks(
auto numBytesPerValue = Types::getDataTypeSize(vector.dataType);
auto numBytesForData = numBytesPerValue * numFlatTuplesInVector;
auto overflowBlockBuffer = allocateUnflatTupleBlock(
numBytesForData + FactorizedTableSchema::getNumBytesForNullBuffer(numFlatTuplesInVector));
numBytesForData + NullBuffer::getNumBytesForNullValues(numFlatTuplesInVector));
if (vector.state->selVector->isUnfiltered()) {
if (vector.hasNoNullsGuarantee()) {
auto dstDataBuffer = overflowBlockBuffer;
Expand Down
19 changes: 12 additions & 7 deletions src/storage/storage_structure/disk_overflow_file.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "storage/storage_structure/disk_overflow_file.h"

#include "common/in_mem_overflow_buffer_utils.h"
#include "common/null_bytes.h"
#include "common/string_utils.h"
#include "common/type_utils.h"

Expand Down Expand Up @@ -240,28 +241,32 @@ void DiskOverflowFile::setListRecursiveIfNestedWithoutLock(
nextBytePosToWriteTo, BufferPoolConstants::PAGE_4KB_SIZE);
diskDstList.size = inMemSrcList.size;
// Copy non-overflow part for elements in the list.
// TODO(Ziyi): Current storage design doesn't support nulls within a list, so we can't read the
// nullBits from factorizedTable to InMemLists.
auto listValues = reinterpret_cast<uint8_t*>(inMemSrcList.overflowPtr) +
NullBuffer::getNumBytesForNullValues(inMemSrcList.size);
memcpy(updatedPageInfoAndWALPageFrame.frame + updatedPageInfoAndWALPageFrame.posInPage,
(uint8_t*)inMemSrcList.overflowPtr, inMemSrcList.size * elementSize);
listValues, inMemSrcList.size * elementSize);
nextBytePosToWriteTo += inMemSrcList.size * elementSize;
TypeUtils::encodeOverflowPtr(diskDstList.overflowPtr,
updatedPageInfoAndWALPageFrame.originalPageIdx, updatedPageInfoAndWALPageFrame.posInPage);
StorageStructureUtils::unpinWALPageAndReleaseOriginalPageLock(
updatedPageInfoAndWALPageFrame, *fileHandle, *bufferManager, *wal);
if (dataType.getChildType()->typeID == STRING) {
// Copy overflow for string elements in the list.
auto dstListElements = (ku_string_t*)(updatedPageInfoAndWALPageFrame.frame +
updatedPageInfoAndWALPageFrame.posInPage);
auto dstListElements = reinterpret_cast<ku_string_t*>(
updatedPageInfoAndWALPageFrame.frame + updatedPageInfoAndWALPageFrame.posInPage);
for (auto i = 0u; i < diskDstList.size; i++) {
auto kuString = ((ku_string_t*)inMemSrcList.overflowPtr)[i];
auto kuString = ((ku_string_t*)listValues)[i];
setStringOverflowWithoutLock(
(const char*)kuString.overflowPtr, kuString.len, dstListElements[i]);
}
} else if (dataType.getChildType()->typeID == VAR_LIST) {
// Recursively copy overflow for list elements in the list.
auto dstListElements = (ku_list_t*)(updatedPageInfoAndWALPageFrame.frame +
updatedPageInfoAndWALPageFrame.posInPage);
auto dstListElements = reinterpret_cast<ku_list_t*>(
updatedPageInfoAndWALPageFrame.frame + updatedPageInfoAndWALPageFrame.posInPage);
for (auto i = 0u; i < diskDstList.size; i++) {
setListRecursiveIfNestedWithoutLock(((ku_list_t*)inMemSrcList.overflowPtr)[i],
setListRecursiveIfNestedWithoutLock((reinterpret_cast<ku_list_t*>(listValues))[i],
dstListElements[i], *dataType.getChildType());
}
}
Expand Down
7 changes: 0 additions & 7 deletions test/binder/binder_error_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,13 +468,6 @@ TEST_F(BinderErrorTest, RenamePropertyDuplicateName) {
ASSERT_STREQ(expectedException.c_str(), getBindingError(input).c_str());
}

TEST_F(BinderErrorTest, EmptyList) {
std::string expectedException =
"Binder exception: Cannot resolve child data type for LIST_CREATION.";
auto input = "RETURN []";
ASSERT_STREQ(expectedException.c_str(), getBindingError(input).c_str());
}

TEST_F(BinderErrorTest, InvalidFixedListChildType) {
std::string expectedException =
"Binder exception: The child type of a fixed list must be a numeric type. Given: STRING.";
Expand Down
Loading