Skip to content

Commit

Permalink
Implement map-literal
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin committed Jun 8, 2023
1 parent 20d696a commit efabd0c
Show file tree
Hide file tree
Showing 37 changed files with 621 additions and 431 deletions.
36 changes: 23 additions & 13 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,14 @@ bool LogicalType::operator==(const LogicalType& other) const {
if (typeID != other.typeID) {
return false;
}
switch (other.typeID) {
case LogicalTypeID::VAR_LIST:
switch (other.getPhysicalType()) {
case PhysicalTypeID::VAR_LIST:
return *reinterpret_cast<VarListTypeInfo*>(extraTypeInfo.get()) ==
*reinterpret_cast<VarListTypeInfo*>(other.extraTypeInfo.get());
case LogicalTypeID::FIXED_LIST:
case PhysicalTypeID::FIXED_LIST:
return *reinterpret_cast<FixedListTypeInfo*>(extraTypeInfo.get()) ==
*reinterpret_cast<FixedListTypeInfo*>(other.extraTypeInfo.get());
case LogicalTypeID::STRUCT:
case PhysicalTypeID::STRUCT:
return *reinterpret_cast<StructTypeInfo*>(extraTypeInfo.get()) ==
*reinterpret_cast<StructTypeInfo*>(other.extraTypeInfo.get());
default:
Expand Down Expand Up @@ -228,6 +228,7 @@ void LogicalType::setPhysicalType() {
physicalType = PhysicalTypeID::STRING;
} break;
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::MAP:
case LogicalTypeID::VAR_LIST: {
physicalType = PhysicalTypeID::VAR_LIST;
} break;
Expand Down Expand Up @@ -322,6 +323,12 @@ LogicalTypeID LogicalTypeUtils::dataTypeIDFromString(const std::string& dataType

std::string LogicalTypeUtils::dataTypeToString(const LogicalType& dataType) {
switch (dataType.typeID) {
case LogicalTypeID::MAP: {
auto structType = common::VarListType::getChildType(&dataType);
auto fieldTypes = common::StructType::getFieldTypes(structType);
return "MAP(" + dataTypeToString(*fieldTypes[0]) + ": " + dataTypeToString(*fieldTypes[1]) +
")";

Check warning on line 330 in src/common/types/types.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/types/types.cpp#L329-L330

Added lines #L329 - L330 were not covered by tests
}
case LogicalTypeID::VAR_LIST: {
auto varListTypeInfo = reinterpret_cast<VarListTypeInfo*>(dataType.extraTypeInfo.get());
return dataTypeToString(*varListTypeInfo->getChildType()) + "[]";
Expand Down Expand Up @@ -407,6 +414,8 @@ std::string LogicalTypeUtils::dataTypeToString(LogicalTypeID dataTypeID) {
return "STRUCT";
case LogicalTypeID::SERIAL:
return "SERIAL";
case LogicalTypeID::MAP:
return "MAP";

Check warning on line 418 in src/common/types/types.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/types/types.cpp#L418

Added line #L418 was not covered by tests
default:
throw NotImplementedException("LogicalTypeUtils::dataTypeToString.");
}
Expand Down Expand Up @@ -484,7 +493,8 @@ std::vector<LogicalTypeID> LogicalTypeUtils::getNumericalLogicalTypeIDs() {
}

std::vector<LogicalTypeID> LogicalTypeUtils::getAllValidLogicTypeIDs() {
// TODO(Ziyi): Add FIX_LIST,STRUCT type to allValidTypeID when we support functions on VAR_LIST.
// TODO(Ziyi): Add FIX_LIST,STRUCT,MAP type to allValidTypeID when we support functions on
// FIXED_LIST,STRUCT,MAP.
return std::vector<LogicalTypeID>{LogicalTypeID::INTERNAL_ID, LogicalTypeID::BOOL,
LogicalTypeID::INT64, LogicalTypeID::INT32, LogicalTypeID::INT16, LogicalTypeID::DOUBLE,
LogicalTypeID::STRING, LogicalTypeID::DATE, LogicalTypeID::TIMESTAMP,
Expand Down Expand Up @@ -578,16 +588,16 @@ uint64_t SerDeser::deserializeValue(
template<>
uint64_t SerDeser::serializeValue(const LogicalType& value, FileInfo* fileInfo, uint64_t offset) {
offset = SerDeser::serializeValue(value.getLogicalTypeID(), fileInfo, offset);
switch (value.getLogicalTypeID()) {
case LogicalTypeID::VAR_LIST: {
switch (value.getPhysicalType()) {
case PhysicalTypeID::VAR_LIST: {
auto varListTypeInfo = reinterpret_cast<VarListTypeInfo*>(value.extraTypeInfo.get());
offset = serializeValue(*varListTypeInfo, fileInfo, offset);
} break;
case LogicalTypeID::FIXED_LIST: {
case PhysicalTypeID::FIXED_LIST: {
auto fixedListTypeInfo = reinterpret_cast<FixedListTypeInfo*>(value.extraTypeInfo.get());
offset = serializeValue(*fixedListTypeInfo, fileInfo, offset);
} break;
case LogicalTypeID::STRUCT: {
case PhysicalTypeID::STRUCT: {
auto structTypeInfo = reinterpret_cast<StructTypeInfo*>(value.extraTypeInfo.get());
offset = serializeValue(*structTypeInfo, fileInfo, offset);
} break;
Expand All @@ -601,19 +611,19 @@ template<>
uint64_t SerDeser::deserializeValue(LogicalType& value, FileInfo* fileInfo, uint64_t offset) {
offset = SerDeser::deserializeValue(value.typeID, fileInfo, offset);
value.setPhysicalType();
switch (value.getLogicalTypeID()) {
case LogicalTypeID::VAR_LIST: {
switch (value.getPhysicalType()) {
case PhysicalTypeID::VAR_LIST: {
value.extraTypeInfo = std::make_unique<VarListTypeInfo>();
offset = deserializeValue(
*reinterpret_cast<VarListTypeInfo*>(value.extraTypeInfo.get()), fileInfo, offset);

} break;
case LogicalTypeID::FIXED_LIST: {
case PhysicalTypeID::FIXED_LIST: {
value.extraTypeInfo = std::make_unique<FixedListTypeInfo>();
offset = deserializeValue(
*reinterpret_cast<FixedListTypeInfo*>(value.extraTypeInfo.get()), fileInfo, offset);
} break;
case LogicalTypeID::STRUCT: {
case PhysicalTypeID::STRUCT: {
value.extraTypeInfo = std::make_unique<StructTypeInfo>();
offset = deserializeValue(
*reinterpret_cast<StructTypeInfo*>(value.extraTypeInfo.get()), fileInfo, offset);
Expand Down
13 changes: 12 additions & 1 deletion src/common/types/value.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include "common/types/value.h"

#include "common/null_buffer.h"
#include "common/string_utils.h"
#include "storage/storage_utils.h"

namespace kuzu {
Expand Down Expand Up @@ -66,6 +65,7 @@ Value Value::createDefaultValue(const LogicalType& dataType) {
case LogicalTypeID::FLOAT:
return Value((float_t)0);
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::MAP:
case LogicalTypeID::VAR_LIST:
case LogicalTypeID::FIXED_LIST:
case LogicalTypeID::STRUCT:
Expand Down Expand Up @@ -282,6 +282,17 @@ std::string Value::toString() const {
return TypeUtils::toString(val.internalIDVal);
case LogicalTypeID::STRING:
return strVal;
case LogicalTypeID::MAP: {
std::string result = "{";
for (auto i = 0u; i < nestedTypeVal.size(); ++i) {
auto structVal = nestedTypeVal[i].get();
result += structVal->nestedTypeVal[0]->toString();
result += "=";
result += structVal->nestedTypeVal[1]->toString();
result += (i == nestedTypeVal.size() - 1 ? "}" : ", ");
}
return result;
}
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::VAR_LIST:
case LogicalTypeID::FIXED_LIST: {
Expand Down
34 changes: 22 additions & 12 deletions src/common/vector/auxiliary_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,29 +35,39 @@ list_entry_t ListAuxiliaryBuffer::addList(uint64_t listSize) {
}
auto numBytesPerElement = dataVector->getNumBytesPerValue();
if (needResizeDataVector) {
auto buffer = std::make_unique<uint8_t[]>(capacity * numBytesPerElement);
memcpy(buffer.get(), dataVector->valueBuffer.get(), size * numBytesPerElement);
dataVector->valueBuffer = std::move(buffer);
dataVector->nullMask->resize(capacity);
resizeDataVector(dataVector.get());
}
size += listSize;
return listEntry;
}

void ListAuxiliaryBuffer::resizeDataVector(ValueVector* dataVector) {
// If the dataVector is a struct vector, we need to resize its field vectors.
if (dataVector->dataType.getPhysicalType() == PhysicalTypeID::STRUCT) {
auto fieldVectors = StructVector::getFieldVectors(dataVector);
for (auto& fieldVector : fieldVectors) {
resizeDataVector(fieldVector.get());

Check warning on line 49 in src/common/vector/auxiliary_buffer.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/vector/auxiliary_buffer.cpp#L47-L49

Added lines #L47 - L49 were not covered by tests
}
} else {

Check warning on line 51 in src/common/vector/auxiliary_buffer.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/vector/auxiliary_buffer.cpp#L51

Added line #L51 was not covered by tests
auto buffer = std::make_unique<uint8_t[]>(capacity * dataVector->getNumBytesPerValue());
memcpy(
buffer.get(), dataVector->valueBuffer.get(), size * dataVector->getNumBytesPerValue());
dataVector->valueBuffer = std::move(buffer);
dataVector->nullMask->resize(capacity);
}
}

std::unique_ptr<AuxiliaryBuffer> AuxiliaryBufferFactory::getAuxiliaryBuffer(
LogicalType& type, storage::MemoryManager* memoryManager) {
switch (type.getLogicalTypeID()) {
case LogicalTypeID::STRING:
switch (type.getPhysicalType()) {
case PhysicalTypeID::STRING:
return std::make_unique<StringAuxiliaryBuffer>(memoryManager);
case LogicalTypeID::STRUCT:
case PhysicalTypeID::STRUCT:
return std::make_unique<StructAuxiliaryBuffer>(type, memoryManager);
case LogicalTypeID::RECURSIVE_REL:
return std::make_unique<ListAuxiliaryBuffer>(
common::LogicalType(common::LogicalTypeID::INTERNAL_ID), memoryManager);
case LogicalTypeID::VAR_LIST:
case PhysicalTypeID::VAR_LIST:
return std::make_unique<ListAuxiliaryBuffer>(
*VarListType::getChildType(&type), memoryManager);
case LogicalTypeID::ARROW_COLUMN:
case PhysicalTypeID::ARROW_COLUMN:
return std::make_unique<ArrowColumnAuxiliaryBuffer>();
default:
return nullptr;
Expand Down
21 changes: 10 additions & 11 deletions src/common/vector/value_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ ValueVector::ValueVector(LogicalType dataType, storage::MemoryManager* memoryMan
void ValueVector::setState(std::shared_ptr<DataChunkState> state) {
this->state = state;
if (dataType.getLogicalTypeID() == LogicalTypeID::STRUCT) {
auto childrenVectors = StructVector::getChildrenVectors(this);
auto childrenVectors = StructVector::getFieldVectors(this);
for (auto& childVector : childrenVectors) {
childVector->setState(state);
}
Expand Down Expand Up @@ -57,12 +57,12 @@ void ValueVector::setValue(uint32_t pos, std::string val) {
}

void ValueVector::resetAuxiliaryBuffer() {
switch (dataType.getLogicalTypeID()) {
case LogicalTypeID::STRING: {
switch (dataType.getPhysicalType()) {
case PhysicalTypeID::STRING: {
reinterpret_cast<StringAuxiliaryBuffer*>(auxiliaryBuffer.get())->resetOverflowBuffer();
return;
}
case LogicalTypeID::VAR_LIST: {
case PhysicalTypeID::VAR_LIST: {
reinterpret_cast<ListAuxiliaryBuffer*>(auxiliaryBuffer.get())->resetSize();
return;
}
Expand All @@ -72,22 +72,21 @@ void ValueVector::resetAuxiliaryBuffer() {
}

uint32_t ValueVector::getDataTypeSize(const LogicalType& type) {
switch (type.getLogicalTypeID()) {
case common::LogicalTypeID::STRING: {
switch (type.getPhysicalType()) {
case PhysicalTypeID::STRING: {
return sizeof(common::ku_string_t);
}
case common::LogicalTypeID::FIXED_LIST: {
case PhysicalTypeID::FIXED_LIST: {
return getDataTypeSize(*common::FixedListType::getChildType(&type)) *
common::FixedListType::getNumElementsInList(&type);
}
case LogicalTypeID::STRUCT: {
case PhysicalTypeID::STRUCT: {
return sizeof(struct_entry_t);
}
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::VAR_LIST: {
case PhysicalTypeID::VAR_LIST: {
return sizeof(list_entry_t);
}
case LogicalTypeID::ARROW_COLUMN: {
case PhysicalTypeID::ARROW_COLUMN: {
return 0;
}
default: {
Expand Down
35 changes: 16 additions & 19 deletions src/common/vector/value_vector_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ using namespace common;

void ValueVectorUtils::copyNonNullDataWithSameTypeIntoPos(
ValueVector& resultVector, uint64_t pos, const uint8_t* srcData) {
switch (resultVector.dataType.getLogicalTypeID()) {
case LogicalTypeID::STRUCT: {
auto structFields = StructVector::getChildrenVectors(&resultVector);
switch (resultVector.dataType.getPhysicalType()) {
case PhysicalTypeID::STRUCT: {
auto structFields = StructVector::getFieldVectors(&resultVector);
auto structNullBytes = srcData;
auto structValues =
structNullBytes + NullBuffer::getNumBytesForNullValues(structFields.size());
Expand All @@ -25,8 +25,7 @@ void ValueVectorUtils::copyNonNullDataWithSameTypeIntoPos(
structValues += processor::FactorizedTable::getDataTypeSize(structField->dataType);
}
} break;
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::VAR_LIST: {
case PhysicalTypeID::VAR_LIST: {
auto srcKuList = *(ku_list_t*)srcData;
auto srcNullBytes = reinterpret_cast<uint8_t*>(srcKuList.overflowPtr);
auto srcListValues = srcNullBytes + NullBuffer::getNumBytesForNullValues(srcKuList.size);
Expand All @@ -46,7 +45,7 @@ void ValueVectorUtils::copyNonNullDataWithSameTypeIntoPos(
srcListValues += numBytesPerValue;
}
} break;
case LogicalTypeID::STRING: {
case PhysicalTypeID::STRING: {
auto dstData = resultVector.getData() +
pos * processor::FactorizedTable::getDataTypeSize(resultVector.dataType);
InMemOverflowBufferUtils::copyString(*(ku_string_t*)srcData, *(ku_string_t*)dstData,
Expand All @@ -61,11 +60,11 @@ void ValueVectorUtils::copyNonNullDataWithSameTypeIntoPos(

void ValueVectorUtils::copyNonNullDataWithSameTypeOutFromPos(const ValueVector& srcVector,
uint64_t pos, uint8_t* dstData, InMemOverflowBuffer& dstOverflowBuffer) {
switch (srcVector.dataType.getLogicalTypeID()) {
case LogicalTypeID::STRUCT: {
switch (srcVector.dataType.getPhysicalType()) {
case PhysicalTypeID::STRUCT: {
// The storage structure of STRUCT type in factorizedTable is:
// [NULLBYTES, FIELD1, FIELD2, ...]
auto structFields = StructVector::getChildrenVectors(&srcVector);
auto structFields = StructVector::getFieldVectors(&srcVector);
NullBuffer::initNullBytes(dstData, structFields.size());
auto structNullBytes = dstData;
auto structValues =
Expand All @@ -81,8 +80,7 @@ void ValueVectorUtils::copyNonNullDataWithSameTypeOutFromPos(const ValueVector&
structValues += processor::FactorizedTable::getDataTypeSize(structField->dataType);
}
} break;
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::VAR_LIST: {
case PhysicalTypeID::VAR_LIST: {
auto srcListEntry = srcVector.getValue<list_entry_t>(pos);
auto srcListDataVector = common::ListVector::getDataVector(&srcVector);
ku_list_t dstList;
Expand All @@ -108,7 +106,7 @@ void ValueVectorUtils::copyNonNullDataWithSameTypeOutFromPos(const ValueVector&
}
memcpy(dstData, &dstList, sizeof(dstList));
} break;
case LogicalTypeID::STRING: {
case PhysicalTypeID::STRING: {
auto srcData = srcVector.getData() +
pos * processor::FactorizedTable::getDataTypeSize(srcVector.dataType);
InMemOverflowBufferUtils::copyString(
Expand All @@ -123,9 +121,8 @@ void ValueVectorUtils::copyNonNullDataWithSameTypeOutFromPos(const ValueVector&

void ValueVectorUtils::copyValue(uint8_t* dstValue, common::ValueVector& dstVector,
const uint8_t* srcValue, const common::ValueVector& srcVector) {
switch (srcVector.dataType.getLogicalTypeID()) {
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::VAR_LIST: {
switch (srcVector.dataType.getPhysicalType()) {
case PhysicalTypeID::VAR_LIST: {
auto srcList = reinterpret_cast<const common::list_entry_t*>(srcValue);
auto dstList = reinterpret_cast<common::list_entry_t*>(dstValue);
*dstList = ListVector::addList(&dstVector, srcList->size);
Expand All @@ -144,9 +141,9 @@ void ValueVectorUtils::copyValue(uint8_t* dstValue, common::ValueVector& dstVect
dstValues += numBytesPerValue;
}
} break;
case LogicalTypeID::STRUCT: {
auto srcFields = common::StructVector::getChildrenVectors(&srcVector);
auto dstFields = common::StructVector::getChildrenVectors(&dstVector);
case PhysicalTypeID::STRUCT: {
auto srcFields = common::StructVector::getFieldVectors(&srcVector);
auto dstFields = common::StructVector::getFieldVectors(&dstVector);
auto srcPos = *(int64_t*)srcValue;
auto dstPos = *(int64_t*)dstValue;
for (auto i = 0u; i < srcFields.size(); i++) {
Expand All @@ -160,7 +157,7 @@ void ValueVectorUtils::copyValue(uint8_t* dstValue, common::ValueVector& dstVect
}
}
} break;
case LogicalTypeID::STRING: {
case PhysicalTypeID::STRING: {
common::InMemOverflowBufferUtils::copyString(*(common::ku_string_t*)srcValue,
*(common::ku_string_t*)dstValue, *StringVector::getInMemOverflowBuffer(&dstVector));
} break;
Expand Down
3 changes: 2 additions & 1 deletion src/function/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ add_library(kuzu_function
vector_null_operations.cpp
vector_string_operations.cpp
vector_timestamp_operations.cpp
vector_struct_operations.cpp)
vector_struct_operations.cpp
vector_map_operation.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_function>
Expand Down
10 changes: 8 additions & 2 deletions src/function/built_in_vector_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "function/date/vector_date_operations.h"
#include "function/interval/vector_interval_operations.h"
#include "function/list/vector_list_operations.h"
#include "function/map/vector_map_operations.h"
#include "function/schema/vector_offset_operations.h"
#include "function/string/vector_string_operations.h"
#include "function/struct/vector_struct_operations.h"
Expand All @@ -25,7 +26,8 @@ void BuiltInVectorOperations::registerVectorOperations() {
registerStringOperations();
registerCastOperations();
registerListOperations();
registerStructOperation();
registerStructOperations();
registerMapOperations();
// register internal offset operation
vectorOperations.insert({OFFSET_FUNC_NAME, OffsetVectorOperation::getDefinitions()});
}
Expand Down Expand Up @@ -478,11 +480,15 @@ void BuiltInVectorOperations::registerListOperations() {
{LIST_ANY_VALUE_FUNC_NAME, ListAnyValueVectorOperation::getDefinitions()});
}

void BuiltInVectorOperations::registerStructOperation() {
void BuiltInVectorOperations::registerStructOperations() {
vectorOperations.insert({STRUCT_PACK_FUNC_NAME, StructPackVectorOperations::getDefinitions()});
vectorOperations.insert(
{STRUCT_EXTRACT_FUNC_NAME, StructExtractVectorOperations::getDefinitions()});
}

void BuiltInVectorOperations::registerMapOperations() {
vectorOperations.insert({MAP_CREATION_FUNC_NAME, MapVectorOperations::getDefinitions()});
}

} // namespace function
} // namespace kuzu
Loading

0 comments on commit efabd0c

Please sign in to comment.