Skip to content

Commit

Permalink
Reader function refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin committed Aug 29, 2023
1 parent 440eb0a commit 05af99b
Show file tree
Hide file tree
Showing 11 changed files with 281 additions and 213 deletions.
48 changes: 48 additions & 0 deletions src/common/vector/value_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,54 @@ void ValueVector::copyFromVectorData(
}
}

void ValueVector::copyFromValue(uint64_t pos, const Value& value) {
auto dstValue = valueBuffer.get() + pos * numBytesPerValue;
switch (dataType.getPhysicalType()) {
case PhysicalTypeID::INT64: {
memcpy(dstValue, &value.val.int64Val, numBytesPerValue);
} break;
case PhysicalTypeID::INT32: {
memcpy(dstValue, &value.val.int32Val, numBytesPerValue);
} break;
case PhysicalTypeID::INT16: {
memcpy(dstValue, &value.val.int16Val, numBytesPerValue);
} break;
case PhysicalTypeID::DOUBLE: {
memcpy(dstValue, &value.val.doubleVal, numBytesPerValue);
} break;
case PhysicalTypeID::FLOAT: {
memcpy(dstValue, &value.val.floatVal, numBytesPerValue);
} break;
case PhysicalTypeID::BOOL: {
memcpy(dstValue, &value.val.booleanVal, numBytesPerValue);
} break;
case PhysicalTypeID::INTERVAL: {
memcpy(dstValue, &value.val.intervalVal, numBytesPerValue);
} break;
case PhysicalTypeID::STRING: {
StringVector::addString(
this, *(ku_string_t*)dstValue, value.strVal.data(), value.strVal.length());
} break;
case PhysicalTypeID::VAR_LIST: {
auto listEntry = reinterpret_cast<list_entry_t*>(dstValue);
auto numValues = NestedVal::getChildrenSize(&value);
*listEntry = ListVector::addList(this, numValues);
auto dstDataVector = ListVector::getDataVector(this);
for (auto i = 0u; i < numValues; ++i) {
dstDataVector->copyFromValue(listEntry->offset + i, *NestedVal::getChildVal(&value, i));
}
} break;
case PhysicalTypeID::STRUCT: {
auto structFields = StructVector::getFieldVectors(this);
for (auto i = 0u; i < structFields.size(); ++i) {
structFields[i]->copyFromValue(pos, *NestedVal::getChildVal(&value, i));

Check warning on line 166 in src/common/vector/value_vector.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/vector/value_vector.cpp#L164-L166

Added lines #L164 - L166 were not covered by tests
}
} break;
default:
throw NotImplementedException("ValueVector::copyFromValue");

Check warning on line 170 in src/common/vector/value_vector.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/vector/value_vector.cpp#L168-L170

Added lines #L168 - L170 were not covered by tests
}
}

void ValueVector::resetAuxiliaryBuffer() {
switch (dataType.getPhysicalType()) {
case PhysicalTypeID::STRING: {
Expand Down
50 changes: 2 additions & 48 deletions src/expression_evaluator/literal_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,57 +16,11 @@ bool LiteralExpressionEvaluator::select(SelectionVector& selVector) {
void LiteralExpressionEvaluator::resolveResultVector(
const processor::ResultSet& resultSet, MemoryManager* memoryManager) {
resultVector = std::make_shared<ValueVector>(*value->getDataType(), memoryManager);
resultVector->setState(DataChunkState::getSingleValueDataChunkState());
if (value->isNull()) {
resultVector->setNull(0 /* pos */, true);
} else {
copyValueToVector(resultVector->getData(), resultVector.get(), value.get());
}
resultVector->setState(DataChunkState::getSingleValueDataChunkState());
}

void LiteralExpressionEvaluator::copyValueToVector(
uint8_t* dstValue, ValueVector* dstVector, const Value* srcValue) {
auto numBytesPerValue = dstVector->getNumBytesPerValue();
switch (srcValue->getDataType()->getPhysicalType()) {
case PhysicalTypeID::INT64: {
memcpy(dstValue, &srcValue->val.int64Val, numBytesPerValue);
} break;
case PhysicalTypeID::INT32: {
memcpy(dstValue, &srcValue->val.int32Val, numBytesPerValue);
} break;
case PhysicalTypeID::INT16: {
memcpy(dstValue, &srcValue->val.int16Val, numBytesPerValue);
} break;
case PhysicalTypeID::DOUBLE: {
memcpy(dstValue, &srcValue->val.doubleVal, numBytesPerValue);
} break;
case PhysicalTypeID::FLOAT: {
memcpy(dstValue, &srcValue->val.floatVal, numBytesPerValue);
} break;
case PhysicalTypeID::BOOL: {
memcpy(dstValue, &srcValue->val.booleanVal, numBytesPerValue);
} break;
case PhysicalTypeID::INTERVAL: {
memcpy(dstValue, &srcValue->val.intervalVal, numBytesPerValue);
} break;
case PhysicalTypeID::STRING: {
StringVector::addString(
dstVector, *(ku_string_t*)dstValue, srcValue->strVal.data(), srcValue->strVal.length());
} break;
case PhysicalTypeID::VAR_LIST: {
auto listListEntry = reinterpret_cast<list_entry_t*>(dstValue);
auto numValues = NestedVal::getChildrenSize(srcValue);
*listListEntry = ListVector::addList(dstVector, numValues);
auto dstDataVector = ListVector::getDataVector(dstVector);
auto dstElements = ListVector::getListValues(dstVector, *listListEntry);
for (auto i = 0u; i < numValues; ++i) {
copyValueToVector(dstElements + i * dstDataVector->getNumBytesPerValue(), dstDataVector,
NestedVal::getChildVal(srcValue, i));
}
} break;
default:
throw NotImplementedException("Unimplemented setLiteral() for type " +
LogicalTypeUtils::dataTypeToString(dstVector->dataType));
resultVector->copyFromValue(resultVector->state->selVector->selectedPositions[0], *value);
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/include/common/types/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace common {

class FileInfo;

using sel_t = uint16_t;
using sel_t = uint32_t;
using hash_t = uint64_t;
using page_idx_t = uint32_t;
using frame_idx_t = page_idx_t;
Expand Down
2 changes: 2 additions & 0 deletions src/include/common/vector/value_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class ValueVector {
uint8_t* dstData, const ValueVector* srcVector, const uint8_t* srcVectorData);
void copyFromVectorData(uint64_t dstPos, const ValueVector* srcVector, uint64_t srcPos);

void copyFromValue(uint64_t pos, const Value& value);

inline uint8_t* getData() const { return valueBuffer.get(); }

inline offset_t readNodeOffset(uint32_t pos) const {
Expand Down
4 changes: 0 additions & 4 deletions src/include/expression_evaluator/literal_evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ class LiteralExpressionEvaluator : public ExpressionEvaluator {
void resolveResultVector(
const processor::ResultSet& resultSet, storage::MemoryManager* memoryManager) override;

private:
static void copyValueToVector(
uint8_t* dstValue, common::ValueVector* dstVector, const common::Value* srcValue);

private:
std::shared_ptr<common::Value> value;
};
Expand Down
15 changes: 8 additions & 7 deletions src/include/processor/operator/persistent/reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ class Reader : public PhysicalOperator {
readerInfo)},
sharedState{std::move(sharedState)}, leftNumRows{0}, readFuncData{nullptr} {}

inline void initGlobalStateInternal(ExecutionContext* context) final {
sharedState->validate();
sharedState->countBlocks();
}
void initGlobalStateInternal(ExecutionContext* context) final;

void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) final;

inline bool isSource() const final { return true; }

inline std::unique_ptr<PhysicalOperator> clone() final {
Expand All @@ -36,13 +36,14 @@ class Reader : public PhysicalOperator {
bool getNextTuplesInternal(ExecutionContext* context) final;

private:
void getNextNodeGroupInSerial(std::shared_ptr<arrow::Table>& table);
void getNextNodeGroupInParallel(std::shared_ptr<arrow::Table>& table);
void getNextNodeGroupInSerial();
void getNextNodeGroupInParallel();

private:
ReaderInfo readerInfo;
std::shared_ptr<storage::ReaderSharedState> sharedState;
std::vector<std::shared_ptr<arrow::RecordBatch>> leftRecordBatches;
std::vector<common::ValueVector*> vectorsToRead;
std::vector<arrow::ArrayVector> leftArrays;
common::row_idx_t leftNumRows;

// For parallel reading.
Expand Down
42 changes: 23 additions & 19 deletions src/include/storage/copier/reader_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,9 @@ struct ReaderMorsel {
};

struct SerialReaderMorsel : public ReaderMorsel {
SerialReaderMorsel(common::vector_idx_t fileIdx, common::block_idx_t blockIdx,
common::row_idx_t rowIdx, std::shared_ptr<arrow::Table> table)
: ReaderMorsel{fileIdx, blockIdx, rowIdx}, table{std::move(table)} {}

std::shared_ptr<arrow::Table> table;
SerialReaderMorsel(
common::vector_idx_t fileIdx, common::block_idx_t blockIdx, common::row_idx_t rowIdx)
: ReaderMorsel{fileIdx, blockIdx, rowIdx} {}
};

using validate_func_t =
Expand All @@ -84,8 +82,8 @@ using init_reader_data_func_t = std::function<std::unique_ptr<ReaderFunctionData
using count_blocks_func_t =
std::function<std::vector<FileBlocksInfo>(std::vector<std::string>& paths,
common::CSVReaderConfig csvReaderConfig, catalog::TableSchema* tableSchema)>;
using read_rows_func_t = std::function<arrow::RecordBatchVector(
const ReaderFunctionData& functionData, common::block_idx_t blockIdx)>;
using read_rows_func_t = std::function<void(const ReaderFunctionData& functionData,
common::block_idx_t blockIdx, std::vector<common::ValueVector*>)>;

struct ReaderFunctions {
static validate_func_t getValidateFunc(common::CopyDescription::FileType fileType);
Expand Down Expand Up @@ -121,12 +119,12 @@ struct ReaderFunctions {
common::vector_idx_t fileIdx, common::CSVReaderConfig csvReaderConfig,
catalog::TableSchema* tableSchema);

static arrow::RecordBatchVector readRowsFromCSVFile(
const ReaderFunctionData& functionData, common::block_idx_t blockIdx);
static arrow::RecordBatchVector readRowsFromParquetFile(
const ReaderFunctionData& functionData, common::block_idx_t blockIdx);
static arrow::RecordBatchVector readRowsFromNPYFile(
const ReaderFunctionData& functionData, common::block_idx_t blockIdx);
static void readRowsFromCSVFile(const ReaderFunctionData& functionData,
common::block_idx_t blockIdx, std::vector<common::ValueVector*> vectorsToRead);
static void readRowsFromParquetFile(const ReaderFunctionData& functionData,
common::block_idx_t blockIdx, std::vector<common::ValueVector*> vectorsToRead);
static void readRowsFromNPYFile(const ReaderFunctionData& functionData,
common::block_idx_t blockIdx, std::vector<common::ValueVector*> vectorsToRead);
};

class ReaderSharedState {
Expand All @@ -138,7 +136,7 @@ class ReaderSharedState {
catalog::TableSchema* tableSchema)
: fileType{fileType}, filePaths{std::move(filePaths)}, csvReaderConfig{csvReaderConfig},
tableSchema{tableSchema}, numRows{0}, currFileIdx{0}, currBlockIdx{0}, currRowIdx{0},
leftRecordBatches{}, leftNumRows{0} {
leftNumRows{0} {
validateFunc = ReaderFunctions::getValidateFunc(fileType);
initFunc = ReaderFunctions::getInitDataFunc(fileType);
countBlocksFunc = ReaderFunctions::getCountBlocksFunc(fileType);
Expand All @@ -148,18 +146,24 @@ class ReaderSharedState {
void validate();
void countBlocks();

std::unique_ptr<ReaderMorsel> getSerialMorsel();
std::unique_ptr<ReaderMorsel> getSerialMorsel(std::vector<common::ValueVector*> vectorsToRead);
std::unique_ptr<ReaderMorsel> getParallelMorsel();

inline void lock() { mtx.lock(); }
inline void unlock() { mtx.unlock(); }
inline common::row_idx_t& getNumRowsRef() { return std::ref(numRows); }

static void appendToArrayVectors(std::vector<arrow::ArrayVector>& arrayVectors,
std::vector<common::ValueVector*>& vectorsToAppend, uint64_t& numRows);

static void appendArrowArraysToVectors(std::vector<arrow::ArrayVector>& arrayVectorsToAppend,
std::vector<common::ValueVector*> vectors, uint64_t numRowsToAppend);

private:
std::unique_ptr<ReaderMorsel> getMorselOfNextBlock();
static arrow::ArrayVector extractArrayVectorToAppend(
arrow::ArrayVector& arrayVectorToExtract, uint64_t numRowsToAppend);

static std::shared_ptr<arrow::Table> constructTableFromBatches(
std::vector<std::shared_ptr<arrow::RecordBatch>>& recordBatches);
std::unique_ptr<ReaderMorsel> getMorselOfNextBlock();

public:
std::mutex mtx;
Expand All @@ -182,8 +186,8 @@ class ReaderSharedState {
common::block_idx_t currBlockIdx;
common::row_idx_t currRowIdx;

std::vector<std::shared_ptr<arrow::RecordBatch>> leftRecordBatches;
common::row_idx_t leftNumRows;
std::vector<arrow::ArrayVector> leftArrays;
};

} // namespace storage
Expand Down
14 changes: 10 additions & 4 deletions src/include/storage/copier/rel_copier.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,16 @@ class RelCopier {
static void indexLookup(arrow::Array* pkArray, const common::LogicalType& pkColumnType,
PrimaryKeyIndex* pkIndex, common::offset_t* offsets);

void copyRelColumnsOrCountRelListsSize(common::row_idx_t rowIdx,
arrow::RecordBatch* recordBatch, common::RelDataDirection direction,
void copyRelColumnsOrCountRelListsSize(common::row_idx_t rowIdx, arrow::ArrayVector& arrays,
common::RelDataDirection direction,
const std::vector<std::unique_ptr<arrow::Array>>& pkOffsets);

void copyRelColumns(common::row_idx_t rowIdx, arrow::RecordBatch* recordBatch,
void copyRelColumns(common::row_idx_t rowIdx, arrow::ArrayVector& arrays,
common::RelDataDirection direction,
const std::vector<std::unique_ptr<arrow::Array>>& pkOffsets);
void countRelListsSize(common::RelDataDirection direction,
const std::vector<std::unique_ptr<arrow::Array>>& pkOffsets);
void copyRelLists(common::row_idx_t rowIdx, arrow::RecordBatch* recordBatch,
void copyRelLists(common::row_idx_t rowIdx, arrow::ArrayVector& arrays,
common::RelDataDirection direction,
const std::vector<std::unique_ptr<arrow::Array>>& pkOffsets);
void checkViolationOfRelColumn(
Expand All @@ -84,6 +84,12 @@ class RelCopier {
std::unique_ptr<storage::ReaderFunctionData> readFuncData;
storage::read_rows_func_t readFunc;
storage::init_reader_data_func_t initFunc;

protected:
std::vector<common::ValueVector*> vectorsToRead;

private:
std::vector<std::unique_ptr<common::ValueVector>> arrowVectors;
};

class RelListsCounterAndColumnCopier : public RelCopier {
Expand Down
50 changes: 26 additions & 24 deletions src/processor/operator/persistent/reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,36 @@ using namespace kuzu::storage;
namespace kuzu {
namespace processor {

bool Reader::getNextTuplesInternal(ExecutionContext* context) {
std::shared_ptr<arrow::Table> table = nullptr;
readerInfo.isOrderPreserving ? getNextNodeGroupInSerial(table) :
getNextNodeGroupInParallel(table);
if (table == nullptr) {
return false;
}
for (auto i = 0u; i < readerInfo.dataColumnPoses.size(); i++) {
ArrowColumnVector::setArrowColumn(
resultSet->getValueVector(readerInfo.dataColumnPoses[i]).get(), table->column((int)i));
void Reader::initGlobalStateInternal(ExecutionContext* context) {
sharedState->validate();
sharedState->countBlocks();
}

void Reader::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) {
vectorsToRead.reserve(readerInfo.dataColumnPoses.size());
for (auto& dataPos : readerInfo.dataColumnPoses) {
auto valueVector = resultSet->getValueVector(dataPos);
vectorsToRead.push_back(resultSet->getValueVector(dataPos).get());
}
return true;
leftArrays.resize(vectorsToRead.size());
}

void Reader::getNextNodeGroupInSerial(std::shared_ptr<arrow::Table>& table) {
auto morsel = sharedState->getSerialMorsel();
bool Reader::getNextTuplesInternal(ExecutionContext* context) {
readerInfo.isOrderPreserving ? getNextNodeGroupInSerial() : getNextNodeGroupInParallel();
return vectorsToRead[0]->state->selVector->selectedSize != 0;
}

void Reader::getNextNodeGroupInSerial() {
auto morsel = sharedState->getSerialMorsel(vectorsToRead);
if (morsel->fileIdx == INVALID_VECTOR_IDX) {
return;
}
auto serialMorsel = reinterpret_cast<SerialReaderMorsel*>(morsel.get());
table = serialMorsel->table;
auto nodeOffsetVector = resultSet->getValueVector(readerInfo.nodeOffsetPos).get();
nodeOffsetVector->setValue(
nodeOffsetVector->state->selVector->selectedPositions[0], morsel->rowIdx);
}

void Reader::getNextNodeGroupInParallel(std::shared_ptr<arrow::Table>& table) {
void Reader::getNextNodeGroupInParallel() {
while (leftNumRows < StorageConstants::NODE_GROUP_SIZE) {
auto morsel = sharedState->getParallelMorsel();
if (morsel->fileIdx == INVALID_VECTOR_IDX) {
Expand All @@ -45,17 +48,16 @@ void Reader::getNextNodeGroupInParallel(std::shared_ptr<arrow::Table>& table) {
readFuncData = readerInfo.initFunc(sharedState->filePaths, morsel->fileIdx,
sharedState->csvReaderConfig, sharedState->tableSchema);
}
auto batchVector = readerInfo.readFunc(*readFuncData, morsel->blockIdx);
for (auto& batch : batchVector) {
leftNumRows += batch->num_rows();
leftRecordBatches.push_back(std::move(batch));
}
readerInfo.readFunc(*readFuncData, morsel->blockIdx, vectorsToRead);
ReaderSharedState::appendToArrayVectors(leftArrays, vectorsToRead, leftNumRows);
}
if (leftNumRows == 0) {
return;
vectorsToRead[0]->state->selVector->selectedSize = 0;
} else {
int64_t numRowsToReturn = std::min(leftNumRows, StorageConstants::NODE_GROUP_SIZE);
ReaderSharedState::appendArrowArraysToVectors(leftArrays, vectorsToRead, numRowsToReturn);
leftNumRows -= numRowsToReturn;
}
table = ReaderSharedState::constructTableFromBatches(leftRecordBatches);
leftNumRows -= table->num_rows();
}

} // namespace processor
Expand Down
Loading

0 comments on commit 05af99b

Please sign in to comment.