Skip to content

Commit

Permalink
Implement map functions
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin committed Jun 12, 2023
1 parent 139dbe0 commit ff6c3db
Show file tree
Hide file tree
Showing 24 changed files with 671 additions and 280 deletions.
249 changes: 174 additions & 75 deletions src/common/type_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ std::string TypeUtils::toString(const struct_entry_t& val, void* valVector) {
auto fields = StructType::getFields(&structVector->dataType);
for (auto i = 0u; i < fields.size(); ++i) {
auto field = fields[i];
auto fieldVector = StructVector::getChildVector(structVector, i);
auto fieldVector = StructVector::getFieldVector(structVector, i);
auto value = fieldVector->getData() + fieldVector->getNumBytesPerValue() * val.pos;
result += castValueToString(*field->getType(), value, fieldVector.get());
result += (fields.size() - 1 == i ? "}" : ",");
Expand All @@ -109,104 +109,203 @@ bool TypeUtils::isValueEqual(
if (leftVector->dataType != rightVector->dataType || leftEntry.size != rightEntry.size) {
return false;
}
auto leftValues = ListVector::getListValues(leftVector, leftEntry);
auto rightValues = ListVector::getListValues(rightVector, rightEntry);
switch (VarListType::getChildType(&leftVector->dataType)->getLogicalTypeID()) {
case LogicalTypeID::BOOL: {
for (auto i = 0u; i < leftEntry.size; i++) {
if (!isValueEqual(reinterpret_cast<uint8_t*>(leftValues)[i],
reinterpret_cast<uint8_t*>(rightValues)[i], left, right)) {
auto leftDataVector = ListVector::getDataVector(leftVector);
auto rightDataVector = ListVector::getDataVector(rightVector);
for (auto i = 0u; i < leftEntry.size; i++) {
auto leftPos = leftEntry.offset + i;
auto rightPos = rightEntry.offset + i;
if (leftDataVector->isNull(leftPos) && rightDataVector->isNull(rightPos)) {
continue;

Check warning on line 118 in src/common/type_utils.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/type_utils.cpp#L118

Added line #L118 was not covered by tests
} else if (leftDataVector->isNull(leftPos) != rightDataVector->isNull(rightPos)) {
return false;
}
switch (leftDataVector->dataType.getPhysicalType()) {
case PhysicalTypeID::BOOL: {
if (!isValueEqual(leftDataVector->getValue<uint8_t>(leftPos),
rightDataVector->getValue<uint8_t>(rightPos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
}
} break;
case LogicalTypeID::INT64: {
for (auto i = 0u; i < leftEntry.size; i++) {
if (!isValueEqual(reinterpret_cast<int64_t*>(leftValues)[i],
reinterpret_cast<int64_t*>(rightValues)[i], left, right)) {
} break;
case PhysicalTypeID::INT64: {
if (!isValueEqual(leftDataVector->getValue<int64_t>(leftPos),
rightDataVector->getValue<int64_t>(rightPos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
}
} break;
case LogicalTypeID::INT32: {
for (auto i = 0u; i < leftEntry.size; i++) {
if (!isValueEqual(reinterpret_cast<int32_t*>(leftValues)[i],
reinterpret_cast<int32_t*>(rightValues)[i], left, right)) {
} break;
case PhysicalTypeID::INT32: {
if (!isValueEqual(leftDataVector->getValue<int32_t>(leftPos),
rightDataVector->getValue<int32_t>(rightPos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
}
} break;
case LogicalTypeID::INT16: {
for (auto i = 0u; i < leftEntry.size; i++) {
if (!isValueEqual(reinterpret_cast<int16_t*>(leftValues)[i],
reinterpret_cast<int16_t*>(rightValues)[i], left, right)) {
} break;
case PhysicalTypeID::INT16: {
if (!isValueEqual(leftDataVector->getValue<int16_t>(leftPos),

Check warning on line 145 in src/common/type_utils.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/type_utils.cpp#L144-L145

Added lines #L144 - L145 were not covered by tests
rightDataVector->getValue<int16_t>(rightPos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
}
} break;
case LogicalTypeID::DOUBLE: {
for (auto i = 0u; i < leftEntry.size; i++) {
if (!isValueEqual(reinterpret_cast<double_t*>(leftValues)[i],
reinterpret_cast<double_t*>(rightValues)[i], left, right)) {
} break;
case PhysicalTypeID::DOUBLE: {
if (!isValueEqual(leftDataVector->getValue<double_t>(leftPos),
rightDataVector->getValue<double_t>(rightPos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
}
} break;
case LogicalTypeID::FLOAT: {
for (auto i = 0u; i < leftEntry.size; i++) {
if (!isValueEqual(reinterpret_cast<float*>(leftValues)[i],
reinterpret_cast<float*>(rightValues)[i], left, right)) {
} break;
case PhysicalTypeID::FLOAT: {
if (!isValueEqual(leftDataVector->getValue<float_t>(leftPos),

Check warning on line 159 in src/common/type_utils.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/type_utils.cpp#L158-L159

Added lines #L158 - L159 were not covered by tests
rightDataVector->getValue<float_t>(rightPos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
}
} break;
case LogicalTypeID::STRING: {
for (auto i = 0u; i < leftEntry.size; i++) {
if (!isValueEqual(reinterpret_cast<ku_string_t*>(leftValues)[i],
reinterpret_cast<ku_string_t*>(rightValues)[i], left, right)) {
} break;
case PhysicalTypeID::STRING: {
if (!isValueEqual(leftDataVector->getValue<ku_string_t>(leftPos),
rightDataVector->getValue<ku_string_t>(rightPos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
}
} break;
case LogicalTypeID::DATE: {
for (auto i = 0u; i < leftEntry.size; i++) {
if (!isValueEqual(reinterpret_cast<date_t*>(leftValues)[i],
reinterpret_cast<date_t*>(rightValues)[i], left, right)) {
} break;
case PhysicalTypeID::INTERVAL: {
if (!isValueEqual(leftDataVector->getValue<interval_t>(leftPos),
rightDataVector->getValue<interval_t>(rightPos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
}
} break;
case LogicalTypeID::TIMESTAMP: {
for (auto i = 0u; i < leftEntry.size; i++) {
if (!isValueEqual(reinterpret_cast<timestamp_t*>(leftValues)[i],
reinterpret_cast<timestamp_t*>(rightValues)[i], left, right)) {
} break;
case PhysicalTypeID::INTERNAL_ID: {
if (!isValueEqual(leftDataVector->getValue<internalID_t>(leftPos),

Check warning on line 180 in src/common/type_utils.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/type_utils.cpp#L179-L180

Added lines #L179 - L180 were not covered by tests
rightDataVector->getValue<internalID_t>(rightPos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
}
} break;
case LogicalTypeID::INTERVAL: {
for (auto i = 0u; i < leftEntry.size; i++) {
if (!isValueEqual(reinterpret_cast<interval_t*>(leftValues)[i],
reinterpret_cast<interval_t*>(rightValues)[i], left, right)) {
} break;
case PhysicalTypeID::VAR_LIST: {
if (!isValueEqual(leftDataVector->getValue<list_entry_t>(leftPos),
rightDataVector->getValue<list_entry_t>(rightPos), leftDataVector,
rightDataVector)) {
return false;
}
}
} break;
case LogicalTypeID::VAR_LIST: {
for (auto i = 0u; i < leftEntry.size; i++) {
if (!isValueEqual(reinterpret_cast<list_entry_t*>(leftValues)[i],
reinterpret_cast<list_entry_t*>(rightValues)[i],
ListVector::getDataVector(leftVector),
ListVector::getDataVector(rightVector))) {
} break;
case PhysicalTypeID::STRUCT: {
if (!isValueEqual(leftDataVector->getValue<struct_entry_t>(leftPos),

Check warning on line 194 in src/common/type_utils.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/type_utils.cpp#L193-L194

Added lines #L193 - L194 were not covered by tests
rightDataVector->getValue<struct_entry_t>(rightPos), leftDataVector,
rightDataVector)) {
return false;
}
} break;
default: {
throw NotImplementedException("TypeUtils::isValueEqual");

Check warning on line 201 in src/common/type_utils.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/type_utils.cpp#L200-L201

Added lines #L200 - L201 were not covered by tests
}
}
} break;
default: {
throw RuntimeException("Unsupported data type " +
LogicalTypeUtils::dataTypeToString(leftVector->dataType) +
" for TypeUtils::isValueEqual.");
}
return true;
}

template<>
bool TypeUtils::isValueEqual(common::struct_entry_t& leftEntry, common::struct_entry_t& rightEntry,
void* left, void* right) {
auto leftVector = (ValueVector*)left;
auto rightVector = (ValueVector*)right;
if (leftVector->dataType != rightVector->dataType) {
return false;
}
auto leftStructFields = common::StructVector::getFieldVectors(leftVector);
auto rightStructFields = common::StructVector::getFieldVectors(rightVector);
for (auto i = 0u; i < leftStructFields.size(); i++) {
auto leftStructField = leftStructFields[i];
auto rightStructField = rightStructFields[i];
if (leftStructField->isNull(leftEntry.pos) && rightStructField->isNull(rightEntry.pos)) {
continue;
} else if (leftStructField->isNull(leftEntry.pos) !=
rightStructField->isNull(rightEntry.pos)) {
return false;
}
switch (leftStructFields[i]->dataType.getPhysicalType()) {
case PhysicalTypeID::BOOL: {
if (!isValueEqual(leftStructField->getValue<uint8_t>(leftEntry.pos),

Check warning on line 229 in src/common/type_utils.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/type_utils.cpp#L229

Added line #L229 was not covered by tests
rightStructField->getValue<uint8_t>(rightEntry.pos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
} break;
case PhysicalTypeID::INT64: {
if (!isValueEqual(leftStructField->getValue<int64_t>(leftEntry.pos),
rightStructField->getValue<int64_t>(rightEntry.pos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
} break;
case PhysicalTypeID::INT32: {
if (!isValueEqual(leftStructField->getValue<int32_t>(leftEntry.pos),

Check warning on line 243 in src/common/type_utils.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/type_utils.cpp#L243

Added line #L243 was not covered by tests
rightStructField->getValue<int32_t>(rightEntry.pos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
} break;
case PhysicalTypeID::INT16: {
if (!isValueEqual(leftStructField->getValue<int16_t>(leftEntry.pos),

Check warning on line 250 in src/common/type_utils.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/type_utils.cpp#L250

Added line #L250 was not covered by tests
rightStructField->getValue<int16_t>(rightEntry.pos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
} break;
case PhysicalTypeID::DOUBLE: {
if (!isValueEqual(leftStructField->getValue<double_t>(leftEntry.pos),

Check warning on line 257 in src/common/type_utils.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/type_utils.cpp#L257

Added line #L257 was not covered by tests
rightStructField->getValue<double_t>(rightEntry.pos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
} break;
case PhysicalTypeID::FLOAT: {
if (!isValueEqual(leftStructField->getValue<float_t>(leftEntry.pos),

Check warning on line 264 in src/common/type_utils.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/type_utils.cpp#L264

Added line #L264 was not covered by tests
rightStructField->getValue<float_t>(rightEntry.pos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
} break;
case PhysicalTypeID::STRING: {
if (!isValueEqual(leftStructField->getValue<ku_string_t>(leftEntry.pos),

Check warning on line 271 in src/common/type_utils.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/type_utils.cpp#L271

Added line #L271 was not covered by tests
rightStructField->getValue<ku_string_t>(rightEntry.pos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
} break;
case PhysicalTypeID::INTERVAL: {
if (!isValueEqual(leftStructField->getValue<interval_t>(leftEntry.pos),

Check warning on line 278 in src/common/type_utils.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/type_utils.cpp#L278

Added line #L278 was not covered by tests
rightStructField->getValue<interval_t>(rightEntry.pos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
} break;
case PhysicalTypeID::INTERNAL_ID: {
if (!isValueEqual(leftStructField->getValue<internalID_t>(leftEntry.pos),

Check warning on line 285 in src/common/type_utils.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/type_utils.cpp#L285

Added line #L285 was not covered by tests
rightStructField->getValue<internalID_t>(rightEntry.pos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
} break;
case PhysicalTypeID::VAR_LIST: {
if (!isValueEqual(leftStructField->getValue<list_entry_t>(leftEntry.pos),
rightStructField->getValue<list_entry_t>(rightEntry.pos), leftStructField.get(),
rightStructField.get())) {
return false;
}
} break;
case PhysicalTypeID::STRUCT: {
if (!isValueEqual(leftStructField->getValue<struct_entry_t>(leftEntry.pos),
rightStructField->getValue<struct_entry_t>(rightEntry.pos),
leftStructField.get(), rightStructField.get())) {
return false;
}
} break;
default: {
throw NotImplementedException("TypeUtils::isValueEqual");

Check warning on line 306 in src/common/type_utils.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/type_utils.cpp#L305-L306

Added lines #L305 - L306 were not covered by tests
}
}
}
return true;
}
Expand Down
2 changes: 1 addition & 1 deletion src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ StructField::StructField(std::string name, std::unique_ptr<LogicalType> type)
}

bool StructField::operator==(const kuzu::common::StructField& other) const {
return name == other.name && *type == *other.type;
return *type == *other.type;
}

std::unique_ptr<StructField> StructField::copy() const {
Expand Down
9 changes: 4 additions & 5 deletions src/common/vector/value_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ void ValueVector::resetAuxiliaryBuffer() {
uint32_t ValueVector::getDataTypeSize(const LogicalType& type) {
switch (type.getPhysicalType()) {
case PhysicalTypeID::STRING: {
return sizeof(common::ku_string_t);
return sizeof(ku_string_t);
}
case PhysicalTypeID::FIXED_LIST: {
return getDataTypeSize(*common::FixedListType::getChildType(&type)) *
common::FixedListType::getNumElementsInList(&type);
return getDataTypeSize(*FixedListType::getChildType(&type)) *
FixedListType::getNumElementsInList(&type);
}
case PhysicalTypeID::STRUCT: {
return sizeof(struct_entry_t);
Expand All @@ -104,8 +104,7 @@ void ValueVector::initializeValueBuffer() {
}
}

void ArrowColumnVector::setArrowColumn(
kuzu::common::ValueVector* vector, std::shared_ptr<arrow::Array> column) {
void ArrowColumnVector::setArrowColumn(ValueVector* vector, std::shared_ptr<arrow::Array> column) {
auto arrowColumnBuffer =
reinterpret_cast<ArrowColumnAuxiliaryBuffer*>(vector->auxiliaryBuffer.get());
arrowColumnBuffer->column = std::move(column);
Expand Down
2 changes: 1 addition & 1 deletion src/expression_evaluator/function_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ void FunctionExpressionEvaluator::resolveResultVector(
if (functionExpression.getFunctionName() == STRUCT_EXTRACT_FUNC_NAME) {
auto& bindData = (function::StructExtractBindData&)*functionExpression.getBindData();
resultVector =
StructVector::getChildVector(children[0]->resultVector.get(), bindData.childIdx);
StructVector::getFieldVector(children[0]->resultVector.get(), bindData.childIdx);
} else {
resultVector = std::make_shared<ValueVector>(expression->dataType, memoryManager);
}
Expand Down
8 changes: 7 additions & 1 deletion src/function/built_in_vector_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,13 @@ void BuiltInVectorOperations::registerStructOperations() {
}

void BuiltInVectorOperations::registerMapOperations() {
vectorOperations.insert({MAP_CREATION_FUNC_NAME, MapVectorOperations::getDefinitions()});
vectorOperations.insert(
{MAP_CREATION_FUNC_NAME, MapCreationVectorOperations::getDefinitions()});
vectorOperations.insert({MAP_EXTRACT_FUNC_NAME, MapExtractVectorOperations::getDefinitions()});
vectorOperations.insert({ELEMENT_AT_FUNC_NAME, MapExtractVectorOperations::getDefinitions()});
vectorOperations.insert({CARDINALITY_FUNC_NAME, ListLenVectorOperation::getDefinitions()});
vectorOperations.insert({MAP_KEYS_FUNC_NAME, MapKeysVectorOperations::getDefinitions()});
vectorOperations.insert({MAP_VALUES_FUNC_NAME, MapValuesVectorOperations::getDefinitions()});
}

} // namespace function
Expand Down
Loading

0 comments on commit ff6c3db

Please sign in to comment.