Skip to content

Commit

Permalink
Merge pull request #1765 from kuzudb/list-null-fix
Browse files Browse the repository at this point in the history
fix concat-null
  • Loading branch information
acquamarin committed Jul 6, 2023
2 parents ca3128e + 4b20699 commit 7845162
Show file tree
Hide file tree
Showing 14 changed files with 98 additions and 119 deletions.
35 changes: 13 additions & 22 deletions src/common/vector/value_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,15 @@ void ValueVector::copyFromVectorData(
}
}

void ValueVector::copyFromVectorData(
uint64_t posToCopy, const ValueVector* srcVector, uint64_t srcPos) {
setNull(posToCopy, srcVector->isNull(srcPos));
if (!isNull(posToCopy)) {
copyFromVectorData(getData() + posToCopy * getNumBytesPerValue(), srcVector,
srcVector->getData() + srcPos * srcVector->getNumBytesPerValue());
}
}

void ValueVector::resetAuxiliaryBuffer() {
switch (dataType.getPhysicalType()) {
case PhysicalTypeID::STRING: {
Expand Down Expand Up @@ -300,20 +309,12 @@ void ListVector::copyFromVectorData(ValueVector* dstVector, uint8_t* dstData,
auto& srcListEntry = *(common::list_entry_t*)(srcData);
auto& dstListEntry = *(common::list_entry_t*)(dstData);
dstListEntry = addList(dstVector, srcListEntry.size);
auto srcListData = getListValues(srcVector, srcListEntry);
auto srcDataVector = getDataVector(srcVector);
auto dstListData = getListValues(dstVector, dstListEntry);
auto srcPos = srcListEntry.offset;
auto dstDataVector = getDataVector(dstVector);
auto numBytesPerValue = srcDataVector->getNumBytesPerValue();
auto dstPos = dstListEntry.offset;
for (auto i = 0u; i < srcListEntry.size; i++) {
if (srcDataVector->isNull(srcListEntry.offset + i)) {
dstDataVector->setNull(dstListEntry.offset + i, true);
} else {
dstDataVector->setNull(dstListEntry.offset + i, false);
dstDataVector->copyFromVectorData(dstListData, srcDataVector, srcListData);
}
srcListData += numBytesPerValue;
dstListData += numBytesPerValue;
dstDataVector->copyFromVectorData(dstPos++, srcDataVector, srcPos++);
}
}

Expand Down Expand Up @@ -362,17 +363,7 @@ void StructVector::copyFromVectorData(ValueVector* dstVector, const uint8_t* dst
for (auto i = 0u; i < srcFieldVectors.size(); i++) {
auto srcFieldVector = srcFieldVectors[i];
auto dstFieldVector = dstFieldVectors[i];
if (srcFieldVector->isNull(srcPos)) {
dstFieldVector->setNull(dstPos, true /* isNull */);
} else {
dstFieldVector->setNull(dstPos, false /* isNull */);
auto srcFieldVectorData =
srcFieldVector->getData() + srcFieldVector->getNumBytesPerValue() * srcPos;
auto dstFieldVectorData =
dstFieldVector->getData() + dstFieldVector->getNumBytesPerValue() * dstPos;
dstFieldVector->copyFromVectorData(
dstFieldVectorData, srcFieldVector.get(), srcFieldVectorData);
}
dstFieldVector->copyFromVectorData(dstPos, srcFieldVector.get(), srcPos);
}
}

Expand Down
4 changes: 1 addition & 3 deletions src/expression_evaluator/case_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,7 @@ void CaseExpressionEvaluator::fillEntry(sel_t resultPos, const ValueVector& then
if (thenVector.dataType.getLogicalTypeID() == common::LogicalTypeID::VAR_LIST) {
auto srcListEntry = thenVector.getValue<list_entry_t>(thenPos);
list_entry_t resultEntry = ListVector::addList(resultVector.get(), srcListEntry.size);
resultVector->copyFromVectorData(reinterpret_cast<uint8_t*>(&resultEntry), &thenVector,
reinterpret_cast<uint8_t*>(&srcListEntry));
resultVector->setValue(resultPos, resultEntry);
resultVector->copyFromVectorData(resultPos, &thenVector, thenPos);
} else {
auto val = thenVector.getValue<T>(thenPos);
resultVector->setValue<T>(resultPos, val);
Expand Down
15 changes: 5 additions & 10 deletions src/function/vector_list_operation.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "binder/expression/literal_expression.h"
#include "binder/expression_binder.h"
#include "common/types/ku_list.h"
#include "function/list/operations/list_any_value_operation.h"
Expand Down Expand Up @@ -37,21 +38,14 @@ void ListCreationVectorOperation::execFunc(
auto pos = result.state->selVector->selectedPositions[selectedPos];
auto resultEntry = common::ListVector::addList(&result, parameters.size());
result.setValue(pos, resultEntry);
auto resultValues = common::ListVector::getListValues(&result, resultEntry);
auto resultDataVector = common::ListVector::getDataVector(&result);
auto numBytesPerValue = resultDataVector->getNumBytesPerValue();
auto resultPos = resultEntry.offset;
for (auto i = 0u; i < parameters.size(); i++) {
auto parameter = parameters[i];
auto paramPos = parameter->state->isFlat() ?
parameter->state->selVector->selectedPositions[0] :
pos;
if (parameter->isNull(paramPos)) {
resultDataVector->setNull(resultEntry.offset + i, true);
} else {
resultDataVector->copyFromVectorData(resultValues, parameter.get(),
parameter->getData() + parameter->getNumBytesPerValue() * paramPos);
}
resultValues += numBytesPerValue;
resultDataVector->copyFromVectorData(resultPos++, parameter.get(), paramPos);
}
}
}
Expand Down Expand Up @@ -219,7 +213,8 @@ vector_operation_definitions ListAppendVectorOperation::getDefinitions() {

std::unique_ptr<FunctionBindData> ListPrependVectorOperation::bindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition) {
if (arguments[0]->dataType != *VarListType::getChildType(&arguments[1]->dataType)) {
if (arguments[0]->getDataType().getLogicalTypeID() != LogicalTypeID::ANY &&
arguments[0]->dataType != *VarListType::getChildType(&arguments[1]->dataType)) {
throw BinderException(getListFunctionIncompatibleChildrenTypeErrorMsg(
LIST_APPEND_FUNC_NAME, arguments[0]->getDataType(), arguments[1]->getDataType()));
}
Expand Down
24 changes: 5 additions & 19 deletions src/function/vector_struct_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,28 +67,14 @@ void StructPackVectorOperations::copyParameterValueToStructFieldVector(
// If the parameter is unFlat, then its state must be consistent with the result's state.
// Thus, we don't need to copy values to structFieldVector.
assert(parameter->state->isFlat());
auto srcPos = parameter->state->selVector->selectedPositions[0];
auto srcValue = parameter->getData() + parameter->getNumBytesPerValue() * srcPos;
bool isSrcValueNull = parameter->isNull(srcPos);
auto paramPos = parameter->state->selVector->selectedPositions[0];
if (structField->state->isFlat()) {
auto pos = structField->state->selVector->selectedPositions[0];
if (isSrcValueNull) {
structField->setNull(pos, true /* isNull */);
} else {
structField->copyFromVectorData(
structField->getData() + structField->getNumBytesPerValue() * pos, parameter,
srcValue);
}
structField->copyFromVectorData(pos, parameter, paramPos);
} else {
for (auto j = 0u; j < structField->state->selVector->selectedSize; j++) {
auto pos = structField->state->selVector->selectedPositions[j];
if (isSrcValueNull) {
structField->setNull(pos, true /* isNull */);
} else {
structField->copyFromVectorData(
structField->getData() + structField->getNumBytesPerValue() * pos, parameter,
srcValue);
}
for (auto i = 0u; i < structField->state->selVector->selectedSize; i++) {
auto pos = structField->state->selVector->selectedPositions[i];
structField->copyFromVectorData(pos, parameter, paramPos);
}
}
}
Expand Down
4 changes: 3 additions & 1 deletion src/include/common/vector/value_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,13 @@ class ValueVector {
void setValue(uint32_t pos, T val);
// copyFromRowData assumes rowData is non-NULL.
void copyFromRowData(uint32_t pos, const uint8_t* rowData);
// copyFromVectorData assumes srcVectorData is non-NULL.
// copyToRowData assumes srcVectorData is non-NULL.
void copyToRowData(
uint32_t pos, uint8_t* rowData, InMemOverflowBuffer* rowOverflowBuffer) const;
// copyFromVectorData assumes srcVectorData is non-NULL.
void copyFromVectorData(
uint8_t* dstData, const ValueVector* srcVector, const uint8_t* srcVectorData);
void copyFromVectorData(uint64_t posToCopy, const ValueVector* srcVector, uint64_t srcPos);

inline uint8_t* getData() const { return valueBuffer.get(); }

Expand Down
16 changes: 6 additions & 10 deletions src/include/function/list/operations/base_list_sort_operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ struct BaseListSortOperation {
bool nullFirst) {
// TODO(Ziyi) - Replace this sort implementation with radix_sort implementation:
// https://github.com/kuzudb/kuzu/issues/1536.
auto inputValues = common::ListVector::getListValues(&inputVector, input);
auto inputDataVector = common::ListVector::getDataVector(&inputVector);
auto inputPos = input.offset;

// Calculate null count.
auto nullCount = 0;
Expand All @@ -49,32 +49,28 @@ struct BaseListSortOperation {
}

result = common::ListVector::addList(&resultVector, input.size);
auto resultValues = common::ListVector::getListValues(&resultVector, result);
auto resultDataVector = common::ListVector::getDataVector(&resultVector);
auto numBytesPerValue = resultDataVector->getNumBytesPerValue();
auto resultPos = result.offset;

// Add nulls first.
if (nullFirst) {
setVectorRangeToNull(*resultDataVector, result.offset, 0, nullCount);
resultValues += numBytesPerValue * nullCount;
resultPos += nullCount;
}

// Add actual data.
for (auto i = 0; i < input.size; i++) {
if (inputDataVector->isNull(input.offset + i)) {
inputValues += numBytesPerValue;
if (inputDataVector->isNull(inputPos)) {
inputPos++;
continue;
}
resultDataVector->copyFromVectorData(resultValues, inputDataVector, inputValues);
resultValues += numBytesPerValue;
inputValues += numBytesPerValue;
resultDataVector->copyFromVectorData(resultPos++, inputDataVector, inputPos++);
}

// Add nulls in the end.
if (!nullFirst) {
setVectorRangeToNull(
*resultDataVector, result.offset, input.size - nullCount, input.size);
resultValues += numBytesPerValue * nullCount;
}

// Determine the starting and ending position of the data to be sorted.
Expand Down
12 changes: 5 additions & 7 deletions src/include/function/list/operations/list_append_operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,16 @@ struct ListAppend {
common::list_entry_t& result, common::ValueVector& listVector,
common::ValueVector& valueVector, common::ValueVector& resultVector) {
result = common::ListVector::addList(&resultVector, listEntry.size + 1);
auto listValues = common::ListVector::getListValues(&listVector, listEntry);
auto listDataVector = common::ListVector::getDataVector(&listVector);
auto resultValues = common::ListVector::getListValues(&resultVector, result);
auto listPos = listEntry.offset;
auto resultDataVector = common::ListVector::getDataVector(&resultVector);
auto numBytesPerValue = resultDataVector->getNumBytesPerValue();
auto resultPos = result.offset;
for (auto i = 0u; i < listEntry.size; i++) {
resultDataVector->copyFromVectorData(resultValues, listDataVector, listValues);
listValues += numBytesPerValue;
resultValues += numBytesPerValue;
resultDataVector->copyFromVectorData(resultPos++, listDataVector, listPos++);
}
resultDataVector->copyFromVectorData(
resultValues, &valueVector, reinterpret_cast<uint8_t*>(&value));
resultDataVector->getData() + resultPos * resultDataVector->getNumBytesPerValue(),
&valueVector, reinterpret_cast<uint8_t*>(&value));
}
};

Expand Down
17 changes: 6 additions & 11 deletions src/include/function/list/operations/list_concat_operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,17 @@ struct ListConcat {
common::list_entry_t& result, common::ValueVector& leftVector,
common::ValueVector& rightVector, common::ValueVector& resultVector) {
result = common::ListVector::addList(&resultVector, left.size + right.size);
auto leftValues = common::ListVector::getListValues(&leftVector, left);
auto leftDataVector = common::ListVector::getDataVector(&leftVector);
auto resultValues = common::ListVector::getListValues(&resultVector, result);
auto resultDataVector = common::ListVector::getDataVector(&resultVector);
auto numBytesPerValue = resultDataVector->getNumBytesPerValue();
auto resultPos = result.offset;
auto leftDataVector = common::ListVector::getDataVector(&leftVector);
auto leftPos = left.offset;
for (auto i = 0u; i < left.size; i++) {
resultDataVector->copyFromVectorData(resultValues, leftDataVector, leftValues);
resultValues += numBytesPerValue;
leftValues += numBytesPerValue;
resultDataVector->copyFromVectorData(resultPos++, leftDataVector, leftPos++);
}
auto rightValues = common::ListVector::getListValues(&rightVector, right);
auto rightDataVector = common::ListVector::getDataVector(&rightVector);
auto rightPos = right.offset;
for (auto i = 0u; i < right.size; i++) {
resultDataVector->copyFromVectorData(resultValues, rightDataVector, rightValues);
resultValues += numBytesPerValue;
rightValues += numBytesPerValue;
resultDataVector->copyFromVectorData(resultPos++, rightDataVector, rightPos++);
}
}
};
Expand Down
15 changes: 6 additions & 9 deletions src/include/function/list/operations/list_prepend_operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,15 @@ struct ListPrepend {
common::list_entry_t& result, common::ValueVector& valueVector,
common::ValueVector& listVector, common::ValueVector& resultVector) {
result = common::ListVector::addList(&resultVector, listEntry.size + 1);
auto listValues = common::ListVector::getListValues(&listVector, listEntry);
auto listDataVector = common::ListVector::getDataVector(&listVector);
auto resultValues = common::ListVector::getListValues(&resultVector, result);
auto resultDataVector = common::ListVector::getDataVector(&resultVector);
auto numBytesPerValue = resultDataVector->getNumBytesPerValue();
resultDataVector->copyFromVectorData(
resultValues, &valueVector, reinterpret_cast<uint8_t*>(&value));
resultValues += numBytesPerValue;
common::ListVector::getListValues(&resultVector, result), &valueVector,
reinterpret_cast<uint8_t*>(&value));
auto resultPos = result.offset + 1;
auto listDataVector = common::ListVector::getDataVector(&listVector);
auto listPos = listEntry.offset;
for (auto i = 0u; i < listEntry.size; i++) {
resultDataVector->copyFromVectorData(resultValues, listDataVector, listValues);
listValues += numBytesPerValue;
resultValues += numBytesPerValue;
resultDataVector->copyFromVectorData(resultPos++, listDataVector, listPos++);
}
}
};
Expand Down
13 changes: 4 additions & 9 deletions src/include/function/list/operations/list_slice_operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,12 @@ struct ListSlice {
int64_t startIdx = (begin == 0) ? 1 : begin;
int64_t endIdx = (end == 0) ? listEntry.size : end;
result = common::ListVector::addList(&resultVector, endIdx - startIdx);
auto srcValues =
common::ListVector::getListValuesWithOffset(&listVector, listEntry, startIdx - 1);
auto dstValues = common::ListVector::getListValues(&resultVector, result);
auto numBytesPerValue =
common::ListVector::getDataVector(&listVector)->getNumBytesPerValue();
auto srcDataVector = common::ListVector::getDataVector(&listVector);
auto srcPos = listEntry.offset + startIdx - 1;
auto dstDataVector = common::ListVector::getDataVector(&resultVector);
for (auto i = startIdx; i < endIdx; i++) {
dstDataVector->copyFromVectorData(dstValues, srcDataVector, srcValues);
srcValues += numBytesPerValue;
dstValues += numBytesPerValue;
auto dstPos = result.offset;
for (auto i = 0u; i < endIdx - startIdx; i++) {
dstDataVector->copyFromVectorData(dstPos++, srcDataVector, srcPos++);
}
}

Expand Down
9 changes: 3 additions & 6 deletions src/include/function/map/operations/map_creation_operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,11 @@ struct MapCreation {

static void copyListEntry(common::list_entry_t& resultEntry, common::ValueVector* resultVector,
common::list_entry_t& srcEntry, common::ValueVector* srcVector) {
auto resultValues =
resultVector->getData() + resultVector->getNumBytesPerValue() * resultEntry.offset;
auto srcValues = common::ListVector::getListValues(srcVector, srcEntry);
auto resultPos = resultEntry.offset;
auto srcDataVector = common::ListVector::getDataVector(srcVector);
auto srcPos = srcEntry.offset;
for (auto i = 0u; i < srcEntry.size; i++) {
resultVector->copyFromVectorData(resultValues, srcDataVector, srcValues);
srcValues += srcDataVector->getNumBytesPerValue();
resultValues += resultVector->getNumBytesPerValue();
resultVector->copyFromVectorData(resultPos++, srcDataVector, srcPos++);
}
}
};
Expand Down
8 changes: 3 additions & 5 deletions src/include/function/map/operations/map_extract_operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,19 @@ struct MapExtract {
auto mapKeyVector = common::MapVector::getKeyVector(&listVector);
auto mapKeyValues = common::MapVector::getMapKeys(&listVector, listEntry);
auto mapValVector = common::MapVector::getValueVector(&listVector);
auto mapValValues = common::MapVector::getMapValues(&listVector, listEntry);
auto mapValPos = listEntry.offset;
uint8_t comparisonResult;
for (auto i = 0u; i < listEntry.size; i++) {
Equals::operation(*reinterpret_cast<T*>(mapKeyValues), key, comparisonResult,
mapKeyVector, &keyVector);
if (comparisonResult) {
resultEntry = common::ListVector::addList(&resultVector, 1 /* size */);
common::ListVector::getDataVector(&resultVector)
->copyFromVectorData(
common::ListVector::getListValues(&resultVector, resultEntry), mapValVector,
mapValValues);
->copyFromVectorData(resultEntry.offset, mapValVector, mapValPos);
return;
}
mapKeyValues += mapKeyVector->getNumBytesPerValue();
mapValValues += mapValVector->getNumBytesPerValue();
mapValPos++;
}
// If the key is not found, return an empty list.
resultEntry = common::ListVector::addList(&resultVector, 0 /* size */);
Expand Down
10 changes: 3 additions & 7 deletions src/processor/operator/unwind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,11 @@ bool Unwind::hasMoreToRead() const {
}

void Unwind::copyTuplesToOutVector(uint64_t startPos, uint64_t endPos) const {
auto listValues = common::ListVector::getListValuesWithOffset(
expressionEvaluator->resultVector.get(), listEntry, startPos);
auto listDataVector =
common::ListVector::getDataVector(expressionEvaluator->resultVector.get());
for (auto pos = startPos; pos < endPos; pos++) {
outValueVector->copyFromVectorData(
outValueVector->getData() + outValueVector->getNumBytesPerValue() * (pos - startPos),
listDataVector, listValues);
listValues += listDataVector->getNumBytesPerValue();
auto listPos = listEntry.offset + startPos;
for (auto i = 0u; i < endPos - startPos; i++) {
outValueVector->copyFromVectorData(i, listDataVector, listPos++);
}
}

Expand Down
Loading

0 comments on commit 7845162

Please sign in to comment.