Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement map functions #1660

Merged
merged 1 commit into from
Jun 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 174 additions & 75 deletions src/common/type_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
auto fields = StructType::getFields(&structVector->dataType);
for (auto i = 0u; i < fields.size(); ++i) {
auto field = fields[i];
auto fieldVector = StructVector::getChildVector(structVector, i);
auto fieldVector = StructVector::getFieldVector(structVector, i);
auto value = fieldVector->getData() + fieldVector->getNumBytesPerValue() * val.pos;
result += castValueToString(*field->getType(), value, fieldVector.get());
result += (fields.size() - 1 == i ? "}" : ",");
Expand All @@ -109,104 +109,203 @@
if (leftVector->dataType != rightVector->dataType || leftEntry.size != rightEntry.size) {
return false;
}
auto leftValues = ListVector::getListValues(leftVector, leftEntry);
auto rightValues = ListVector::getListValues(rightVector, rightEntry);
switch (VarListType::getChildType(&leftVector->dataType)->getLogicalTypeID()) {
case LogicalTypeID::BOOL: {
for (auto i = 0u; i < leftEntry.size; i++) {
if (!isValueEqual(reinterpret_cast<uint8_t*>(leftValues)[i],
reinterpret_cast<uint8_t*>(rightValues)[i], left, right)) {
auto leftDataVector = ListVector::getDataVector(leftVector);
auto rightDataVector = ListVector::getDataVector(rightVector);
for (auto i = 0u; i < leftEntry.size; i++) {
auto leftPos = leftEntry.offset + i;
auto rightPos = rightEntry.offset + i;
if (leftDataVector->isNull(leftPos) && rightDataVector->isNull(rightPos)) {
continue;

Check warning on line 118 in src/common/type_utils.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/type_utils.cpp#L118

Added line #L118 was not covered by tests
} else if (leftDataVector->isNull(leftPos) != rightDataVector->isNull(rightPos)) {
return false;
}
switch (leftDataVector->dataType.getPhysicalType()) {
case PhysicalTypeID::BOOL: {
if (!isValueEqual(leftDataVector->getValue<uint8_t>(leftPos),
rightDataVector->getValue<uint8_t>(rightPos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
}
} break;
case LogicalTypeID::INT64: {
for (auto i = 0u; i < leftEntry.size; i++) {
if (!isValueEqual(reinterpret_cast<int64_t*>(leftValues)[i],
reinterpret_cast<int64_t*>(rightValues)[i], left, right)) {
} break;
case PhysicalTypeID::INT64: {
if (!isValueEqual(leftDataVector->getValue<int64_t>(leftPos),
rightDataVector->getValue<int64_t>(rightPos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
}
} break;
case LogicalTypeID::INT32: {
for (auto i = 0u; i < leftEntry.size; i++) {
if (!isValueEqual(reinterpret_cast<int32_t*>(leftValues)[i],
reinterpret_cast<int32_t*>(rightValues)[i], left, right)) {
} break;
case PhysicalTypeID::INT32: {
if (!isValueEqual(leftDataVector->getValue<int32_t>(leftPos),
rightDataVector->getValue<int32_t>(rightPos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
}
} break;
case LogicalTypeID::INT16: {
for (auto i = 0u; i < leftEntry.size; i++) {
if (!isValueEqual(reinterpret_cast<int16_t*>(leftValues)[i],
reinterpret_cast<int16_t*>(rightValues)[i], left, right)) {
} break;
case PhysicalTypeID::INT16: {
if (!isValueEqual(leftDataVector->getValue<int16_t>(leftPos),

Check warning on line 145 in src/common/type_utils.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/type_utils.cpp#L144-L145

Added lines #L144 - L145 were not covered by tests
rightDataVector->getValue<int16_t>(rightPos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
}
} break;
case LogicalTypeID::DOUBLE: {
for (auto i = 0u; i < leftEntry.size; i++) {
if (!isValueEqual(reinterpret_cast<double_t*>(leftValues)[i],
reinterpret_cast<double_t*>(rightValues)[i], left, right)) {
} break;
case PhysicalTypeID::DOUBLE: {
if (!isValueEqual(leftDataVector->getValue<double_t>(leftPos),
rightDataVector->getValue<double_t>(rightPos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
}
} break;
case LogicalTypeID::FLOAT: {
for (auto i = 0u; i < leftEntry.size; i++) {
if (!isValueEqual(reinterpret_cast<float*>(leftValues)[i],
reinterpret_cast<float*>(rightValues)[i], left, right)) {
} break;
case PhysicalTypeID::FLOAT: {
if (!isValueEqual(leftDataVector->getValue<float_t>(leftPos),

Check warning on line 159 in src/common/type_utils.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/type_utils.cpp#L158-L159

Added lines #L158 - L159 were not covered by tests
rightDataVector->getValue<float_t>(rightPos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
}
} break;
case LogicalTypeID::STRING: {
for (auto i = 0u; i < leftEntry.size; i++) {
if (!isValueEqual(reinterpret_cast<ku_string_t*>(leftValues)[i],
reinterpret_cast<ku_string_t*>(rightValues)[i], left, right)) {
} break;
case PhysicalTypeID::STRING: {
if (!isValueEqual(leftDataVector->getValue<ku_string_t>(leftPos),
rightDataVector->getValue<ku_string_t>(rightPos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
}
} break;
case LogicalTypeID::DATE: {
for (auto i = 0u; i < leftEntry.size; i++) {
if (!isValueEqual(reinterpret_cast<date_t*>(leftValues)[i],
reinterpret_cast<date_t*>(rightValues)[i], left, right)) {
} break;
case PhysicalTypeID::INTERVAL: {
if (!isValueEqual(leftDataVector->getValue<interval_t>(leftPos),
rightDataVector->getValue<interval_t>(rightPos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
}
} break;
case LogicalTypeID::TIMESTAMP: {
for (auto i = 0u; i < leftEntry.size; i++) {
if (!isValueEqual(reinterpret_cast<timestamp_t*>(leftValues)[i],
reinterpret_cast<timestamp_t*>(rightValues)[i], left, right)) {
} break;
case PhysicalTypeID::INTERNAL_ID: {
if (!isValueEqual(leftDataVector->getValue<internalID_t>(leftPos),

Check warning on line 180 in src/common/type_utils.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/type_utils.cpp#L179-L180

Added lines #L179 - L180 were not covered by tests
rightDataVector->getValue<internalID_t>(rightPos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
}
} break;
case LogicalTypeID::INTERVAL: {
for (auto i = 0u; i < leftEntry.size; i++) {
if (!isValueEqual(reinterpret_cast<interval_t*>(leftValues)[i],
reinterpret_cast<interval_t*>(rightValues)[i], left, right)) {
} break;
case PhysicalTypeID::VAR_LIST: {
if (!isValueEqual(leftDataVector->getValue<list_entry_t>(leftPos),
rightDataVector->getValue<list_entry_t>(rightPos), leftDataVector,
rightDataVector)) {
return false;
}
}
} break;
case LogicalTypeID::VAR_LIST: {
for (auto i = 0u; i < leftEntry.size; i++) {
if (!isValueEqual(reinterpret_cast<list_entry_t*>(leftValues)[i],
reinterpret_cast<list_entry_t*>(rightValues)[i],
ListVector::getDataVector(leftVector),
ListVector::getDataVector(rightVector))) {
} break;
case PhysicalTypeID::STRUCT: {
if (!isValueEqual(leftDataVector->getValue<struct_entry_t>(leftPos),

Check warning on line 194 in src/common/type_utils.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/type_utils.cpp#L193-L194

Added lines #L193 - L194 were not covered by tests
rightDataVector->getValue<struct_entry_t>(rightPos), leftDataVector,
rightDataVector)) {
return false;
}
} break;
default: {
throw NotImplementedException("TypeUtils::isValueEqual");

Check warning on line 201 in src/common/type_utils.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/type_utils.cpp#L200-L201

Added lines #L200 - L201 were not covered by tests
}
}
} break;
default: {
throw RuntimeException("Unsupported data type " +
LogicalTypeUtils::dataTypeToString(leftVector->dataType) +
" for TypeUtils::isValueEqual.");
}
return true;
}

template<>
bool TypeUtils::isValueEqual(common::struct_entry_t& leftEntry, common::struct_entry_t& rightEntry,
void* left, void* right) {
auto leftVector = (ValueVector*)left;
auto rightVector = (ValueVector*)right;
if (leftVector->dataType != rightVector->dataType) {
return false;
}
auto leftStructFields = common::StructVector::getFieldVectors(leftVector);
auto rightStructFields = common::StructVector::getFieldVectors(rightVector);
for (auto i = 0u; i < leftStructFields.size(); i++) {
auto leftStructField = leftStructFields[i];
auto rightStructField = rightStructFields[i];
if (leftStructField->isNull(leftEntry.pos) && rightStructField->isNull(rightEntry.pos)) {
continue;
} else if (leftStructField->isNull(leftEntry.pos) !=
rightStructField->isNull(rightEntry.pos)) {
return false;
}
switch (leftStructFields[i]->dataType.getPhysicalType()) {
case PhysicalTypeID::BOOL: {
if (!isValueEqual(leftStructField->getValue<uint8_t>(leftEntry.pos),

Check warning on line 229 in src/common/type_utils.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/type_utils.cpp#L229

Added line #L229 was not covered by tests
rightStructField->getValue<uint8_t>(rightEntry.pos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
} break;
case PhysicalTypeID::INT64: {
if (!isValueEqual(leftStructField->getValue<int64_t>(leftEntry.pos),
rightStructField->getValue<int64_t>(rightEntry.pos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
} break;
case PhysicalTypeID::INT32: {
if (!isValueEqual(leftStructField->getValue<int32_t>(leftEntry.pos),

Check warning on line 243 in src/common/type_utils.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/type_utils.cpp#L243

Added line #L243 was not covered by tests
rightStructField->getValue<int32_t>(rightEntry.pos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
} break;
case PhysicalTypeID::INT16: {
if (!isValueEqual(leftStructField->getValue<int16_t>(leftEntry.pos),

Check warning on line 250 in src/common/type_utils.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/type_utils.cpp#L250

Added line #L250 was not covered by tests
rightStructField->getValue<int16_t>(rightEntry.pos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
} break;
case PhysicalTypeID::DOUBLE: {
if (!isValueEqual(leftStructField->getValue<double_t>(leftEntry.pos),

Check warning on line 257 in src/common/type_utils.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/type_utils.cpp#L257

Added line #L257 was not covered by tests
rightStructField->getValue<double_t>(rightEntry.pos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
} break;
case PhysicalTypeID::FLOAT: {
if (!isValueEqual(leftStructField->getValue<float_t>(leftEntry.pos),

Check warning on line 264 in src/common/type_utils.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/type_utils.cpp#L264

Added line #L264 was not covered by tests
rightStructField->getValue<float_t>(rightEntry.pos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
} break;
case PhysicalTypeID::STRING: {
if (!isValueEqual(leftStructField->getValue<ku_string_t>(leftEntry.pos),

Check warning on line 271 in src/common/type_utils.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/type_utils.cpp#L271

Added line #L271 was not covered by tests
rightStructField->getValue<ku_string_t>(rightEntry.pos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
} break;
case PhysicalTypeID::INTERVAL: {
if (!isValueEqual(leftStructField->getValue<interval_t>(leftEntry.pos),

Check warning on line 278 in src/common/type_utils.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/type_utils.cpp#L278

Added line #L278 was not covered by tests
rightStructField->getValue<interval_t>(rightEntry.pos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
} break;
case PhysicalTypeID::INTERNAL_ID: {
if (!isValueEqual(leftStructField->getValue<internalID_t>(leftEntry.pos),

Check warning on line 285 in src/common/type_utils.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/type_utils.cpp#L285

Added line #L285 was not covered by tests
rightStructField->getValue<internalID_t>(rightEntry.pos), nullptr /* left */,
nullptr /* right */)) {
return false;
}
} break;
case PhysicalTypeID::VAR_LIST: {
if (!isValueEqual(leftStructField->getValue<list_entry_t>(leftEntry.pos),
rightStructField->getValue<list_entry_t>(rightEntry.pos), leftStructField.get(),
rightStructField.get())) {
return false;
}
} break;
case PhysicalTypeID::STRUCT: {
if (!isValueEqual(leftStructField->getValue<struct_entry_t>(leftEntry.pos),
rightStructField->getValue<struct_entry_t>(rightEntry.pos),
leftStructField.get(), rightStructField.get())) {
return false;
}
} break;
default: {
throw NotImplementedException("TypeUtils::isValueEqual");

Check warning on line 306 in src/common/type_utils.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/type_utils.cpp#L305-L306

Added lines #L305 - L306 were not covered by tests
}
}
}
return true;
}
Expand Down
2 changes: 1 addition & 1 deletion src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ StructField::StructField(std::string name, std::unique_ptr<LogicalType> type)
}

bool StructField::operator==(const kuzu::common::StructField& other) const {
return name == other.name && *type == *other.type;
return *type == *other.type;
acquamarin marked this conversation as resolved.
Show resolved Hide resolved
}

std::unique_ptr<StructField> StructField::copy() const {
Expand Down
9 changes: 4 additions & 5 deletions src/common/vector/value_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ void ValueVector::resetAuxiliaryBuffer() {
uint32_t ValueVector::getDataTypeSize(const LogicalType& type) {
switch (type.getPhysicalType()) {
case PhysicalTypeID::STRING: {
return sizeof(common::ku_string_t);
return sizeof(ku_string_t);
}
case PhysicalTypeID::FIXED_LIST: {
return getDataTypeSize(*common::FixedListType::getChildType(&type)) *
common::FixedListType::getNumElementsInList(&type);
return getDataTypeSize(*FixedListType::getChildType(&type)) *
FixedListType::getNumElementsInList(&type);
}
case PhysicalTypeID::STRUCT: {
return sizeof(struct_entry_t);
Expand All @@ -104,8 +104,7 @@ void ValueVector::initializeValueBuffer() {
}
}

void ArrowColumnVector::setArrowColumn(
kuzu::common::ValueVector* vector, std::shared_ptr<arrow::Array> column) {
void ArrowColumnVector::setArrowColumn(ValueVector* vector, std::shared_ptr<arrow::Array> column) {
auto arrowColumnBuffer =
reinterpret_cast<ArrowColumnAuxiliaryBuffer*>(vector->auxiliaryBuffer.get());
arrowColumnBuffer->column = std::move(column);
Expand Down
2 changes: 1 addition & 1 deletion src/expression_evaluator/function_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ void FunctionExpressionEvaluator::resolveResultVector(
if (functionExpression.getFunctionName() == STRUCT_EXTRACT_FUNC_NAME) {
auto& bindData = (function::StructExtractBindData&)*functionExpression.getBindData();
resultVector =
StructVector::getChildVector(children[0]->resultVector.get(), bindData.childIdx);
StructVector::getFieldVector(children[0]->resultVector.get(), bindData.childIdx);
} else {
resultVector = std::make_shared<ValueVector>(expression->dataType, memoryManager);
}
Expand Down
8 changes: 7 additions & 1 deletion src/function/built_in_vector_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,13 @@ void BuiltInVectorOperations::registerStructOperations() {
}

void BuiltInVectorOperations::registerMapOperations() {
vectorOperations.insert({MAP_CREATION_FUNC_NAME, MapVectorOperations::getDefinitions()});
vectorOperations.insert(
{MAP_CREATION_FUNC_NAME, MapCreationVectorOperations::getDefinitions()});
vectorOperations.insert({MAP_EXTRACT_FUNC_NAME, MapExtractVectorOperations::getDefinitions()});
vectorOperations.insert({ELEMENT_AT_FUNC_NAME, MapExtractVectorOperations::getDefinitions()});
vectorOperations.insert({CARDINALITY_FUNC_NAME, ListLenVectorOperation::getDefinitions()});
vectorOperations.insert({MAP_KEYS_FUNC_NAME, MapKeysVectorOperations::getDefinitions()});
vectorOperations.insert({MAP_VALUES_FUNC_NAME, MapValuesVectorOperations::getDefinitions()});
}

} // namespace function
Expand Down
Loading
Loading