Skip to content

Commit

Permalink
copy: use string_view to avoid a copy
Browse files Browse the repository at this point in the history
  • Loading branch information
Riolku committed Sep 29, 2023
1 parent 8a32704 commit d4b0f19
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 97 deletions.
2 changes: 1 addition & 1 deletion src/include/common/type_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class StringCastUtils {
static T castToNum(const char* data, uint64_t length) {
T result;
if (!tryCastToNum(data, length, result)) {
throw ConversionException{"Invalid number: " + std::string{data} + "."};
throw ConversionException{"Invalid number: " + std::string{data, length} + "."};
}
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <cstdint>
#include <string>
#include <string_view>

#include "common/copier_config/copier_config.h"
#include "common/data_chunk/data_chunk.h"
Expand All @@ -28,7 +29,7 @@ class BaseCSVReader {
uint64_t countRows();

protected:
void addValue(common::DataChunk& resultChunk, std::string strVal, common::column_id_t columnIdx,
void addValue(common::DataChunk&, std::string_view, common::column_id_t columnIdx,
std::vector<uint64_t>& escapePositions);
void addRow(common::DataChunk&, common::column_id_t column);

Expand Down Expand Up @@ -63,7 +64,7 @@ class BaseCSVReader {
virtual void handleQuotedNewline() = 0;

private:
void copyStringToVector(common::ValueVector*, std::string);
void copyStringToVector(common::ValueVector*, std::string_view);
//! Called after a row is finished to determine if we should keep processing.
inline bool finishedBlock() {
return mode != ParserMode::PARSING || rowToAdd >= common::DEFAULT_VECTOR_CAPACITY ||
Expand Down
25 changes: 12 additions & 13 deletions src/include/storage/store/table_copy_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,29 +35,29 @@ struct StructFieldIdxAndValue {
class TableCopyUtils {
public:
static void throwCopyExceptionIfNotOK(const arrow::Status& status);
static std::unique_ptr<common::Value> getVarListValue(const std::string& l, int64_t from,
static std::unique_ptr<common::Value> getVarListValue(std::string_view l, int64_t from,
int64_t to, const common::LogicalType& dataType,
const common::CSVReaderConfig& csvReaderConfig);
static std::unique_ptr<common::Value> getArrowFixedListVal(const std::string& l, int64_t from,
static std::unique_ptr<common::Value> getArrowFixedListVal(std::string_view l, int64_t from,
int64_t to, const common::LogicalType& dataType,
const common::CSVReaderConfig& csvReaderConfig);
static std::unique_ptr<uint8_t[]> getArrowFixedList(const std::string& l, int64_t from,
static std::unique_ptr<uint8_t[]> getArrowFixedList(std::string_view l, int64_t from,
int64_t to, const common::LogicalType& dataType,
const common::CSVReaderConfig& csvReaderConfig);
static std::shared_ptr<arrow::csv::StreamingReader> createRelTableCSVReader(
const std::string& filePath, const common::ReaderConfig& config);
static std::unique_ptr<parquet::arrow::FileReader> createParquetReader(
const std::string& filePath, const common::ReaderConfig& config);

static std::vector<std::pair<int64_t, int64_t>> splitByDelimiter(const std::string& l,
static std::vector<std::pair<int64_t, int64_t>> splitByDelimiter(std::string_view l,
int64_t from, int64_t to, const common::CSVReaderConfig& csvReaderConfig);

static std::shared_ptr<arrow::DataType> toArrowDataType(const common::LogicalType& dataType);

static bool tryCast(const common::LogicalType& targetType, const char* value, uint64_t length);

static std::vector<StructFieldIdxAndValue> parseStructFieldNameAndValues(
common::LogicalType& type, const std::string& structString,
common::LogicalType& type, std::string_view structString,
const common::CSVReaderConfig& csvReaderConfig);

static std::unique_ptr<arrow::PrimitiveArray> createArrowPrimitiveArray(
Expand All @@ -67,20 +67,19 @@ class TableCopyUtils {
uint64_t length);

private:
static std::unique_ptr<common::Value> convertStringToValue(std::string element,
static std::unique_ptr<common::Value> convertStringToValue(std::string_view element,
const common::LogicalType& type, const common::CSVReaderConfig& csvReaderConfig);

static void validateNumElementsInList(
uint64_t numElementsRead, const common::LogicalType& type);
static std::unique_ptr<common::Value> parseVarList(const std::string& l, int64_t from,
int64_t to, const common::LogicalType& dataType,
const common::CSVReaderConfig& csvReaderConfig);
static std::unique_ptr<common::Value> parseMap(const std::string& l, int64_t from, int64_t to,
static std::unique_ptr<common::Value> parseVarList(std::string_view l, int64_t from, int64_t to,
const common::LogicalType& dataType, const common::CSVReaderConfig& csvReaderConfig);
static std::unique_ptr<common::Value> parseMap(std::string_view l, int64_t from, int64_t to,
const common::LogicalType& dataType, const common::CSVReaderConfig& csvReaderConfig);
static std::pair<std::string, std::string> parseMapFields(const std::string& l, int64_t from,
static std::pair<std::string, std::string> parseMapFields(std::string_view l, int64_t from,
int64_t length, const common::CSVReaderConfig& csvReaderConfig);
static std::string parseStructFieldName(const std::string& structString, uint64_t& curPos);
static std::string parseStructFieldValue(const std::string& structString, uint64_t& curPos,
static std::string parseStructFieldName(std::string_view structString, uint64_t& curPos);
static std::string parseStructFieldValue(std::string_view structString, uint64_t& curPos,
const common::CSVReaderConfig& csvReaderConfig);
};

Expand Down
59 changes: 30 additions & 29 deletions src/processor/operator/persistent/reader/csv/base_csv_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ uint64_t BaseCSVReader::countRows() {
goto in_quotes;
}

void BaseCSVReader::addValue(DataChunk& resultChunk, std::string strVal,
void BaseCSVReader::addValue(DataChunk& resultChunk, std::string_view strVal,
const column_id_t columnIdx, std::vector<uint64_t>& escapePositions) {
if (mode == ParserMode::PARSING_HEADER) {
return;
Expand All @@ -163,22 +163,23 @@ void BaseCSVReader::addValue(DataChunk& resultChunk, std::string strVal,
getLineNumber(), expectedNumColumns));
}

ValueVector* destination_vector = resultChunk.getValueVector(columnIdx).get();
// insert the line number into the chunk
if (!escapePositions.empty()) {
// remove escape characters (if any)
std::string oldVal = strVal;
std::string newVal = "";
uint64_t prevPos = 0;
for (auto i = 0u; i < escapePositions.size(); i++) {
auto nextPos = escapePositions[i];
newVal += oldVal.substr(prevPos, nextPos - prevPos);
newVal += strVal.substr(prevPos, nextPos - prevPos);
prevPos = nextPos + 1;
}
newVal += oldVal.substr(prevPos, oldVal.size() - prevPos);
newVal += strVal.substr(prevPos, strVal.size() - prevPos);
escapePositions.clear();
strVal = newVal;
copyStringToVector(destination_vector, std::string_view(newVal.begin(), newVal.end()));
} else {
copyStringToVector(destination_vector, strVal);
}
copyStringToVector(resultChunk.getValueVector(columnIdx).get(), std::move(strVal));
}

void BaseCSVReader::addRow(DataChunk& resultChunk, column_id_t column) {
Expand Down Expand Up @@ -206,7 +207,7 @@ void BaseCSVReader::addRow(DataChunk& resultChunk, column_id_t column) {
}
}

void BaseCSVReader::copyStringToVector(common::ValueVector* vector, std::string strVal) {
void BaseCSVReader::copyStringToVector(common::ValueVector* vector, std::string_view strVal) {
auto& type = vector->dataType;
if (strVal.empty()) {
vector->setNull(rowToAdd, true /* isNull */);
Expand All @@ -217,68 +218,68 @@ void BaseCSVReader::copyStringToVector(common::ValueVector* vector, std::string
switch (type.getLogicalTypeID()) {
case LogicalTypeID::INT64: {
vector->setValue(
rowToAdd, StringCastUtils::castToNum<int64_t>(strVal.c_str(), strVal.length()));
rowToAdd, StringCastUtils::castToNum<int64_t>(strVal.data(), strVal.length()));
} break;
case LogicalTypeID::INT32: {
vector->setValue(
rowToAdd, StringCastUtils::castToNum<int32_t>(strVal.c_str(), strVal.length()));
rowToAdd, StringCastUtils::castToNum<int32_t>(strVal.data(), strVal.length()));
} break;
case LogicalTypeID::INT16: {
vector->setValue(
rowToAdd, StringCastUtils::castToNum<int16_t>(strVal.c_str(), strVal.length()));
rowToAdd, StringCastUtils::castToNum<int16_t>(strVal.data(), strVal.length()));
} break;
case LogicalTypeID::INT8: {
vector->setValue(
rowToAdd, StringCastUtils::castToNum<int8_t>(strVal.c_str(), strVal.length()));
rowToAdd, StringCastUtils::castToNum<int8_t>(strVal.data(), strVal.length()));
} break;
case LogicalTypeID::UINT64: {
vector->setValue(
rowToAdd, StringCastUtils::castToNum<uint64_t>(strVal.c_str(), strVal.length()));
rowToAdd, StringCastUtils::castToNum<uint64_t>(strVal.data(), strVal.length()));
} break;
case LogicalTypeID::UINT32: {
vector->setValue(
rowToAdd, StringCastUtils::castToNum<uint32_t>(strVal.c_str(), strVal.length()));
rowToAdd, StringCastUtils::castToNum<uint32_t>(strVal.data(), strVal.length()));
} break;
case LogicalTypeID::UINT16: {
vector->setValue(
rowToAdd, StringCastUtils::castToNum<uint16_t>(strVal.c_str(), strVal.length()));
rowToAdd, StringCastUtils::castToNum<uint16_t>(strVal.data(), strVal.length()));
} break;
case LogicalTypeID::UINT8: {
vector->setValue(
rowToAdd, StringCastUtils::castToNum<uint8_t>(strVal.c_str(), strVal.length()));
rowToAdd, StringCastUtils::castToNum<uint8_t>(strVal.data(), strVal.length()));
} break;
case LogicalTypeID::FLOAT: {
vector->setValue(
rowToAdd, StringCastUtils::castToNum<float_t>(strVal.c_str(), strVal.length()));
rowToAdd, StringCastUtils::castToNum<float_t>(strVal.data(), strVal.length()));
} break;
case LogicalTypeID::DOUBLE: {
vector->setValue(
rowToAdd, StringCastUtils::castToNum<double_t>(strVal.c_str(), strVal.length()));
rowToAdd, StringCastUtils::castToNum<double_t>(strVal.data(), strVal.length()));
} break;
case LogicalTypeID::BOOL: {
vector->setValue(rowToAdd, StringCastUtils::castToBool(strVal.c_str(), strVal.length()));
vector->setValue(rowToAdd, StringCastUtils::castToBool(strVal.data(), strVal.length()));
} break;
case LogicalTypeID::BLOB: {
if (strVal.length() > BufferPoolConstants::PAGE_4KB_SIZE) {
throw CopyException(
ExceptionMessage::overLargeStringValueException(std::to_string(strVal.length())));
}
auto blobBuffer = std::make_unique<uint8_t[]>(strVal.length());
auto blobLen = Blob::fromString(strVal.c_str(), strVal.length(), blobBuffer.get());
auto blobLen = Blob::fromString(strVal.data(), strVal.length(), blobBuffer.get());
StringVector::addString(
vector, rowToAdd, reinterpret_cast<char*>(blobBuffer.get()), blobLen);
} break;
case LogicalTypeID::STRING: {
StringVector::addString(vector, rowToAdd, strVal.c_str(), strVal.length());
StringVector::addString(vector, rowToAdd, strVal.data(), strVal.length());
} break;
case LogicalTypeID::DATE: {
vector->setValue(rowToAdd, Date::fromCString(strVal.c_str(), strVal.length()));
vector->setValue(rowToAdd, Date::fromCString(strVal.data(), strVal.length()));
} break;
case LogicalTypeID::TIMESTAMP: {
vector->setValue(rowToAdd, Timestamp::fromCString(strVal.c_str(), strVal.length()));
vector->setValue(rowToAdd, Timestamp::fromCString(strVal.data(), strVal.length()));
} break;
case LogicalTypeID::INTERVAL: {
vector->setValue(rowToAdd, Interval::fromCString(strVal.c_str(), strVal.length()));
vector->setValue(rowToAdd, Interval::fromCString(strVal.data(), strVal.length()));
} break;
case LogicalTypeID::MAP:
case LogicalTypeID::VAR_LIST: {
Expand Down Expand Up @@ -306,7 +307,7 @@ void BaseCSVReader::copyStringToVector(common::ValueVector* vector, std::string
for (auto i = 0u; i < UnionType::getNumFields(&type); i++) {
auto internalFieldIdx = UnionType::getInternalFieldIdx(i);
if (storage::TableCopyUtils::tryCast(
*UnionType::getFieldType(&type, i), strVal.c_str(), strVal.length())) {
*UnionType::getFieldType(&type, i), strVal.data(), strVal.length())) {
StructVector::getFieldVector(vector, internalFieldIdx)
->setNull(rowToAdd, false /* isNull */);
copyStringToVector(
Expand Down Expand Up @@ -440,8 +441,8 @@ uint64_t BaseCSVReader::parseCSV(DataChunk& resultChunk) {
// We get here after we have a delimiter.
assert(buffer[position] == csvReaderConfig.delimiter);
// Trim one character if we have quotes.
addValue(resultChunk, std::string(buffer.get() + start, position - start - hasQuotes), column,
escapePositions);
addValue(resultChunk, std::string_view(buffer.get() + start, position - start - hasQuotes),
column, escapePositions);
column++;

// Move past the delimiter.
Expand All @@ -458,8 +459,8 @@ add_row : {
// We get here after we have a newline.
assert(isNewLine(buffer[position]));
bool isCarriageReturn = buffer[position] == '\r';
addValue(resultChunk, std::string(buffer.get() + start, position - start - hasQuotes), column,
escapePositions);
addValue(resultChunk, std::string_view(buffer.get() + start, position - start - hasQuotes),
column, escapePositions);
column++;

addRow(resultChunk, column);
Expand Down Expand Up @@ -571,7 +572,7 @@ add_row : {
// If we were mid-value, add the remaining value to the chunk.
if (position > start) {
// Add remaining value to chunk.
addValue(resultChunk, std::string(buffer.get() + start, position - start - hasQuotes),
addValue(resultChunk, std::string_view(buffer.get() + start, position - start - hasQuotes),
column, escapePositions);
column++;
}
Expand Down
Loading

0 comments on commit d4b0f19

Please sign in to comment.