diff --git a/src/common/types/value.cpp b/src/common/types/value.cpp index e36facb21d..0774d6f055 100644 --- a/src/common/types/value.cpp +++ b/src/common/types/value.cpp @@ -1,5 +1,6 @@ #include "common/types/value.h" +#include "common/null_bytes.h" #include "common/string_utils.h" namespace kuzu { @@ -327,12 +328,19 @@ Value::Value(DataType dataType) : dataType{std::move(dataType)}, isNull_{true} { std::vector> Value::convertKUVarListToVector(ku_list_t& list) const { std::vector> listResultValue; auto numBytesPerElement = Types::getDataTypeSize(*dataType.getChildType()); + auto listNullBytes = reinterpret_cast(list.overflowPtr); + auto numBytesForNullValues = NullBuffer::getNumBytesForNullValues(list.size); + auto listValues = listNullBytes + numBytesForNullValues; for (auto i = 0; i < list.size; i++) { auto childValue = std::make_unique(Value::createDefaultValue(*dataType.getChildType())); - childValue->copyValueFrom( - reinterpret_cast(list.overflowPtr + i * numBytesPerElement)); - listResultValue.emplace_back(std::move(childValue)); + if (NullBuffer::isNull(listNullBytes, i)) { + childValue->setNull(); + } else { + childValue->copyValueFrom(listValues); + } + listResultValue.push_back(std::move(childValue)); + listValues += numBytesPerElement; } return listResultValue; } diff --git a/src/common/vector/value_vector_utils.cpp b/src/common/vector/value_vector_utils.cpp index f2080a2ac7..7aed7d9e77 100644 --- a/src/common/vector/value_vector_utils.cpp +++ b/src/common/vector/value_vector_utils.cpp @@ -1,6 +1,7 @@ #include "common/vector/value_vector_utils.h" #include "common/in_mem_overflow_buffer_utils.h" +#include "common/null_bytes.h" using namespace kuzu; using namespace common; @@ -16,13 +17,19 @@ void ValueVectorUtils::copyNonNullDataWithSameTypeIntoPos( } break; case VAR_LIST: { auto srcKuList = *(ku_list_t*)srcData; - auto srcListValues = reinterpret_cast(srcKuList.overflowPtr); + auto srcNullBytes = reinterpret_cast(srcKuList.overflowPtr); + auto srcListValues = srcNullBytes + NullBuffer::getNumBytesForNullValues(srcKuList.size); auto dstListEntry = ListVector::addList(&resultVector, srcKuList.size); resultVector.setValue(pos, dstListEntry); auto resultDataVector = common::ListVector::getDataVector(&resultVector); for (auto i = 0u; i < srcKuList.size; i++) { - copyNonNullDataWithSameTypeIntoPos( - *resultDataVector, dstListEntry.offset + i, srcListValues); + auto dstListValuePos = dstListEntry.offset + i; + if (NullBuffer::isNull(srcNullBytes, i)) { + resultDataVector->setNull(dstListValuePos, true); + } else { + copyNonNullDataWithSameTypeIntoPos( + *resultDataVector, dstListValuePos, srcListValues); + } srcListValues += Types::getDataTypeSize(resultDataVector->dataType); } } break; @@ -49,12 +56,20 @@ void ValueVectorUtils::copyNonNullDataWithSameTypeOutFromPos(const ValueVector& ku_list_t dstList; dstList.size = srcListEntry.size; InMemOverflowBufferUtils::allocateSpaceForList(dstList, - Types::getDataTypeSize(srcListDataVector->dataType) * dstList.size, dstOverflowBuffer); + Types::getDataTypeSize(srcListDataVector->dataType) * dstList.size + + NullBuffer::getNumBytesForNullValues(dstList.size), + dstOverflowBuffer); + auto dstListNullBytes = reinterpret_cast(dstList.overflowPtr); + NullBuffer::initNullBytes(dstListNullBytes, dstList.size); + auto dstListValues = dstListNullBytes + NullBuffer::getNumBytesForNullValues(dstList.size); for (auto i = 0u; i < srcListEntry.size; i++) { - copyNonNullDataWithSameTypeOutFromPos(*srcListDataVector, srcListEntry.offset + i, - reinterpret_cast(dstList.overflowPtr) + - i * Types::getDataTypeSize(srcListDataVector->dataType), - dstOverflowBuffer); + if (srcListDataVector->isNull(srcListEntry.offset + i)) { + NullBuffer::setNull(dstListNullBytes, i); + } else { + copyNonNullDataWithSameTypeOutFromPos( + *srcListDataVector, srcListEntry.offset + i, dstListValues, dstOverflowBuffer); + } + dstListValues += Types::getDataTypeSize(srcListDataVector->dataType); } memcpy(dstData, &dstList, sizeof(dstList)); } break; @@ -74,11 +89,16 @@ void ValueVectorUtils::copyValue(uint8_t* dstValue, common::ValueVector& dstVect auto dstList = reinterpret_cast(dstValue); *dstList = ListVector::addList(&dstVector, srcList->size); auto srcValues = ListVector::getListValues(&srcVector, *srcList); + auto srcDataVector = ListVector::getDataVector(&srcVector); auto dstValues = ListVector::getListValues(&dstVector, *dstList); - auto numBytesPerValue = ListVector::getDataVector(&srcVector)->getNumBytesPerValue(); + auto dstDataVector = ListVector::getDataVector(&dstVector); + auto numBytesPerValue = srcDataVector->getNumBytesPerValue(); for (auto i = 0u; i < srcList->size; i++) { - copyValue(dstValues, *ListVector::getDataVector(&dstVector), srcValues, - *ListVector::getDataVector(&srcVector)); + if (srcDataVector->isNull(srcList->offset + i)) { + dstDataVector->setNull(dstList->offset + i, true); + } else { + copyValue(dstValues, *dstDataVector, srcValues, *srcDataVector); + } srcValues += numBytesPerValue; dstValues += numBytesPerValue; } diff --git a/src/function/vector_list_operation.cpp b/src/function/vector_list_operation.cpp index 0558326a48..1411b833c4 100644 --- a/src/function/vector_list_operation.cpp +++ b/src/function/vector_list_operation.cpp @@ -1,3 +1,4 @@ +#include "binder/expression_binder.h" #include "common/types/ku_list.h" #include "common/vector/value_vector_utils.h" #include "function/list/operations/list_append_operation.h" @@ -24,7 +25,7 @@ static std::string getListFunctionIncompatibleChildrenTypeErrorMsg( void ListCreationVectorOperation::execFunc( const std::vector>& parameters, ValueVector& result) { - assert(!parameters.empty() && result.dataType.typeID == VAR_LIST); + assert(result.dataType.typeID == VAR_LIST); common::StringVector::resetOverflowBuffer(&result); for (auto selectedPos = 0u; selectedPos < result.state->selVector->selectedSize; ++selectedPos) { @@ -34,12 +35,17 @@ void ListCreationVectorOperation::execFunc( auto resultValues = common::ListVector::getListValues(&result, resultEntry); auto resultDataVector = common::ListVector::getDataVector(&result); auto numBytesPerValue = resultDataVector->getNumBytesPerValue(); - for (auto& parameter : parameters) { + for (auto i = 0u; i < parameters.size(); i++) { + auto parameter = parameters[i]; auto paramPos = parameter->state->isFlat() ? parameter->state->selVector->selectedPositions[0] : pos; - common::ValueVectorUtils::copyValue(resultValues, *resultDataVector, - parameter->getData() + parameter->getNumBytesPerValue() * paramPos, *parameter); + if (parameter->isNull(paramPos)) { + resultDataVector->setNull(resultEntry.offset + i, true); + } else { + common::ValueVectorUtils::copyValue(resultValues, *resultDataVector, + parameter->getData() + parameter->getNumBytesPerValue() * paramPos, *parameter); + } resultValues += numBytesPerValue; } } @@ -47,17 +53,30 @@ void ListCreationVectorOperation::execFunc( std::unique_ptr ListCreationVectorOperation::bindFunc( const binder::expression_vector& arguments, FunctionDefinition* definition) { - if (arguments.empty()) { - throw BinderException( - "Cannot resolve child data type for " + LIST_CREATION_FUNC_NAME + "."); + // ListCreation requires all parameters to have the same type or be ANY type. The result type of + // listCreation can be determined by the first non-ANY type parameter. If all parameters have + // dataType ANY, then the resultType will be INT64[] (default type). + auto resultType = DataType{std::make_unique(INT64)}; + for (auto i = 0u; i < arguments.size(); i++) { + if (arguments[i]->getDataType().typeID != common::ANY) { + resultType = DataType{std::make_unique(arguments[i]->getDataType())}; + break; + } } - for (auto i = 1u; i < arguments.size(); i++) { - if (arguments[i]->getDataType() != arguments[0]->getDataType()) { - throw BinderException(getListFunctionIncompatibleChildrenTypeErrorMsg( - LIST_CREATION_FUNC_NAME, arguments[0]->getDataType(), arguments[i]->getDataType())); + // Cast parameters with ANY dataType to resultChildType. + for (auto i = 0u; i < arguments.size(); i++) { + auto parameterType = arguments[i]->getDataType(); + if (parameterType != *resultType.getChildType()) { + if (parameterType.typeID == common::ANY) { + binder::ExpressionBinder::resolveAnyDataType( + *arguments[i], *resultType.getChildType()); + } else { + throw BinderException( + getListFunctionIncompatibleChildrenTypeErrorMsg(LIST_CREATION_FUNC_NAME, + arguments[0]->getDataType(), arguments[i]->getDataType())); + } } } - auto resultType = DataType(std::make_unique(arguments[0]->getDataType())); return std::make_unique(resultType); } diff --git a/src/include/binder/expression_binder.h b/src/include/binder/expression_binder.h index bb4a4715ed..3d520e077d 100644 --- a/src/include/binder/expression_binder.h +++ b/src/include/binder/expression_binder.h @@ -19,6 +19,8 @@ class ExpressionBinder { std::shared_ptr bindExpression(const parser::ParsedExpression& parsedExpression); + static void resolveAnyDataType(Expression& expression, const common::DataType& targetType); + private: std::shared_ptr bindBooleanExpression( const parser::ParsedExpression& parsedExpression); @@ -91,7 +93,6 @@ class ExpressionBinder { const std::shared_ptr& expression, const common::DataType& targetType); static std::shared_ptr implicitCastIfNecessary( const std::shared_ptr& expression, common::DataTypeID targetTypeID); - static void resolveAnyDataType(Expression& expression, const common::DataType& targetType); static std::shared_ptr implicitCast( const std::shared_ptr& expression, const common::DataType& targetType); diff --git a/src/include/common/null_bytes.h b/src/include/common/null_bytes.h new file mode 100644 index 0000000000..f182f508cd --- /dev/null +++ b/src/include/common/null_bytes.h @@ -0,0 +1,33 @@ +#pragma once + +#include + +namespace kuzu { +namespace common { + +class NullBuffer { + +public: + constexpr static const uint64_t NUM_NULL_MASKS_PER_BYTE = 8; + + static inline bool isNull(const uint8_t* nullBytes, uint64_t valueIdx) { + return nullBytes[valueIdx / NUM_NULL_MASKS_PER_BYTE] & + (1 << (valueIdx % NUM_NULL_MASKS_PER_BYTE)); + } + + static inline void setNull(uint8_t* nullBytes, uint64_t valueIdx) { + nullBytes[valueIdx / NUM_NULL_MASKS_PER_BYTE] |= + (1 << (valueIdx % NUM_NULL_MASKS_PER_BYTE)); + } + + static inline uint64_t getNumBytesForNullValues(uint64_t numValues) { + return (numValues + NUM_NULL_MASKS_PER_BYTE - 1) / NUM_NULL_MASKS_PER_BYTE; + } + + static inline void initNullBytes(uint8_t* nullBytes, uint64_t numValues) { + memset(nullBytes, 0 /* value */, getNumBytesForNullValues(numValues)); + } +}; + +} // namespace common +} // namespace kuzu diff --git a/src/include/processor/result/factorized_table.h b/src/include/processor/result/factorized_table.h index 0af4a74e7c..3dde3ccbbe 100644 --- a/src/include/processor/result/factorized_table.h +++ b/src/include/processor/result/factorized_table.h @@ -149,10 +149,6 @@ class FactorizedTableSchema { columns[idx]->setMayContainsNullsToTrue(); } - static inline uint32_t getNumBytesForNullBuffer(uint32_t numColumns) { - return (numColumns >> 3) + ((numColumns & 7) != 0); // &7 is the same as %8; - } - inline bool isEmpty() const { return columns.empty(); } bool operator==(const FactorizedTableSchema& other) const; @@ -264,8 +260,6 @@ class FactorizedTable { int64_t findValueInFlatColumn(ft_col_idx_t colIdx, int64_t value) const; private: - static bool isNull(const uint8_t* nullMapBuffer, ft_col_idx_t idx); - void setNull(uint8_t* nullBuffer, ft_col_idx_t idx); void setOverflowColNull(uint8_t* nullBuffer, ft_col_idx_t colIdx, ft_tuple_idx_t tupleIdx); uint64_t computeNumTuplesToAppend( diff --git a/src/processor/result/factorized_table.cpp b/src/processor/result/factorized_table.cpp index 2ce8b23c2f..2fb1646ef0 100644 --- a/src/processor/result/factorized_table.cpp +++ b/src/processor/result/factorized_table.cpp @@ -1,6 +1,7 @@ #include "processor/result/factorized_table.h" #include "common/exception.h" +#include "common/null_bytes.h" #include "common/vector/value_vector_utils.h" using namespace kuzu::common; @@ -27,7 +28,7 @@ void FactorizedTableSchema::appendColumn(std::unique_ptr column) { columns.push_back(std::move(column)); colOffsets.push_back( colOffsets.empty() ? 0 : colOffsets.back() + getColumn(columns.size() - 2)->getNumBytes()); - numBytesForNullMapPerTuple = getNumBytesForNullBuffer(getNumColumns()); + numBytesForNullMapPerTuple = NullBuffer::getNumBytesForNullValues(getNumColumns()); numBytesPerTuple = numBytesForDataPerTuple + numBytesForNullMapPerTuple; } @@ -262,7 +263,7 @@ bool FactorizedTable::isOverflowColNull( if (tableSchema->getColumn(colIdx)->hasNoNullGuarantee()) { return false; } - return isNull(nullBuffer, tupleIdx); + return NullBuffer::isNull(nullBuffer, tupleIdx); } bool FactorizedTable::isNonOverflowColNull(const uint8_t* nullBuffer, ft_col_idx_t colIdx) const { @@ -270,11 +271,11 @@ bool FactorizedTable::isNonOverflowColNull(const uint8_t* nullBuffer, ft_col_idx if (tableSchema->getColumn(colIdx)->hasNoNullGuarantee()) { return false; } - return isNull(nullBuffer, colIdx); + return NullBuffer::isNull(nullBuffer, colIdx); } void FactorizedTable::setNonOverflowColNull(uint8_t* nullBuffer, ft_col_idx_t colIdx) { - setNull(nullBuffer, colIdx); + NullBuffer::setNull(nullBuffer, colIdx); tableSchema->setMayContainsNullsToTrue(colIdx); } @@ -337,21 +338,9 @@ void FactorizedTable::clear() { inMemOverflowBuffer->resetBuffer(); } -bool FactorizedTable::isNull(const uint8_t* nullMapBuffer, ft_col_idx_t idx) { - uint32_t nullMapIdx = idx >> 3; - uint8_t nullMapMask = 0x1 << (idx & 7); // note: &7 is the same as %8 - return nullMapBuffer[nullMapIdx] & nullMapMask; -} - -void FactorizedTable::setNull(uint8_t* nullBuffer, ft_col_idx_t idx) { - uint64_t nullMapIdx = idx >> 3; - uint8_t nullMapMask = 0x1 << (idx & 7); // note: &7 is the same as %8 - nullBuffer[nullMapIdx] |= nullMapMask; -} - void FactorizedTable::setOverflowColNull( uint8_t* nullBuffer, ft_col_idx_t colIdx, ft_tuple_idx_t tupleIdx) { - setNull(nullBuffer, tupleIdx); + NullBuffer::setNull(nullBuffer, tupleIdx); tableSchema->setMayContainsNullsToTrue(colIdx); } @@ -511,7 +500,7 @@ overflow_value_t FactorizedTable::appendVectorToUnflatTupleBlocks( auto numBytesPerValue = Types::getDataTypeSize(vector.dataType); auto numBytesForData = numBytesPerValue * numFlatTuplesInVector; auto overflowBlockBuffer = allocateUnflatTupleBlock( - numBytesForData + FactorizedTableSchema::getNumBytesForNullBuffer(numFlatTuplesInVector)); + numBytesForData + NullBuffer::getNumBytesForNullValues(numFlatTuplesInVector)); if (vector.state->selVector->isUnfiltered()) { if (vector.hasNoNullsGuarantee()) { auto dstDataBuffer = overflowBlockBuffer; diff --git a/src/storage/storage_structure/disk_overflow_file.cpp b/src/storage/storage_structure/disk_overflow_file.cpp index 4030a3de2b..aab1efd2d9 100644 --- a/src/storage/storage_structure/disk_overflow_file.cpp +++ b/src/storage/storage_structure/disk_overflow_file.cpp @@ -1,6 +1,7 @@ #include "storage/storage_structure/disk_overflow_file.h" #include "common/in_mem_overflow_buffer_utils.h" +#include "common/null_bytes.h" #include "common/string_utils.h" #include "common/type_utils.h" @@ -240,8 +241,12 @@ void DiskOverflowFile::setListRecursiveIfNestedWithoutLock( nextBytePosToWriteTo, BufferPoolConstants::PAGE_4KB_SIZE); diskDstList.size = inMemSrcList.size; // Copy non-overflow part for elements in the list. + // TODO(Ziyi): Current storage design doesn't support nulls within a list, so we can't read the + // nullBits from factorizedTable to InMemLists. + auto listValues = reinterpret_cast(inMemSrcList.overflowPtr) + + NullBuffer::getNumBytesForNullValues(inMemSrcList.size); memcpy(updatedPageInfoAndWALPageFrame.frame + updatedPageInfoAndWALPageFrame.posInPage, - (uint8_t*)inMemSrcList.overflowPtr, inMemSrcList.size * elementSize); + listValues, inMemSrcList.size * elementSize); nextBytePosToWriteTo += inMemSrcList.size * elementSize; TypeUtils::encodeOverflowPtr(diskDstList.overflowPtr, updatedPageInfoAndWALPageFrame.originalPageIdx, updatedPageInfoAndWALPageFrame.posInPage); @@ -249,19 +254,19 @@ void DiskOverflowFile::setListRecursiveIfNestedWithoutLock( updatedPageInfoAndWALPageFrame, *fileHandle, *bufferManager, *wal); if (dataType.getChildType()->typeID == STRING) { // Copy overflow for string elements in the list. - auto dstListElements = (ku_string_t*)(updatedPageInfoAndWALPageFrame.frame + - updatedPageInfoAndWALPageFrame.posInPage); + auto dstListElements = reinterpret_cast( + updatedPageInfoAndWALPageFrame.frame + updatedPageInfoAndWALPageFrame.posInPage); for (auto i = 0u; i < diskDstList.size; i++) { - auto kuString = ((ku_string_t*)inMemSrcList.overflowPtr)[i]; + auto kuString = ((ku_string_t*)listValues)[i]; setStringOverflowWithoutLock( (const char*)kuString.overflowPtr, kuString.len, dstListElements[i]); } } else if (dataType.getChildType()->typeID == VAR_LIST) { // Recursively copy overflow for list elements in the list. - auto dstListElements = (ku_list_t*)(updatedPageInfoAndWALPageFrame.frame + - updatedPageInfoAndWALPageFrame.posInPage); + auto dstListElements = reinterpret_cast( + updatedPageInfoAndWALPageFrame.frame + updatedPageInfoAndWALPageFrame.posInPage); for (auto i = 0u; i < diskDstList.size; i++) { - setListRecursiveIfNestedWithoutLock(((ku_list_t*)inMemSrcList.overflowPtr)[i], + setListRecursiveIfNestedWithoutLock((reinterpret_cast(listValues))[i], dstListElements[i], *dataType.getChildType()); } } diff --git a/test/binder/binder_error_test.cpp b/test/binder/binder_error_test.cpp index 94310eaf2b..163eb38cd6 100644 --- a/test/binder/binder_error_test.cpp +++ b/test/binder/binder_error_test.cpp @@ -468,13 +468,6 @@ TEST_F(BinderErrorTest, RenamePropertyDuplicateName) { ASSERT_STREQ(expectedException.c_str(), getBindingError(input).c_str()); } -TEST_F(BinderErrorTest, EmptyList) { - std::string expectedException = - "Binder exception: Cannot resolve child data type for LIST_CREATION."; - auto input = "RETURN []"; - ASSERT_STREQ(expectedException.c_str(), getBindingError(input).c_str()); -} - TEST_F(BinderErrorTest, InvalidFixedListChildType) { std::string expectedException = "Binder exception: The child type of a fixed list must be a numeric type. Given: STRING."; diff --git a/test/test_files/tinysnb/projection/single_label.test b/test/test_files/tinysnb/projection/single_label.test index 7280b6f9a1..05b3f8fd2c 100644 --- a/test/test_files/tinysnb/projection/single_label.test +++ b/test/test_files/tinysnb/projection/single_label.test @@ -132,6 +132,26 @@ False [[10]] [[7],[10],[6,7]] +-NAME ListWithNullValues +-QUERY RETURN [5, null, 100, null, null] +---- 1 +[5,,100,,] + +-NAME NestedListWithNullValues +-QUERY RETURN [[78, null, 100, null, null], null, [null, 5, null, 100, null]] +---- 1 +[[78,,100,,],,[,5,,100,]] + +-NAME EmptyList +-QUERY RETURN [] +---- 1 +[] + +-NAME NestedEmptyList +-QUERY RETURN [[],[]] +---- 1 +[[],[]] + -NAME CrossProductReturn -QUERY MATCH (a:organisation), (b:organisation) RETURN a.orgCode = b.orgCode ---- 9