Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add setValue and getValue for ValueVector #1045

Merged
merged 1 commit into from
Nov 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions src/common/include/vector/value_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,10 @@ class ValueVector {

~ValueVector() = default;

inline void setState(shared_ptr<DataChunkState> state_) { state = move(state_); }
void addString(uint64_t pos, string value) const;
void addString(uint64_t pos, char* value, uint64_t len) const;
inline void setState(shared_ptr<DataChunkState> state_) { state = std::move(state_); }

inline void setAllNull() { nullMask->setAllNull(); }

inline void setAllNonNull() { nullMask->setAllNonNull(); }

inline void setMayContainNulls() { nullMask->setMayContainNulls(); }
// Note that if this function returns true, there are no null. However, if it returns false, it
// doesn't mean there are nulls, i.e., there may or may not be nulls.
Expand All @@ -40,31 +36,35 @@ class ValueVector {
assert(!state->isFlat());
return nullMask->hasNoNullsGuarantee();
}

inline void setRangeNonNull(uint64_t startPos, uint64_t len) {
inline void setRangeNonNull(uint32_t startPos, uint32_t len) {
for (auto i = 0u; i < len; ++i) {
setNull(startPos + i, false);
}
}

inline uint64_t* getNullMaskData() { return nullMask->getData(); }

inline void setNull(uint64_t pos, bool isNull) { nullMask->setNull(pos, isNull); }

inline void setNull(uint32_t pos, bool isNull) { nullMask->setNull(pos, isNull); }
inline uint8_t isNull(uint32_t pos) const { return nullMask->isNull(pos); }

inline uint32_t getNumBytesPerValue() const { return numBytesPerValue; }

inline node_offset_t readNodeOffset(uint64_t pos) const {
template<typename T>
inline T getValue(uint32_t pos) const {
return ((T*)valueBuffer.get())[pos];
}
template<typename T>
void setValue(uint32_t pos, T val);

inline uint8_t* getData() const { return valueBuffer.get(); }

inline node_offset_t readNodeOffset(uint32_t pos) const {
assert(dataType.typeID == NODE_ID);
return ((nodeID_t*)values)[pos].offset;
return getValue<nodeID_t>(pos).offset;
}

inline void setSequential() { _isSequential = true; }
inline bool isSequential() const { return _isSequential; }

inline InMemOverflowBuffer& getOverflowBuffer() const { return *inMemOverflowBuffer; }

inline void resetOverflowBuffer() const {
if (inMemOverflowBuffer) {
inMemOverflowBuffer->resetBuffer();
Expand All @@ -77,9 +77,10 @@ class ValueVector {
dataType.typeID == UNSTRUCTURED;
}

void addString(uint32_t pos, char* value, uint64_t len) const;

public:
DataType dataType;
uint8_t* values;
shared_ptr<DataChunkState> state;

private:
Expand Down
2 changes: 1 addition & 1 deletion src/common/types/include/ku_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ struct ku_list_t {
private:
friend class InMemOverflowBufferUtils;

void set(const vector<uint8_t*>& parameters, DataTypeID childTypeId);
void set(const std::vector<uint8_t*>& parameters, DataTypeID childTypeId);

public:
uint64_t size;
Expand Down
47 changes: 23 additions & 24 deletions src/common/types/include/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
#include <string>
#include <vector>

using namespace std;

namespace kuzu {
namespace common {

Expand Down Expand Up @@ -60,22 +58,23 @@ class DataType {
explicit DataType(DataTypeID typeID) : typeID{typeID}, childType{nullptr} {
assert(typeID != LIST);
}
DataType(DataTypeID typeID, unique_ptr<DataType> childType)
: typeID{typeID}, childType{move(childType)} {
DataType(DataTypeID typeID, std::unique_ptr<DataType> childType)
: typeID{typeID}, childType{std::move(childType)} {
assert(typeID == LIST);
}

DataType(const DataType& other);
DataType(DataType&& other) noexcept : typeID{other.typeID}, childType{move(other.childType)} {}
DataType(DataType&& other) noexcept
: typeID{other.typeID}, childType{std::move(other.childType)} {}

static inline vector<DataTypeID> getNumericalTypeIDs() {
return vector<DataTypeID>{INT64, DOUBLE};
static inline std::vector<DataTypeID> getNumericalTypeIDs() {
return std::vector<DataTypeID>{INT64, DOUBLE};
}
static inline vector<DataTypeID> getNumericalAndUnstructuredTypeIDs() {
return vector<DataTypeID>{INT64, DOUBLE, UNSTRUCTURED};
static inline std::vector<DataTypeID> getNumericalAndUnstructuredTypeIDs() {
return std::vector<DataTypeID>{INT64, DOUBLE, UNSTRUCTURED};
}
static inline vector<DataTypeID> getAllValidTypeIDs() {
return vector<DataTypeID>{
static inline std::vector<DataTypeID> getAllValidTypeIDs() {
return std::vector<DataTypeID>{
NODE_ID, BOOL, INT64, DOUBLE, STRING, UNSTRUCTURED, DATE, TIMESTAMP, INTERVAL, LIST};
}

Expand All @@ -87,39 +86,39 @@ class DataType {

inline DataType& operator=(DataType&& other) noexcept {
typeID = other.typeID;
childType = move(other.childType);
childType = std::move(other.childType);
return *this;
}

public:
DataTypeID typeID;
unique_ptr<DataType> childType;
std::unique_ptr<DataType> childType;

private:
unique_ptr<DataType> copy();
std::unique_ptr<DataType> copy();
};

class Types {
public:
static string dataTypeToString(const DataType& dataType);
static string dataTypeToString(DataTypeID dataTypeID);
static string dataTypesToString(const vector<DataType>& dataTypes);
static string dataTypesToString(const vector<DataTypeID>& dataTypeIDs);
static DataType dataTypeFromString(const string& dataTypeString);
static const uint32_t getDataTypeSize(DataTypeID dataTypeID);
static inline const uint32_t getDataTypeSize(const DataType& dataType) {
static std::string dataTypeToString(const DataType& dataType);
static std::string dataTypeToString(DataTypeID dataTypeID);
static std::string dataTypesToString(const std::vector<DataType>& dataTypes);
static std::string dataTypesToString(const std::vector<DataTypeID>& dataTypeIDs);
static DataType dataTypeFromString(const std::string& dataTypeString);
static uint32_t getDataTypeSize(DataTypeID dataTypeID);
static inline uint32_t getDataTypeSize(const DataType& dataType) {
return getDataTypeSize(dataType.typeID);
}

private:
static DataTypeID dataTypeIDFromString(const string& dataTypeIDString);
static DataTypeID dataTypeIDFromString(const std::string& dataTypeIDString);
};

// RelDirection
enum RelDirection : uint8_t { FWD = 0, BWD = 1 };
const vector<RelDirection> REL_DIRECTIONS = {FWD, BWD};
const std::vector<RelDirection> REL_DIRECTIONS = {FWD, BWD};
RelDirection operator!(RelDirection& direction);
string getRelDirectionAsString(RelDirection relDirection);
std::string getRelDirectionAsString(RelDirection relDirection);

enum class DBFileType : uint8_t { ORIGINAL = 0, WAL_VERSION = 1 };

Expand Down
2 changes: 1 addition & 1 deletion src/common/types/ku_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ void ku_list_t::set(const uint8_t* values, const DataType& dataType) const {
size * Types::getDataTypeSize(*dataType.childType));
}

void ku_list_t::set(const vector<uint8_t*>& parameters, DataTypeID childTypeId) {
void ku_list_t::set(const std::vector<uint8_t*>& parameters, DataTypeID childTypeId) {
this->size = parameters.size();
auto numBytesOfListElement = Types::getDataTypeSize(childTypeId);
for (auto i = 0u; i < parameters.size(); i++) {
Expand Down
2 changes: 1 addition & 1 deletion src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ string Types::dataTypesToString(const vector<DataTypeID>& dataTypeIDs) {
return result;
}

const uint32_t Types::getDataTypeSize(DataTypeID dataTypeID) {
uint32_t Types::getDataTypeSize(DataTypeID dataTypeID) {
switch (dataTypeID) {
case NODE_ID:
return sizeof(nodeID_t);
Expand Down
33 changes: 25 additions & 8 deletions src/common/vector/value_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@ namespace kuzu {
namespace common {

ValueVector::ValueVector(DataType dataType, MemoryManager* memoryManager)
: dataType{move(dataType)} {
: dataType{std::move(dataType)} {
valueBuffer =
make_unique<uint8_t[]>(Types::getDataTypeSize(this->dataType) * DEFAULT_VECTOR_CAPACITY);
values = valueBuffer.get();
if (needOverflowBuffer()) {
assert(memoryManager != nullptr);
inMemOverflowBuffer = make_unique<InMemOverflowBuffer>(memoryManager);
Expand All @@ -19,17 +18,13 @@ ValueVector::ValueVector(DataType dataType, MemoryManager* memoryManager)
numBytesPerValue = Types::getDataTypeSize(this->dataType);
}

void ValueVector::addString(uint64_t pos, char* value, uint64_t len) const {
void ValueVector::addString(uint32_t pos, char* value, uint64_t len) const {
assert(dataType.typeID == STRING);
auto vectorData = (ku_string_t*)values;
auto vectorData = (ku_string_t*)valueBuffer.get();
auto& result = vectorData[pos];
InMemOverflowBufferUtils::copyString(value, len, result, *inMemOverflowBuffer);
}

void ValueVector::addString(uint64_t pos, string value) const {
addString(pos, value.data(), value.length());
}

bool NodeIDVector::discardNull(ValueVector& vector) {
if (vector.state->isFlat()) {
return !vector.isNull(vector.state->getPositionOfCurrIdx());
Expand Down Expand Up @@ -57,5 +52,27 @@ bool NodeIDVector::discardNull(ValueVector& vector) {
}
}

template<typename T>
void ValueVector::setValue(uint32_t pos, T val) {
((T*)valueBuffer.get())[pos] = val;
}

template<>
void ValueVector::setValue(uint32_t pos, string val) {
addString(pos, val.data(), val.length());
}

template void ValueVector::setValue<nodeID_t>(uint32_t pos, nodeID_t val);
template void ValueVector::setValue<bool>(uint32_t pos, bool val);
template void ValueVector::setValue<int64_t>(uint32_t pos, int64_t val);
template void ValueVector::setValue<hash_t>(uint32_t pos, hash_t val);
template void ValueVector::setValue<double_t>(uint32_t pos, double_t val);
template void ValueVector::setValue<date_t>(uint32_t pos, date_t val);
template void ValueVector::setValue<timestamp_t>(uint32_t pos, timestamp_t val);
template void ValueVector::setValue<interval_t>(uint32_t pos, interval_t val);
template void ValueVector::setValue<ku_string_t>(uint32_t pos, ku_string_t val);
template void ValueVector::setValue<ku_list_t>(uint32_t pos, ku_list_t val);
template void ValueVector::setValue<Value>(uint32_t pos, Value val);

} // namespace common
} // namespace kuzu
18 changes: 9 additions & 9 deletions src/common/vector/value_vector_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,25 @@ void ValueVectorUtils::addLiteralToStructuredVector(
}
switch (literal.dataType.typeID) {
case INT64: {
((int64_t*)resultVector.values)[pos] = literal.val.int64Val;
resultVector.setValue(pos, literal.val.int64Val);
} break;
case DOUBLE: {
((double_t*)resultVector.values)[pos] = literal.val.doubleVal;
resultVector.setValue(pos, literal.val.doubleVal);
} break;
case BOOL: {
((bool*)resultVector.values)[pos] = literal.val.booleanVal;
resultVector.setValue(pos, literal.val.booleanVal);
} break;
case DATE: {
((date_t*)resultVector.values)[pos] = literal.val.dateVal;
resultVector.setValue(pos, literal.val.dateVal);
} break;
case TIMESTAMP: {
((timestamp_t*)resultVector.values)[pos] = literal.val.timestampVal;
resultVector.setValue(pos, literal.val.timestampVal);
} break;
case INTERVAL: {
((interval_t*)resultVector.values)[pos] = literal.val.intervalVal;
resultVector.setValue(pos, literal.val.intervalVal);
} break;
case STRING: {
resultVector.addString(pos, literal.strVal);
resultVector.setValue(pos, literal.strVal);
} break;
default:
assert(false);
Expand All @@ -42,14 +42,14 @@ void ValueVectorUtils::addLiteralToStructuredVector(
void ValueVectorUtils::copyNonNullDataWithSameTypeIntoPos(
ValueVector& resultVector, uint64_t pos, const uint8_t* srcData) {
copyNonNullDataWithSameType(resultVector.dataType, srcData,
resultVector.values + pos * resultVector.getNumBytesPerValue(),
resultVector.getData() + pos * resultVector.getNumBytesPerValue(),
resultVector.getOverflowBuffer());
}

void ValueVectorUtils::copyNonNullDataWithSameTypeOutFromPos(const ValueVector& srcVector,
uint64_t pos, uint8_t* dstData, InMemOverflowBuffer& dstOverflowBuffer) {
copyNonNullDataWithSameType(srcVector.dataType,
srcVector.values + pos * srcVector.getNumBytesPerValue(), dstData, dstOverflowBuffer);
srcVector.getData() + pos * srcVector.getNumBytesPerValue(), dstData, dstOverflowBuffer);
}

void ValueVectorUtils::copyNonNullDataWithSameType(const DataType& dataType, const uint8_t* srcData,
Expand Down
3 changes: 2 additions & 1 deletion src/expression_evaluator/literal_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ void LiteralExpressionEvaluator::init(const ResultSet& resultSet, MemoryManager*
}

bool LiteralExpressionEvaluator::select(SelectionVector& selVector) {
assert(resultVector->dataType.typeID == BOOL); // TODO(Guodong): Is this expected here?
auto pos = resultVector->state->getPositionOfCurrIdx();
assert(pos == 0u);
return resultVector->values[pos] == true && (!resultVector->isNull(pos));
return resultVector->getValue<bool>(pos) == true && (!resultVector->isNull(pos));
}

} // namespace evaluator
Expand Down
2 changes: 1 addition & 1 deletion src/expression_evaluator/reference_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace evaluator {

inline static bool isTrue(ValueVector& vector, uint64_t pos) {
assert(vector.dataType.typeID == BOOL);
return !vector.isNull(pos) && ((bool*)vector.values)[pos];
return !vector.isNull(pos) && vector.getValue<bool>(pos);
}

void ReferenceExpressionEvaluator::init(const ResultSet& resultSet, MemoryManager* memoryManager) {
Expand Down
6 changes: 3 additions & 3 deletions src/function/aggregate/include/avg.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ struct AvgFunction {

static void updateSingleValue(
AvgState* state, ValueVector* input, uint32_t pos, uint64_t multiplicity) {
auto inputValues = (T*)input->values;
T val = input->getValue<T>(pos);
for (auto i = 0u; i < multiplicity; ++i) {
if (state->isNull) {
state->sum = inputValues[pos];
state->sum = val;
state->isNull = false;
} else {
Add::operation(state->sum, inputValues[pos], state->sum);
Add::operation(state->sum, val, state->sum);
}
}
state->count += multiplicity;
Expand Down
8 changes: 4 additions & 4 deletions src/function/aggregate/include/min_max.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ struct MinMaxFunction {

template<class OP>
static void updateSingleValue(MinMaxState* state, ValueVector* input, uint32_t pos) {
auto inputValues = (T*)input->values;
T val = input->getValue<T>(pos);
if (state->isNull) {
state->val = inputValues[pos];
state->val = val;
state->isNull = false;
} else {
uint8_t compare_result;
OP::template operation<T, T>(inputValues[pos], state->val, compare_result);
state->val = compare_result ? inputValues[pos] : state->val;
OP::template operation<T, T>(val, state->val, compare_result);
state->val = compare_result ? val : state->val;
}
}

Expand Down
6 changes: 3 additions & 3 deletions src/function/aggregate/include/sum.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ struct SumFunction {

static void updateSingleValue(
SumState* state, ValueVector* input, uint32_t pos, uint64_t multiplicity) {
auto inputValues = (T*)input->values;
T val = input->getValue<T>(pos);
for (auto j = 0u; j < multiplicity; ++j) {
if (state->isNull) {
state->sum = inputValues[pos];
state->sum = val;
state->isNull = false;
} else {
Add::operation(state->sum, inputValues[pos], state->sum);
Add::operation(state->sum, val, state->sum);
}
}
}
Expand Down
Loading