From 79596995ec27879e92c35c03acbe2ba4e3ab796f Mon Sep 17 00:00:00 2001 From: ziyi chen Date: Tue, 16 May 2023 12:55:16 -0400 Subject: [PATCH] Add physical type --- src/binder/bind/bind_copy.cpp | 7 +- src/binder/bind/bind_ddl.cpp | 38 +- src/binder/bind/bind_graph_pattern.cpp | 19 +- src/binder/bind/bind_projection_clause.cpp | 17 +- src/binder/bind/bind_reading_clause.cpp | 7 +- src/binder/bind/bind_updating_clause.cpp | 12 +- .../bind_boolean_expression.cpp | 4 +- .../bind_expression/bind_case_expression.cpp | 2 +- .../bind_comparison_expression.cpp | 4 +- .../bind_function_expression.cpp | 48 +- .../bind_null_operator_expression.cpp | 3 +- .../bind_property_expression.cpp | 12 +- src/binder/binder.cpp | 8 +- src/binder/expression/expression.cpp | 4 +- src/binder/expression_binder.cpp | 47 +- src/c_api/data_type.cpp | 51 +- src/c_api/query_result.cpp | 2 +- src/c_api/value.cpp | 12 +- src/catalog/catalog.cpp | 9 +- src/common/arrow/arrow_converter.cpp | 28 +- src/common/arrow/arrow_row_batch.cpp | 289 +++++---- src/common/file_utils.cpp | 2 +- src/common/in_mem_overflow_buffer_utils.cpp | 33 - src/common/string_utils.cpp | 7 +- src/common/type_utils.cpp | 93 ++- src/common/types/ku_list.cpp | 10 +- src/common/types/types.cpp | 601 +++++++++--------- src/common/types/value.cpp | 222 +++---- src/common/vector/auxiliary_buffer.cpp | 25 +- src/common/vector/value_vector.cpp | 33 +- src/common/vector/value_vector_utils.cpp | 58 +- src/expression_evaluator/case_evaluator.cpp | 50 +- .../function_evaluator.cpp | 4 +- .../literal_evaluator.cpp | 31 +- .../reference_evaluator.cpp | 2 +- src/function/aggregate_function.cpp | 73 ++- src/function/built_in_aggregate_functions.cpp | 64 +- src/function/built_in_vector_operations.cpp | 102 +-- src/function/vector_arithmetic_operations.cpp | 169 ++--- src/function/vector_boolean_operations.cpp | 10 +- src/function/vector_cast_operations.cpp | 132 ++-- src/function/vector_date_operations.cpp | 56 +- src/function/vector_hash_operations.cpp | 38 +- src/function/vector_list_operation.cpp | 303 +++++---- src/function/vector_string_operations.cpp | 65 +- src/function/vector_struct_operations.cpp | 23 +- src/function/vector_timestamp_operations.cpp | 10 +- src/include/binder/binder.h | 10 +- src/include/binder/ddl/bound_add_property.h | 7 +- .../binder/expression/case_expression.h | 2 +- .../existential_subquery_expression.h | 2 +- src/include/binder/expression/expression.h | 14 +- .../binder/expression/literal_expression.h | 4 +- .../binder/expression/node_expression.h | 4 +- .../binder/expression/node_rel_expression.h | 4 +- .../binder/expression/parameter_expression.h | 6 +- .../binder/expression/property_expression.h | 2 +- .../binder/expression/rel_expression.h | 4 +- .../binder/expression/variable_expression.h | 3 +- src/include/binder/expression_binder.h | 15 +- src/include/c_api/kuzu.h | 1 - src/include/catalog/catalog.h | 2 +- src/include/catalog/catalog_structs.h | 10 +- src/include/common/arrow/arrow_row_batch.h | 8 +- .../common/in_mem_overflow_buffer_utils.h | 32 - src/include/common/ser_deser.h | 4 +- src/include/common/string_utils.h | 3 +- src/include/common/type_utils.h | 6 +- src/include/common/types/ku_list.h | 4 +- src/include/common/types/types.h | 190 ++++-- src/include/common/types/value.h | 68 +- src/include/common/vector/auxiliary_buffer.h | 6 +- src/include/common/vector/value_vector.h | 34 +- .../common/vector/value_vector_utils.h | 2 +- .../function/aggregate/aggregate_function.h | 37 +- .../aggregate/built_in_aggregate_functions.h | 8 +- src/include/function/aggregate/collect.h | 8 +- .../arithmetic/vector_arithmetic_operations.h | 54 +- .../boolean/boolean_operation_executor.h | 8 +- .../function/built_in_vector_operations.h | 31 +- .../function/cast/vector_cast_operations.h | 51 +- .../comparison/vector_comparison_operations.h | 134 ++-- src/include/function/function_definition.h | 20 +- src/include/function/hash/hash_operations.h | 10 - .../interval/vector_interval_operations.h | 3 +- .../list/operations/list_position_operation.h | 4 +- .../function/list/vector_list_operations.h | 30 +- .../function/null/null_operation_executor.h | 2 +- .../schema/vector_offset_operations.h | 4 +- .../string/vector_string_operations.h | 3 +- .../struct/vector_struct_operations.h | 2 +- src/include/function/vector_operations.h | 16 +- src/include/main/query_result.h | 10 +- .../logical_operator/logical_add_property.h | 6 +- .../operator/aggregate/aggregate_hash_table.h | 16 +- .../operator/ddl/add_node_property.h | 2 +- .../processor/operator/ddl/add_property.h | 4 +- .../processor/operator/ddl/add_rel_property.h | 2 +- .../operator/hash_join/hash_join_build.h | 8 +- .../operator/order_by/key_block_merger.h | 3 +- .../processor/operator/order_by/order_by.h | 8 +- .../operator/order_by/order_by_key_encoder.h | 8 +- .../processor/operator/result_collector.h | 4 +- src/include/processor/operator/unwind.h | 4 +- .../processor/result/factorized_table.h | 5 +- src/include/storage/copier/npy_reader.h | 7 +- .../storage/copier/rel_copy_executor.h | 8 +- .../storage/copier/table_copy_executor.h | 8 +- .../in_mem_storage_structure/in_mem_column.h | 7 +- .../in_mem_column_chunk.h | 10 +- .../in_mem_storage_structure/in_mem_lists.h | 29 +- src/include/storage/index/hash_index.h | 47 +- .../storage/index/hash_index_builder.h | 38 +- src/include/storage/index/hash_index_header.h | 15 +- src/include/storage/index/hash_index_utils.h | 16 +- .../storage/storage_structure/column.h | 71 ++- .../storage_structure/disk_overflow_file.h | 11 +- .../storage/storage_structure/in_mem_file.h | 8 +- .../storage/storage_structure/lists/lists.h | 56 +- .../lists/lists_update_store.h | 4 +- .../storage_structure/storage_structure.h | 6 +- src/include/storage/storage_utils.h | 2 + src/main/connection.cpp | 11 +- src/main/query_result.cpp | 77 +-- src/planner/projection_planner.cpp | 3 +- src/planner/query_planner.cpp | 8 +- src/processor/mapper/map_expressions_scan.cpp | 2 +- src/processor/mapper/map_hash_join.cpp | 5 +- src/processor/mapper/map_order_by.cpp | 4 +- src/processor/mapper/map_unwind.cpp | 3 +- src/processor/mapper/plan_mapper.cpp | 2 +- .../aggregate/aggregate_hash_table.cpp | 59 +- .../operator/aggregate/hash_aggregate.cpp | 4 +- .../operator/aggregate/simple_aggregate.cpp | 2 +- .../operator/hash_join/hash_join_build.cpp | 8 +- .../operator/hash_join/hash_join_probe.cpp | 7 +- src/processor/operator/order_by/order_by.cpp | 5 +- .../order_by/order_by_key_encoder.cpp | 41 +- .../recursive_extend/recursive_join.cpp | 12 +- src/processor/operator/result_collector.cpp | 2 +- src/processor/operator/semi_masker.cpp | 2 +- src/processor/operator/update/create.cpp | 2 +- src/processor/processor.cpp | 7 +- src/processor/result/factorized_table.cpp | 55 +- src/storage/copier/node_copier.cpp | 8 +- src/storage/copier/node_copy_executor.cpp | 2 +- src/storage/copier/npy_reader.cpp | 27 +- src/storage/copier/rel_copy_executor.cpp | 130 ++-- src/storage/copier/table_copy_executor.cpp | 115 ++-- .../in_mem_column.cpp | 23 +- .../in_mem_column_chunk.cpp | 68 +- .../in_mem_storage_structure/in_mem_lists.cpp | 76 +-- src/storage/index/hash_index.cpp | 38 +- src/storage/index/hash_index_builder.cpp | 8 +- src/storage/index/hash_index_utils.cpp | 36 +- src/storage/storage_structure/column.cpp | 7 +- .../storage_structure/disk_overflow_file.cpp | 39 +- src/storage/storage_structure/in_mem_file.cpp | 61 +- .../lists/lists_update_store.cpp | 7 +- .../storage_structure/storage_structure.cpp | 2 +- src/storage/storage_utils.cpp | 28 + src/storage/store/node_table.cpp | 4 +- src/storage/store/rel_table.cpp | 4 +- src/storage/wal_replayer_utils.cpp | 18 +- test/c_api/data_type_test.cpp | 81 +-- test/c_api/flat_tuple_test.cpp | 6 +- test/c_api/query_result_test.cpp | 12 +- test/c_api/value_test.cpp | 59 +- test/copy/copy_lists_test.cpp | 3 +- test/copy/npy_reader_test.cpp | 22 +- test/graph_test/graph_test.cpp | 7 +- test/include/graph_test/graph_test.h | 4 +- .../order_by/key_block_merger_test.cpp | 104 +-- .../order_by/order_by_key_encoder_test.cpp | 60 +- test/processor/order_by/radix_sort_test.cpp | 88 +-- test/runner/e2e_ddl_test.cpp | 14 +- test/storage/node_insertion_deletion_test.cpp | 6 +- test/transaction/transaction_test.cpp | 11 +- .../include/py_query_result_converter.h | 6 +- tools/python_api/src_cpp/py_query_result.cpp | 41 +- .../src_cpp/py_query_result_converter.cpp | 49 +- 181 files changed, 3042 insertions(+), 2687 deletions(-) diff --git a/src/binder/bind/bind_copy.cpp b/src/binder/bind/bind_copy.cpp index 6588970262..c630929752 100644 --- a/src/binder/bind/bind_copy.cpp +++ b/src/binder/bind/bind_copy.cpp @@ -86,15 +86,16 @@ CSVReaderConfig Binder::bindParsingOptions( auto boundCopyOptionExpression = expressionBinder.bindExpression(*copyOptionExpression); assert(boundCopyOptionExpression->expressionType = LITERAL); if (copyOptionName == "HEADER") { - if (boundCopyOptionExpression->dataType.typeID != BOOL) { + if (boundCopyOptionExpression->dataType.getLogicalTypeID() != LogicalTypeID::BOOL) { throw BinderException( "The value type of parsing csv option " + copyOptionName + " must be boolean."); } csvReaderConfig.hasHeader = ((LiteralExpression&)(*boundCopyOptionExpression)).value->getValue(); - } else if (boundCopyOptionExpression->dataType.typeID == STRING && + } else if (boundCopyOptionExpression->dataType.getLogicalTypeID() == + LogicalTypeID::STRING && isValidStringParsingOption) { - if (boundCopyOptionExpression->dataType.typeID != STRING) { + if (boundCopyOptionExpression->dataType.getLogicalTypeID() != LogicalTypeID::STRING) { throw BinderException( "The value type of parsing csv option " + copyOptionName + " must be string."); } diff --git a/src/binder/bind/bind_ddl.cpp b/src/binder/bind/bind_ddl.cpp index 71b859f168..8ad6a14801 100644 --- a/src/binder/bind/bind_ddl.cpp +++ b/src/binder/bind/bind_ddl.cpp @@ -33,7 +33,8 @@ std::unique_ptr Binder::bindCreateNodeTableClause( auto primaryKeyIdx = bindPrimaryKey( createNodeTableClause.getPKColName(), createNodeTableClause.getPropertyNameDataTypes()); for (auto i = 0u; i < boundProperties.size(); ++i) { - if (boundProperties[i].dataType.typeID == SERIAL && primaryKeyIdx != i) { + if (boundProperties[i].dataType.getLogicalTypeID() == LogicalTypeID::SERIAL && + primaryKeyIdx != i) { throw BinderException("Serial property in node table must be the primary key."); } } @@ -49,7 +50,7 @@ std::unique_ptr Binder::bindCreateRelTableClause( } auto boundProperties = bindProperties(createRelClause.getPropertyNameDataTypes()); for (auto& boundProperty : boundProperties) { - if (boundProperty.dataType.typeID == SERIAL) { + if (boundProperty.dataType.getLogicalTypeID() == LogicalTypeID::SERIAL) { throw BinderException("Serial property is not supported in rel table."); } } @@ -93,7 +94,7 @@ std::unique_ptr Binder::bindAddPropertyClause(const parser::Stat if (catalogContent->getTableSchema(tableID)->containProperty(addProperty.getPropertyName())) { throw BinderException("Property: " + addProperty.getPropertyName() + " already exists."); } - if (dataType.typeID == SERIAL) { + if (dataType.getLogicalTypeID() == LogicalTypeID::SERIAL) { throw BinderException("Serial property in node table must be the primary key."); } auto defaultVal = ExpressionBinder::implicitCastIfNecessary( @@ -173,10 +174,10 @@ uint32_t Binder::bindPrimaryKey(const std::string& pkColName, auto primaryKey = propertyNameDataTypes[primaryKeyIdx]; StringUtils::toUpper(primaryKey.second); // We only support INT64, STRING and SERIAL column as the primary key. - switch (Types::dataTypeFromString(primaryKey.second).typeID) { - case common::INT64: - case common::STRING: - case common::SERIAL: + switch (LogicalTypeUtils::dataTypeFromString(primaryKey.second).getLogicalTypeID()) { + case common::LogicalTypeID::INT64: + case common::LogicalTypeID::STRING: + case common::LogicalTypeID::SERIAL: break; default: throw BinderException( @@ -195,30 +196,31 @@ property_id_t Binder::bindPropertyName(TableSchema* tableSchema, const std::stri tableSchema->tableName + " table doesn't have property: " + propertyName + "."); } -DataType Binder::bindDataType(const std::string& dataType) { - auto boundType = Types::dataTypeFromString(dataType); - if (boundType.typeID == common::FIXED_LIST) { - auto validNumericTypes = common::DataType::getNumericalTypeIDs(); - auto fixedListTypeInfo = reinterpret_cast(boundType.getExtraTypeInfo()); +LogicalType Binder::bindDataType(const std::string& dataType) { + auto boundType = LogicalTypeUtils::dataTypeFromString(dataType); + if (boundType.getLogicalTypeID() == common::LogicalTypeID::FIXED_LIST) { + auto validNumericTypes = common::LogicalType::getNumericalLogicalTypeIDs(); + auto childType = common::FixedListType::getChildType(&boundType); + auto numElementsInList = common::FixedListType::getNumElementsInList(&boundType); if (find(validNumericTypes.begin(), validNumericTypes.end(), - boundType.getChildType()->typeID) == validNumericTypes.end()) { + childType->getLogicalTypeID()) == validNumericTypes.end()) { throw common::BinderException( "The child type of a fixed list must be a numeric type. Given: " + - common::Types::dataTypeToString(*boundType.getChildType()) + "."); + common::LogicalTypeUtils::dataTypeToString(*childType) + "."); } - if (fixedListTypeInfo->getFixedNumElementsInList() == 0) { + if (numElementsInList == 0) { // Note: the parser already guarantees that the number of elements is a non-negative // number. However, we still need to check whether the number of elements is 0. throw common::BinderException( "The number of elements in a fixed list must be greater than 0. Given: " + - std::to_string(fixedListTypeInfo->getFixedNumElementsInList()) + "."); + std::to_string(numElementsInList) + "."); } auto numElementsPerPage = storage::PageUtils::getNumElementsInAPage( - Types::getDataTypeSize(boundType), true /* hasNull */); + storage::StorageUtils::getDataTypeSize(boundType), true /* hasNull */); if (numElementsPerPage == 0) { throw common::BinderException( StringUtils::string_format("Cannot store a fixed list of size {} in a page.", - Types::getDataTypeSize(boundType))); + storage::StorageUtils::getDataTypeSize(boundType))); } } return boundType; diff --git a/src/binder/bind/bind_graph_pattern.cpp b/src/binder/bind/bind_graph_pattern.cpp index 02c485fdcd..47d3ba2411 100644 --- a/src/binder/bind/bind_graph_pattern.cpp +++ b/src/binder/bind/bind_graph_pattern.cpp @@ -110,7 +110,7 @@ void Binder::bindQueryRel(const RelPattern& relPattern, auto parsedName = relPattern.getVariableName(); if (variablesInScope.contains(parsedName)) { auto prevVariable = variablesInScope.at(parsedName); - ExpressionBinder::validateExpectedDataType(*prevVariable, REL); + ExpressionBinder::validateExpectedDataType(*prevVariable, LogicalTypeID::REL); throw BinderException("Bind relationship " + parsedName + " to relationship with same name is not supported."); } @@ -131,8 +131,11 @@ void Binder::bindQueryRel(const RelPattern& relPattern, // bind variable length auto [lowerBound, upperBound] = bindVariableLengthRelBound(relPattern); auto isVariableLength = !(lowerBound == 1 && upperBound == 1); - auto dataType = isVariableLength ? common::DataType(std::make_unique(INTERNAL_ID)) : - common::DataType(common::REL); + auto dataType = isVariableLength ? + common::LogicalType(LogicalTypeID::VAR_LIST, + std::make_unique( + std::make_unique(LogicalTypeID::INTERNAL_ID))) : + common::LogicalType(common::LogicalTypeID::REL); auto queryRel = make_shared(std::move(dataType), getUniqueExpressionName(parsedName), parsedName, tableIDs, srcNode, dstNode, relPattern.getDirection() != BOTH, relPattern.getRelType(), lowerBound, upperBound); @@ -186,7 +189,7 @@ std::shared_ptr Binder::bindQueryNode( std::shared_ptr queryNode; if (variablesInScope.contains(parsedName)) { // bind to node in scope auto prevVariable = variablesInScope.at(parsedName); - ExpressionBinder::validateExpectedDataType(*prevVariable, NODE); + ExpressionBinder::validateExpectedDataType(*prevVariable, LogicalTypeID::NODE); queryNode = static_pointer_cast(prevVariable); // E.g. MATCH (a:person) MATCH (a:organisation) // We bind to single node a with both labels @@ -236,10 +239,10 @@ std::shared_ptr Binder::createQueryNode(const NodePattern& nodeP } std::vector Binder::bindTableIDs( - const std::vector& tableNames, DataTypeID nodeOrRelType) { + const std::vector& tableNames, LogicalTypeID nodeOrRelType) { std::unordered_set tableIDs; switch (nodeOrRelType) { - case NODE: { + case LogicalTypeID::NODE: { if (tableNames.empty()) { for (auto tableID : catalog.getReadOnlyVersion()->getNodeTableIDs()) { tableIDs.insert(tableID); @@ -251,7 +254,7 @@ std::vector Binder::bindTableIDs( } } break; - case REL: { + case LogicalTypeID::REL: { if (tableNames.empty()) { for (auto tableID : catalog.getReadOnlyVersion()->getRelTableIDs()) { tableIDs.insert(tableID); @@ -263,7 +266,7 @@ std::vector Binder::bindTableIDs( } break; default: throw NotImplementedException( - "bindTableIDs(" + Types::dataTypeToString(nodeOrRelType) + ")."); + "bindTableIDs(" + LogicalTypeUtils::dataTypeToString(nodeOrRelType) + ")."); } auto result = std::vector{tableIDs.begin(), tableIDs.end()}; std::sort(result.begin(), result.end()); diff --git a/src/binder/bind/bind_projection_clause.cpp b/src/binder/bind/bind_projection_clause.cpp index 1205830faf..982431bf41 100644 --- a/src/binder/bind/bind_projection_clause.cpp +++ b/src/binder/bind/bind_projection_clause.cpp @@ -32,7 +32,8 @@ std::unique_ptr Binder::bindReturnClause(const ReturnClause& auto statementResult = std::make_unique(); for (auto& expression : boundProjectionExpressions) { auto dataType = expression->getDataType(); - if (dataType.typeID == common::NODE || dataType.typeID == common::REL) { + if (dataType.getLogicalTypeID() == common::LogicalTypeID::NODE || + dataType.getLogicalTypeID() == common::LogicalTypeID::REL) { statementResult->addColumn(expression, rewriteNodeOrRelExpression(*expression)); } else { statementResult->addColumn(expression, expression_vector{expression}); @@ -67,10 +68,10 @@ expression_vector Binder::bindProjectionExpressions( } expression_vector Binder::rewriteNodeOrRelExpression(const Expression& expression) { - if (expression.dataType.typeID == common::NODE) { + if (expression.dataType.getLogicalTypeID() == common::LogicalTypeID::NODE) { return rewriteNodeExpression(expression); } else { - assert(expression.dataType.typeID == common::REL); + assert(expression.dataType.getLogicalTypeID() == common::LogicalTypeID::REL); return rewriteRelExpression(expression); } } @@ -138,7 +139,8 @@ expression_vector Binder::bindOrderByExpressions( expression_vector boundOrderByExpressions; for (auto& expression : orderByExpressions) { auto boundExpression = expressionBinder.bindExpression(*expression); - if (boundExpression->dataType.typeID == NODE || boundExpression->dataType.typeID == REL) { + if (boundExpression->dataType.getLogicalTypeID() == LogicalTypeID::NODE || + boundExpression->dataType.getLogicalTypeID() == LogicalTypeID::REL) { throw BinderException("Cannot order by " + boundExpression->toString() + ". Order by node or rel is not supported."); } @@ -153,7 +155,8 @@ uint64_t Binder::bindSkipLimitExpression(const ParsedExpression& expression) { // We currently do not support the number of rows to skip/limit written as an expression (eg. // SKIP 3 + 2 is not supported). if (expression.getExpressionType() != LITERAL || - ((LiteralExpression&)(*boundExpression)).getDataType().typeID != INT64) { + ((LiteralExpression&)(*boundExpression)).getDataType().getLogicalTypeID() != + LogicalTypeID::INT64) { throw BinderException("The number of rows to skip/limit must be a non-negative integer."); } return ((LiteralExpression&)(*boundExpression)).value->getValue(); @@ -169,8 +172,8 @@ void Binder::addExpressionsToScope(const expression_vector& projectionExpression void Binder::resolveAnyDataTypeWithDefaultType(const expression_vector& expressions) { for (auto& expression : expressions) { - if (expression->dataType.typeID == ANY) { - ExpressionBinder::implicitCastIfNecessary(expression, STRING); + if (expression->dataType.getLogicalTypeID() == LogicalTypeID::ANY) { + ExpressionBinder::implicitCastIfNecessary(expression, LogicalTypeID::STRING); } } } diff --git a/src/binder/bind/bind_reading_clause.cpp b/src/binder/bind/bind_reading_clause.cpp index 5e07c24b32..5055524263 100644 --- a/src/binder/bind/bind_reading_clause.cpp +++ b/src/binder/bind/bind_reading_clause.cpp @@ -50,9 +50,10 @@ std::unique_ptr Binder::bindMatchClause(const ReadingClause& std::unique_ptr Binder::bindUnwindClause(const ReadingClause& readingClause) { auto& unwindClause = (UnwindClause&)readingClause; auto boundExpression = expressionBinder.bindExpression(*unwindClause.getExpression()); - boundExpression = ExpressionBinder::implicitCastIfNecessary(boundExpression, VAR_LIST); - auto aliasExpression = - createVariable(unwindClause.getAlias(), *boundExpression->dataType.getChildType()); + boundExpression = + ExpressionBinder::implicitCastIfNecessary(boundExpression, LogicalTypeID::VAR_LIST); + auto aliasExpression = createVariable( + unwindClause.getAlias(), *common::VarListType::getChildType(&boundExpression->dataType)); return make_unique(std::move(boundExpression), std::move(aliasExpression)); } diff --git a/src/binder/bind/bind_updating_clause.cpp b/src/binder/bind/bind_updating_clause.cpp index 15273f7614..513f7cf45c 100644 --- a/src/binder/bind/bind_updating_clause.cpp +++ b/src/binder/bind/bind_updating_clause.cpp @@ -113,12 +113,12 @@ std::unique_ptr Binder::bindSetClause(const UpdatingClause& for (auto i = 0u; i < setClause.getNumSetItems(); ++i) { auto setItem = setClause.getSetItem(i); auto nodeOrRel = expressionBinder.bindExpression(*setItem.first->getChild(0)); - switch (nodeOrRel->dataType.typeID) { - case DataTypeID::NODE: { + switch (nodeOrRel->dataType.getLogicalTypeID()) { + case LogicalTypeID::NODE: { auto node = static_pointer_cast(nodeOrRel); boundSetClause->addSetNodeProperty(bindSetNodeProperty(node, setItem)); } break; - case DataTypeID::REL: { + case LogicalTypeID::REL: { auto rel = static_pointer_cast(nodeOrRel); boundSetClause->addSetRelProperty(bindSetRelProperty(rel, setItem)); } break; @@ -162,12 +162,12 @@ std::unique_ptr Binder::bindDeleteClause( auto boundDeleteClause = std::make_unique(); for (auto i = 0u; i < deleteClause.getNumExpressions(); ++i) { auto nodeOrRel = expressionBinder.bindExpression(*deleteClause.getExpression(i)); - switch (nodeOrRel->dataType.typeID) { - case DataTypeID::NODE: { + switch (nodeOrRel->dataType.getLogicalTypeID()) { + case LogicalTypeID::NODE: { auto deleteNode = bindDeleteNode(static_pointer_cast(nodeOrRel)); boundDeleteClause->addDeleteNode(std::move(deleteNode)); } break; - case DataTypeID::REL: { + case LogicalTypeID::REL: { auto deleteRel = bindDeleteRel(static_pointer_cast(nodeOrRel)); boundDeleteClause->addDeleteRel(std::move(deleteRel)); } break; diff --git a/src/binder/bind_expression/bind_boolean_expression.cpp b/src/binder/bind_expression/bind_boolean_expression.cpp index 36ef4b1281..16cb09eed8 100644 --- a/src/binder/bind_expression/bind_boolean_expression.cpp +++ b/src/binder/bind_expression/bind_boolean_expression.cpp @@ -21,14 +21,14 @@ std::shared_ptr ExpressionBinder::bindBooleanExpression( ExpressionType expressionType, const expression_vector& children) { expression_vector childrenAfterCast; for (auto& child : children) { - childrenAfterCast.push_back(implicitCastIfNecessary(child, BOOL)); + childrenAfterCast.push_back(implicitCastIfNecessary(child, LogicalTypeID::BOOL)); } auto functionName = expressionTypeToString(expressionType); auto execFunc = function::VectorBooleanOperations::bindExecFunction(expressionType, childrenAfterCast); auto selectFunc = function::VectorBooleanOperations::bindSelectFunction(expressionType, childrenAfterCast); - auto bindData = std::make_unique(DataType(BOOL)); + auto bindData = std::make_unique(LogicalType(LogicalTypeID::BOOL)); auto uniqueExpressionName = ScalarFunctionExpression::getUniqueName(functionName, childrenAfterCast); return make_shared(functionName, expressionType, std::move(bindData), diff --git a/src/binder/bind_expression/bind_case_expression.cpp b/src/binder/bind_expression/bind_case_expression.cpp index bfbd564fda..1c520ed3ab 100644 --- a/src/binder/bind_expression/bind_case_expression.cpp +++ b/src/binder/bind_expression/bind_case_expression.cpp @@ -42,7 +42,7 @@ std::shared_ptr ExpressionBinder::bindCaseExpression( for (auto i = 0u; i < parsedCaseExpression.getNumCaseAlternative(); ++i) { auto caseAlternative = parsedCaseExpression.getCaseAlternative(i); auto boundWhen = bindExpression(*caseAlternative->whenExpression); - boundWhen = implicitCastIfNecessary(boundWhen, common::BOOL); + boundWhen = implicitCastIfNecessary(boundWhen, common::LogicalTypeID::BOOL); auto boundThen = bindExpression(*caseAlternative->thenExpression); boundThen = implicitCastIfNecessary(boundThen, outDataType); boundCaseExpression->addCaseAlternative(boundWhen, boundThen); diff --git a/src/binder/bind_expression/bind_comparison_expression.cpp b/src/binder/bind_expression/bind_comparison_expression.cpp index b5fa610eab..53d4104719 100644 --- a/src/binder/bind_expression/bind_comparison_expression.cpp +++ b/src/binder/bind_expression/bind_comparison_expression.cpp @@ -21,7 +21,7 @@ std::shared_ptr ExpressionBinder::bindComparisonExpression( common::ExpressionType expressionType, const expression_vector& children) { auto builtInFunctions = binder->catalog.getBuiltInScalarFunctions(); auto functionName = expressionTypeToString(expressionType); - std::vector childrenTypes; + std::vector childrenTypes; for (auto& child : children) { childrenTypes.push_back(child->dataType); } @@ -32,7 +32,7 @@ std::shared_ptr ExpressionBinder::bindComparisonExpression( implicitCastIfNecessary(children[i], function->parameterTypeIDs[i])); } auto bindData = - std::make_unique(common::DataType(function->returnTypeID)); + std::make_unique(common::LogicalType(function->returnTypeID)); auto uniqueExpressionName = ScalarFunctionExpression::getUniqueName(function->name, childrenAfterCast); return make_shared(functionName, expressionType, std::move(bindData), diff --git a/src/binder/bind_expression/bind_function_expression.cpp b/src/binder/bind_expression/bind_function_expression.cpp index 3d44a0cbdd..09a5ca459d 100644 --- a/src/binder/bind_expression/bind_function_expression.cpp +++ b/src/binder/bind_expression/bind_function_expression.cpp @@ -46,7 +46,7 @@ std::shared_ptr ExpressionBinder::bindScalarFunctionExpression( std::shared_ptr ExpressionBinder::bindScalarFunctionExpression( const expression_vector& children, const std::string& functionName) { auto builtInFunctions = binder->catalog.getBuiltInScalarFunctions(); - std::vector childrenTypes; + std::vector childrenTypes; for (auto& child : children) { childrenTypes.push_back(child->dataType); } @@ -64,7 +64,8 @@ std::shared_ptr ExpressionBinder::bindScalarFunctionExpression( if (function->bindFunc) { bindData = function->bindFunc(childrenAfterCast, function); } else { - bindData = std::make_unique(DataType(function->returnTypeID)); + bindData = + std::make_unique(LogicalType(function->returnTypeID)); } auto uniqueExpressionName = ScalarFunctionExpression::getUniqueName(function->name, childrenAfterCast); @@ -76,13 +77,14 @@ std::shared_ptr ExpressionBinder::bindScalarFunctionExpression( std::shared_ptr ExpressionBinder::bindAggregateFunctionExpression( const ParsedExpression& parsedExpression, const std::string& functionName, bool isDistinct) { auto builtInFunctions = binder->catalog.getBuiltInAggregateFunction(); - std::vector childrenTypes; + std::vector childrenTypes; expression_vector children; for (auto i = 0u; i < parsedExpression.getNumChildren(); ++i) { auto child = bindExpression(*parsedExpression.getChild(i)); // rewrite aggregate on node or rel as aggregate on their internal IDs. // e.g. COUNT(a) -> COUNT(a._id) - if (child->dataType.typeID == NODE || child->dataType.typeID == REL) { + if (child->dataType.getLogicalTypeID() == LogicalTypeID::NODE || + child->dataType.getLogicalTypeID() == LogicalTypeID::REL) { child = bindInternalIDExpression(*child); } childrenTypes.push_back(child->dataType); @@ -98,7 +100,8 @@ std::shared_ptr ExpressionBinder::bindAggregateFunctionExpression( if (function->bindFunc) { bindData = function->bindFunc(children, function); } else { - bindData = std::make_unique(DataType(function->returnTypeID)); + bindData = + std::make_unique(LogicalType(function->returnTypeID)); } return make_shared(functionName, std::move(bindData), std::move(children), function->aggregateFunction->clone(), uniqueExpressionName); @@ -123,17 +126,18 @@ std::shared_ptr ExpressionBinder::staticEvaluate( std::shared_ptr ExpressionBinder::bindInternalIDExpression( const ParsedExpression& parsedExpression) { auto child = bindExpression(*parsedExpression.getChild(0)); - validateExpectedDataType(*child, std::unordered_set{NODE, REL}); + validateExpectedDataType( + *child, std::unordered_set{LogicalTypeID::NODE, LogicalTypeID::REL}); return bindInternalIDExpression(*child); } std::shared_ptr ExpressionBinder::bindInternalIDExpression( const Expression& expression) { - if (expression.dataType.typeID == NODE) { + if (expression.dataType.getLogicalTypeID() == LogicalTypeID::NODE) { auto& node = (NodeExpression&)expression; return node.getInternalIDProperty(); } else { - assert(expression.dataType.typeID == REL); + assert(expression.dataType.getLogicalTypeID() == LogicalTypeID::REL); return bindRelPropertyExpression(expression, INTERNAL_ID_SUFFIX); } } @@ -145,8 +149,8 @@ std::unique_ptr ExpressionBinder::createInternalNodeIDExpression( for (auto tableID : node.getTableIDs()) { propertyIDPerTable.insert({tableID, INVALID_PROPERTY_ID}); } - auto result = std::make_unique(DataType(INTERNAL_ID), INTERNAL_ID_SUFFIX, - node, std::move(propertyIDPerTable), false /* isPrimaryKey */); + auto result = std::make_unique(LogicalType(LogicalTypeID::INTERNAL_ID), + INTERNAL_ID_SUFFIX, node, std::move(propertyIDPerTable), false /* isPrimaryKey */); return result; } @@ -154,10 +158,10 @@ std::shared_ptr ExpressionBinder::bindLabelFunction( const ParsedExpression& parsedExpression) { // bind child node auto child = bindExpression(*parsedExpression.getChild(0)); - if (child->dataType.typeID == common::NODE) { + if (child->dataType.getLogicalTypeID() == common::LogicalTypeID::NODE) { return bindNodeLabelFunction(*child); } else { - assert(child->dataType.typeID == common::REL); + assert(child->dataType.getLogicalTypeID() == common::LogicalTypeID::REL); return bindRelLabelFunction(*child); } } @@ -189,11 +193,15 @@ std::shared_ptr ExpressionBinder::bindNodeLabelFunction(const Expres auto nodeTableIDs = catalogContent->getNodeTableIDs(); expression_vector children; children.push_back(node.getInternalIDProperty()); - auto labelsValue = std::make_unique(DataType(std::make_unique(STRING)), - populateLabelValues(nodeTableIDs, *catalogContent)); + auto labelsValue = + std::make_unique(LogicalType(LogicalTypeID::VAR_LIST, + std::make_unique( + std::make_unique(LogicalTypeID::STRING))), + populateLabelValues(nodeTableIDs, *catalogContent)); children.push_back(createLiteralExpression(std::move(labelsValue))); auto execFunc = function::LabelVectorOperation::execFunction; - auto bindData = std::make_unique(DataType(STRING)); + auto bindData = + std::make_unique(LogicalType(LogicalTypeID::STRING)); auto uniqueExpressionName = ScalarFunctionExpression::getUniqueName(LABEL_FUNC_NAME, children); return make_shared(LABEL_FUNC_NAME, FUNCTION, std::move(bindData), std::move(children), execFunc, nullptr, uniqueExpressionName); @@ -209,11 +217,15 @@ std::shared_ptr ExpressionBinder::bindRelLabelFunction(const Express auto relTableIDs = catalogContent->getRelTableIDs(); expression_vector children; children.push_back(rel.getInternalIDProperty()); - auto labelsValue = std::make_unique(DataType(std::make_unique(STRING)), - populateLabelValues(relTableIDs, *catalogContent)); + auto labelsValue = + std::make_unique(LogicalType(LogicalTypeID::VAR_LIST, + std::make_unique( + std::make_unique(LogicalTypeID::STRING))), + populateLabelValues(relTableIDs, *catalogContent)); children.push_back(createLiteralExpression(std::move(labelsValue))); auto execFunc = function::LabelVectorOperation::execFunction; - auto bindData = std::make_unique(DataType(STRING)); + auto bindData = + std::make_unique(LogicalType(LogicalTypeID::STRING)); auto uniqueExpressionName = ScalarFunctionExpression::getUniqueName(LABEL_FUNC_NAME, children); return make_shared(LABEL_FUNC_NAME, FUNCTION, std::move(bindData), std::move(children), execFunc, nullptr, uniqueExpressionName); diff --git a/src/binder/bind_expression/bind_null_operator_expression.cpp b/src/binder/bind_expression/bind_null_operator_expression.cpp index 245fdcc41e..0cfadc248a 100644 --- a/src/binder/bind_expression/bind_null_operator_expression.cpp +++ b/src/binder/bind_expression/bind_null_operator_expression.cpp @@ -17,7 +17,8 @@ std::shared_ptr ExpressionBinder::bindNullOperatorExpression( auto functionName = expressionTypeToString(expressionType); auto execFunc = function::VectorNullOperations::bindExecFunction(expressionType, children); auto selectFunc = function::VectorNullOperations::bindSelectFunction(expressionType, children); - auto bindData = std::make_unique(common::DataType(common::BOOL)); + auto bindData = std::make_unique( + common::LogicalType(common::LogicalTypeID::BOOL)); auto uniqueExpressionName = ScalarFunctionExpression::getUniqueName(functionName, children); return make_shared(functionName, expressionType, std::move(bindData), std::move(children), std::move(execFunc), std::move(selectFunc), uniqueExpressionName); diff --git a/src/binder/bind_expression/bind_property_expression.cpp b/src/binder/bind_expression/bind_property_expression.cpp index b04f59dfb1..f6cc774ffc 100644 --- a/src/binder/bind_expression/bind_property_expression.cpp +++ b/src/binder/bind_expression/bind_property_expression.cpp @@ -21,13 +21,15 @@ std::shared_ptr ExpressionBinder::bindPropertyExpression( propertyName + " is reserved for system usage. External access is not allowed."); } auto child = bindExpression(*parsedExpression.getChild(0)); - validateExpectedDataType(*child, std::unordered_set{NODE, REL, STRUCT}); - if (NODE == child->dataType.typeID) { + validateExpectedDataType(*child, std::unordered_set{LogicalTypeID::NODE, + LogicalTypeID::REL, LogicalTypeID::STRUCT}); + auto childTypeID = child->dataType.getLogicalTypeID(); + if (LogicalTypeID::NODE == childTypeID) { return bindNodePropertyExpression(*child, propertyName); - } else if (common::REL == child->dataType.typeID) { + } else if (LogicalTypeID::REL == childTypeID) { return bindRelPropertyExpression(*child, propertyName); } else { - assert(common::STRUCT == child->dataType.typeID); + assert(LogicalTypeID::STRUCT == childTypeID); auto stringValue = std::make_unique(propertyName); return bindScalarFunctionExpression( expression_vector{child, createLiteralExpression(std::move(stringValue))}, @@ -46,7 +48,7 @@ std::shared_ptr ExpressionBinder::bindNodePropertyExpression( } static void validatePropertiesWithSameDataType(const std::vector& properties, - const DataType& dataType, const std::string& propertyName, const std::string& variableName) { + const LogicalType& dataType, const std::string& propertyName, const std::string& variableName) { for (auto& property : properties) { if (property.dataType != dataType) { throw BinderException( diff --git a/src/binder/binder.cpp b/src/binder/binder.cpp index 50406c48f5..595c9eabf3 100644 --- a/src/binder/binder.cpp +++ b/src/binder/binder.cpp @@ -46,7 +46,7 @@ std::unique_ptr Binder::bind(const Statement& statement) { std::shared_ptr Binder::bindWhereExpression(const ParsedExpression& parsedExpression) { auto whereExpression = expressionBinder.bindExpression(parsedExpression); - ExpressionBinder::implicitCastIfNecessary(whereExpression, BOOL); + ExpressionBinder::implicitCastIfNecessary(whereExpression, LogicalTypeID::BOOL); return whereExpression; } @@ -65,7 +65,7 @@ table_id_t Binder::bindNodeTableID(const std::string& tableName) const { } std::shared_ptr Binder::createVariable( - const std::string& name, const DataType& dataType) { + const std::string& name, const LogicalType& dataType) { if (variablesInScope.contains(name)) { throw BinderException("Variable " + name + " already exists."); } @@ -124,8 +124,8 @@ void Binder::validateUnionColumnsOfTheSameType( // Check whether the dataTypes in union expressions are exactly the same in each single // query. for (auto j = 0u; j < expressionsToProject.size(); j++) { - ExpressionBinder::validateExpectedDataType( - *expressionsToProjectToCheck[j], expressionsToProject[j]->dataType.typeID); + ExpressionBinder::validateExpectedDataType(*expressionsToProjectToCheck[j], + expressionsToProject[j]->dataType.getLogicalTypeID()); } } } diff --git a/src/binder/expression/expression.cpp b/src/binder/expression/expression.cpp index d2e6a18c51..c83050becb 100644 --- a/src/binder/expression/expression.cpp +++ b/src/binder/expression/expression.cpp @@ -80,9 +80,9 @@ bool Expression::hasSubExpressionOfType( } bool ExpressionUtil::allExpressionsHaveDataType( - expression_vector& expressions, DataTypeID dataTypeID) { + expression_vector& expressions, LogicalTypeID dataTypeID) { for (auto& expression : expressions) { - if (expression->dataType.typeID != dataTypeID) { + if (expression->dataType.getLogicalTypeID() != dataTypeID) { return false; } } diff --git a/src/binder/expression_binder.cpp b/src/binder/expression_binder.cpp index d73434f831..a078c3aed8 100644 --- a/src/binder/expression_binder.cpp +++ b/src/binder/expression_binder.cpp @@ -51,11 +51,11 @@ std::shared_ptr ExpressionBinder::bindExpression( } std::shared_ptr ExpressionBinder::implicitCastIfNecessary( - const std::shared_ptr& expression, const DataType& targetType) { - if (targetType.typeID == ANY || expression->dataType == targetType) { + const std::shared_ptr& expression, const LogicalType& targetType) { + if (targetType.getLogicalTypeID() == LogicalTypeID::ANY || expression->dataType == targetType) { return expression; } - if (expression->dataType.typeID == ANY) { + if (expression->dataType.getLogicalTypeID() == LogicalTypeID::ANY) { resolveAnyDataType(*expression, targetType); return expression; } @@ -63,25 +63,26 @@ std::shared_ptr ExpressionBinder::implicitCastIfNecessary( } std::shared_ptr ExpressionBinder::implicitCastIfNecessary( - const std::shared_ptr& expression, DataTypeID targetTypeID) { - if (targetTypeID == ANY || expression->dataType.typeID == targetTypeID) { + const std::shared_ptr& expression, LogicalTypeID targetTypeID) { + if (targetTypeID == LogicalTypeID::ANY || + expression->dataType.getLogicalTypeID() == targetTypeID) { return expression; } - if (expression->dataType.typeID == ANY) { - if (targetTypeID == VAR_LIST) { + if (expression->dataType.getLogicalTypeID() == LogicalTypeID::ANY) { + if (targetTypeID == LogicalTypeID::VAR_LIST) { // e.g. len($1) we cannot infer the child type for $1. throw BinderException("Cannot resolve recursive data type for expression " + expression->toString() + "."); } - resolveAnyDataType(*expression, DataType(targetTypeID)); + resolveAnyDataType(*expression, LogicalType(targetTypeID)); return expression; } - assert(targetTypeID != VAR_LIST); - return implicitCast(expression, DataType(targetTypeID)); + assert(targetTypeID != LogicalTypeID::VAR_LIST); + return implicitCast(expression, LogicalType(targetTypeID)); } std::shared_ptr ExpressionBinder::implicitCast( - const std::shared_ptr& expression, const common::DataType& targetType) { + const std::shared_ptr& expression, const common::LogicalType& targetType) { if (VectorCastOperations::hasImplicitCast(expression->dataType, targetType)) { auto functionName = VectorCastOperations::bindImplicitCastFuncName(targetType); auto children = expression_vector{expression}; @@ -90,17 +91,18 @@ std::shared_ptr ExpressionBinder::implicitCast( return std::make_shared(functionName, FUNCTION, std::move(bindData), std::move(children), VectorCastOperations::bindImplicitCastFunc( - expression->dataType.typeID, targetType.typeID), + expression->dataType.getLogicalTypeID(), targetType.getLogicalTypeID()), nullptr /* selectFunc */, std::move(uniqueName)); } else { - throw common::BinderException("Expression " + expression->toString() + " has data type " + - common::Types::dataTypeToString(expression->dataType) + - " but expect " + common::Types::dataTypeToString(targetType) + - ". Implicit cast is not supported."); + throw common::BinderException( + "Expression " + expression->toString() + " has data type " + + common::LogicalTypeUtils::dataTypeToString(expression->dataType) + " but expect " + + common::LogicalTypeUtils::dataTypeToString(targetType) + + ". Implicit cast is not supported."); } } -void ExpressionBinder::resolveAnyDataType(Expression& expression, const DataType& targetType) { +void ExpressionBinder::resolveAnyDataType(Expression& expression, const LogicalType& targetType) { if (expression.expressionType == PARAMETER) { // expression is parameter ((ParameterExpression&)expression).setDataType(targetType); } else { // expression is null literal @@ -110,13 +112,14 @@ void ExpressionBinder::resolveAnyDataType(Expression& expression, const DataType } void ExpressionBinder::validateExpectedDataType( - const Expression& expression, const std::unordered_set& targets) { + const Expression& expression, const std::unordered_set& targets) { auto dataType = expression.dataType; - if (!targets.contains(dataType.typeID)) { - std::vector targetsVec{targets.begin(), targets.end()}; + if (!targets.contains(dataType.getLogicalTypeID())) { + std::vector targetsVec{targets.begin(), targets.end()}; throw BinderException(expression.toString() + " has data type " + - Types::dataTypeToString(dataType.typeID) + ". " + - Types::dataTypesToString(targetsVec) + " was expected."); + LogicalTypeUtils::dataTypeToString(dataType.getLogicalTypeID()) + + ". " + LogicalTypeUtils::dataTypesToString(targetsVec) + + " was expected."); } } diff --git a/src/c_api/data_type.cpp b/src/c_api/data_type.cpp index bf94185b91..0665d49b92 100644 --- a/src/c_api/data_type.cpp +++ b/src/c_api/data_type.cpp @@ -9,15 +9,18 @@ kuzu_data_type* kuzu_data_type_create( kuzu_data_type_id id, kuzu_data_type* child_type, uint64_t fixed_num_elements_in_list) { auto* c_data_type = (kuzu_data_type*)malloc(sizeof(kuzu_data_type)); uint8_t data_type_id_u8 = id; - DataType* data_type; + LogicalType* data_type; + auto logicalTypeID = static_cast(data_type_id_u8); if (child_type == nullptr) { - data_type = new DataType(static_cast(data_type_id_u8)); + data_type = new LogicalType(logicalTypeID); } else { auto child_type_pty = - std::make_unique(*static_cast(child_type->_data_type)); - data_type = fixed_num_elements_in_list > 0 ? - new DataType(std::move(child_type_pty), fixed_num_elements_in_list) : - new DataType(std::move(child_type_pty)); + std::make_unique(*static_cast(child_type->_data_type)); + auto extraTypeInfo = fixed_num_elements_in_list > 0 ? + std::make_unique( + std::move(child_type_pty), fixed_num_elements_in_list) : + std::make_unique(std::move(child_type_pty)); + data_type = new LogicalType(logicalTypeID, std::move(extraTypeInfo)); } c_data_type->_data_type = data_type; return c_data_type; @@ -25,7 +28,7 @@ kuzu_data_type* kuzu_data_type_create( kuzu_data_type* kuzu_data_type_clone(kuzu_data_type* data_type) { auto* c_data_type = (kuzu_data_type*)malloc(sizeof(kuzu_data_type)); - c_data_type->_data_type = new DataType(*static_cast(data_type->_data_type)); + c_data_type->_data_type = new LogicalType(*static_cast(data_type->_data_type)); return c_data_type; } @@ -34,46 +37,26 @@ void kuzu_data_type_destroy(kuzu_data_type* data_type) { return; } if (data_type->_data_type != nullptr) { - delete static_cast(data_type->_data_type); + delete static_cast(data_type->_data_type); } free(data_type); } bool kuzu_data_type_equals(kuzu_data_type* data_type1, kuzu_data_type* data_type2) { - return *static_cast(data_type1->_data_type) == - *static_cast(data_type2->_data_type); + return *static_cast(data_type1->_data_type) == + *static_cast(data_type2->_data_type); } kuzu_data_type_id kuzu_data_type_get_id(kuzu_data_type* data_type) { auto data_type_id_u8 = - static_cast(static_cast(data_type->_data_type)->getTypeID()); + static_cast(static_cast(data_type->_data_type)->getLogicalTypeID()); return static_cast(data_type_id_u8); } -kuzu_data_type* kuzu_data_type_get_child_type(kuzu_data_type* data_type) { - auto parent_type = static_cast(data_type->_data_type); - if (parent_type->getTypeID() != DataTypeID::FIXED_LIST && - parent_type->getTypeID() != DataTypeID::VAR_LIST) { - return nullptr; - } - auto child_type = static_cast(data_type->_data_type)->getChildType(); - if (child_type == nullptr) { - return nullptr; - } - auto* child_type_c = (kuzu_data_type*)malloc(sizeof(kuzu_data_type)); - child_type_c->_data_type = new DataType(*child_type); - return child_type_c; -} - uint64_t kuzu_data_type_get_fixed_num_elements_in_list(kuzu_data_type* data_type) { - auto parent_type = static_cast(data_type->_data_type); - if (parent_type->getTypeID() != DataTypeID::FIXED_LIST) { - return 0; - } - auto extra_info = static_cast(data_type->_data_type)->getExtraTypeInfo(); - if (extra_info == nullptr) { + auto parent_type = static_cast(data_type->_data_type); + if (parent_type->getLogicalTypeID() != LogicalTypeID::FIXED_LIST) { return 0; } - auto fixed_list_info = dynamic_cast(extra_info); - return fixed_list_info->getFixedNumElementsInList(); + return FixedListType::getNumElementsInList(static_cast(data_type->_data_type)); } diff --git a/src/c_api/query_result.cpp b/src/c_api/query_result.cpp index a8cb7f1d42..25652e7da0 100644 --- a/src/c_api/query_result.cpp +++ b/src/c_api/query_result.cpp @@ -53,7 +53,7 @@ kuzu_data_type* kuzu_query_result_get_column_data_type( } auto column_data_type = column_data_types[index]; auto* column_data_type_c = (kuzu_data_type*)malloc(sizeof(kuzu_data_type)); - column_data_type_c->_data_type = new DataType(column_data_type); + column_data_type_c->_data_type = new LogicalType(column_data_type); return column_data_type_c; } diff --git a/src/c_api/value.cpp b/src/c_api/value.cpp index 5491515e1b..edbeabc879 100644 --- a/src/c_api/value.cpp +++ b/src/c_api/value.cpp @@ -17,7 +17,7 @@ kuzu_value* kuzu_value_create_null() { kuzu_value* kuzu_value_create_null_with_data_type(kuzu_data_type* data_type) { auto* c_value = (kuzu_value*)calloc(1, sizeof(kuzu_value)); c_value->_value = - new Value(Value::createNullValue(*static_cast(data_type->_data_type))); + new Value(Value::createNullValue(*static_cast(data_type->_data_type))); return c_value; } @@ -32,7 +32,7 @@ void kuzu_value_set_null(kuzu_value* value, bool is_null) { kuzu_value* kuzu_value_create_default(kuzu_data_type* data_type) { auto* c_value = (kuzu_value*)calloc(1, sizeof(kuzu_value)); c_value->_value = - new Value(Value::createDefaultValue(*static_cast(data_type->_data_type))); + new Value(Value::createDefaultValue(*static_cast(data_type->_data_type))); return c_value; } @@ -161,15 +161,13 @@ kuzu_value* kuzu_value_get_list_element(kuzu_value* value, uint64_t index) { uint64_t kuzu_value_get_struct_num_fields(kuzu_value* value) { auto val = static_cast(value->_value); auto data_type = val->getDataType(); - auto struct_type_info = reinterpret_cast(data_type.getExtraTypeInfo()); - return struct_type_info->getStructFields().size(); + return StructType::getNumFields(&data_type); } char* kuzu_value_get_struct_field_name(kuzu_value* value, uint64_t index) { auto val = static_cast(value->_value); auto data_type = val->getDataType(); - auto struct_type_info = reinterpret_cast(data_type.getExtraTypeInfo()); - auto struct_field_name = struct_type_info->getStructFields()[index]->getName(); + auto struct_field_name = StructType::getStructFields(&data_type)[index]->getName(); auto* c_struct_field_name = (char*)malloc(sizeof(char) * (struct_field_name.size() + 1)); strcpy(c_struct_field_name, struct_field_name.c_str()); return c_struct_field_name; @@ -181,7 +179,7 @@ kuzu_value* kuzu_value_get_struct_field_value(kuzu_value* value, uint64_t index) kuzu_data_type* kuzu_value_get_data_type(kuzu_value* value) { auto* c_data_type = (kuzu_data_type*)malloc(sizeof(kuzu_data_type)); - c_data_type->_data_type = new DataType(static_cast(value->_value)->getDataType()); + c_data_type->_data_type = new LogicalType(static_cast(value->_value)->getDataType()); return c_data_type; } diff --git a/src/catalog/catalog.cpp b/src/catalog/catalog.cpp index 41755bacee..18c7aa9610 100644 --- a/src/catalog/catalog.cpp +++ b/src/catalog/catalog.cpp @@ -39,7 +39,7 @@ template<> uint64_t SerDeser::serializeValue( const Property& value, FileInfo* fileInfo, uint64_t offset) { offset = SerDeser::serializeValue(value.name, fileInfo, offset); - offset = SerDeser::serializeValue(value.dataType, fileInfo, offset); + offset = SerDeser::serializeValue(value.dataType, fileInfo, offset); offset = SerDeser::serializeValue(value.propertyID, fileInfo, offset); return SerDeser::serializeValue(value.tableID, fileInfo, offset); } @@ -48,7 +48,7 @@ template<> uint64_t SerDeser::deserializeValue( Property& value, FileInfo* fileInfo, uint64_t offset) { offset = SerDeser::deserializeValue(value.name, fileInfo, offset); - offset = SerDeser::deserializeValue(value.dataType, fileInfo, offset); + offset = SerDeser::deserializeValue(value.dataType, fileInfo, offset); offset = SerDeser::deserializeValue(value.propertyID, fileInfo, offset); return SerDeser::deserializeValue(value.tableID, fileInfo, offset); } @@ -208,7 +208,8 @@ table_id_t CatalogContent::addRelTableSchema(std::string tableName, RelMultiplic table_id_t tableID = assignNextTableID(); nodeTableSchemas[srcTableID]->addFwdRelTableID(tableID); nodeTableSchemas[dstTableID]->addBwdRelTableID(tableID); - auto relInternalIDProperty = Property(INTERNAL_ID_SUFFIX, DataType{INTERNAL_ID}); + auto relInternalIDProperty = + Property(INTERNAL_ID_SUFFIX, LogicalType{LogicalTypeID::INTERNAL_ID}); properties.insert(properties.begin(), relInternalIDProperty); for (auto i = 0u; i < properties.size(); ++i) { properties[i].propertyID = i; @@ -410,7 +411,7 @@ void Catalog::renameTable(table_id_t tableID, std::string newName) { catalogContentForWriteTrx->renameTable(tableID, std::move(newName)); } -void Catalog::addProperty(table_id_t tableID, std::string propertyName, DataType dataType) { +void Catalog::addProperty(table_id_t tableID, std::string propertyName, LogicalType dataType) { initCatalogContentForWriteTrxIfNecessary(); catalogContentForWriteTrx->getTableSchema(tableID)->addProperty( propertyName, std::move(dataType)); diff --git a/src/common/arrow/arrow_converter.cpp b/src/common/arrow/arrow_converter.cpp index f8c7198abe..83ad6270b1 100644 --- a/src/common/arrow/arrow_converter.cpp +++ b/src/common/arrow/arrow_converter.cpp @@ -58,34 +58,34 @@ void ArrowConverter::setArrowFormatForStruct( void ArrowConverter::setArrowFormat( ArrowSchemaHolder& rootHolder, ArrowSchema& child, const main::DataTypeInfo& typeInfo) { switch (typeInfo.typeID) { - case DataTypeID::BOOL: { + case LogicalTypeID::BOOL: { child.format = "b"; } break; - case DataTypeID::INT64: { + case LogicalTypeID::INT64: { child.format = "l"; } break; - case DataTypeID::INT32: { + case LogicalTypeID::INT32: { child.format = "i"; } break; - case DataTypeID::INT16: { + case LogicalTypeID::INT16: { child.format = "s"; } break; - case DataTypeID::DOUBLE: { + case LogicalTypeID::DOUBLE: { child.format = "g"; } break; - case DataTypeID::DATE: { + case LogicalTypeID::DATE: { child.format = "tdD"; } break; - case DataTypeID::TIMESTAMP: { + case LogicalTypeID::TIMESTAMP: { child.format = "tsu:"; } break; - case DataTypeID::INTERVAL: { + case LogicalTypeID::INTERVAL: { child.format = "tDm"; } break; - case DataTypeID::STRING: { + case LogicalTypeID::STRING: { child.format = "u"; } break; - case VAR_LIST: { + case LogicalTypeID::VAR_LIST: { child.format = "+l"; child.n_children = 1; rootHolder.nestedChildren.emplace_back(); @@ -97,14 +97,14 @@ void ArrowConverter::setArrowFormat( child.children[0]->name = "l"; setArrowFormat(rootHolder, **child.children, *typeInfo.childrenTypesInfo[0]); } break; - case DataTypeID::INTERNAL_ID: - case DataTypeID::NODE: - case DataTypeID::REL: { + case LogicalTypeID::INTERNAL_ID: + case LogicalTypeID::NODE: + case LogicalTypeID::REL: { setArrowFormatForStruct(rootHolder, child, typeInfo); } break; default: throw InternalException( - "Unsupported Arrow type " + Types::dataTypeToString(typeInfo.typeID)); + "Unsupported Arrow type " + LogicalTypeUtils::dataTypeToString(typeInfo.typeID)); } } diff --git a/src/common/arrow/arrow_row_batch.cpp b/src/common/arrow/arrow_row_batch.cpp index 185bf1c2d3..af6272c234 100644 --- a/src/common/arrow/arrow_row_batch.cpp +++ b/src/common/arrow/arrow_row_batch.cpp @@ -1,6 +1,7 @@ #include "common/arrow/arrow_row_batch.h" #include "common/types/value.h" +#include "storage/storage_utils.h" namespace kuzu { namespace common { @@ -15,22 +16,23 @@ ArrowRowBatch::ArrowRowBatch( } } -template +// TODO(Ziyi): use physical type instead of logical type here. +template void ArrowRowBatch::templateInitializeVector( ArrowVector* vector, const main::DataTypeInfo& typeInfo, std::int64_t capacity) { initializeNullBits(vector->validity, capacity); - vector->data.reserve(Types::getDataTypeSize(DT) * capacity); + vector->data.reserve(storage::StorageUtils::getDataTypeSize(LogicalType{DT}) * capacity); } template<> -void ArrowRowBatch::templateInitializeVector( +void ArrowRowBatch::templateInitializeVector( ArrowVector* vector, const main::DataTypeInfo& typeInfo, std::int64_t capacity) { initializeNullBits(vector->validity, capacity); vector->data.reserve(getNumBytesForBits(capacity)); } template<> -void ArrowRowBatch::templateInitializeVector( +void ArrowRowBatch::templateInitializeVector( ArrowVector* vector, const main::DataTypeInfo& typeInfo, std::int64_t capacity) { initializeNullBits(vector->validity, capacity); // Initialize offsets and string values buffer. @@ -40,7 +42,7 @@ void ArrowRowBatch::templateInitializeVector( } template<> -void ArrowRowBatch::templateInitializeVector( +void ArrowRowBatch::templateInitializeVector( ArrowVector* vector, const main::DataTypeInfo& typeInfo, std::int64_t capacity) { initializeNullBits(vector->validity, capacity); assert(typeInfo.childrenTypesInfo.size() == 1); @@ -62,19 +64,19 @@ void ArrowRowBatch::initializeStructVector( } template<> -void ArrowRowBatch::templateInitializeVector( +void ArrowRowBatch::templateInitializeVector( ArrowVector* vector, const main::DataTypeInfo& typeInfo, std::int64_t capacity) { initializeStructVector(vector, typeInfo, capacity); } template<> -void ArrowRowBatch::templateInitializeVector( +void ArrowRowBatch::templateInitializeVector( ArrowVector* vector, const main::DataTypeInfo& typeInfo, std::int64_t capacity) { initializeStructVector(vector, typeInfo, capacity); } template<> -void ArrowRowBatch::templateInitializeVector( +void ArrowRowBatch::templateInitializeVector( ArrowVector* vector, const main::DataTypeInfo& typeInfo, std::int64_t capacity) { initializeStructVector(vector, typeInfo, capacity); } @@ -83,48 +85,49 @@ std::unique_ptr ArrowRowBatch::createVector( const main::DataTypeInfo& typeInfo, std::int64_t capacity) { auto result = std::make_unique(); switch (typeInfo.typeID) { - case BOOL: { - templateInitializeVector(result.get(), typeInfo, capacity); + case LogicalTypeID::BOOL: { + templateInitializeVector(result.get(), typeInfo, capacity); } break; - case INT64: { - templateInitializeVector(result.get(), typeInfo, capacity); + case LogicalTypeID::INT64: { + templateInitializeVector(result.get(), typeInfo, capacity); } break; - case INT32: { - templateInitializeVector(result.get(), typeInfo, capacity); + case LogicalTypeID::INT32: { + templateInitializeVector(result.get(), typeInfo, capacity); } break; - case INT16: { - templateInitializeVector(result.get(), typeInfo, capacity); + case LogicalTypeID::INT16: { + templateInitializeVector(result.get(), typeInfo, capacity); } break; - case DOUBLE: { - templateInitializeVector(result.get(), typeInfo, capacity); + case LogicalTypeID::DOUBLE: { + templateInitializeVector(result.get(), typeInfo, capacity); } break; - case DATE: { - templateInitializeVector(result.get(), typeInfo, capacity); + case LogicalTypeID::DATE: { + templateInitializeVector(result.get(), typeInfo, capacity); } break; - case TIMESTAMP: { - templateInitializeVector(result.get(), typeInfo, capacity); + case LogicalTypeID::TIMESTAMP: { + templateInitializeVector(result.get(), typeInfo, capacity); } break; - case INTERVAL: { - templateInitializeVector(result.get(), typeInfo, capacity); + case LogicalTypeID::INTERVAL: { + templateInitializeVector(result.get(), typeInfo, capacity); } break; - case STRING: { - templateInitializeVector(result.get(), typeInfo, capacity); + case LogicalTypeID::STRING: { + templateInitializeVector(result.get(), typeInfo, capacity); } break; - case VAR_LIST: { - templateInitializeVector(result.get(), typeInfo, capacity); + case LogicalTypeID::VAR_LIST: { + templateInitializeVector(result.get(), typeInfo, capacity); } break; - case INTERNAL_ID: { - templateInitializeVector(result.get(), typeInfo, capacity); + case LogicalTypeID::INTERNAL_ID: { + templateInitializeVector(result.get(), typeInfo, capacity); } break; - case NODE: { - templateInitializeVector(result.get(), typeInfo, capacity); + case LogicalTypeID::NODE: { + templateInitializeVector(result.get(), typeInfo, capacity); } break; - case REL: { - templateInitializeVector(result.get(), typeInfo, capacity); + case LogicalTypeID::REL: { + templateInitializeVector(result.get(), typeInfo, capacity); } break; default: { - throw RuntimeException( - "Invalid data type " + Types::dataTypeToString(typeInfo.typeID) + " for arrow export."); + throw RuntimeException("Invalid data type " + + LogicalTypeUtils::dataTypeToString(typeInfo.typeID) + + " for arrow export."); } } return std::move(result); @@ -157,15 +160,15 @@ void ArrowRowBatch::appendValue( vector->numValues++; } -template +template void ArrowRowBatch::templateCopyNonNullValue( ArrowVector* vector, const main::DataTypeInfo& typeInfo, Value* value, std::int64_t pos) { - auto valSize = Types::getDataTypeSize(DT); + auto valSize = storage::StorageUtils::getDataTypeSize(LogicalType{DT}); std::memcpy(vector->data.data() + pos * valSize, &value->val, valSize); } template<> -void ArrowRowBatch::templateCopyNonNullValue( +void ArrowRowBatch::templateCopyNonNullValue( ArrowVector* vector, const main::DataTypeInfo& typeInfo, Value* value, std::int64_t pos) { if (value->val.booleanVal) { setBitToOne(vector->data.data(), pos); @@ -175,7 +178,7 @@ void ArrowRowBatch::templateCopyNonNullValue( } template<> -void ArrowRowBatch::templateCopyNonNullValue( +void ArrowRowBatch::templateCopyNonNullValue( ArrowVector* vector, const main::DataTypeInfo& typeInfo, Value* value, std::int64_t pos) { auto offsets = (std::uint32_t*)vector->data.data(); auto strLength = value->strVal.length(); @@ -185,7 +188,7 @@ void ArrowRowBatch::templateCopyNonNullValue( } template<> -void ArrowRowBatch::templateCopyNonNullValue( +void ArrowRowBatch::templateCopyNonNullValue( ArrowVector* vector, const main::DataTypeInfo& typeInfo, Value* value, std::int64_t pos) { vector->data.resize((pos + 2) * sizeof(std::uint32_t)); auto offsets = (std::uint32_t*)vector->data.data(); @@ -199,9 +202,10 @@ void ArrowRowBatch::templateCopyNonNullValue( for (auto i = currentNumBytesForChildValidity; i < numBytesForChildValidity; i++) { vector->childData[0]->validity.data()[i] = 0xFF; // Init each value to be valid (as 1). } - if (typeInfo.childrenTypesInfo[0]->typeID != VAR_LIST) { + if (typeInfo.childrenTypesInfo[0]->typeID != LogicalTypeID::VAR_LIST) { vector->childData[0]->data.resize( - numChildElements * Types::getDataTypeSize(typeInfo.childrenTypesInfo[0]->typeID)); + numChildElements * storage::StorageUtils::getDataTypeSize( + LogicalType{typeInfo.childrenTypesInfo[0]->typeID})); } for (auto i = 0u; i < numElements; i++) { appendValue(vector->childData[0].get(), *typeInfo.childrenTypesInfo[0], @@ -210,7 +214,7 @@ void ArrowRowBatch::templateCopyNonNullValue( } template<> -void ArrowRowBatch::templateCopyNonNullValue( +void ArrowRowBatch::templateCopyNonNullValue( ArrowVector* vector, const main::DataTypeInfo& typeInfo, Value* value, std::int64_t pos) { auto nodeID = value->getValue(); Value offsetVal((std::int64_t)nodeID.offset); @@ -220,7 +224,7 @@ void ArrowRowBatch::templateCopyNonNullValue( } template<> -void ArrowRowBatch::templateCopyNonNullValue( +void ArrowRowBatch::templateCopyNonNullValue( ArrowVector* vector, const main::DataTypeInfo& typeInfo, Value* value, std::int64_t pos) { appendValue( vector->childData[0].get(), *typeInfo.childrenTypesInfo[0], value->nodeVal->getNodeIDVal()); @@ -235,7 +239,7 @@ void ArrowRowBatch::templateCopyNonNullValue( } template<> -void ArrowRowBatch::templateCopyNonNullValue( +void ArrowRowBatch::templateCopyNonNullValue( ArrowVector* vector, const main::DataTypeInfo& typeInfo, Value* value, std::int64_t pos) { appendValue(vector->childData[0].get(), *typeInfo.childrenTypesInfo[0], value->relVal->getSrcNodeIDVal()); @@ -252,53 +256,54 @@ void ArrowRowBatch::templateCopyNonNullValue( void ArrowRowBatch::copyNonNullValue( ArrowVector* vector, const main::DataTypeInfo& typeInfo, Value* value, std::int64_t pos) { switch (typeInfo.typeID) { - case BOOL: { - templateCopyNonNullValue(vector, typeInfo, value, pos); + case LogicalTypeID::BOOL: { + templateCopyNonNullValue(vector, typeInfo, value, pos); } break; - case INT64: { - templateCopyNonNullValue(vector, typeInfo, value, pos); + case LogicalTypeID::INT64: { + templateCopyNonNullValue(vector, typeInfo, value, pos); } break; - case INT32: { - templateCopyNonNullValue(vector, typeInfo, value, pos); + case LogicalTypeID::INT32: { + templateCopyNonNullValue(vector, typeInfo, value, pos); } break; - case INT16: { - templateCopyNonNullValue(vector, typeInfo, value, pos); + case LogicalTypeID::INT16: { + templateCopyNonNullValue(vector, typeInfo, value, pos); } break; - case DOUBLE: { - templateCopyNonNullValue(vector, typeInfo, value, pos); + case LogicalTypeID::DOUBLE: { + templateCopyNonNullValue(vector, typeInfo, value, pos); } break; - case DATE: { - templateCopyNonNullValue(vector, typeInfo, value, pos); + case LogicalTypeID::DATE: { + templateCopyNonNullValue(vector, typeInfo, value, pos); } break; - case TIMESTAMP: { - templateCopyNonNullValue(vector, typeInfo, value, pos); + case LogicalTypeID::TIMESTAMP: { + templateCopyNonNullValue(vector, typeInfo, value, pos); } break; - case INTERVAL: { - templateCopyNonNullValue(vector, typeInfo, value, pos); + case LogicalTypeID::INTERVAL: { + templateCopyNonNullValue(vector, typeInfo, value, pos); } break; - case STRING: { - templateCopyNonNullValue(vector, typeInfo, value, pos); + case LogicalTypeID::STRING: { + templateCopyNonNullValue(vector, typeInfo, value, pos); } break; - case VAR_LIST: { - templateCopyNonNullValue(vector, typeInfo, value, pos); + case LogicalTypeID::VAR_LIST: { + templateCopyNonNullValue(vector, typeInfo, value, pos); } break; - case INTERNAL_ID: { - templateCopyNonNullValue(vector, typeInfo, value, pos); + case LogicalTypeID::INTERNAL_ID: { + templateCopyNonNullValue(vector, typeInfo, value, pos); } break; - case NODE: { - templateCopyNonNullValue(vector, typeInfo, value, pos); + case LogicalTypeID::NODE: { + templateCopyNonNullValue(vector, typeInfo, value, pos); } break; - case REL: { - templateCopyNonNullValue(vector, typeInfo, value, pos); + case LogicalTypeID::REL: { + templateCopyNonNullValue(vector, typeInfo, value, pos); } break; default: { - throw RuntimeException( - "Invalid data type " + Types::dataTypeToString(value->dataType) + " for arrow export."); + throw RuntimeException("Invalid data type " + + LogicalTypeUtils::dataTypeToString(value->dataType) + + " for arrow export."); } } } -template +template void ArrowRowBatch::templateCopyNullValue(ArrowVector* vector, std::int64_t pos) { // TODO(Guodong): make this as a function. setBitToZero(vector->validity.data(), pos); @@ -306,7 +311,8 @@ void ArrowRowBatch::templateCopyNullValue(ArrowVector* vector, std::int64_t pos) } template<> -void ArrowRowBatch::templateCopyNullValue(ArrowVector* vector, std::int64_t pos) { +void ArrowRowBatch::templateCopyNullValue( + ArrowVector* vector, std::int64_t pos) { auto offsets = (std::uint32_t*)vector->data.data(); offsets[pos + 1] = offsets[pos]; setBitToZero(vector->validity.data(), pos); @@ -314,7 +320,8 @@ void ArrowRowBatch::templateCopyNullValue(ArrowVector* vector, std::int6 } template<> -void ArrowRowBatch::templateCopyNullValue(ArrowVector* vector, std::int64_t pos) { +void ArrowRowBatch::templateCopyNullValue( + ArrowVector* vector, std::int64_t pos) { auto offsets = (std::uint32_t*)vector->data.data(); offsets[pos + 1] = offsets[pos]; setBitToZero(vector->validity.data(), pos); @@ -322,49 +329,50 @@ void ArrowRowBatch::templateCopyNullValue(ArrowVector* vector, std::in } void ArrowRowBatch::copyNullValue(ArrowVector* vector, Value* value, std::int64_t pos) { - switch (value->dataType.typeID) { - case BOOL: { - templateCopyNullValue(vector, pos); + switch (value->dataType.getLogicalTypeID()) { + case LogicalTypeID::BOOL: { + templateCopyNullValue(vector, pos); } break; - case INT64: { - templateCopyNullValue(vector, pos); + case LogicalTypeID::INT64: { + templateCopyNullValue(vector, pos); } break; - case INT32: { - templateCopyNullValue(vector, pos); + case LogicalTypeID::INT32: { + templateCopyNullValue(vector, pos); } break; - case INT16: { - templateCopyNullValue(vector, pos); + case LogicalTypeID::INT16: { + templateCopyNullValue(vector, pos); } break; - case DOUBLE: { - templateCopyNullValue(vector, pos); + case LogicalTypeID::DOUBLE: { + templateCopyNullValue(vector, pos); } break; - case DATE: { - templateCopyNullValue(vector, pos); + case LogicalTypeID::DATE: { + templateCopyNullValue(vector, pos); } break; - case TIMESTAMP: { - templateCopyNullValue(vector, pos); + case LogicalTypeID::TIMESTAMP: { + templateCopyNullValue(vector, pos); } break; - case INTERVAL: { - templateCopyNullValue(vector, pos); + case LogicalTypeID::INTERVAL: { + templateCopyNullValue(vector, pos); } break; - case STRING: { - templateCopyNullValue(vector, pos); + case LogicalTypeID::STRING: { + templateCopyNullValue(vector, pos); } break; - case VAR_LIST: { - templateCopyNullValue(vector, pos); + case LogicalTypeID::VAR_LIST: { + templateCopyNullValue(vector, pos); } break; - case INTERNAL_ID: { - templateCopyNullValue(vector, pos); + case LogicalTypeID::INTERNAL_ID: { + templateCopyNullValue(vector, pos); } break; - case NODE: { - templateCopyNullValue(vector, pos); + case LogicalTypeID::NODE: { + templateCopyNullValue(vector, pos); } break; - case REL: { - templateCopyNullValue(vector, pos); + case LogicalTypeID::REL: { + templateCopyNullValue(vector, pos); } break; default: { - throw RuntimeException( - "Invalid data type " + Types::dataTypeToString(value->dataType) + " for arrow export."); + throw RuntimeException("Invalid data type " + + LogicalTypeUtils::dataTypeToString(value->dataType) + + " for arrow export."); } } } @@ -394,7 +402,7 @@ static std::unique_ptr createArrayFromVector(ArrowVector& vector) { return std::move(result); } -template +template ArrowArray* ArrowRowBatch::templateCreateArray( ArrowVector& vector, const main::DataTypeInfo& typeInfo) { auto result = createArrayFromVector(vector); @@ -403,7 +411,7 @@ ArrowArray* ArrowRowBatch::templateCreateArray( } template<> -ArrowArray* ArrowRowBatch::templateCreateArray( +ArrowArray* ArrowRowBatch::templateCreateArray( ArrowVector& vector, const main::DataTypeInfo& typeInfo) { auto result = createArrayFromVector(vector); result->n_buffers = 3; @@ -413,7 +421,7 @@ ArrowArray* ArrowRowBatch::templateCreateArray( } template<> -ArrowArray* ArrowRowBatch::templateCreateArray( +ArrowArray* ArrowRowBatch::templateCreateArray( ArrowVector& vector, const main::DataTypeInfo& typeInfo) { auto result = createArrayFromVector(vector); vector.childPointers.resize(1); @@ -441,19 +449,19 @@ ArrowArray* ArrowRowBatch::convertStructVectorToArray( } template<> -ArrowArray* ArrowRowBatch::templateCreateArray( +ArrowArray* ArrowRowBatch::templateCreateArray( ArrowVector& vector, const main::DataTypeInfo& typeInfo) { return convertStructVectorToArray(vector, typeInfo); } template<> -ArrowArray* ArrowRowBatch::templateCreateArray( +ArrowArray* ArrowRowBatch::templateCreateArray( ArrowVector& vector, const main::DataTypeInfo& typeInfo) { return convertStructVectorToArray(vector, typeInfo); } template<> -ArrowArray* ArrowRowBatch::templateCreateArray( +ArrowArray* ArrowRowBatch::templateCreateArray( ArrowVector& vector, const main::DataTypeInfo& typeInfo) { return convertStructVectorToArray(vector, typeInfo); } @@ -461,48 +469,49 @@ ArrowArray* ArrowRowBatch::templateCreateArray( ArrowArray* ArrowRowBatch::convertVectorToArray( ArrowVector& vector, const main::DataTypeInfo& typeInfo) { switch (typeInfo.typeID) { - case BOOL: { - return templateCreateArray(vector, typeInfo); + case LogicalTypeID::BOOL: { + return templateCreateArray(vector, typeInfo); } - case INT64: { - return templateCreateArray(vector, typeInfo); + case LogicalTypeID::INT64: { + return templateCreateArray(vector, typeInfo); } - case INT32: { - return templateCreateArray(vector, typeInfo); + case LogicalTypeID::INT32: { + return templateCreateArray(vector, typeInfo); } - case INT16: { - return templateCreateArray(vector, typeInfo); + case LogicalTypeID::INT16: { + return templateCreateArray(vector, typeInfo); } - case DOUBLE: { - return templateCreateArray(vector, typeInfo); + case LogicalTypeID::DOUBLE: { + return templateCreateArray(vector, typeInfo); } - case DATE: { - return templateCreateArray(vector, typeInfo); + case LogicalTypeID::DATE: { + return templateCreateArray(vector, typeInfo); } - case TIMESTAMP: { - return templateCreateArray(vector, typeInfo); + case LogicalTypeID::TIMESTAMP: { + return templateCreateArray(vector, typeInfo); } - case INTERVAL: { - return templateCreateArray(vector, typeInfo); + case LogicalTypeID::INTERVAL: { + return templateCreateArray(vector, typeInfo); } - case STRING: { - return templateCreateArray(vector, typeInfo); + case LogicalTypeID::STRING: { + return templateCreateArray(vector, typeInfo); } - case VAR_LIST: { - return templateCreateArray(vector, typeInfo); + case LogicalTypeID::VAR_LIST: { + return templateCreateArray(vector, typeInfo); } - case INTERNAL_ID: { - return templateCreateArray(vector, typeInfo); + case LogicalTypeID::INTERNAL_ID: { + return templateCreateArray(vector, typeInfo); } - case NODE: { - return templateCreateArray(vector, typeInfo); + case LogicalTypeID::NODE: { + return templateCreateArray(vector, typeInfo); } - case REL: { - return templateCreateArray(vector, typeInfo); + case LogicalTypeID::REL: { + return templateCreateArray(vector, typeInfo); } default: { - throw RuntimeException( - "Invalid data type " + Types::dataTypeToString(typeInfo.typeID) + " for arrow export."); + throw RuntimeException("Invalid data type " + + LogicalTypeUtils::dataTypeToString(typeInfo.typeID) + + " for arrow export."); } } } diff --git a/src/common/file_utils.cpp b/src/common/file_utils.cpp index 4574fd06dd..89c9ca2545 100644 --- a/src/common/file_utils.cpp +++ b/src/common/file_utils.cpp @@ -2,7 +2,6 @@ #include "common/exception.h" #include "common/string_utils.h" -#include "common/utils.h" namespace kuzu { namespace common { @@ -61,6 +60,7 @@ void FileUtils::overwriteFile(const std::string& from, const std::string& to) { void FileUtils::readFromFile( FileInfo* fileInfo, void* buffer, uint64_t numBytes, uint64_t position) { auto numBytesRead = pread(fileInfo->fd, buffer, numBytes, position); + if (numBytesRead != numBytes && getFileSize(fileInfo->fd) != position + numBytesRead) { throw Exception( StringUtils::string_format("Cannot read from file: {} fileDescriptor: {} " diff --git a/src/common/in_mem_overflow_buffer_utils.cpp b/src/common/in_mem_overflow_buffer_utils.cpp index 680f939b3e..9fcd8bef57 100644 --- a/src/common/in_mem_overflow_buffer_utils.cpp +++ b/src/common/in_mem_overflow_buffer_utils.cpp @@ -15,38 +15,5 @@ void InMemOverflowBufferUtils::copyString( dest.set(src); } -void InMemOverflowBufferUtils::copyListRecursiveIfNested(const ku_list_t& src, ku_list_t& dst, - const DataType& dataType, InMemOverflowBuffer& inMemOverflowBuffer, uint32_t srcStartIdx, - uint32_t srcEndIdx) { - if (src.size == 0) { - dst.size = 0; - return; - } - if (srcEndIdx == UINT32_MAX) { - srcEndIdx = src.size - 1; - } - assert(srcEndIdx < src.size); - auto numElements = srcEndIdx - srcStartIdx + 1; - auto elementSize = Types::getDataTypeSize(*dataType.getChildType()); - InMemOverflowBufferUtils::allocateSpaceForList( - dst, numElements * elementSize, inMemOverflowBuffer); - memcpy((uint8_t*)dst.overflowPtr, (uint8_t*)src.overflowPtr + srcStartIdx * elementSize, - numElements * elementSize); - dst.size = numElements; - if (dataType.getChildType()->typeID == STRING) { - for (auto i = 0u; i < dst.size; i++) { - InMemOverflowBufferUtils::copyString(((ku_string_t*)src.overflowPtr)[i + srcStartIdx], - ((ku_string_t*)dst.overflowPtr)[i], inMemOverflowBuffer); - } - } - if (dataType.getChildType()->typeID == VAR_LIST) { - for (auto i = 0u; i < dst.size; i++) { - InMemOverflowBufferUtils::copyListRecursiveIfNested( - ((ku_list_t*)src.overflowPtr)[i + srcStartIdx], ((ku_list_t*)dst.overflowPtr)[i], - *dataType.getChildType(), inMemOverflowBuffer); - } - } -} - } // namespace common } // namespace kuzu diff --git a/src/common/string_utils.cpp b/src/common/string_utils.cpp index 505178d8ac..e5eae61e8f 100644 --- a/src/common/string_utils.cpp +++ b/src/common/string_utils.cpp @@ -6,12 +6,15 @@ namespace kuzu { namespace common { std::vector StringUtils::split( - const std::string& input, const std::string& delimiter) { + const std::string& input, const std::string& delimiter, bool ignoreEmptyStringParts) { auto result = std::vector(); auto prevPos = 0u; auto currentPos = input.find(delimiter, prevPos); while (currentPos != std::string::npos) { - result.push_back(input.substr(prevPos, currentPos - prevPos)); + auto stringPart = input.substr(prevPos, currentPos - prevPos); + if (!ignoreEmptyStringParts || !stringPart.empty()) { + result.push_back(input.substr(prevPos, currentPos - prevPos)); + } prevPos = currentPos + 1; currentPos = input.find(delimiter, prevPos); } diff --git a/src/common/type_utils.cpp b/src/common/type_utils.cpp index d45e279260..ee14ddeba5 100644 --- a/src/common/type_utils.cpp +++ b/src/common/type_utils.cpp @@ -27,40 +27,40 @@ bool TypeUtils::convertToBoolean(const char* data) { return false; } throw ConversionException( - prefixConversionExceptionMessage(data, BOOL) + + prefixConversionExceptionMessage(data, LogicalTypeID::BOOL) + ". Input is not equal to True or False (in a case-insensitive manner)"); } std::string TypeUtils::listValueToString( - const DataType& dataType, uint8_t* listValues, uint64_t pos) { - switch (dataType.typeID) { - case BOOL: + const LogicalType& dataType, uint8_t* listValues, uint64_t pos) { + switch (dataType.getLogicalTypeID()) { + case LogicalTypeID::BOOL: return TypeUtils::toString(((bool*)listValues)[pos]); - case INT64: + case LogicalTypeID::INT64: return TypeUtils::toString(((int64_t*)listValues)[pos]); - case DOUBLE: + case LogicalTypeID::DOUBLE: return TypeUtils::toString(((double_t*)listValues)[pos]); - case DATE: + case LogicalTypeID::DATE: return TypeUtils::toString(((date_t*)listValues)[pos]); - case TIMESTAMP: + case LogicalTypeID::TIMESTAMP: return TypeUtils::toString(((timestamp_t*)listValues)[pos]); - case INTERVAL: + case LogicalTypeID::INTERVAL: return TypeUtils::toString(((interval_t*)listValues)[pos]); - case STRING: + case LogicalTypeID::STRING: return TypeUtils::toString(((ku_string_t*)listValues)[pos]); - case VAR_LIST: + case LogicalTypeID::VAR_LIST: return TypeUtils::toString(((ku_list_t*)listValues)[pos], dataType); default: - throw RuntimeException("Invalid data type " + Types::dataTypeToString(dataType) + + throw RuntimeException("Invalid data type " + LogicalTypeUtils::dataTypeToString(dataType) + " for TypeUtils::listValueToString."); } } -std::string TypeUtils::toString(const ku_list_t& val, const DataType& dataType) { +std::string TypeUtils::toString(const ku_list_t& val, const LogicalType& dataType) { std::string result = "["; for (auto i = 0u; i < val.size; ++i) { result += listValueToString( - *dataType.getChildType(), reinterpret_cast(val.overflowPtr), i); + *VarListType::getChildType(&dataType), reinterpret_cast(val.overflowPtr), i); result += (i == val.size - 1 ? "]" : ","); } return result; @@ -70,25 +70,26 @@ std::string TypeUtils::toString(const list_entry_t& val, void* valVector) { auto listVector = (common::ValueVector*)valVector; std::string result = "["; auto values = ListVector::getListValues(listVector, val); + auto childType = VarListType::getChildType(&listVector->dataType); for (auto i = 0u; i < val.size - 1; ++i) { - result += (listVector->dataType.getChildType()->typeID == VAR_LIST ? + result += (childType->getLogicalTypeID() == LogicalTypeID::VAR_LIST ? toString(reinterpret_cast(values)[i], ListVector::getDataVector(listVector)) : - listValueToString(*listVector->dataType.getChildType(), values, i)) + + listValueToString(*childType, values, i)) + ","; } - result += - (listVector->dataType.getChildType()->typeID == VAR_LIST ? - toString(reinterpret_cast(values)[val.size - 1], - ListVector::getDataVector(listVector)) : - listValueToString(*listVector->dataType.getChildType(), values, val.size - 1)) + - "]"; + result += (childType->getLogicalTypeID() == LogicalTypeID::VAR_LIST ? + toString(reinterpret_cast(values)[val.size - 1], + ListVector::getDataVector(listVector)) : + listValueToString(*childType, values, val.size - 1)) + + "]"; return result; } -std::string TypeUtils::prefixConversionExceptionMessage(const char* data, DataTypeID dataTypeID) { +std::string TypeUtils::prefixConversionExceptionMessage( + const char* data, LogicalTypeID dataTypeID) { return "Cannot convert string " + std::string(data) + " to " + - Types::dataTypeToString(dataTypeID) + "."; + LogicalTypeUtils::dataTypeToString(dataTypeID) + "."; } template<> @@ -101,8 +102,8 @@ bool TypeUtils::isValueEqual( } auto leftValues = ListVector::getListValues(leftVector, leftEntry); auto rightValues = ListVector::getListValues(rightVector, rightEntry); - switch (leftVector->dataType.getChildType()->typeID) { - case BOOL: { + switch (VarListType::getChildType(&leftVector->dataType)->getLogicalTypeID()) { + case LogicalTypeID::BOOL: { for (auto i = 0u; i < leftEntry.size; i++) { if (!isValueEqual(reinterpret_cast(leftValues)[i], reinterpret_cast(rightValues)[i], left, right)) { @@ -110,7 +111,7 @@ bool TypeUtils::isValueEqual( } } } break; - case INT64: { + case LogicalTypeID::INT64: { for (auto i = 0u; i < leftEntry.size; i++) { if (!isValueEqual(reinterpret_cast(leftValues)[i], reinterpret_cast(rightValues)[i], left, right)) { @@ -118,7 +119,23 @@ bool TypeUtils::isValueEqual( } } } break; - case DOUBLE: { + case LogicalTypeID::INT32: { + for (auto i = 0u; i < leftEntry.size; i++) { + if (!isValueEqual(reinterpret_cast(leftValues)[i], + reinterpret_cast(rightValues)[i], left, right)) { + return false; + } + } + } break; + case LogicalTypeID::INT16: { + for (auto i = 0u; i < leftEntry.size; i++) { + if (!isValueEqual(reinterpret_cast(leftValues)[i], + reinterpret_cast(rightValues)[i], left, right)) { + return false; + } + } + } break; + case LogicalTypeID::DOUBLE: { for (auto i = 0u; i < leftEntry.size; i++) { if (!isValueEqual(reinterpret_cast(leftValues)[i], reinterpret_cast(rightValues)[i], left, right)) { @@ -126,7 +143,15 @@ bool TypeUtils::isValueEqual( } } } break; - case STRING: { + case LogicalTypeID::FLOAT: { + for (auto i = 0u; i < leftEntry.size; i++) { + if (!isValueEqual(reinterpret_cast(leftValues)[i], + reinterpret_cast(rightValues)[i], left, right)) { + return false; + } + } + } break; + case LogicalTypeID::STRING: { for (auto i = 0u; i < leftEntry.size; i++) { if (!isValueEqual(reinterpret_cast(leftValues)[i], reinterpret_cast(rightValues)[i], left, right)) { @@ -134,7 +159,7 @@ bool TypeUtils::isValueEqual( } } } break; - case DATE: { + case LogicalTypeID::DATE: { for (auto i = 0u; i < leftEntry.size; i++) { if (!isValueEqual(reinterpret_cast(leftValues)[i], reinterpret_cast(rightValues)[i], left, right)) { @@ -142,7 +167,7 @@ bool TypeUtils::isValueEqual( } } } break; - case TIMESTAMP: { + case LogicalTypeID::TIMESTAMP: { for (auto i = 0u; i < leftEntry.size; i++) { if (!isValueEqual(reinterpret_cast(leftValues)[i], reinterpret_cast(rightValues)[i], left, right)) { @@ -150,7 +175,7 @@ bool TypeUtils::isValueEqual( } } } break; - case INTERVAL: { + case LogicalTypeID::INTERVAL: { for (auto i = 0u; i < leftEntry.size; i++) { if (!isValueEqual(reinterpret_cast(leftValues)[i], reinterpret_cast(rightValues)[i], left, right)) { @@ -158,7 +183,7 @@ bool TypeUtils::isValueEqual( } } } break; - case VAR_LIST: { + case LogicalTypeID::VAR_LIST: { for (auto i = 0u; i < leftEntry.size; i++) { if (!isValueEqual(reinterpret_cast(leftValues)[i], reinterpret_cast(rightValues)[i], @@ -170,7 +195,7 @@ bool TypeUtils::isValueEqual( } break; default: { throw RuntimeException("Unsupported data type " + - Types::dataTypeToString(leftVector->dataType) + + LogicalTypeUtils::dataTypeToString(leftVector->dataType) + " for TypeUtils::isValueEqual."); } } diff --git a/src/common/types/ku_list.cpp b/src/common/types/ku_list.cpp index 3152d3fb88..696d2c04fa 100644 --- a/src/common/types/ku_list.cpp +++ b/src/common/types/ku_list.cpp @@ -1,18 +1,18 @@ #include "common/types/ku_list.h" -#include +#include "storage/storage_utils.h" namespace kuzu { namespace common { -void ku_list_t::set(const uint8_t* values, const DataType& dataType) const { +void ku_list_t::set(const uint8_t* values, const LogicalType& dataType) const { memcpy(reinterpret_cast(overflowPtr), values, - size * Types::getDataTypeSize(*dataType.getChildType())); + size * storage::StorageUtils::getDataTypeSize(*VarListType::getChildType(&dataType))); } -void ku_list_t::set(const std::vector& parameters, DataTypeID childTypeId) { +void ku_list_t::set(const std::vector& parameters, LogicalTypeID childTypeId) { this->size = parameters.size(); - auto numBytesOfListElement = Types::getDataTypeSize(childTypeId); + auto numBytesOfListElement = storage::StorageUtils::getDataTypeSize(LogicalType{childTypeId}); for (auto i = 0u; i < parameters.size(); i++) { memcpy(reinterpret_cast(this->overflowPtr) + (i * numBytesOfListElement), parameters[i], numBytesOfListElement); diff --git a/src/common/types/types.cpp b/src/common/types/types.cpp index a18f45357e..4de507dc49 100644 --- a/src/common/types/types.cpp +++ b/src/common/types/types.cpp @@ -1,116 +1,43 @@ #include "common/types/types.h" -#include #include #include "common/exception.h" -#include "common/null_buffer.h" #include "common/ser_deser.h" #include "common/types/types_include.h" namespace kuzu { namespace common { -template<> -uint64_t SerDeser::serializeValue( - const VarListTypeInfo& value, FileInfo* fileInfo, uint64_t offset) { - return SerDeser::serializeValue(*value.getChildType(), fileInfo, offset); -} - -template<> -uint64_t SerDeser::deserializeValue(VarListTypeInfo& value, FileInfo* fileInfo, uint64_t offset) { - value.childType = std::make_unique(); - offset = SerDeser::deserializeValue(*value.getChildType(), fileInfo, offset); - return offset; -} - -template<> -uint64_t SerDeser::serializeValue( - const FixedListTypeInfo& value, FileInfo* fileInfo, uint64_t offset) { - offset = SerDeser::serializeValue(*value.getChildType(), fileInfo, offset); - return SerDeser::serializeValue(value.getFixedNumElementsInList(), fileInfo, offset); -} - -template<> -uint64_t SerDeser::deserializeValue(FixedListTypeInfo& value, FileInfo* fileInfo, uint64_t offset) { - value.childType = std::make_unique(); - offset = SerDeser::deserializeValue(*value.getChildType(), fileInfo, offset); - offset = SerDeser::deserializeValue(value.fixedNumElementsInList, fileInfo, offset); - return offset; -} - -template<> -uint64_t SerDeser::serializeValue( - const StructTypeInfo& value, FileInfo* fileInfo, uint64_t offset) { - return serializeVector(value.fields, fileInfo, offset); -} - -template<> -uint64_t SerDeser::deserializeValue(StructTypeInfo& value, FileInfo* fileInfo, uint64_t offset) { - return deserializeVector(value.fields, fileInfo, offset); -} - -template<> -uint64_t SerDeser::serializeValue( - const std::unique_ptr& value, FileInfo* fileInfo, uint64_t offset) { - offset = serializeValue(value->name, fileInfo, offset); - return serializeValue(*value->getType(), fileInfo, offset); -} - -template<> -uint64_t SerDeser::deserializeValue( - std::unique_ptr& value, FileInfo* fileInfo, uint64_t offset) { - value = std::make_unique(); - offset = deserializeValue(value->name, fileInfo, offset); - return deserializeValue(*value->type, fileInfo, offset); -} - -template<> -uint64_t SerDeser::serializeValue(const DataType& value, FileInfo* fileInfo, uint64_t offset) { - offset = SerDeser::serializeValue(value.typeID, fileInfo, offset); - switch (value.getTypeID()) { - case VAR_LIST: { - auto varListTypeInfo = reinterpret_cast(value.getExtraTypeInfo()); - offset = serializeValue(*varListTypeInfo, fileInfo, offset); - } break; - case FIXED_LIST: { - auto fixedListTypeInfo = reinterpret_cast(value.getExtraTypeInfo()); - offset = serializeValue(*fixedListTypeInfo, fileInfo, offset); - } break; - case STRUCT: { - auto structTypeInfo = reinterpret_cast(value.getExtraTypeInfo()); - offset = serializeValue(*structTypeInfo, fileInfo, offset); - } break; - default: - break; - } - return offset; -} - -template<> -uint64_t SerDeser::deserializeValue(DataType& value, FileInfo* fileInfo, uint64_t offset) { - offset = SerDeser::deserializeValue(value.typeID, fileInfo, offset); - switch (value.getTypeID()) { - case VAR_LIST: { - value.extraTypeInfo = std::make_unique(); - offset = deserializeValue( - *reinterpret_cast(value.getExtraTypeInfo()), fileInfo, offset); - - } break; - case FIXED_LIST: { - value.extraTypeInfo = std::make_unique(); - offset = deserializeValue( - *reinterpret_cast(value.getExtraTypeInfo()), fileInfo, offset); - } break; - case STRUCT: { - value.extraTypeInfo = std::make_unique(); - offset = deserializeValue( - *reinterpret_cast(value.getExtraTypeInfo()), fileInfo, offset); - } break; +std::string PhysicalTypeUtils::physicalTypeToString(PhysicalTypeID physicalType) { + switch (physicalType) { + case PhysicalTypeID::BOOL: + return "BOOL"; + case PhysicalTypeID::INT64: + return "INT64"; + case PhysicalTypeID::INT32: + return "INT32"; + case PhysicalTypeID::INT16: + return "INT16"; + case PhysicalTypeID::DOUBLE: + return "DOUBLE"; + case PhysicalTypeID::FLOAT: + return "FLOAT"; + case PhysicalTypeID::INTERVAL: + return "INTERVAL"; + case PhysicalTypeID::FIXED_LIST: + return "FIXED_LIST"; + case PhysicalTypeID::INTERNAL_ID: + return "INTERNAL_ID"; + case PhysicalTypeID::STRING: + return "STRING"; + case PhysicalTypeID::STRUCT: + return "STRUCT"; + case PhysicalTypeID::VAR_LIST: + return "VAR_LIST"; default: - break; + throw common::NotImplementedException{"Unrecognized physicalType."}; } - return offset; } bool VarListTypeInfo::operator==(const kuzu::common::VarListTypeInfo& other) const { @@ -152,8 +79,8 @@ struct_field_idx_t StructTypeInfo::getStructFieldIdx(std::string fieldName) cons return INVALID_STRUCT_FIELD_IDX; } -std::vector StructTypeInfo::getChildrenTypes() const { - std::vector childrenTypesToReturn{fields.size()}; +std::vector StructTypeInfo::getChildrenTypes() const { + std::vector childrenTypesToReturn{fields.size()}; for (auto i = 0u; i < fields.size(); i++) { childrenTypesToReturn[i] = fields[i]->getType(); } @@ -196,51 +123,59 @@ std::unique_ptr StructTypeInfo::copy() const { return std::make_unique(std::move(structFields)); } -DataType::DataType(const DataType& other) { +LogicalType::LogicalType(const LogicalType& other) { typeID = other.typeID; + physicalType = other.physicalType; if (other.extraTypeInfo != nullptr) { - extraTypeInfo = other.getExtraTypeInfo()->copy(); + extraTypeInfo = other.extraTypeInfo->copy(); } } -DataType::DataType(DataType&& other) noexcept - : typeID{other.typeID}, extraTypeInfo{std::move(other.extraTypeInfo)} {} +LogicalType::LogicalType(LogicalType&& other) noexcept + : typeID{other.typeID}, physicalType{other.physicalType}, extraTypeInfo{ + std::move(other.extraTypeInfo)} {} -std::vector DataType::getNumericalTypeIDs() { - return std::vector{INT64, INT32, INT16, DOUBLE, FLOAT}; +std::vector LogicalType::getNumericalLogicalTypeIDs() { + return std::vector{LogicalTypeID::INT64, LogicalTypeID::INT32, + LogicalTypeID::INT16, LogicalTypeID::DOUBLE, LogicalTypeID::FLOAT}; } -std::vector DataType::getAllValidComparableTypes() { - return std::vector{ - BOOL, INT64, INT32, INT16, DOUBLE, FLOAT, DATE, TIMESTAMP, INTERVAL, STRING}; +std::vector LogicalType::getAllValidComparableLogicalTypes() { + return std::vector{LogicalTypeID::BOOL, LogicalTypeID::INT64, + LogicalTypeID::INT32, LogicalTypeID::INT16, LogicalTypeID::DOUBLE, LogicalTypeID::FLOAT, + LogicalTypeID::DATE, LogicalTypeID::TIMESTAMP, LogicalTypeID::INTERVAL, + LogicalTypeID::STRING}; } -std::vector DataType::getAllValidTypeIDs() { +std::vector LogicalType::getAllValidLogicTypeIDs() { // TODO(Ziyi): Add FIX_LIST type to allValidTypeID when we support functions on VAR_LIST. - return std::vector{INTERNAL_ID, BOOL, INT64, INT32, INT16, DOUBLE, STRING, DATE, - TIMESTAMP, INTERVAL, VAR_LIST, FLOAT}; + return std::vector{LogicalTypeID::INTERNAL_ID, LogicalTypeID::BOOL, + LogicalTypeID::INT64, LogicalTypeID::INT32, LogicalTypeID::INT16, LogicalTypeID::DOUBLE, + LogicalTypeID::STRING, LogicalTypeID::DATE, LogicalTypeID::TIMESTAMP, + LogicalTypeID::INTERVAL, LogicalTypeID::VAR_LIST, LogicalTypeID::FLOAT}; } -DataType& DataType::operator=(const DataType& other) { +LogicalType& LogicalType::operator=(const LogicalType& other) { typeID = other.typeID; + physicalType = other.physicalType; if (other.extraTypeInfo != nullptr) { extraTypeInfo = other.extraTypeInfo->copy(); } return *this; } -bool DataType::operator==(const DataType& other) const { +bool LogicalType::operator==(const LogicalType& other) const { if (typeID != other.typeID) { return false; } switch (other.typeID) { - case VAR_LIST: + case LogicalTypeID::VAR_LIST: return *reinterpret_cast(extraTypeInfo.get()) == *reinterpret_cast(other.extraTypeInfo.get()); - case FIXED_LIST: + case LogicalTypeID::FIXED_LIST: return *reinterpret_cast(extraTypeInfo.get()) == *reinterpret_cast(other.extraTypeInfo.get()); - case STRUCT: + case LogicalTypeID::STRUCT: return *reinterpret_cast(extraTypeInfo.get()) == *reinterpret_cast(other.extraTypeInfo.get()); default: @@ -248,48 +183,36 @@ bool DataType::operator==(const DataType& other) const { } } -bool DataType::operator!=(const DataType& other) const { +bool LogicalType::operator!=(const LogicalType& other) const { return !((*this) == other); } -DataType& DataType::operator=(DataType&& other) noexcept { +LogicalType& LogicalType::operator=(LogicalType&& other) noexcept { typeID = other.typeID; + physicalType = other.physicalType; extraTypeInfo = std::move(other.extraTypeInfo); return *this; } -std::unique_ptr DataType::copy() { - auto dataType = std::make_unique(typeID); +std::unique_ptr LogicalType::copy() { + auto dataType = std::make_unique(typeID); if (extraTypeInfo != nullptr) { dataType->extraTypeInfo = extraTypeInfo->copy(); } return dataType; } -ExtraTypeInfo* DataType::getExtraTypeInfo() const { - return extraTypeInfo.get(); -} - -DataTypeID DataType::getTypeID() const { - return typeID; -} - -DataType* DataType::getChildType() const { - assert(typeID == FIXED_LIST || typeID == VAR_LIST); - return reinterpret_cast(extraTypeInfo.get())->getChildType(); -} - -DataType Types::dataTypeFromString(const std::string& dataTypeString) { - DataType dataType; +LogicalType LogicalTypeUtils::dataTypeFromString(const std::string& dataTypeString) { + LogicalType dataType; if (dataTypeString.ends_with("[]")) { - dataType.typeID = VAR_LIST; - dataType.extraTypeInfo = std::make_unique(std::make_unique( + dataType.typeID = LogicalTypeID::VAR_LIST; + dataType.extraTypeInfo = std::make_unique(std::make_unique( dataTypeFromString(dataTypeString.substr(0, dataTypeString.size() - 2)))); } else if (dataTypeString.ends_with("]")) { - dataType.typeID = FIXED_LIST; + dataType.typeID = LogicalTypeID::FIXED_LIST; auto leftBracketPos = dataTypeString.find('['); auto rightBracketPos = dataTypeString.find(']'); - auto childType = std::make_unique( + auto childType = std::make_unique( dataTypeFromString(dataTypeString.substr(0, leftBracketPos))); auto fixedNumElementsInList = std::strtoll( dataTypeString.substr(leftBracketPos + 1, rightBracketPos - leftBracketPos - 1).c_str(), @@ -297,76 +220,77 @@ DataType Types::dataTypeFromString(const std::string& dataTypeString) { dataType.extraTypeInfo = std::make_unique(std::move(childType), fixedNumElementsInList); } else if (dataTypeString.starts_with("STRUCT")) { - dataType.typeID = STRUCT; + dataType.typeID = LogicalTypeID::STRUCT; auto leftBracketPos = dataTypeString.find('('); auto rightBracketPos = dataTypeString.find(')'); if (leftBracketPos == std::string::npos || rightBracketPos == std::string::npos) { throw Exception("Cannot parse struct type: " + dataTypeString); } - std::istringstream iss{ - dataTypeString.substr(leftBracketPos + 1, rightBracketPos - leftBracketPos - 1)}; - std::vector> childrenTypes; - std::string field, fieldName, fieldType; - while (getline(iss, field, ',')) { - std::istringstream fieldStream{field}; - // The first word is the field name. - fieldStream >> fieldName; - // The second word is the field type. - fieldStream >> fieldType; - childrenTypes.emplace_back(std::make_unique( - fieldName, std::make_unique(dataTypeFromString(fieldType)))); + // Remove the leading and trailing brackets. + auto structTypeStr = + dataTypeString.substr(leftBracketPos + 1, rightBracketPos - leftBracketPos - 1); + auto structFieldsStr = common::StringUtils::split(structTypeStr, ","); + std::vector> structFields; + for (auto& structField : structFieldsStr) { + auto structFieldParts = common::StringUtils::split(structField, " "); + if (structFieldParts.size() != 2) { + throw RuntimeException("Cannot parse struct type: " + dataTypeString); + } + structFields.emplace_back(std::make_unique(structFieldParts[0], + std::make_unique(dataTypeFromString(structFieldParts[1])))); } - dataType.extraTypeInfo = std::make_unique(std::move(childrenTypes)); + dataType.extraTypeInfo = std::make_unique(std::move(structFields)); } else { dataType.typeID = dataTypeIDFromString(dataTypeString); } + dataType.setPhysicalType(); return dataType; } -DataTypeID Types::dataTypeIDFromString(const std::string& dataTypeIDString) { +LogicalTypeID LogicalTypeUtils::dataTypeIDFromString(const std::string& dataTypeIDString) { if ("INTERNAL_ID" == dataTypeIDString) { - return INTERNAL_ID; + return LogicalTypeID::INTERNAL_ID; } else if ("INT64" == dataTypeIDString) { - return INT64; + return LogicalTypeID::INT64; } else if ("INT32" == dataTypeIDString) { - return INT32; + return LogicalTypeID::INT32; } else if ("INT16" == dataTypeIDString) { - return INT16; + return LogicalTypeID::INT16; } else if ("INT" == dataTypeIDString) { - return INT32; + return LogicalTypeID::INT32; } else if ("DOUBLE" == dataTypeIDString) { - return DOUBLE; + return LogicalTypeID::DOUBLE; } else if ("FLOAT" == dataTypeIDString) { - return FLOAT; + return LogicalTypeID::FLOAT; } else if ("BOOLEAN" == dataTypeIDString) { - return BOOL; + return LogicalTypeID::BOOL; } else if ("STRING" == dataTypeIDString) { - return STRING; + return LogicalTypeID::STRING; } else if ("DATE" == dataTypeIDString) { - return DATE; + return LogicalTypeID::DATE; } else if ("TIMESTAMP" == dataTypeIDString) { - return TIMESTAMP; + return LogicalTypeID::TIMESTAMP; } else if ("INTERVAL" == dataTypeIDString) { - return INTERVAL; + return LogicalTypeID::INTERVAL; } else if ("SERIAL" == dataTypeIDString) { - return SERIAL; + return LogicalTypeID::SERIAL; } else { - throw InternalException("Cannot parse dataTypeID: " + dataTypeIDString); + throw NotImplementedException("Cannot parse dataTypeID: " + dataTypeIDString); } } -std::string Types::dataTypeToString(const DataType& dataType) { +std::string LogicalTypeUtils::dataTypeToString(const LogicalType& dataType) { switch (dataType.typeID) { - case VAR_LIST: { + case LogicalTypeID::VAR_LIST: { auto varListTypeInfo = reinterpret_cast(dataType.extraTypeInfo.get()); return dataTypeToString(*varListTypeInfo->getChildType()) + "[]"; } - case FIXED_LIST: { + case LogicalTypeID::FIXED_LIST: { auto fixedListTypeInfo = reinterpret_cast(dataType.extraTypeInfo.get()); return dataTypeToString(*fixedListTypeInfo->getChildType()) + "[" + - std::to_string(fixedListTypeInfo->getFixedNumElementsInList()) + "]"; + std::to_string(fixedListTypeInfo->getNumElementsInList()) + "]"; } - case STRUCT: { + case LogicalTypeID::STRUCT: { auto structTypeInfo = reinterpret_cast(dataType.extraTypeInfo.get()); std::string dataTypeStr = dataTypeToString(dataType.typeID) + "("; auto numFields = structTypeInfo->getChildrenTypes().size(); @@ -377,173 +301,176 @@ std::string Types::dataTypeToString(const DataType& dataType) { dataTypeStr += dataTypeToString(*structTypeInfo->getChildrenTypes()[numFields - 1]); return dataTypeStr + ")"; } - case ANY: - case NODE: - case REL: - case INTERNAL_ID: - case BOOL: - case INT64: - case INT32: - case INT16: - case DOUBLE: - case FLOAT: - case DATE: - case TIMESTAMP: - case INTERVAL: - case STRING: + case LogicalTypeID::ANY: + case LogicalTypeID::NODE: + case LogicalTypeID::REL: + case LogicalTypeID::INTERNAL_ID: + case LogicalTypeID::BOOL: + case LogicalTypeID::INT64: + case LogicalTypeID::INT32: + case LogicalTypeID::INT16: + case LogicalTypeID::DOUBLE: + case LogicalTypeID::FLOAT: + case LogicalTypeID::DATE: + case LogicalTypeID::TIMESTAMP: + case LogicalTypeID::INTERVAL: + case LogicalTypeID::STRING: return dataTypeToString(dataType.typeID); default: - throw InternalException("Unsupported DataType: " + Types::dataTypeToString(dataType) + "."); + throw NotImplementedException( + "Unsupported DataType: " + LogicalTypeUtils::dataTypeToString(dataType) + "."); } } -std::string Types::dataTypeToString(DataTypeID dataTypeID) { +std::string LogicalTypeUtils::dataTypeToString(LogicalTypeID dataTypeID) { switch (dataTypeID) { - case ANY: + case LogicalTypeID::ANY: return "ANY"; - case NODE: + case LogicalTypeID::NODE: return "NODE"; - case REL: + case LogicalTypeID::REL: return "REL"; - case INTERNAL_ID: + case LogicalTypeID::INTERNAL_ID: return "INTERNAL_ID"; - case BOOL: + case LogicalTypeID::BOOL: return "BOOL"; - case INT64: + case LogicalTypeID::INT64: return "INT64"; - case INT32: + case LogicalTypeID::INT32: return "INT32"; - case INT16: + case LogicalTypeID::INT16: return "INT16"; - case DOUBLE: + case LogicalTypeID::DOUBLE: return "DOUBLE"; - case FLOAT: + case LogicalTypeID::FLOAT: return "FLOAT"; - case DATE: + case LogicalTypeID::DATE: return "DATE"; - case TIMESTAMP: + case LogicalTypeID::TIMESTAMP: return "TIMESTAMP"; - case INTERVAL: + case LogicalTypeID::INTERVAL: return "INTERVAL"; - case STRING: + case LogicalTypeID::STRING: return "STRING"; - case VAR_LIST: + case LogicalTypeID::VAR_LIST: return "VAR_LIST"; - case FIXED_LIST: + case LogicalTypeID::FIXED_LIST: return "FIXED_LIST"; - case STRUCT: + case LogicalTypeID::STRUCT: return "STRUCT"; - case SERIAL: + case LogicalTypeID::SERIAL: return "SERIAL"; default: - throw InternalException( - "Unsupported DataType: " + Types::dataTypeToString(dataTypeID) + "."); + throw NotImplementedException( + "Unsupported DataType: " + LogicalTypeUtils::dataTypeToString(dataTypeID) + "."); } } -std::string Types::dataTypesToString(const std::vector& dataTypes) { - std::vector dataTypeIDs; +std::string LogicalTypeUtils::dataTypesToString(const std::vector& dataTypes) { + std::vector dataTypeIDs; for (auto& dataType : dataTypes) { dataTypeIDs.push_back(dataType.typeID); } return dataTypesToString(dataTypeIDs); } -std::string Types::dataTypesToString(const std::vector& dataTypeIDs) { +std::string LogicalTypeUtils::dataTypesToString(const std::vector& dataTypeIDs) { if (dataTypeIDs.empty()) { return {""}; } - std::string result = "(" + Types::dataTypeToString(dataTypeIDs[0]); + std::string result = "(" + LogicalTypeUtils::dataTypeToString(dataTypeIDs[0]); for (auto i = 1u; i < dataTypeIDs.size(); ++i) { - result += "," + Types::dataTypeToString(dataTypeIDs[i]); + result += "," + LogicalTypeUtils::dataTypeToString(dataTypeIDs[i]); } result += ")"; return result; } -uint32_t Types::getDataTypeSize(DataTypeID dataTypeID) { - switch (dataTypeID) { - case INTERNAL_ID: - return sizeof(internalID_t); - case BOOL: - return sizeof(uint8_t); - case SERIAL: - case INT64: +uint32_t LogicalTypeUtils::getFixedTypeSize(kuzu::common::PhysicalTypeID physicalType) { + switch (physicalType) { + case PhysicalTypeID::BOOL: + return sizeof(bool); + case PhysicalTypeID::INT64: return sizeof(int64_t); - case INT32: + case PhysicalTypeID::INT32: return sizeof(int32_t); - case INT16: + case PhysicalTypeID::INT16: return sizeof(int16_t); - case DOUBLE: + case PhysicalTypeID::DOUBLE: return sizeof(double_t); - case FLOAT: + case PhysicalTypeID::FLOAT: return sizeof(float_t); - case DATE: - return sizeof(date_t); - case TIMESTAMP: - return sizeof(timestamp_t); - case INTERVAL: + case PhysicalTypeID::INTERVAL: return sizeof(interval_t); - case STRING: - return sizeof(ku_string_t); - case VAR_LIST: - return sizeof(ku_list_t); + case PhysicalTypeID::INTERNAL_ID: + return sizeof(internalID_t); default: - throw InternalException( - "Cannot infer the size of dataTypeID: " + dataTypeToString(dataTypeID) + "."); + throw NotImplementedException{"Cannot infer the size of a variable dataType."}; } } -// This function returns the size of the dataType when stored in a row layout. (e.g. -// factorizedTable). -uint32_t Types::getDataTypeSize(const DataType& dataType) { +bool LogicalTypeUtils::isNumerical(const kuzu::common::LogicalType& dataType) { switch (dataType.typeID) { - case FIXED_LIST: { - auto fixedListTypeInfo = reinterpret_cast(dataType.extraTypeInfo.get()); - return getDataTypeSize(*fixedListTypeInfo->getChildType()) * - fixedListTypeInfo->getFixedNumElementsInList(); - } - case STRUCT: { - auto structTypeInfo = reinterpret_cast(dataType.extraTypeInfo.get()); - uint32_t size = 0; - for (auto& childType : structTypeInfo->getChildrenTypes()) { - size += getDataTypeSize(*childType); - } - size += NullBuffer::getNumBytesForNullValues(structTypeInfo->getChildrenNames().size()); - return size; - } - case INTERNAL_ID: - case BOOL: - case SERIAL: - case INT64: - case INT32: - case INT16: - case DOUBLE: - case FLOAT: - case DATE: - case TIMESTAMP: - case INTERVAL: - case STRING: - case VAR_LIST: { - return getDataTypeSize(dataType.typeID); - } - default: { - throw InternalException( - "Cannot infer the size of dataTypeID: " + dataTypeToString(dataType.typeID) + "."); - } + case LogicalTypeID::INT64: + case LogicalTypeID::INT32: + case LogicalTypeID::INT16: + case LogicalTypeID::DOUBLE: + case LogicalTypeID::FLOAT: + return true; + default: + return false; } } -bool Types::isNumerical(const kuzu::common::DataType& dataType) { - switch (dataType.typeID) { - case INT64: - case INT32: - case INT16: - case DOUBLE: - case FLOAT: - return true; +void LogicalType::setPhysicalType() { + switch (typeID) { + case LogicalTypeID::ANY: { + physicalType = PhysicalTypeID::ANY; + } break; + case LogicalTypeID::BOOL: { + physicalType = PhysicalTypeID::BOOL; + } break; + case LogicalTypeID::TIMESTAMP: + case LogicalTypeID::SERIAL: + case LogicalTypeID::INT64: { + physicalType = PhysicalTypeID::INT64; + } break; + case LogicalTypeID::DATE: + case LogicalTypeID::INT32: { + physicalType = PhysicalTypeID::INT32; + } break; + case LogicalTypeID::INT16: { + physicalType = PhysicalTypeID::INT16; + } break; + case LogicalTypeID::DOUBLE: { + physicalType = PhysicalTypeID::DOUBLE; + } break; + case LogicalTypeID::FLOAT: { + physicalType = PhysicalTypeID::FLOAT; + } break; + case LogicalTypeID::INTERVAL: { + physicalType = PhysicalTypeID::INTERVAL; + } break; + case LogicalTypeID::FIXED_LIST: { + physicalType = PhysicalTypeID::FIXED_LIST; + } break; + case LogicalTypeID::INTERNAL_ID: { + physicalType = PhysicalTypeID::INTERNAL_ID; + } break; + case LogicalTypeID::STRING: { + physicalType = PhysicalTypeID::STRING; + } break; + case LogicalTypeID::VAR_LIST: { + physicalType = PhysicalTypeID::VAR_LIST; + } break; + case LogicalTypeID::NODE: + case LogicalTypeID::REL: + case LogicalTypeID::STRUCT: { + physicalType = PhysicalTypeID::STRUCT; + } break; default: - return false; + throw NotImplementedException{ + "Unsupported LogicalType: " + LogicalTypeUtils::dataTypeToString(typeID) + "."}; } } @@ -555,5 +482,109 @@ std::string getRelDataDirectionAsString(RelDataDirection direction) { return (FWD == direction) ? "forward" : "backward"; } +// Specialized Ser/Deser functions for logical dataTypes. +template<> +uint64_t SerDeser::serializeValue( + const VarListTypeInfo& value, FileInfo* fileInfo, uint64_t offset) { + return SerDeser::serializeValue(*value.getChildType(), fileInfo, offset); +} + +template<> +uint64_t SerDeser::deserializeValue(VarListTypeInfo& value, FileInfo* fileInfo, uint64_t offset) { + value.childType = std::make_unique(); + offset = SerDeser::deserializeValue(*value.getChildType(), fileInfo, offset); + return offset; +} + +template<> +uint64_t SerDeser::serializeValue( + const FixedListTypeInfo& value, FileInfo* fileInfo, uint64_t offset) { + offset = SerDeser::serializeValue(*value.getChildType(), fileInfo, offset); + return SerDeser::serializeValue(value.getNumElementsInList(), fileInfo, offset); +} + +template<> +uint64_t SerDeser::deserializeValue(FixedListTypeInfo& value, FileInfo* fileInfo, uint64_t offset) { + value.childType = std::make_unique(); + offset = SerDeser::deserializeValue(*value.getChildType(), fileInfo, offset); + offset = SerDeser::deserializeValue(value.fixedNumElementsInList, fileInfo, offset); + return offset; +} + +template<> +uint64_t SerDeser::serializeValue( + const StructTypeInfo& value, FileInfo* fileInfo, uint64_t offset) { + return serializeVector(value.fields, fileInfo, offset); +} + +template<> +uint64_t SerDeser::deserializeValue(StructTypeInfo& value, FileInfo* fileInfo, uint64_t offset) { + return deserializeVector(value.fields, fileInfo, offset); +} + +template<> +uint64_t SerDeser::serializeValue( + const std::unique_ptr& value, FileInfo* fileInfo, uint64_t offset) { + offset = serializeValue(value->name, fileInfo, offset); + return serializeValue(*value->getType(), fileInfo, offset); +} + +template<> +uint64_t SerDeser::deserializeValue( + std::unique_ptr& value, FileInfo* fileInfo, uint64_t offset) { + value = std::make_unique(); + offset = deserializeValue(value->name, fileInfo, offset); + return deserializeValue(*value->type, fileInfo, offset); +} + +template<> +uint64_t SerDeser::serializeValue(const LogicalType& value, FileInfo* fileInfo, uint64_t offset) { + offset = SerDeser::serializeValue(value.getLogicalTypeID(), fileInfo, offset); + switch (value.getLogicalTypeID()) { + case LogicalTypeID::VAR_LIST: { + auto varListTypeInfo = reinterpret_cast(value.extraTypeInfo.get()); + offset = serializeValue(*varListTypeInfo, fileInfo, offset); + } break; + case LogicalTypeID::FIXED_LIST: { + auto fixedListTypeInfo = reinterpret_cast(value.extraTypeInfo.get()); + offset = serializeValue(*fixedListTypeInfo, fileInfo, offset); + } break; + case LogicalTypeID::STRUCT: { + auto structTypeInfo = reinterpret_cast(value.extraTypeInfo.get()); + offset = serializeValue(*structTypeInfo, fileInfo, offset); + } break; + default: + break; + } + return offset; +} + +template<> +uint64_t SerDeser::deserializeValue(LogicalType& value, FileInfo* fileInfo, uint64_t offset) { + offset = SerDeser::deserializeValue(value.typeID, fileInfo, offset); + value.setPhysicalType(); + switch (value.getLogicalTypeID()) { + case LogicalTypeID::VAR_LIST: { + value.extraTypeInfo = std::make_unique(); + offset = deserializeValue( + *reinterpret_cast(value.extraTypeInfo.get()), fileInfo, offset); + + } break; + case LogicalTypeID::FIXED_LIST: { + value.extraTypeInfo = std::make_unique(); + offset = deserializeValue( + *reinterpret_cast(value.extraTypeInfo.get()), fileInfo, offset); + } break; + case LogicalTypeID::STRUCT: { + value.extraTypeInfo = std::make_unique(); + offset = deserializeValue( + *reinterpret_cast(value.extraTypeInfo.get()), fileInfo, offset); + } break; + default: + break; + } + return offset; +} + } // namespace common } // namespace kuzu diff --git a/src/common/types/value.cpp b/src/common/types/value.cpp index 44446b8104..e7df9b42a0 100644 --- a/src/common/types/value.cpp +++ b/src/common/types/value.cpp @@ -2,16 +2,17 @@ #include "common/null_buffer.h" #include "common/string_utils.h" +#include "storage/storage_utils.h" namespace kuzu { namespace common { -void Value::setDataType(const DataType& dataType_) { - assert(dataType.typeID == ANY); +void Value::setDataType(const LogicalType& dataType_) { + assert(dataType.getLogicalTypeID() == LogicalTypeID::ANY); dataType = dataType_; } -DataType Value::getDataType() const { +LogicalType Value::getDataType() const { return dataType; } @@ -35,106 +36,107 @@ Value Value::createNullValue() { return {}; } -Value Value::createNullValue(DataType dataType) { +Value Value::createNullValue(LogicalType dataType) { return Value(std::move(dataType)); } -Value Value::createDefaultValue(const DataType& dataType) { - switch (dataType.typeID) { - case INT64: +Value Value::createDefaultValue(const LogicalType& dataType) { + switch (dataType.getLogicalTypeID()) { + case LogicalTypeID::INT64: return Value((int64_t)0); - case INT32: + case LogicalTypeID::INT32: return Value((int32_t)0); - case INT16: + case LogicalTypeID::INT16: return Value((int16_t)0); - case BOOL: + case LogicalTypeID::BOOL: return Value(true); - case DOUBLE: + case LogicalTypeID::DOUBLE: return Value((double_t)0); - case DATE: + case LogicalTypeID::DATE: return Value(date_t()); - case TIMESTAMP: + case LogicalTypeID::TIMESTAMP: return Value(timestamp_t()); - case INTERVAL: + case LogicalTypeID::INTERVAL: return Value(interval_t()); - case INTERNAL_ID: + case LogicalTypeID::INTERNAL_ID: return Value(nodeID_t()); - case STRING: + case LogicalTypeID::STRING: return Value(std::string("")); - case FLOAT: + case LogicalTypeID::FLOAT: return Value((float_t)0); - case VAR_LIST: - case FIXED_LIST: - case STRUCT: + case LogicalTypeID::VAR_LIST: + case LogicalTypeID::FIXED_LIST: + case LogicalTypeID::STRUCT: return Value(dataType, std::vector>{}); default: - throw RuntimeException("Data type " + Types::dataTypeToString(dataType) + + throw RuntimeException("Data type " + LogicalTypeUtils::dataTypeToString(dataType) + " is not supported for Value::createDefaultValue"); } } -Value::Value(bool val_) : dataType{BOOL}, isNull_{false} { +Value::Value(bool val_) : dataType{LogicalTypeID::BOOL}, isNull_{false} { val.booleanVal = val_; } -Value::Value(int16_t val_) : dataType{INT16}, isNull_{false} { +Value::Value(int16_t val_) : dataType{LogicalTypeID::INT16}, isNull_{false} { val.int16Val = val_; } -Value::Value(int32_t val_) : dataType{INT32}, isNull_{false} { +Value::Value(int32_t val_) : dataType{LogicalTypeID::INT32}, isNull_{false} { val.int32Val = val_; } -Value::Value(int64_t val_) : dataType{INT64}, isNull_{false} { +Value::Value(int64_t val_) : dataType{LogicalTypeID::INT64}, isNull_{false} { val.int64Val = val_; } -Value::Value(float_t val_) : dataType{FLOAT}, isNull_{false} { +Value::Value(float_t val_) : dataType{LogicalTypeID::FLOAT}, isNull_{false} { val.floatVal = val_; } -Value::Value(double val_) : dataType{DOUBLE}, isNull_{false} { +Value::Value(double val_) : dataType{LogicalTypeID::DOUBLE}, isNull_{false} { val.doubleVal = val_; } -Value::Value(date_t val_) : dataType{DATE}, isNull_{false} { +Value::Value(date_t val_) : dataType{LogicalTypeID::DATE}, isNull_{false} { val.dateVal = val_; } -Value::Value(kuzu::common::timestamp_t val_) : dataType{TIMESTAMP}, isNull_{false} { +Value::Value(kuzu::common::timestamp_t val_) : dataType{LogicalTypeID::TIMESTAMP}, isNull_{false} { val.timestampVal = val_; } -Value::Value(kuzu::common::interval_t val_) : dataType{INTERVAL}, isNull_{false} { +Value::Value(kuzu::common::interval_t val_) : dataType{LogicalTypeID::INTERVAL}, isNull_{false} { val.intervalVal = val_; } -Value::Value(kuzu::common::internalID_t val_) : dataType{INTERNAL_ID}, isNull_{false} { +Value::Value(kuzu::common::internalID_t val_) + : dataType{LogicalTypeID::INTERNAL_ID}, isNull_{false} { val.internalIDVal = val_; } -Value::Value(const char* val_) : dataType{STRING}, isNull_{false} { +Value::Value(const char* val_) : dataType{LogicalTypeID::STRING}, isNull_{false} { strVal = std::string(val_); } -Value::Value(const std::string& val_) : dataType{STRING}, isNull_{false} { +Value::Value(const std::string& val_) : dataType{LogicalTypeID::STRING}, isNull_{false} { strVal = val_; } -Value::Value(DataType dataType, std::vector> vals) +Value::Value(LogicalType dataType, std::vector> vals) : dataType{std::move(dataType)}, isNull_{false} { nestedTypeVal = std::move(vals); } -Value::Value(std::unique_ptr val_) : dataType{NODE}, isNull_{false} { +Value::Value(std::unique_ptr val_) : dataType{LogicalTypeID::NODE}, isNull_{false} { nodeVal = std::move(val_); } -Value::Value(std::unique_ptr val_) : dataType{REL}, isNull_{false} { +Value::Value(std::unique_ptr val_) : dataType{LogicalTypeID::REL}, isNull_{false} { relVal = std::move(val_); } -Value::Value(DataType dataType, const uint8_t* val_) +Value::Value(LogicalType dataType, const uint8_t* val_) : dataType{std::move(dataType)}, isNull_{false} { copyValueFrom(val_); } @@ -144,52 +146,52 @@ Value::Value(const Value& other) : dataType{other.dataType}, isNull_{other.isNul } void Value::copyValueFrom(const uint8_t* value) { - switch (dataType.typeID) { - case INT64: { + switch (dataType.getLogicalTypeID()) { + case LogicalTypeID::INT64: { val.int64Val = *((int64_t*)value); } break; - case INT32: { + case LogicalTypeID::INT32: { val.int32Val = *((int32_t*)value); } break; - case INT16: { + case LogicalTypeID::INT16: { val.int16Val = *((int16_t*)value); } break; - case BOOL: { + case LogicalTypeID::BOOL: { val.booleanVal = *((bool*)value); } break; - case DOUBLE: { + case LogicalTypeID::DOUBLE: { val.doubleVal = *((double*)value); } break; - case FLOAT: { + case LogicalTypeID::FLOAT: { val.floatVal = *((float_t*)value); } break; - case DATE: { + case LogicalTypeID::DATE: { val.dateVal = *((date_t*)value); } break; - case TIMESTAMP: { + case LogicalTypeID::TIMESTAMP: { val.timestampVal = *((timestamp_t*)value); } break; - case INTERVAL: { + case LogicalTypeID::INTERVAL: { val.intervalVal = *((interval_t*)value); } break; - case INTERNAL_ID: { + case LogicalTypeID::INTERNAL_ID: { val.internalIDVal = *((nodeID_t*)value); } break; - case STRING: { + case LogicalTypeID::STRING: { strVal = ((ku_string_t*)value)->getAsString(); } break; - case VAR_LIST: { + case LogicalTypeID::VAR_LIST: { nestedTypeVal = convertKUVarListToVector(*(ku_list_t*)value); } break; - case FIXED_LIST: { + case LogicalTypeID::FIXED_LIST: { nestedTypeVal = convertKUFixedListToVector(value); } break; - case STRUCT: { + case LogicalTypeID::STRUCT: { nestedTypeVal = convertKUStructToVector(value); } break; default: - throw RuntimeException( - "Data type " + Types::dataTypeToString(dataType) + " is not supported for Value::set"); + throw RuntimeException("Data type " + LogicalTypeUtils::dataTypeToString(dataType) + + " is not supported for Value::set"); } } @@ -200,56 +202,57 @@ void Value::copyValueFrom(const Value& other) { } isNull_ = false; assert(dataType == other.dataType); - switch (dataType.typeID) { - case BOOL: { + switch (dataType.getLogicalTypeID()) { + case LogicalTypeID::BOOL: { val.booleanVal = other.val.booleanVal; } break; - case INT64: { + case LogicalTypeID::INT64: { val.int64Val = other.val.int64Val; } break; - case INT32: { + case LogicalTypeID::INT32: { val.int32Val = other.val.int32Val; } break; - case INT16: { + case LogicalTypeID::INT16: { val.int16Val = other.val.int16Val; } break; - case DOUBLE: { + case LogicalTypeID::DOUBLE: { val.doubleVal = other.val.doubleVal; } break; - case FLOAT: { + case LogicalTypeID::FLOAT: { val.floatVal = other.val.floatVal; } break; - case DATE: { + case LogicalTypeID::DATE: { val.dateVal = other.val.dateVal; } break; - case TIMESTAMP: { + case LogicalTypeID::TIMESTAMP: { val.timestampVal = other.val.timestampVal; } break; - case INTERVAL: { + case LogicalTypeID::INTERVAL: { val.intervalVal = other.val.intervalVal; } break; - case INTERNAL_ID: { + case LogicalTypeID::INTERNAL_ID: { val.internalIDVal = other.val.internalIDVal; } break; - case STRING: { + case LogicalTypeID::STRING: { strVal = other.strVal; } break; - case VAR_LIST: - case FIXED_LIST: - case STRUCT: { + case LogicalTypeID::VAR_LIST: + case LogicalTypeID::FIXED_LIST: + case LogicalTypeID::STRUCT: { for (auto& value : other.nestedTypeVal) { nestedTypeVal.push_back(value->copy()); } } break; - case NODE: { + case LogicalTypeID::NODE: { nodeVal = other.nodeVal->copy(); } break; - case REL: { + case LogicalTypeID::REL: { relVal = other.relVal->copy(); } break; default: throw NotImplementedException("Value::Value(const Value&) for type " + - Types::dataTypeToString(dataType) + " is not implemented."); + LogicalTypeUtils::dataTypeToString(dataType) + + " is not implemented."); } } @@ -261,31 +264,31 @@ std::string Value::toString() const { if (isNull_) { return ""; } - switch (dataType.typeID) { - case BOOL: + switch (dataType.getLogicalTypeID()) { + case LogicalTypeID::BOOL: return TypeUtils::toString(val.booleanVal); - case INT64: + case LogicalTypeID::INT64: return TypeUtils::toString(val.int64Val); - case INT32: + case LogicalTypeID::INT32: return TypeUtils::toString(val.int32Val); - case INT16: + case LogicalTypeID::INT16: return TypeUtils::toString(val.int16Val); - case DOUBLE: + case LogicalTypeID::DOUBLE: return TypeUtils::toString(val.doubleVal); - case FLOAT: + case LogicalTypeID::FLOAT: return TypeUtils::toString(val.floatVal); - case DATE: + case LogicalTypeID::DATE: return TypeUtils::toString(val.dateVal); - case TIMESTAMP: + case LogicalTypeID::TIMESTAMP: return TypeUtils::toString(val.timestampVal); - case INTERVAL: + case LogicalTypeID::INTERVAL: return TypeUtils::toString(val.intervalVal); - case INTERNAL_ID: + case LogicalTypeID::INTERNAL_ID: return TypeUtils::toString(val.internalIDVal); - case STRING: + case LogicalTypeID::STRING: return strVal; - case VAR_LIST: - case FIXED_LIST: { + case LogicalTypeID::VAR_LIST: + case LogicalTypeID::FIXED_LIST: { std::string result = "["; for (auto i = 0u; i < nestedTypeVal.size(); ++i) { result += nestedTypeVal[i]->toString(); @@ -296,10 +299,9 @@ std::string Value::toString() const { result += "]"; return result; } - case STRUCT: { + case LogicalTypeID::STRUCT: { std::string result = "{"; - auto structTypeInfo = reinterpret_cast(dataType.getExtraTypeInfo()); - auto childrenNames = structTypeInfo->getChildrenNames(); + auto childrenNames = common::StructType::getStructFieldNames(&dataType); for (auto i = 0u; i < nestedTypeVal.size(); ++i) { result += childrenNames[i]; result += ": "; @@ -311,29 +313,30 @@ std::string Value::toString() const { result += "}"; return result; } - case NODE: + case LogicalTypeID::NODE: return nodeVal->toString(); - case REL: + case LogicalTypeID::REL: return relVal->toString(); default: throw NotImplementedException("Value::toString for type " + - Types::dataTypeToString(dataType) + " is not implemented."); + LogicalTypeUtils::dataTypeToString(dataType) + + " is not implemented."); } } -Value::Value() : dataType{ANY}, isNull_{true} {} +Value::Value() : dataType{LogicalTypeID::ANY}, isNull_{true} {} -Value::Value(DataType dataType) : dataType{std::move(dataType)}, isNull_{true} {} +Value::Value(LogicalType dataType) : dataType{std::move(dataType)}, isNull_{true} {} std::vector> Value::convertKUVarListToVector(ku_list_t& list) const { std::vector> listResultValue; - auto numBytesPerElement = Types::getDataTypeSize(*dataType.getChildType()); + auto childType = VarListType::getChildType(&dataType); + auto numBytesPerElement = storage::StorageUtils::getDataTypeSize(*childType); auto listNullBytes = reinterpret_cast(list.overflowPtr); auto numBytesForNullValues = NullBuffer::getNumBytesForNullValues(list.size); auto listValues = listNullBytes + numBytesForNullValues; for (auto i = 0; i < list.size; i++) { - auto childValue = - std::make_unique(Value::createDefaultValue(*dataType.getChildType())); + auto childValue = std::make_unique(Value::createDefaultValue(*childType)); if (NullBuffer::isNull(listNullBytes, i)) { childValue->setNull(); } else { @@ -347,24 +350,24 @@ std::vector> Value::convertKUVarListToVector(ku_list_t& l std::vector> Value::convertKUFixedListToVector( const uint8_t* fixedList) const { - auto fixedListTypeInfo = reinterpret_cast(dataType.getExtraTypeInfo()); - std::vector> fixedListResultVal{ - fixedListTypeInfo->getFixedNumElementsInList()}; - auto numBytesPerElement = Types::getDataTypeSize(*dataType.getChildType()); - switch (dataType.getChildType()->typeID) { - case common::DataTypeID::INT64: { + auto numElementsInList = FixedListType::getNumElementsInList(&dataType); + std::vector> fixedListResultVal{numElementsInList}; + auto childType = FixedListType::getChildType(&dataType); + auto numBytesPerElement = storage::StorageUtils::getDataTypeSize(*childType); + switch (childType->getLogicalTypeID()) { + case common::LogicalTypeID::INT64: { putValuesIntoVector(fixedListResultVal, fixedList, numBytesPerElement); } break; - case common::DataTypeID::INT32: { + case common::LogicalTypeID::INT32: { putValuesIntoVector(fixedListResultVal, fixedList, numBytesPerElement); } break; - case common::DataTypeID::INT16: { + case common::LogicalTypeID::INT16: { putValuesIntoVector(fixedListResultVal, fixedList, numBytesPerElement); } break; - case common::DataTypeID::DOUBLE: { + case common::LogicalTypeID::DOUBLE: { putValuesIntoVector(fixedListResultVal, fixedList, numBytesPerElement); } break; - case common::DataTypeID::FLOAT: { + case common::LogicalTypeID::FLOAT: { putValuesIntoVector(fixedListResultVal, fixedList, numBytesPerElement); } break; default: @@ -374,9 +377,8 @@ std::vector> Value::convertKUFixedListToVector( } std::vector> Value::convertKUStructToVector(const uint8_t* kuStruct) const { - auto structTypeInfo = reinterpret_cast(dataType.getExtraTypeInfo()); std::vector> structVal; - auto childrenTypes = structTypeInfo->getChildrenTypes(); + auto childrenTypes = StructType::getStructFieldTypes(&dataType); auto numFields = childrenTypes.size(); auto structNullValues = kuStruct; auto structValues = structNullValues + NullBuffer::getNumBytesForNullValues(numFields); @@ -388,7 +390,7 @@ std::vector> Value::convertKUStructToVector(const uint8_t childValue->copyValueFrom(structValues); } structVal.emplace_back(std::move(childValue)); - structValues += Types::getDataTypeSize(*childrenTypes[i]); + structValues += storage::StorageUtils::getDataTypeSize(*childrenTypes[i]); } return structVal; } diff --git a/src/common/vector/auxiliary_buffer.cpp b/src/common/vector/auxiliary_buffer.cpp index fc39cef395..96c1e27350 100644 --- a/src/common/vector/auxiliary_buffer.cpp +++ b/src/common/vector/auxiliary_buffer.cpp @@ -8,22 +8,22 @@ namespace common { void StringAuxiliaryBuffer::addString( common::ValueVector* vector, uint32_t pos, char* value, uint64_t len) const { - assert(vector->dataType.typeID == STRING); + assert(vector->dataType.getLogicalTypeID() == LogicalTypeID::STRING); auto& entry = ((ku_string_t*)vector->getData())[pos]; InMemOverflowBufferUtils::copyString(value, len, entry, *inMemOverflowBuffer); } StructAuxiliaryBuffer::StructAuxiliaryBuffer( - const DataType& type, storage::MemoryManager* memoryManager) { - auto structTypeInfo = reinterpret_cast(type.getExtraTypeInfo()); - childrenVectors.reserve(structTypeInfo->getChildrenTypes().size()); - for (auto structFieldType : structTypeInfo->getChildrenTypes()) { + const LogicalType& type, storage::MemoryManager* memoryManager) { + auto structFieldTypes = StructType::getStructFieldTypes(&type); + childrenVectors.reserve(structFieldTypes.size()); + for (auto structFieldType : structFieldTypes) { childrenVectors.push_back(std::make_shared(*structFieldType, memoryManager)); } } ListAuxiliaryBuffer::ListAuxiliaryBuffer( - const DataType& dataVectorType, storage::MemoryManager* memoryManager) + const LogicalType& dataVectorType, storage::MemoryManager* memoryManager) : capacity{common::DEFAULT_VECTOR_CAPACITY}, size{0}, dataVector{std::make_unique( dataVectorType, memoryManager)} {} @@ -45,14 +45,15 @@ list_entry_t ListAuxiliaryBuffer::addList(uint64_t listSize) { } std::unique_ptr AuxiliaryBufferFactory::getAuxiliaryBuffer( - DataType& type, storage::MemoryManager* memoryManager) { - switch (type.typeID) { - case STRING: + LogicalType& type, storage::MemoryManager* memoryManager) { + switch (type.getLogicalTypeID()) { + case LogicalTypeID::STRING: return std::make_unique(memoryManager); - case STRUCT: + case LogicalTypeID::STRUCT: return std::make_unique(type, memoryManager); - case VAR_LIST: - return std::make_unique(*type.getChildType(), memoryManager); + case LogicalTypeID::VAR_LIST: + return std::make_unique( + *VarListType::getChildType(&type), memoryManager); default: return nullptr; } diff --git a/src/common/vector/value_vector.cpp b/src/common/vector/value_vector.cpp index 44e804163c..36f342cdb7 100644 --- a/src/common/vector/value_vector.cpp +++ b/src/common/vector/value_vector.cpp @@ -5,9 +5,9 @@ namespace kuzu { namespace common { -ValueVector::ValueVector(DataType dataType, storage::MemoryManager* memoryManager) +ValueVector::ValueVector(LogicalType dataType, storage::MemoryManager* memoryManager) : dataType{std::move(dataType)} { - setNumBytesPerValue(); + numBytesPerValue = getDataTypeSize(this->dataType); initializeValueBuffer(); nullMask = std::make_unique(); auxiliaryBuffer = AuxiliaryBufferFactory::getAuxiliaryBuffer(this->dataType, memoryManager); @@ -15,7 +15,7 @@ ValueVector::ValueVector(DataType dataType, storage::MemoryManager* memoryManage void ValueVector::setState(std::shared_ptr state) { this->state = state; - if (dataType.typeID == STRUCT) { + if (dataType.getLogicalTypeID() == LogicalTypeID::STRUCT) { auto childrenVectors = StructVector::getChildrenVectors(this); for (auto childVector : childrenVectors) { childVector->setState(state); @@ -56,23 +56,30 @@ void ValueVector::setValue(uint32_t pos, std::string val) { StringVector::addString(this, pos, val.data(), val.length()); } -void ValueVector::setNumBytesPerValue() { - switch (dataType.typeID) { - case STRUCT: { - numBytesPerValue = sizeof(struct_entry_t); - } break; - case VAR_LIST: { - numBytesPerValue = sizeof(list_entry_t); - } break; +uint32_t ValueVector::getDataTypeSize(const LogicalType& type) { + switch (type.getLogicalTypeID()) { + case common::LogicalTypeID::STRING: { + return sizeof(common::ku_string_t); + } + case common::LogicalTypeID::FIXED_LIST: { + return getDataTypeSize(*common::FixedListType::getChildType(&type)) * + common::FixedListType::getNumElementsInList(&type); + } + case LogicalTypeID::STRUCT: { + return sizeof(struct_entry_t); + } + case LogicalTypeID::VAR_LIST: { + return sizeof(list_entry_t); + } default: { - numBytesPerValue = Types::getDataTypeSize(dataType); + return LogicalTypeUtils::getFixedTypeSize(type.getPhysicalType()); } } } void ValueVector::initializeValueBuffer() { valueBuffer = std::make_unique(numBytesPerValue * DEFAULT_VECTOR_CAPACITY); - if (dataType.typeID == STRUCT) { + if (dataType.getLogicalTypeID() == LogicalTypeID::STRUCT) { // For struct valueVectors, each struct_entry_t stores its current position in the // valueVector. StructVector::initializeEntries(this); diff --git a/src/common/vector/value_vector_utils.cpp b/src/common/vector/value_vector_utils.cpp index 7aadde0ab1..92bcd7c9cd 100644 --- a/src/common/vector/value_vector_utils.cpp +++ b/src/common/vector/value_vector_utils.cpp @@ -2,14 +2,15 @@ #include "common/in_mem_overflow_buffer_utils.h" #include "common/null_buffer.h" +#include "processor/result/factorized_table.h" using namespace kuzu; using namespace common; void ValueVectorUtils::copyNonNullDataWithSameTypeIntoPos( ValueVector& resultVector, uint64_t pos, const uint8_t* srcData) { - switch (resultVector.dataType.typeID) { - case STRUCT: { + switch (resultVector.dataType.getLogicalTypeID()) { + case LogicalTypeID::STRUCT: { auto structFields = StructVector::getChildrenVectors(&resultVector); auto structNullBytes = srcData; auto structValues = @@ -21,16 +22,18 @@ void ValueVectorUtils::copyNonNullDataWithSameTypeIntoPos( } else { copyNonNullDataWithSameTypeIntoPos(*structField, pos, structValues); } - structValues += Types::getDataTypeSize(structField->dataType); + structValues += processor::FactorizedTable::getDataTypeSize(structField->dataType); } } break; - case VAR_LIST: { + case LogicalTypeID::VAR_LIST: { auto srcKuList = *(ku_list_t*)srcData; auto srcNullBytes = reinterpret_cast(srcKuList.overflowPtr); auto srcListValues = srcNullBytes + NullBuffer::getNumBytesForNullValues(srcKuList.size); auto dstListEntry = ListVector::addList(&resultVector, srcKuList.size); resultVector.setValue(pos, dstListEntry); auto resultDataVector = common::ListVector::getDataVector(&resultVector); + auto numBytesPerValue = + processor::FactorizedTable::getDataTypeSize(resultDataVector->dataType); for (auto i = 0u; i < srcKuList.size; i++) { auto dstListValuePos = dstListEntry.offset + i; if (NullBuffer::isNull(srcNullBytes, i)) { @@ -39,12 +42,13 @@ void ValueVectorUtils::copyNonNullDataWithSameTypeIntoPos( copyNonNullDataWithSameTypeIntoPos( *resultDataVector, dstListValuePos, srcListValues); } - srcListValues += Types::getDataTypeSize(resultDataVector->dataType); + srcListValues += numBytesPerValue; } } break; default: { copyNonNullDataWithSameType(resultVector.dataType, srcData, - resultVector.getData() + pos * Types::getDataTypeSize(resultVector.dataType), + resultVector.getData() + + pos * processor::FactorizedTable::getDataTypeSize(resultVector.dataType), *StringVector::getInMemOverflowBuffer(&resultVector)); } } @@ -52,8 +56,8 @@ void ValueVectorUtils::copyNonNullDataWithSameTypeIntoPos( void ValueVectorUtils::copyNonNullDataWithSameTypeOutFromPos(const ValueVector& srcVector, uint64_t pos, uint8_t* dstData, InMemOverflowBuffer& dstOverflowBuffer) { - switch (srcVector.dataType.typeID) { - case STRUCT: { + switch (srcVector.dataType.getLogicalTypeID()) { + case LogicalTypeID::STRUCT: { // The storage structure of STRUCT type in factorizedTable is: // [NULLBYTES, FIELD1, FIELD2, ...] auto structFields = StructVector::getChildrenVectors(&srcVector); @@ -69,18 +73,20 @@ void ValueVectorUtils::copyNonNullDataWithSameTypeOutFromPos(const ValueVector& copyNonNullDataWithSameTypeOutFromPos( *structField, pos, structValues, dstOverflowBuffer); } - structValues += Types::getDataTypeSize(structField->dataType); + structValues += processor::FactorizedTable::getDataTypeSize(structField->dataType); } } break; - case VAR_LIST: { + case LogicalTypeID::VAR_LIST: { auto srcListEntry = srcVector.getValue(pos); auto srcListDataVector = common::ListVector::getDataVector(&srcVector); ku_list_t dstList; dstList.size = srcListEntry.size; - InMemOverflowBufferUtils::allocateSpaceForList(dstList, - Types::getDataTypeSize(srcListDataVector->dataType) * dstList.size + - NullBuffer::getNumBytesForNullValues(dstList.size), - dstOverflowBuffer); + auto dstListOverflowSize = + processor::FactorizedTable::getDataTypeSize(srcListDataVector->dataType) * + dstList.size + + NullBuffer::getNumBytesForNullValues(dstList.size); + dstList.overflowPtr = + reinterpret_cast(dstOverflowBuffer.allocateSpace(dstListOverflowSize)); auto dstListNullBytes = reinterpret_cast(dstList.overflowPtr); NullBuffer::initNullBytes(dstListNullBytes, dstList.size); auto dstListValues = dstListNullBytes + NullBuffer::getNumBytesForNullValues(dstList.size); @@ -91,22 +97,24 @@ void ValueVectorUtils::copyNonNullDataWithSameTypeOutFromPos(const ValueVector& copyNonNullDataWithSameTypeOutFromPos( *srcListDataVector, srcListEntry.offset + i, dstListValues, dstOverflowBuffer); } - dstListValues += Types::getDataTypeSize(srcListDataVector->dataType); + dstListValues += + processor::FactorizedTable::getDataTypeSize(srcListDataVector->dataType); } memcpy(dstData, &dstList, sizeof(dstList)); } break; default: { copyNonNullDataWithSameType(srcVector.dataType, - srcVector.getData() + pos * Types::getDataTypeSize(srcVector.dataType), dstData, - dstOverflowBuffer); + srcVector.getData() + + pos * processor::FactorizedTable::getDataTypeSize(srcVector.dataType), + dstData, dstOverflowBuffer); } } } void ValueVectorUtils::copyValue(uint8_t* dstValue, common::ValueVector& dstVector, const uint8_t* srcValue, const common::ValueVector& srcVector) { - switch (srcVector.dataType.typeID) { - case VAR_LIST: { + switch (srcVector.dataType.getLogicalTypeID()) { + case LogicalTypeID::VAR_LIST: { auto srcList = reinterpret_cast(srcValue); auto dstList = reinterpret_cast(dstValue); *dstList = ListVector::addList(&dstVector, srcList->size); @@ -125,7 +133,7 @@ void ValueVectorUtils::copyValue(uint8_t* dstValue, common::ValueVector& dstVect dstValues += numBytesPerValue; } } break; - case STRUCT: { + case LogicalTypeID::STRUCT: { auto srcFields = common::StructVector::getChildrenVectors(&srcVector); auto dstFields = common::StructVector::getChildrenVectors(&dstVector); auto srcPos = *(int64_t*)srcValue; @@ -141,7 +149,7 @@ void ValueVectorUtils::copyValue(uint8_t* dstValue, common::ValueVector& dstVect } } } break; - case STRING: { + case LogicalTypeID::STRING: { common::InMemOverflowBufferUtils::copyString(*(common::ku_string_t*)srcValue, *(common::ku_string_t*)dstValue, *StringVector::getInMemOverflowBuffer(&dstVector)); } break; @@ -151,12 +159,12 @@ void ValueVectorUtils::copyValue(uint8_t* dstValue, common::ValueVector& dstVect } } -void ValueVectorUtils::copyNonNullDataWithSameType(const DataType& dataType, const uint8_t* srcData, - uint8_t* dstData, InMemOverflowBuffer& inMemOverflowBuffer) { - if (dataType.typeID == STRING) { +void ValueVectorUtils::copyNonNullDataWithSameType(const LogicalType& dataType, + const uint8_t* srcData, uint8_t* dstData, InMemOverflowBuffer& inMemOverflowBuffer) { + if (dataType.getLogicalTypeID() == LogicalTypeID::STRING) { InMemOverflowBufferUtils::copyString( *(ku_string_t*)srcData, *(ku_string_t*)dstData, inMemOverflowBuffer); } else { - memcpy(dstData, srcData, Types::getDataTypeSize(dataType)); + memcpy(dstData, srcData, processor::FactorizedTable::getDataTypeSize(dataType)); } } diff --git a/src/expression_evaluator/case_evaluator.cpp b/src/expression_evaluator/case_evaluator.cpp index 652c38d925..484c77f77f 100644 --- a/src/expression_evaluator/case_evaluator.cpp +++ b/src/expression_evaluator/case_evaluator.cpp @@ -92,7 +92,7 @@ void CaseExpressionEvaluator::fillEntry(sel_t resultPos, const ValueVector& then if (thenVector.isNull(thenPos)) { resultVector->setNull(resultPos, true); } else { - if (thenVector.dataType.typeID == common::VAR_LIST) { + if (thenVector.dataType.getLogicalTypeID() == common::LogicalTypeID::VAR_LIST) { auto srcListEntry = thenVector.getValue(thenPos); list_entry_t resultEntry = ListVector::addList(resultVector.get(), srcListEntry.size); common::ValueVectorUtils::copyValue(reinterpret_cast(&resultEntry), @@ -106,69 +106,69 @@ void CaseExpressionEvaluator::fillEntry(sel_t resultPos, const ValueVector& then } void CaseExpressionEvaluator::fillAllSwitch(const ValueVector& thenVector) { - auto typeID = resultVector->dataType.typeID; - switch (typeID) { - case BOOL: { + switch (resultVector->dataType.getLogicalTypeID()) { + case LogicalTypeID::BOOL: { fillAll(thenVector); } break; - case INT64: { + case LogicalTypeID::INT64: { fillAll(thenVector); } break; - case DOUBLE: { + case LogicalTypeID::DOUBLE: { fillAll(thenVector); } break; - case DATE: { + case LogicalTypeID::DATE: { fillAll(thenVector); } break; - case TIMESTAMP: { + case LogicalTypeID::TIMESTAMP: { fillAll(thenVector); } break; - case INTERVAL: { + case LogicalTypeID::INTERVAL: { fillAll(thenVector); } break; - case STRING: { + case LogicalTypeID::STRING: { fillAll(thenVector); } break; - case VAR_LIST: { + case LogicalTypeID::VAR_LIST: { fillAll(thenVector); } break; default: - throw NotImplementedException( - "Unimplemented type " + Types::dataTypeToString(typeID) + " for case expression."); + throw NotImplementedException("Unimplemented type " + + LogicalTypeUtils::dataTypeToString(resultVector->dataType) + + " for case expression."); } } void CaseExpressionEvaluator::fillSelectedSwitch( const SelectionVector& selVector, const ValueVector& thenVector) { - auto typeID = resultVector->dataType.typeID; - switch (typeID) { - case BOOL: { + switch (resultVector->dataType.getLogicalTypeID()) { + case LogicalTypeID::BOOL: { fillSelected(selVector, thenVector); } break; - case INT64: { + case LogicalTypeID::INT64: { fillSelected(selVector, thenVector); } break; - case DOUBLE: { + case LogicalTypeID::DOUBLE: { fillSelected(selVector, thenVector); } break; - case DATE: { + case LogicalTypeID::DATE: { fillSelected(selVector, thenVector); } break; - case TIMESTAMP: { + case LogicalTypeID::TIMESTAMP: { fillSelected(selVector, thenVector); } break; - case INTERVAL: { + case LogicalTypeID::INTERVAL: { fillSelected(selVector, thenVector); } break; - case STRING: { + case LogicalTypeID::STRING: { fillSelected(selVector, thenVector); } break; - case VAR_LIST: { + case LogicalTypeID::VAR_LIST: { fillSelected(selVector, thenVector); } break; default: - throw NotImplementedException( - "Unimplemented type " + Types::dataTypeToString(typeID) + " for case expression."); + throw NotImplementedException("Unimplemented type " + + LogicalTypeUtils::dataTypeToString(resultVector->dataType) + + " for case expression."); } } diff --git a/src/expression_evaluator/function_evaluator.cpp b/src/expression_evaluator/function_evaluator.cpp index 1d058bdea5..93f8412710 100644 --- a/src/expression_evaluator/function_evaluator.cpp +++ b/src/expression_evaluator/function_evaluator.cpp @@ -13,7 +13,7 @@ namespace evaluator { void FunctionExpressionEvaluator::init(const ResultSet& resultSet, MemoryManager* memoryManager) { BaseExpressionEvaluator::init(resultSet, memoryManager); execFunc = ((binder::ScalarFunctionExpression&)*expression).execFunc; - if (expression->dataType.typeID == BOOL) { + if (expression->dataType.getLogicalTypeID() == LogicalTypeID::BOOL) { selectFunc = ((binder::ScalarFunctionExpression&)*expression).selectFunc; } for (auto& child : children) { @@ -35,7 +35,7 @@ bool FunctionExpressionEvaluator::select(SelectionVector& selVector) { // Temporary code path for function whose return type is BOOL but select interface is not // implemented (e.g. list_contains). We should remove this if statement eventually. if (selectFunc == nullptr) { - assert(resultVector->dataType.typeID == BOOL); + assert(resultVector->dataType.getLogicalTypeID() == LogicalTypeID::BOOL); execFunc(parameters, *resultVector); auto numSelectedValues = 0u; for (auto i = 0u; i < resultVector->state->selVector->selectedSize; ++i) { diff --git a/src/expression_evaluator/literal_evaluator.cpp b/src/expression_evaluator/literal_evaluator.cpp index f6b9629abb..eb0cc20a28 100644 --- a/src/expression_evaluator/literal_evaluator.cpp +++ b/src/expression_evaluator/literal_evaluator.cpp @@ -10,7 +10,7 @@ namespace kuzu { namespace evaluator { bool LiteralExpressionEvaluator::select(SelectionVector& selVector) { - assert(resultVector->dataType.typeID == BOOL); + assert(resultVector->dataType.getLogicalTypeID() == LogicalTypeID::BOOL); auto pos = resultVector->state->selVector->selectedPositions[0]; assert(pos == 0u); return resultVector->getValue(pos) && (!resultVector->isNull(pos)); @@ -30,40 +30,40 @@ void LiteralExpressionEvaluator::resolveResultVector( void LiteralExpressionEvaluator::copyValueToVector( uint8_t* dstValue, common::ValueVector* dstVector, const common::Value* srcValue) { auto numBytesPerValue = dstVector->getNumBytesPerValue(); - switch (srcValue->getDataType().typeID) { - case common::INT64: { + switch (srcValue->getDataType().getLogicalTypeID()) { + case common::LogicalTypeID::INT64: { memcpy(dstValue, &srcValue->val.int64Val, numBytesPerValue); } break; - case common::INT32: { + case common::LogicalTypeID::INT32: { memcpy(dstValue, &srcValue->val.int32Val, numBytesPerValue); } break; - case common::INT16: { + case common::LogicalTypeID::INT16: { memcpy(dstValue, &srcValue->val.int16Val, numBytesPerValue); } break; - case common::DOUBLE: { + case common::LogicalTypeID::DOUBLE: { memcpy(dstValue, &srcValue->val.doubleVal, numBytesPerValue); } break; - case common::FLOAT: { + case common::LogicalTypeID::FLOAT: { memcpy(dstValue, &srcValue->val.floatVal, numBytesPerValue); } break; - case common::BOOL: { + case common::LogicalTypeID::BOOL: { memcpy(dstValue, &srcValue->val.booleanVal, numBytesPerValue); } break; - case common::DATE: { + case common::LogicalTypeID::DATE: { memcpy(dstValue, &srcValue->val.dateVal, numBytesPerValue); } break; - case common::TIMESTAMP: { + case common::LogicalTypeID::TIMESTAMP: { memcpy(dstValue, &srcValue->val.timestampVal, numBytesPerValue); } break; - case common::INTERVAL: { + case common::LogicalTypeID::INTERVAL: { memcpy(dstValue, &srcValue->val.intervalVal, numBytesPerValue); } break; - case common::STRING: { + case common::LogicalTypeID::STRING: { common::InMemOverflowBufferUtils::copyString(srcValue->strVal.data(), srcValue->strVal.length(), *(common::ku_string_t*)dstValue, *common::StringVector::getInMemOverflowBuffer(dstVector)); } break; - case common::VAR_LIST: { + case common::LogicalTypeID::VAR_LIST: { auto listListEntry = reinterpret_cast(dstValue); auto numValues = srcValue->nestedTypeVal.size(); *listListEntry = common::ListVector::addList(dstVector, numValues); @@ -75,8 +75,9 @@ void LiteralExpressionEvaluator::copyValueToVector( } } break; default: - throw common::NotImplementedException("Unimplemented setLiteral() for type " + - common::Types::dataTypeToString(dstVector->dataType)); + throw common::NotImplementedException( + "Unimplemented setLiteral() for type " + + common::LogicalTypeUtils::dataTypeToString(dstVector->dataType)); } } diff --git a/src/expression_evaluator/reference_evaluator.cpp b/src/expression_evaluator/reference_evaluator.cpp index fe22747526..242b3a16a6 100644 --- a/src/expression_evaluator/reference_evaluator.cpp +++ b/src/expression_evaluator/reference_evaluator.cpp @@ -6,7 +6,7 @@ namespace kuzu { namespace evaluator { inline static bool isTrue(ValueVector& vector, uint64_t pos) { - assert(vector.dataType.typeID == BOOL); + assert(vector.dataType.getLogicalTypeID() == LogicalTypeID::BOOL); return !vector.isNull(pos) && vector.getValue(pos); } diff --git a/src/function/aggregate_function.cpp b/src/function/aggregate_function.cpp index 4ed9700eaa..1289671c86 100644 --- a/src/function/aggregate_function.cpp +++ b/src/function/aggregate_function.cpp @@ -16,86 +16,88 @@ namespace function { std::unique_ptr AggregateFunctionUtil::getCountStarFunction() { return std::make_unique(CountStarFunction::initialize, CountStarFunction::updateAll, CountStarFunction::updatePos, CountStarFunction::combine, - CountStarFunction::finalize, DataType(ANY) /* dummy input data type */); + CountStarFunction::finalize, LogicalType() /* dummy input data type */); } std::unique_ptr AggregateFunctionUtil::getCountFunction( - const DataType& inputType, bool isDistinct) { + const LogicalType& inputType, bool isDistinct) { return std::make_unique(CountFunction::initialize, CountFunction::updateAll, CountFunction::updatePos, CountFunction::combine, CountFunction::finalize, inputType, isDistinct); } std::unique_ptr AggregateFunctionUtil::getAvgFunction( - const DataType& inputType, bool isDistinct) { - switch (inputType.typeID) { - case INT64: + const LogicalType& inputType, bool isDistinct) { + switch (inputType.getLogicalTypeID()) { + case LogicalTypeID::INT64: return std::make_unique(AvgFunction::initialize, AvgFunction::updateAll, AvgFunction::updatePos, AvgFunction::combine, AvgFunction::finalize, inputType, isDistinct); - case INT32: + case LogicalTypeID::INT32: return std::make_unique(AvgFunction::initialize, AvgFunction::updateAll, AvgFunction::updatePos, AvgFunction::combine, AvgFunction::finalize, inputType, isDistinct); - case INT16: + case LogicalTypeID::INT16: return std::make_unique(AvgFunction::initialize, AvgFunction::updateAll, AvgFunction::updatePos, AvgFunction::combine, AvgFunction::finalize, inputType, isDistinct); - case DOUBLE: + case LogicalTypeID::DOUBLE: return std::make_unique(AvgFunction::initialize, AvgFunction::updateAll, AvgFunction::updatePos, AvgFunction::combine, AvgFunction::finalize, inputType, isDistinct); - case FLOAT: + case LogicalTypeID::FLOAT: return std::make_unique(AvgFunction::initialize, AvgFunction::updateAll, AvgFunction::updatePos, AvgFunction::combine, AvgFunction::finalize, inputType, isDistinct); default: - throw RuntimeException("Unsupported input data type " + Types::dataTypeToString(inputType) + + throw RuntimeException("Unsupported input data type " + + LogicalTypeUtils::dataTypeToString(inputType) + " for AggregateFunctionUtil::getAvgFunction."); } } std::unique_ptr AggregateFunctionUtil::getSumFunction( - const DataType& inputType, bool isDistinct) { - switch (inputType.typeID) { - case INT64: + const LogicalType& inputType, bool isDistinct) { + switch (inputType.getLogicalTypeID()) { + case LogicalTypeID::INT64: return std::make_unique(SumFunction::initialize, SumFunction::updateAll, SumFunction::updatePos, SumFunction::combine, SumFunction::finalize, inputType, isDistinct); - case INT32: + case LogicalTypeID::INT32: return std::make_unique(SumFunction::initialize, SumFunction::updateAll, SumFunction::updatePos, SumFunction::combine, SumFunction::finalize, inputType, isDistinct); - case INT16: + case LogicalTypeID::INT16: return std::make_unique(SumFunction::initialize, SumFunction::updateAll, SumFunction::updatePos, SumFunction::combine, SumFunction::finalize, inputType, isDistinct); - case DOUBLE: + case LogicalTypeID::DOUBLE: return std::make_unique(SumFunction::initialize, SumFunction::updateAll, SumFunction::updatePos, SumFunction::combine, SumFunction::finalize, inputType, isDistinct); - case FLOAT: + case LogicalTypeID::FLOAT: return std::make_unique(SumFunction::initialize, SumFunction::updateAll, SumFunction::updatePos, SumFunction::combine, SumFunction::finalize, inputType, isDistinct); default: - throw RuntimeException("Unsupported input data type " + Types::dataTypeToString(inputType) + + throw RuntimeException("Unsupported input data type " + + LogicalTypeUtils::dataTypeToString(inputType) + " for AggregateFunctionUtil::getSumFunction."); } } std::unique_ptr AggregateFunctionUtil::getMinFunction( - const DataType& inputType, bool isDistinct) { + const LogicalType& inputType, bool isDistinct) { return getMinMaxFunction(inputType, isDistinct); } std::unique_ptr AggregateFunctionUtil::getMaxFunction( - const DataType& inputType, bool isDistinct) { + const LogicalType& inputType, bool isDistinct) { return getMinMaxFunction(inputType, isDistinct); } std::unique_ptr AggregateFunctionUtil::getCollectFunction( - const common::DataType& inputType, bool isDistinct) { + const common::LogicalType& inputType, bool isDistinct) { return std::make_unique(CollectFunction::initialize, CollectFunction::updateAll, CollectFunction::updatePos, CollectFunction::combine, CollectFunction::finalize, inputType, isDistinct); @@ -103,67 +105,68 @@ std::unique_ptr AggregateFunctionUtil::getCollectFunction( template std::unique_ptr AggregateFunctionUtil::getMinMaxFunction( - const DataType& inputType, bool isDistinct) { - switch (inputType.typeID) { - case BOOL: + const LogicalType& inputType, bool isDistinct) { + switch (inputType.getLogicalTypeID()) { + case LogicalTypeID::BOOL: return std::make_unique(MinMaxFunction::initialize, MinMaxFunction::updateAll, MinMaxFunction::updatePos, MinMaxFunction::combine, MinMaxFunction::finalize, inputType, isDistinct); - case INT64: + case LogicalTypeID::INT64: return std::make_unique(MinMaxFunction::initialize, MinMaxFunction::updateAll, MinMaxFunction::updatePos, MinMaxFunction::combine, MinMaxFunction::finalize, inputType, isDistinct); - case INT32: + case LogicalTypeID::INT32: return std::make_unique(MinMaxFunction::initialize, MinMaxFunction::updateAll, MinMaxFunction::updatePos, MinMaxFunction::combine, MinMaxFunction::finalize, inputType, isDistinct); - case INT16: + case LogicalTypeID::INT16: return std::make_unique(MinMaxFunction::initialize, MinMaxFunction::updateAll, MinMaxFunction::updatePos, MinMaxFunction::combine, MinMaxFunction::finalize, inputType, isDistinct); - case DOUBLE: + case LogicalTypeID::DOUBLE: return std::make_unique(MinMaxFunction::initialize, MinMaxFunction::updateAll, MinMaxFunction::updatePos, MinMaxFunction::combine, MinMaxFunction::finalize, inputType, isDistinct); - case FLOAT: + case LogicalTypeID::FLOAT: return std::make_unique(MinMaxFunction::initialize, MinMaxFunction::updateAll, MinMaxFunction::updatePos, MinMaxFunction::combine, MinMaxFunction::finalize, inputType, isDistinct); - case DATE: + case LogicalTypeID::DATE: return std::make_unique(MinMaxFunction::initialize, MinMaxFunction::updateAll, MinMaxFunction::updatePos, MinMaxFunction::combine, MinMaxFunction::finalize, inputType, isDistinct); - case TIMESTAMP: + case LogicalTypeID::TIMESTAMP: return std::make_unique(MinMaxFunction::initialize, MinMaxFunction::updateAll, MinMaxFunction::updatePos, MinMaxFunction::combine, MinMaxFunction::finalize, inputType, isDistinct); - case INTERVAL: + case LogicalTypeID::INTERVAL: return std::make_unique(MinMaxFunction::initialize, MinMaxFunction::updateAll, MinMaxFunction::updatePos, MinMaxFunction::combine, MinMaxFunction::finalize, inputType, isDistinct); - case STRING: + case LogicalTypeID::STRING: return std::make_unique(MinMaxFunction::initialize, MinMaxFunction::updateAll, MinMaxFunction::updatePos, MinMaxFunction::combine, MinMaxFunction::finalize, inputType, isDistinct); - case INTERNAL_ID: + case LogicalTypeID::INTERNAL_ID: return std::make_unique(MinMaxFunction::initialize, MinMaxFunction::updateAll, MinMaxFunction::updatePos, MinMaxFunction::combine, MinMaxFunction::finalize, inputType, isDistinct); default: - throw RuntimeException("Unsupported input data type " + Types::dataTypeToString(inputType) + + throw RuntimeException("Unsupported input data type " + + LogicalTypeUtils::dataTypeToString(inputType) + " for AggregateFunctionUtil::getMinMaxFunction."); } } diff --git a/src/function/built_in_aggregate_functions.cpp b/src/function/built_in_aggregate_functions.cpp index e40e3f55b6..b33090da30 100644 --- a/src/function/built_in_aggregate_functions.cpp +++ b/src/function/built_in_aggregate_functions.cpp @@ -8,7 +8,7 @@ namespace kuzu { namespace function { AggregateFunctionDefinition* BuiltInAggregateFunctions::matchFunction( - const std::string& name, const std::vector& inputTypes, bool isDistinct) { + const std::string& name, const std::vector& inputTypes, bool isDistinct) { auto& functionDefinitions = aggregateFunctions.at(name); std::vector candidateFunctions; for (auto& functionDefinition : functionDefinitions) { @@ -31,16 +31,16 @@ std::vector BuiltInAggregateFunctions::getFunctionNames() { return result; } -uint32_t BuiltInAggregateFunctions::getFunctionCost(const std::vector& inputTypes, +uint32_t BuiltInAggregateFunctions::getFunctionCost(const std::vector& inputTypes, bool isDistinct, AggregateFunctionDefinition* function) { if (inputTypes.size() != function->parameterTypeIDs.size() || isDistinct != function->isDistinct) { return UINT32_MAX; } for (auto i = 0u; i < inputTypes.size(); ++i) { - if (function->parameterTypeIDs[i] == ANY) { + if (function->parameterTypeIDs[i] == LogicalTypeID::ANY) { continue; - } else if (inputTypes[i].typeID != function->parameterTypeIDs[i]) { + } else if (inputTypes[i].getLogicalTypeID() != function->parameterTypeIDs[i]) { return UINT32_MAX; } } @@ -49,7 +49,7 @@ uint32_t BuiltInAggregateFunctions::getFunctionCost(const std::vector& void BuiltInAggregateFunctions::validateNonEmptyCandidateFunctions( std::vector& candidateFunctions, const std::string& name, - const std::vector& inputTypes, bool isDistinct) { + const std::vector& inputTypes, bool isDistinct) { if (candidateFunctions.empty()) { std::string supportedInputsString; for (auto& functionDefinition : aggregateFunctions.at(name)) { @@ -60,8 +60,8 @@ void BuiltInAggregateFunctions::validateNonEmptyCandidateFunctions( } throw BinderException("Cannot match a built-in function for given function " + name + (isDistinct ? "DISTINCT " : "") + - Types::dataTypesToString(inputTypes) + ". Supported inputs are\n" + - supportedInputsString); + LogicalTypeUtils::dataTypesToString(inputTypes) + + ". Supported inputs are\n" + supportedInputsString); } } @@ -78,18 +78,24 @@ void BuiltInAggregateFunctions::registerAggregateFunctions() { void BuiltInAggregateFunctions::registerCountStar() { std::vector> definitions; definitions.push_back(std::make_unique(COUNT_STAR_FUNC_NAME, - std::vector{}, INT64, AggregateFunctionUtil::getCountStarFunction(), false)); + std::vector{}, LogicalTypeID::INT64, + AggregateFunctionUtil::getCountStarFunction(), false)); aggregateFunctions.insert({COUNT_STAR_FUNC_NAME, std::move(definitions)}); } void BuiltInAggregateFunctions::registerCount() { std::vector> definitions; - for (auto& typeID : DataType::getAllValidTypeIDs()) { - auto inputType = - (typeID == VAR_LIST ? DataType(std::make_unique(ANY)) : DataType(typeID)); + LogicalType inputType; + for (auto& typeID : LogicalType::getAllValidLogicTypeIDs()) { + if (typeID == LogicalTypeID::VAR_LIST) { + inputType = LogicalType( + typeID, std::make_unique(std::make_unique())); + } else { + inputType = LogicalType(typeID); + } for (auto isDistinct : std::vector{true, false}) { definitions.push_back(std::make_unique(COUNT_FUNC_NAME, - std::vector{typeID}, INT64, + std::vector{typeID}, LogicalTypeID::INT64, AggregateFunctionUtil::getCountFunction(inputType, isDistinct), isDistinct)); } } @@ -98,11 +104,12 @@ void BuiltInAggregateFunctions::registerCount() { void BuiltInAggregateFunctions::registerSum() { std::vector> definitions; - for (auto typeID : DataType::getNumericalTypeIDs()) { + for (auto typeID : LogicalType::getNumericalLogicalTypeIDs()) { for (auto isDistinct : std::vector{true, false}) { definitions.push_back(std::make_unique(SUM_FUNC_NAME, - std::vector{typeID}, typeID, - AggregateFunctionUtil::getSumFunction(DataType(typeID), isDistinct), isDistinct)); + std::vector{typeID}, typeID, + AggregateFunctionUtil::getSumFunction(LogicalType(typeID), isDistinct), + isDistinct)); } } aggregateFunctions.insert({SUM_FUNC_NAME, std::move(definitions)}); @@ -110,11 +117,12 @@ void BuiltInAggregateFunctions::registerSum() { void BuiltInAggregateFunctions::registerAvg() { std::vector> definitions; - for (auto typeID : DataType::getNumericalTypeIDs()) { + for (auto typeID : LogicalType::getNumericalLogicalTypeIDs()) { for (auto isDistinct : std::vector{true, false}) { definitions.push_back(std::make_unique(AVG_FUNC_NAME, - std::vector{typeID}, DOUBLE, - AggregateFunctionUtil::getAvgFunction(DataType(typeID), isDistinct), isDistinct)); + std::vector{typeID}, LogicalTypeID::DOUBLE, + AggregateFunctionUtil::getAvgFunction(LogicalType(typeID), isDistinct), + isDistinct)); } } aggregateFunctions.insert({AVG_FUNC_NAME, std::move(definitions)}); @@ -122,11 +130,12 @@ void BuiltInAggregateFunctions::registerAvg() { void BuiltInAggregateFunctions::registerMin() { std::vector> definitions; - for (auto typeID : DataType::getAllValidComparableTypes()) { + for (auto typeID : LogicalType::getAllValidComparableLogicalTypes()) { for (auto isDistinct : std::vector{true, false}) { definitions.push_back(std::make_unique(MIN_FUNC_NAME, - std::vector{typeID}, typeID, - AggregateFunctionUtil::getMinFunction(DataType(typeID), isDistinct), isDistinct)); + std::vector{typeID}, typeID, + AggregateFunctionUtil::getMinFunction(LogicalType(typeID), isDistinct), + isDistinct)); } } aggregateFunctions.insert({MIN_FUNC_NAME, std::move(definitions)}); @@ -134,11 +143,12 @@ void BuiltInAggregateFunctions::registerMin() { void BuiltInAggregateFunctions::registerMax() { std::vector> definitions; - for (auto typeID : DataType::getAllValidComparableTypes()) { + for (auto typeID : LogicalType::getAllValidComparableLogicalTypes()) { for (auto isDistinct : std::vector{true, false}) { definitions.push_back(std::make_unique(MAX_FUNC_NAME, - std::vector{typeID}, typeID, - AggregateFunctionUtil::getMaxFunction(DataType(typeID), isDistinct), isDistinct)); + std::vector{typeID}, typeID, + AggregateFunctionUtil::getMaxFunction(LogicalType(typeID), isDistinct), + isDistinct)); } } aggregateFunctions.insert({MAX_FUNC_NAME, std::move(definitions)}); @@ -148,9 +158,9 @@ void BuiltInAggregateFunctions::registerCollect() { std::vector> definitions; for (auto isDistinct : std::vector{true, false}) { definitions.push_back(std::make_unique(COLLECT_FUNC_NAME, - std::vector{ANY}, VAR_LIST, - AggregateFunctionUtil::getCollectFunction(DataType(ANY), isDistinct), isDistinct, - CollectFunction::bindFunc)); + std::vector{LogicalTypeID::ANY}, LogicalTypeID::VAR_LIST, + AggregateFunctionUtil::getCollectFunction(LogicalType(LogicalTypeID::ANY), isDistinct), + isDistinct, CollectFunction::bindFunc)); } aggregateFunctions.insert({COLLECT_FUNC_NAME, std::move(definitions)}); } diff --git a/src/function/built_in_vector_operations.cpp b/src/function/built_in_vector_operations.cpp index ce2c317665..65b81ae018 100644 --- a/src/function/built_in_vector_operations.cpp +++ b/src/function/built_in_vector_operations.cpp @@ -35,14 +35,15 @@ bool BuiltInVectorOperations::canApplyStaticEvaluation( const std::string& functionName, const binder::expression_vector& children) { if ((functionName == CAST_TO_DATE_FUNC_NAME || functionName == CAST_TO_TIMESTAMP_FUNC_NAME || functionName == CAST_TO_INTERVAL_FUNC_NAME) && - children[0]->expressionType == LITERAL && children[0]->dataType.typeID == STRING) { + children[0]->expressionType == LITERAL && + children[0]->dataType.getLogicalTypeID() == LogicalTypeID::STRING) { return true; // bind as literal } return false; } VectorOperationDefinition* BuiltInVectorOperations::matchFunction( - const std::string& name, const std::vector& inputTypes) { + const std::string& name, const std::vector& inputTypes) { auto& functionDefinitions = vectorOperations.at(name); bool isOverload = functionDefinitions.size() > 1; std::vector candidateFunctions; @@ -75,27 +76,28 @@ std::vector BuiltInVectorOperations::getFunctionNames() { return result; } -uint32_t BuiltInVectorOperations::getCastCost(DataTypeID inputTypeID, DataTypeID targetTypeID) { +uint32_t BuiltInVectorOperations::getCastCost( + LogicalTypeID inputTypeID, LogicalTypeID targetTypeID) { if (inputTypeID == targetTypeID) { return 0; } else { - if (targetTypeID == ANY) { + if (targetTypeID == LogicalTypeID::ANY) { // Any inputTypeID can match to type ANY return 0; } switch (inputTypeID) { - case common::ANY: + case common::LogicalTypeID::ANY: // ANY type can be any type return 0; - case common::INT64: + case common::LogicalTypeID::INT64: return castInt64(targetTypeID); - case common::INT32: + case common::LogicalTypeID::INT32: return castInt32(targetTypeID); - case common::INT16: + case common::LogicalTypeID::INT16: return castInt16(targetTypeID); - case common::DOUBLE: + case common::LogicalTypeID::DOUBLE: return castDouble(targetTypeID); - case common::FLOAT: + case common::LogicalTypeID::FLOAT: return castFloat(targetTypeID); default: return UINT32_MAX; @@ -104,32 +106,32 @@ uint32_t BuiltInVectorOperations::getCastCost(DataTypeID inputTypeID, DataTypeID } uint32_t BuiltInVectorOperations::getCastCost( - const DataType& inputType, const DataType& targetType) { + const LogicalType& inputType, const LogicalType& targetType) { if (inputType == targetType) { return 0; } else { - switch (inputType.typeID) { - case common::FIXED_LIST: - case common::VAR_LIST: + switch (inputType.getLogicalTypeID()) { + case common::LogicalTypeID::FIXED_LIST: + case common::LogicalTypeID::VAR_LIST: return UINT32_MAX; default: - return getCastCost(inputType.typeID, targetType.typeID); + return getCastCost(inputType.getLogicalTypeID(), targetType.getLogicalTypeID()); } } } -uint32_t BuiltInVectorOperations::getTargetTypeCost(common::DataTypeID typeID) { +uint32_t BuiltInVectorOperations::getTargetTypeCost(common::LogicalTypeID typeID) { switch (typeID) { - case common::INT32: { + case common::LogicalTypeID::INT32: { return 103; } - case common::INT64: { + case common::LogicalTypeID::INT64: { return 101; } - case common::FLOAT: { + case common::LogicalTypeID::FLOAT: { return 110; } - case common::DOUBLE: { + case common::LogicalTypeID::DOUBLE: { return 102; } default: { @@ -138,49 +140,49 @@ uint32_t BuiltInVectorOperations::getTargetTypeCost(common::DataTypeID typeID) { } } -uint32_t BuiltInVectorOperations::castInt64(common::DataTypeID targetTypeID) { +uint32_t BuiltInVectorOperations::castInt64(common::LogicalTypeID targetTypeID) { switch (targetTypeID) { - case common::FLOAT: - case common::DOUBLE: + case common::LogicalTypeID::FLOAT: + case common::LogicalTypeID::DOUBLE: return getTargetTypeCost(targetTypeID); default: return UINT32_MAX; } } -uint32_t BuiltInVectorOperations::castInt32(common::DataTypeID targetTypeID) { +uint32_t BuiltInVectorOperations::castInt32(common::LogicalTypeID targetTypeID) { switch (targetTypeID) { - case common::INT64: - case common::FLOAT: - case common::DOUBLE: + case common::LogicalTypeID::INT64: + case common::LogicalTypeID::FLOAT: + case common::LogicalTypeID::DOUBLE: return getTargetTypeCost(targetTypeID); default: return UINT32_MAX; } } -uint32_t BuiltInVectorOperations::castInt16(common::DataTypeID targetTypeID) { +uint32_t BuiltInVectorOperations::castInt16(common::LogicalTypeID targetTypeID) { switch (targetTypeID) { - case common::INT32: - case common::INT64: - case common::FLOAT: - case common::DOUBLE: + case common::LogicalTypeID::INT32: + case common::LogicalTypeID::INT64: + case common::LogicalTypeID::FLOAT: + case common::LogicalTypeID::DOUBLE: return getTargetTypeCost(targetTypeID); default: return UINT32_MAX; } } -uint32_t BuiltInVectorOperations::castDouble(common::DataTypeID targetTypeID) { +uint32_t BuiltInVectorOperations::castDouble(common::LogicalTypeID targetTypeID) { switch (targetTypeID) { default: return UINT32_MAX; } } -uint32_t BuiltInVectorOperations::castFloat(common::DataTypeID targetTypeID) { +uint32_t BuiltInVectorOperations::castFloat(common::LogicalTypeID targetTypeID) { switch (targetTypeID) { - case common::DOUBLE: + case common::LogicalTypeID::DOUBLE: return getTargetTypeCost(targetTypeID); default: return UINT32_MAX; @@ -195,7 +197,7 @@ VectorOperationDefinition* BuiltInVectorOperations::getBestMatch( VectorOperationDefinition* result = nullptr; auto cost = UINT32_MAX; for (auto& function : functions) { - std::unordered_set distinctParameterTypes; + std::unordered_set distinctParameterTypes; for (auto& parameterTypeID : function->parameterTypeIDs) { if (!distinctParameterTypes.contains(parameterTypeID)) { distinctParameterTypes.insert(parameterTypeID); @@ -210,8 +212,8 @@ VectorOperationDefinition* BuiltInVectorOperations::getBestMatch( return result; } -uint32_t BuiltInVectorOperations::getFunctionCost( - const std::vector& inputTypes, VectorOperationDefinition* function, bool isOverload) { +uint32_t BuiltInVectorOperations::getFunctionCost(const std::vector& inputTypes, + VectorOperationDefinition* function, bool isOverload) { if (function->isVarLength) { assert(function->parameterTypeIDs.size() == 1); return matchVarLengthParameters(inputTypes, function->parameterTypeIDs[0], isOverload); @@ -220,14 +222,14 @@ uint32_t BuiltInVectorOperations::getFunctionCost( } } -uint32_t BuiltInVectorOperations::matchParameters(const std::vector& inputTypes, - const std::vector& targetTypeIDs, bool isOverload) { +uint32_t BuiltInVectorOperations::matchParameters(const std::vector& inputTypes, + const std::vector& targetTypeIDs, bool isOverload) { if (inputTypes.size() != targetTypeIDs.size()) { return UINT32_MAX; } auto cost = 0u; for (auto i = 0u; i < inputTypes.size(); ++i) { - auto castCost = getCastCost(inputTypes[i].typeID, targetTypeIDs[i]); + auto castCost = getCastCost(inputTypes[i].getLogicalTypeID(), targetTypeIDs[i]); if (castCost == UINT32_MAX) { return UINT32_MAX; } @@ -237,10 +239,10 @@ uint32_t BuiltInVectorOperations::matchParameters(const std::vector& i } uint32_t BuiltInVectorOperations::matchVarLengthParameters( - const std::vector& inputTypes, DataTypeID targetTypeID, bool isOverload) { + const std::vector& inputTypes, LogicalTypeID targetTypeID, bool isOverload) { auto cost = 0u; for (auto& inputType : inputTypes) { - auto castCost = getCastCost(inputType.typeID, targetTypeID); + auto castCost = getCastCost(inputType.getLogicalTypeID(), targetTypeID); if (castCost == UINT32_MAX) { return UINT32_MAX; } @@ -251,15 +253,15 @@ uint32_t BuiltInVectorOperations::matchVarLengthParameters( void BuiltInVectorOperations::validateNonEmptyCandidateFunctions( std::vector& candidateFunctions, const std::string& name, - const std::vector& inputTypes) { + const std::vector& inputTypes) { if (candidateFunctions.empty()) { std::string supportedInputsString; for (auto& functionDefinition : vectorOperations.at(name)) { supportedInputsString += functionDefinition->signatureToString() + "\n"; } throw BinderException("Cannot match a built-in function for given function " + name + - Types::dataTypesToString(inputTypes) + ". Supported inputs are\n" + - supportedInputsString); + LogicalTypeUtils::dataTypesToString(inputTypes) + + ". Supported inputs are\n" + supportedInputsString); } } @@ -454,10 +456,10 @@ void BuiltInVectorOperations::registerListOperations() { void BuiltInVectorOperations::registerInternalIDOperation() { std::vector> definitions; - definitions.push_back(make_unique( - ID_FUNC_NAME, std::vector{NODE}, INTERNAL_ID, nullptr)); - definitions.push_back(make_unique( - ID_FUNC_NAME, std::vector{REL}, INTERNAL_ID, nullptr)); + definitions.push_back(make_unique(ID_FUNC_NAME, + std::vector{LogicalTypeID::NODE}, LogicalTypeID::INTERNAL_ID, nullptr)); + definitions.push_back(make_unique(ID_FUNC_NAME, + std::vector{LogicalTypeID::REL}, LogicalTypeID::INTERNAL_ID, nullptr)); vectorOperations.insert({ID_FUNC_NAME, std::move(definitions)}); } diff --git a/src/function/vector_arithmetic_operations.cpp b/src/function/vector_arithmetic_operations.cpp index a45fc16ab9..51cb66c4d2 100644 --- a/src/function/vector_arithmetic_operations.cpp +++ b/src/function/vector_arithmetic_operations.cpp @@ -7,71 +7,74 @@ namespace function { std::vector> AddVectorOperation::getDefinitions() { std::vector> result; - for (auto typeID : DataType::getNumericalTypeIDs()) { + for (auto typeID : LogicalType::getNumericalLogicalTypeIDs()) { result.push_back(getBinaryDefinition(ADD_FUNC_NAME, typeID)); } // interval + interval → interval result.push_back(getBinaryDefinition( - ADD_FUNC_NAME, INTERVAL, INTERVAL)); + ADD_FUNC_NAME, LogicalTypeID::INTERVAL, LogicalTypeID::INTERVAL)); // date + int → date - result.push_back( - make_unique(ADD_FUNC_NAME, std::vector{DATE, INT64}, - DATE, BinaryExecFunction)); + result.push_back(make_unique(ADD_FUNC_NAME, + std::vector{LogicalTypeID::DATE, LogicalTypeID::INT64}, LogicalTypeID::DATE, + BinaryExecFunction)); // int + date → date - result.push_back( - make_unique(ADD_FUNC_NAME, std::vector{INT64, DATE}, - DATE, BinaryExecFunction)); + result.push_back(make_unique(ADD_FUNC_NAME, + std::vector{LogicalTypeID::INT64, LogicalTypeID::DATE}, LogicalTypeID::DATE, + BinaryExecFunction)); // date + interval → date result.push_back(make_unique(ADD_FUNC_NAME, - std::vector{DATE, INTERVAL}, DATE, - BinaryExecFunction)); + std::vector{LogicalTypeID::DATE, LogicalTypeID::INTERVAL}, + LogicalTypeID::DATE, BinaryExecFunction)); // interval + date → date result.push_back(make_unique(ADD_FUNC_NAME, - std::vector{INTERVAL, DATE}, DATE, - BinaryExecFunction)); + std::vector{LogicalTypeID::INTERVAL, LogicalTypeID::DATE}, + LogicalTypeID::DATE, BinaryExecFunction)); // timestamp + interval → timestamp result.push_back(make_unique(ADD_FUNC_NAME, - std::vector{TIMESTAMP, INTERVAL}, TIMESTAMP, + std::vector{LogicalTypeID::TIMESTAMP, LogicalTypeID::INTERVAL}, + LogicalTypeID::TIMESTAMP, BinaryExecFunction)); // interval + timestamp → timestamp result.push_back(make_unique(ADD_FUNC_NAME, - std::vector{INTERVAL, TIMESTAMP}, TIMESTAMP, + std::vector{LogicalTypeID::INTERVAL, LogicalTypeID::TIMESTAMP}, + LogicalTypeID::TIMESTAMP, BinaryExecFunction)); return result; } std::vector> SubtractVectorOperation::getDefinitions() { std::vector> result; - for (auto typeID : DataType::getNumericalTypeIDs()) { + for (auto typeID : LogicalType::getNumericalLogicalTypeIDs()) { result.push_back(getBinaryDefinition(SUBTRACT_FUNC_NAME, typeID)); } // date - date → int64 - result.push_back( - getBinaryDefinition(SUBTRACT_FUNC_NAME, DATE, INT64)); + result.push_back(getBinaryDefinition( + SUBTRACT_FUNC_NAME, LogicalTypeID::DATE, LogicalTypeID::INT64)); // date - integer → date result.push_back(make_unique(SUBTRACT_FUNC_NAME, - std::vector{DATE, INT64}, DATE, + std::vector{LogicalTypeID::DATE, LogicalTypeID::INT64}, LogicalTypeID::DATE, BinaryExecFunction)); // date - interval → date result.push_back(make_unique(SUBTRACT_FUNC_NAME, - std::vector{DATE, INTERVAL}, DATE, - BinaryExecFunction)); + std::vector{LogicalTypeID::DATE, LogicalTypeID::INTERVAL}, + LogicalTypeID::DATE, BinaryExecFunction)); // timestamp - timestamp → interval result.push_back(getBinaryDefinition( - SUBTRACT_FUNC_NAME, TIMESTAMP, INTERVAL)); + SUBTRACT_FUNC_NAME, LogicalTypeID::TIMESTAMP, LogicalTypeID::INTERVAL)); // timestamp - interval → timestamp result.push_back(make_unique(SUBTRACT_FUNC_NAME, - std::vector{TIMESTAMP, INTERVAL}, TIMESTAMP, + std::vector{LogicalTypeID::TIMESTAMP, LogicalTypeID::INTERVAL}, + LogicalTypeID::TIMESTAMP, BinaryExecFunction)); // interval - interval → interval result.push_back(getBinaryDefinition( - SUBTRACT_FUNC_NAME, INTERVAL, INTERVAL)); + SUBTRACT_FUNC_NAME, LogicalTypeID::INTERVAL, LogicalTypeID::INTERVAL)); return result; } std::vector> MultiplyVectorOperation::getDefinitions() { std::vector> result; - for (auto typeID : DataType::getNumericalTypeIDs()) { + for (auto typeID : LogicalType::getNumericalLogicalTypeIDs()) { result.push_back(getBinaryDefinition(MULTIPLY_FUNC_NAME, typeID)); } return result; @@ -79,19 +82,20 @@ std::vector> MultiplyVectorOperation: std::vector> DivideVectorOperation::getDefinitions() { std::vector> result; - for (auto typeID : DataType::getNumericalTypeIDs()) { + for (auto typeID : LogicalType::getNumericalLogicalTypeIDs()) { result.push_back(getBinaryDefinition(DIVIDE_FUNC_NAME, typeID)); } // interval / int → interval result.push_back(make_unique(DIVIDE_FUNC_NAME, - std::vector{INTERVAL, INT64}, INTERVAL, + std::vector{LogicalTypeID::INTERVAL, LogicalTypeID::INT64}, + LogicalTypeID::INTERVAL, BinaryExecFunction)); return result; } std::vector> ModuloVectorOperation::getDefinitions() { std::vector> result; - for (auto typeID : DataType::getNumericalTypeIDs()) { + for (auto typeID : LogicalType::getNumericalLogicalTypeIDs()) { result.push_back(getBinaryDefinition(MODULO_FUNC_NAME, typeID)); } return result; @@ -100,14 +104,14 @@ std::vector> ModuloVectorOperation::g std::vector> PowerVectorOperation::getDefinitions() { std::vector> result; // double_t ^ double_t -> double_t - result.push_back( - getBinaryDefinition(POWER_FUNC_NAME, DOUBLE, DOUBLE)); + result.push_back(getBinaryDefinition( + POWER_FUNC_NAME, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); return result; } std::vector> NegateVectorOperation::getDefinitions() { std::vector> result; - for (auto& typeID : DataType::getNumericalTypeIDs()) { + for (auto& typeID : LogicalType::getNumericalLogicalTypeIDs()) { result.push_back(getUnaryDefinition(NEGATE_FUNC_NAME, typeID)); } return result; @@ -115,7 +119,7 @@ std::vector> NegateVectorOperation::g std::vector> AbsVectorOperation::getDefinitions() { std::vector> result; - for (auto& typeID : DataType::getNumericalTypeIDs()) { + for (auto& typeID : LogicalType::getNumericalLogicalTypeIDs()) { result.push_back(getUnaryDefinition(ABS_FUNC_NAME, typeID)); } return result; @@ -123,7 +127,7 @@ std::vector> AbsVectorOperation::getD std::vector> FloorVectorOperation::getDefinitions() { std::vector> result; - for (auto& typeID : DataType::getNumericalTypeIDs()) { + for (auto& typeID : LogicalType::getNumericalLogicalTypeIDs()) { result.push_back(getUnaryDefinition(FLOOR_FUNC_NAME, typeID)); } return result; @@ -131,7 +135,7 @@ std::vector> FloorVectorOperation::ge std::vector> CeilVectorOperation::getDefinitions() { std::vector> result; - for (auto& typeID : DataType::getNumericalTypeIDs()) { + for (auto& typeID : LogicalType::getNumericalLogicalTypeIDs()) { result.push_back(getUnaryDefinition(CEIL_FUNC_NAME, typeID)); } return result; @@ -139,161 +143,177 @@ std::vector> CeilVectorOperation::get std::vector> SinVectorOperation::getDefinitions() { std::vector> result; - result.push_back(getUnaryDefinition(SIN_FUNC_NAME, DOUBLE, DOUBLE)); + result.push_back(getUnaryDefinition( + SIN_FUNC_NAME, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); return result; } std::vector> CosVectorOperation::getDefinitions() { std::vector> result; - result.push_back(getUnaryDefinition(COS_FUNC_NAME, DOUBLE, DOUBLE)); + result.push_back(getUnaryDefinition( + COS_FUNC_NAME, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); return result; } std::vector> TanVectorOperation::getDefinitions() { std::vector> result; - result.push_back(getUnaryDefinition(TAN_FUNC_NAME, DOUBLE, DOUBLE)); + result.push_back(getUnaryDefinition( + TAN_FUNC_NAME, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); return result; } std::vector> CotVectorOperation::getDefinitions() { std::vector> result; - result.push_back(getUnaryDefinition(COT_FUNC_NAME, DOUBLE, DOUBLE)); + result.push_back(getUnaryDefinition( + COT_FUNC_NAME, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); return result; } std::vector> AsinVectorOperation::getDefinitions() { std::vector> result; - result.push_back(getUnaryDefinition(ASIN_FUNC_NAME, DOUBLE, DOUBLE)); + result.push_back(getUnaryDefinition( + ASIN_FUNC_NAME, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); return result; } std::vector> AcosVectorOperation::getDefinitions() { std::vector> result; - result.push_back(getUnaryDefinition(ACOS_FUNC_NAME, DOUBLE, DOUBLE)); + result.push_back(getUnaryDefinition( + ACOS_FUNC_NAME, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); return result; } std::vector> AtanVectorOperation::getDefinitions() { std::vector> result; - result.push_back(getUnaryDefinition(ATAN_FUNC_NAME, DOUBLE, DOUBLE)); + result.push_back(getUnaryDefinition( + ATAN_FUNC_NAME, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); return result; } std::vector> FactorialVectorOperation::getDefinitions() { std::vector> result; - result.push_back( - make_unique(FACTORIAL_FUNC_NAME, std::vector{INT64}, - INT64, UnaryExecFunction)); + result.push_back(make_unique(FACTORIAL_FUNC_NAME, + std::vector{LogicalTypeID::INT64}, LogicalTypeID::INT64, + UnaryExecFunction)); return result; } std::vector> SqrtVectorOperation::getDefinitions() { std::vector> result; - result.push_back(getUnaryDefinition(SQRT_FUNC_NAME, DOUBLE, DOUBLE)); + result.push_back(getUnaryDefinition( + SQRT_FUNC_NAME, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); return result; } std::vector> CbrtVectorOperation::getDefinitions() { std::vector> result; - result.push_back(getUnaryDefinition(CBRT_FUNC_NAME, DOUBLE, DOUBLE)); + result.push_back(getUnaryDefinition( + CBRT_FUNC_NAME, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); return result; } std::vector> GammaVectorOperation::getDefinitions() { std::vector> result; - result.push_back( - getUnaryDefinition(GAMMA_FUNC_NAME, DOUBLE, DOUBLE)); + result.push_back(getUnaryDefinition( + GAMMA_FUNC_NAME, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); return result; } std::vector> LgammaVectorOperation::getDefinitions() { std::vector> result; - result.push_back( - getUnaryDefinition(LGAMMA_FUNC_NAME, DOUBLE, DOUBLE)); + result.push_back(getUnaryDefinition( + LGAMMA_FUNC_NAME, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); return result; } std::vector> LnVectorOperation::getDefinitions() { std::vector> result; - result.push_back(getUnaryDefinition(LN_FUNC_NAME, DOUBLE, DOUBLE)); + result.push_back(getUnaryDefinition( + LN_FUNC_NAME, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); return result; } std::vector> LogVectorOperation::getDefinitions() { std::vector> result; - result.push_back(getUnaryDefinition(LOG_FUNC_NAME, DOUBLE, DOUBLE)); + result.push_back(getUnaryDefinition( + LOG_FUNC_NAME, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); return result; } std::vector> Log2VectorOperation::getDefinitions() { std::vector> result; - result.push_back(getUnaryDefinition(LOG2_FUNC_NAME, DOUBLE, DOUBLE)); + result.push_back(getUnaryDefinition( + LOG2_FUNC_NAME, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); return result; } std::vector> DegreesVectorOperation::getDefinitions() { std::vector> result; - result.push_back( - getUnaryDefinition(DEGREES_FUNC_NAME, DOUBLE, DOUBLE)); + result.push_back(getUnaryDefinition( + DEGREES_FUNC_NAME, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); return result; } std::vector> RadiansVectorOperation::getDefinitions() { std::vector> result; - result.push_back( - getUnaryDefinition(RADIANS_FUNC_NAME, DOUBLE, DOUBLE)); + result.push_back(getUnaryDefinition( + RADIANS_FUNC_NAME, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); return result; } std::vector> EvenVectorOperation::getDefinitions() { std::vector> result; - result.push_back(getUnaryDefinition(EVEN_FUNC_NAME, DOUBLE, DOUBLE)); + result.push_back(getUnaryDefinition( + EVEN_FUNC_NAME, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); return result; } std::vector> SignVectorOperation::getDefinitions() { std::vector> result; - result.push_back(getUnaryDefinition(SIGN_FUNC_NAME, INT64, INT64)); - result.push_back(getUnaryDefinition(SIGN_FUNC_NAME, DOUBLE, INT64)); - result.push_back(getUnaryDefinition(SIGN_FUNC_NAME, FLOAT, INT64)); + result.push_back(getUnaryDefinition( + SIGN_FUNC_NAME, LogicalTypeID::INT64, LogicalTypeID::INT64)); + result.push_back(getUnaryDefinition( + SIGN_FUNC_NAME, LogicalTypeID::DOUBLE, LogicalTypeID::INT64)); + result.push_back(getUnaryDefinition( + SIGN_FUNC_NAME, LogicalTypeID::FLOAT, LogicalTypeID::INT64)); return result; } std::vector> Atan2VectorOperation::getDefinitions() { std::vector> result; - result.push_back( - getBinaryDefinition(ATAN2_FUNC_NAME, DOUBLE, DOUBLE)); + result.push_back(getBinaryDefinition( + ATAN2_FUNC_NAME, LogicalTypeID::DOUBLE, LogicalTypeID::DOUBLE)); return result; } std::vector> RoundVectorOperation::getDefinitions() { std::vector> result; result.push_back(make_unique(ROUND_FUNC_NAME, - std::vector{DOUBLE, INT64}, DOUBLE, - BinaryExecFunction)); + std::vector{LogicalTypeID::DOUBLE, LogicalTypeID::INT64}, + LogicalTypeID::DOUBLE, BinaryExecFunction)); return result; } std::vector> BitwiseXorVectorOperation::getDefinitions() { std::vector> result; - result.push_back( - getBinaryDefinition(BITWISE_XOR_FUNC_NAME, INT64, INT64)); + result.push_back(getBinaryDefinition( + BITWISE_XOR_FUNC_NAME, LogicalTypeID::INT64, LogicalTypeID::INT64)); return result; } std::vector> BitwiseAndVectorOperation::getDefinitions() { std::vector> result; - result.push_back( - getBinaryDefinition(BITWISE_AND_FUNC_NAME, INT64, INT64)); + result.push_back(getBinaryDefinition( + BITWISE_AND_FUNC_NAME, LogicalTypeID::INT64, LogicalTypeID::INT64)); return result; } std::vector> BitwiseOrVectorOperation::getDefinitions() { std::vector> result; - result.push_back( - getBinaryDefinition(BITWISE_OR_FUNC_NAME, INT64, INT64)); + result.push_back(getBinaryDefinition( + BITWISE_OR_FUNC_NAME, LogicalTypeID::INT64, LogicalTypeID::INT64)); return result; } @@ -301,7 +321,7 @@ std::vector> BitShiftLeftVectorOperation::getDefinitions() { std::vector> result; result.push_back(getBinaryDefinition( - BITSHIFT_LEFT_FUNC_NAME, INT64, INT64)); + BITSHIFT_LEFT_FUNC_NAME, LogicalTypeID::INT64, LogicalTypeID::INT64)); return result; } @@ -309,14 +329,15 @@ std::vector> BitShiftRightVectorOperation::getDefinitions() { std::vector> result; result.push_back(getBinaryDefinition( - BITSHIFT_RIGHT_FUNC_NAME, INT64, INT64)); + BITSHIFT_RIGHT_FUNC_NAME, LogicalTypeID::INT64, LogicalTypeID::INT64)); return result; } std::vector> PiVectorOperation::getDefinitions() { std::vector> result; - result.push_back(make_unique(PI_FUNC_NAME, std::vector{}, - DOUBLE, ConstExecFunction)); + result.push_back( + make_unique(PI_FUNC_NAME, std::vector{}, + LogicalTypeID::DOUBLE, ConstExecFunction)); return result; } diff --git a/src/function/vector_boolean_operations.cpp b/src/function/vector_boolean_operations.cpp index 36d58a692f..aafa5b7ef0 100644 --- a/src/function/vector_boolean_operations.cpp +++ b/src/function/vector_boolean_operations.cpp @@ -32,7 +32,8 @@ scalar_exec_func VectorBooleanOperations::bindBinaryExecFunction( assert(children.size() == 2); auto leftType = children[0]->dataType; auto rightType = children[1]->dataType; - assert(leftType.typeID == BOOL && rightType.typeID == BOOL); + assert(leftType.getLogicalTypeID() == LogicalTypeID::BOOL && + rightType.getLogicalTypeID() == LogicalTypeID::BOOL); switch (expressionType) { case AND: { return BinaryBooleanExecFunction; @@ -54,7 +55,8 @@ scalar_select_func VectorBooleanOperations::bindBinarySelectFunction( assert(children.size() == 2); auto leftType = children[0]->dataType; auto rightType = children[1]->dataType; - assert(leftType.typeID == BOOL && rightType.typeID == BOOL); + assert(leftType.getLogicalTypeID() == LogicalTypeID::BOOL && + rightType.getLogicalTypeID() == LogicalTypeID::BOOL); switch (expressionType) { case AND: { return BinaryBooleanSelectFunction; @@ -73,7 +75,7 @@ scalar_select_func VectorBooleanOperations::bindBinarySelectFunction( scalar_exec_func VectorBooleanOperations::bindUnaryExecFunction( ExpressionType expressionType, const binder::expression_vector& children) { - assert(children.size() == 1 && children[0]->dataType.typeID == BOOL); + assert(children.size() == 1 && children[0]->dataType.getLogicalTypeID() == LogicalTypeID::BOOL); switch (expressionType) { case NOT: { return UnaryBooleanExecFunction; @@ -86,7 +88,7 @@ scalar_exec_func VectorBooleanOperations::bindUnaryExecFunction( scalar_select_func VectorBooleanOperations::bindUnarySelectFunction( ExpressionType expressionType, const binder::expression_vector& children) { - assert(children.size() == 1 && children[0]->dataType.typeID == BOOL); + assert(children.size() == 1 && children[0]->dataType.getLogicalTypeID() == LogicalTypeID::BOOL); switch (expressionType) { case NOT: { return UnaryBooleanSelectFunction; diff --git a/src/function/vector_cast_operations.cpp b/src/function/vector_cast_operations.cpp index 1cc29d5813..4555b715ba 100644 --- a/src/function/vector_cast_operations.cpp +++ b/src/function/vector_cast_operations.cpp @@ -9,17 +9,17 @@ namespace kuzu { namespace function { bool VectorCastOperations::hasImplicitCast( - const common::DataType& srcType, const common::DataType& dstType) { + const common::LogicalType& srcType, const common::LogicalType& dstType) { // We allow cast between any numerical types - if (Types::isNumerical(srcType) && Types::isNumerical(dstType)) { + if (LogicalTypeUtils::isNumerical(srcType) && LogicalTypeUtils::isNumerical(dstType)) { return true; } - switch (srcType.typeID) { - case common::STRING: { - switch (dstType.typeID) { - case common::DATE: - case common::TIMESTAMP: - case common::INTERVAL: + switch (srcType.getLogicalTypeID()) { + case common::LogicalTypeID::STRING: { + switch (dstType.getLogicalTypeID()) { + case common::LogicalTypeID::DATE: + case common::LogicalTypeID::TIMESTAMP: + case common::LogicalTypeID::INTERVAL: return true; default: return false; @@ -30,25 +30,25 @@ bool VectorCastOperations::hasImplicitCast( } } -std::string VectorCastOperations::bindImplicitCastFuncName(const common::DataType& dstType) { - switch (dstType.typeID) { - case common::INT16: +std::string VectorCastOperations::bindImplicitCastFuncName(const common::LogicalType& dstType) { + switch (dstType.getLogicalTypeID()) { + case common::LogicalTypeID::INT16: return CAST_TO_INT16_FUNC_NAME; - case common::INT32: + case common::LogicalTypeID::INT32: return CAST_TO_INT32_FUNC_NAME; - case common::INT64: + case common::LogicalTypeID::INT64: return CAST_TO_INT64_FUNC_NAME; - case common::FLOAT: + case common::LogicalTypeID::FLOAT: return CAST_TO_FLOAT_FUNC_NAME; - case common::DOUBLE: + case common::LogicalTypeID::DOUBLE: return CAST_TO_DOUBLE_FUNC_NAME; - case common::DATE: + case common::LogicalTypeID::DATE: return CAST_TO_DATE_FUNC_NAME; - case common::TIMESTAMP: + case common::LogicalTypeID::TIMESTAMP: return CAST_TO_TIMESTAMP_FUNC_NAME; - case common::INTERVAL: + case common::LogicalTypeID::INTERVAL: return CAST_TO_INTERVAL_FUNC_NAME; - case common::STRING: + case common::LogicalTypeID::STRING: return CAST_TO_STRING_FUNC_NAME; default: throw common::NotImplementedException("bindImplicitCastFuncName()"); @@ -56,43 +56,43 @@ std::string VectorCastOperations::bindImplicitCastFuncName(const common::DataTyp } scalar_exec_func VectorCastOperations::bindImplicitCastFunc( - common::DataTypeID sourceTypeID, common::DataTypeID targetTypeID) { + common::LogicalTypeID sourceTypeID, common::LogicalTypeID targetTypeID) { switch (targetTypeID) { - case common::INT16: { + case common::LogicalTypeID::INT16: { return bindImplicitNumericalCastFunc(sourceTypeID); } - case common::INT32: { + case common::LogicalTypeID::INT32: { return bindImplicitNumericalCastFunc(sourceTypeID); } - case common::INT64: { + case common::LogicalTypeID::INT64: { return bindImplicitNumericalCastFunc(sourceTypeID); } - case common::FLOAT: { + case common::LogicalTypeID::FLOAT: { return bindImplicitNumericalCastFunc(sourceTypeID); } - case common::DOUBLE: { + case common::LogicalTypeID::DOUBLE: { return bindImplicitNumericalCastFunc(sourceTypeID); } - case common::DATE: { - assert(sourceTypeID == common::STRING); + case common::LogicalTypeID::DATE: { + assert(sourceTypeID == common::LogicalTypeID::STRING); return VectorOperations::UnaryExecFunction; } - case common::TIMESTAMP: { - assert(sourceTypeID == common::STRING); + case common::LogicalTypeID::TIMESTAMP: { + assert(sourceTypeID == common::LogicalTypeID::STRING); return VectorOperations::UnaryExecFunction; } - case common::INTERVAL: { - assert(sourceTypeID == common::STRING); + case common::LogicalTypeID::INTERVAL: { + assert(sourceTypeID == common::LogicalTypeID::STRING); return VectorOperations::UnaryExecFunction; } default: - throw common::NotImplementedException("Unimplemented casting operation from " + - common::Types::dataTypeToString(sourceTypeID) + - " to " + - common::Types::dataTypeToString(targetTypeID) + "."); + throw common::NotImplementedException( + "Unimplemented casting operation from " + + common::LogicalTypeUtils::dataTypeToString(sourceTypeID) + " to " + + common::LogicalTypeUtils::dataTypeToString(targetTypeID) + "."); } } @@ -100,7 +100,7 @@ std::vector> CastToDateVectorOperation::getDefinitions() { std::vector> result; result.push_back(make_unique(CAST_TO_DATE_FUNC_NAME, - std::vector{STRING}, DATE, + std::vector{LogicalTypeID::STRING}, LogicalTypeID::DATE, UnaryExecFunction)); return result; } @@ -109,7 +109,7 @@ std::vector> CastToTimestampVectorOperation::getDefinitions() { std::vector> result; result.push_back(make_unique(CAST_TO_TIMESTAMP_FUNC_NAME, - std::vector{STRING}, TIMESTAMP, + std::vector{LogicalTypeID::STRING}, LogicalTypeID::TIMESTAMP, UnaryExecFunction)); return result; } @@ -118,7 +118,7 @@ std::vector> CastToIntervalVectorOperation::getDefinitions() { std::vector> result; result.push_back(make_unique(CAST_TO_INTERVAL_FUNC_NAME, - std::vector{STRING}, INTERVAL, + std::vector{LogicalTypeID::STRING}, LogicalTypeID::INTERVAL, UnaryExecFunction)); return result; } @@ -127,28 +127,28 @@ std::vector> CastToStringVectorOperation::getDefinitions() { std::vector> result; result.push_back(make_unique(CAST_TO_STRING_FUNC_NAME, - std::vector{BOOL}, STRING, + std::vector{LogicalTypeID::BOOL}, LogicalTypeID::STRING, UnaryCastExecFunction)); result.push_back(make_unique(CAST_TO_STRING_FUNC_NAME, - std::vector{INT64}, STRING, + std::vector{LogicalTypeID::INT64}, LogicalTypeID::STRING, UnaryCastExecFunction)); result.push_back(make_unique(CAST_TO_STRING_FUNC_NAME, - std::vector{DOUBLE}, STRING, + std::vector{LogicalTypeID::DOUBLE}, LogicalTypeID::STRING, UnaryCastExecFunction)); result.push_back(make_unique(CAST_TO_STRING_FUNC_NAME, - std::vector{DATE}, STRING, + std::vector{LogicalTypeID::DATE}, LogicalTypeID::STRING, UnaryCastExecFunction)); result.push_back(make_unique(CAST_TO_STRING_FUNC_NAME, - std::vector{TIMESTAMP}, STRING, + std::vector{LogicalTypeID::TIMESTAMP}, LogicalTypeID::STRING, UnaryCastExecFunction)); result.push_back(make_unique(CAST_TO_STRING_FUNC_NAME, - std::vector{INTERVAL}, STRING, + std::vector{LogicalTypeID::INTERVAL}, LogicalTypeID::STRING, UnaryCastExecFunction)); result.push_back(make_unique(CAST_TO_STRING_FUNC_NAME, - std::vector{STRING}, STRING, + std::vector{LogicalTypeID::STRING}, LogicalTypeID::STRING, UnaryCastExecFunction)); result.push_back(make_unique(CAST_TO_STRING_FUNC_NAME, - std::vector{VAR_LIST}, STRING, + std::vector{LogicalTypeID::VAR_LIST}, LogicalTypeID::STRING, UnaryCastExecFunction)); return result; } @@ -157,13 +157,13 @@ std::vector> CastToDoubleVectorOperation::getDefinitions() { std::vector> result; result.push_back(bindVectorOperation( - CAST_TO_DOUBLE_FUNC_NAME, INT16, DOUBLE)); + CAST_TO_DOUBLE_FUNC_NAME, LogicalTypeID::INT16, LogicalTypeID::DOUBLE)); result.push_back(bindVectorOperation( - CAST_TO_DOUBLE_FUNC_NAME, INT32, DOUBLE)); + CAST_TO_DOUBLE_FUNC_NAME, LogicalTypeID::INT32, LogicalTypeID::DOUBLE)); result.push_back(bindVectorOperation( - CAST_TO_DOUBLE_FUNC_NAME, INT64, DOUBLE)); + CAST_TO_DOUBLE_FUNC_NAME, LogicalTypeID::INT64, LogicalTypeID::DOUBLE)); result.push_back(bindVectorOperation( - CAST_TO_DOUBLE_FUNC_NAME, FLOAT, DOUBLE)); + CAST_TO_DOUBLE_FUNC_NAME, LogicalTypeID::FLOAT, LogicalTypeID::DOUBLE)); return result; } @@ -171,14 +171,14 @@ std::vector> CastToFloatVectorOperation::getDefinitions() { std::vector> result; result.push_back(bindVectorOperation( - CAST_TO_FLOAT_FUNC_NAME, INT16, FLOAT)); + CAST_TO_FLOAT_FUNC_NAME, LogicalTypeID::INT16, LogicalTypeID::FLOAT)); result.push_back(bindVectorOperation( - CAST_TO_FLOAT_FUNC_NAME, INT32, FLOAT)); + CAST_TO_FLOAT_FUNC_NAME, LogicalTypeID::INT32, LogicalTypeID::FLOAT)); result.push_back(bindVectorOperation( - CAST_TO_FLOAT_FUNC_NAME, INT64, FLOAT)); + CAST_TO_FLOAT_FUNC_NAME, LogicalTypeID::INT64, LogicalTypeID::FLOAT)); // down cast result.push_back(bindVectorOperation( - CAST_TO_FLOAT_FUNC_NAME, DOUBLE, FLOAT)); + CAST_TO_FLOAT_FUNC_NAME, LogicalTypeID::DOUBLE, LogicalTypeID::FLOAT)); return result; } @@ -186,14 +186,14 @@ std::vector> CastToInt64VectorOperation::getDefinitions() { std::vector> result; result.push_back(bindVectorOperation( - CAST_TO_INT64_FUNC_NAME, INT16, INT64)); + CAST_TO_INT64_FUNC_NAME, LogicalTypeID::INT16, LogicalTypeID::INT64)); result.push_back(bindVectorOperation( - CAST_TO_INT64_FUNC_NAME, INT32, INT64)); + CAST_TO_INT64_FUNC_NAME, LogicalTypeID::INT32, LogicalTypeID::INT64)); // down cast result.push_back(bindVectorOperation( - CAST_TO_INT64_FUNC_NAME, FLOAT, INT64)); + CAST_TO_INT64_FUNC_NAME, LogicalTypeID::FLOAT, LogicalTypeID::INT64)); result.push_back(bindVectorOperation( - CAST_TO_INT64_FUNC_NAME, DOUBLE, INT64)); + CAST_TO_INT64_FUNC_NAME, LogicalTypeID::DOUBLE, LogicalTypeID::INT64)); return result; } @@ -201,14 +201,14 @@ std::vector> CastToInt32VectorOperation::getDefinitions() { std::vector> result; result.push_back(bindVectorOperation( - CAST_TO_INT32_FUNC_NAME, INT16, INT32)); + CAST_TO_INT32_FUNC_NAME, LogicalTypeID::INT16, LogicalTypeID::INT32)); // down cast result.push_back(bindVectorOperation( - CAST_TO_INT32_FUNC_NAME, INT64, INT32)); + CAST_TO_INT32_FUNC_NAME, LogicalTypeID::INT64, LogicalTypeID::INT32)); result.push_back(bindVectorOperation( - CAST_TO_INT32_FUNC_NAME, FLOAT, INT32)); + CAST_TO_INT32_FUNC_NAME, LogicalTypeID::FLOAT, LogicalTypeID::INT32)); result.push_back(bindVectorOperation( - CAST_TO_INT32_FUNC_NAME, DOUBLE, INT32)); + CAST_TO_INT32_FUNC_NAME, LogicalTypeID::DOUBLE, LogicalTypeID::INT32)); return result; } @@ -217,13 +217,13 @@ CastToInt16VectorOperation::getDefinitions() { std::vector> result; // down cast result.push_back(bindVectorOperation( - CAST_TO_INT16_FUNC_NAME, INT32, INT16)); + CAST_TO_INT16_FUNC_NAME, LogicalTypeID::INT32, LogicalTypeID::INT16)); result.push_back(bindVectorOperation( - CAST_TO_INT16_FUNC_NAME, INT64, INT16)); + CAST_TO_INT16_FUNC_NAME, LogicalTypeID::INT64, LogicalTypeID::INT16)); result.push_back(bindVectorOperation( - CAST_TO_INT16_FUNC_NAME, FLOAT, INT16)); + CAST_TO_INT16_FUNC_NAME, LogicalTypeID::FLOAT, LogicalTypeID::INT16)); result.push_back(bindVectorOperation( - CAST_TO_INT16_FUNC_NAME, DOUBLE, INT16)); + CAST_TO_INT16_FUNC_NAME, LogicalTypeID::DOUBLE, LogicalTypeID::INT16)); return result; } diff --git a/src/function/vector_date_operations.cpp b/src/function/vector_date_operations.cpp index 43c2045ad8..e8085e37a6 100644 --- a/src/function/vector_date_operations.cpp +++ b/src/function/vector_date_operations.cpp @@ -10,13 +10,16 @@ namespace function { std::vector> DatePartVectorOperation::getDefinitions() { std::vector> result; result.push_back(make_unique(DATE_PART_FUNC_NAME, - std::vector{STRING, DATE}, INT64, + std::vector{LogicalTypeID::STRING, LogicalTypeID::DATE}, + LogicalTypeID::INT64, BinaryExecFunction)); result.push_back(make_unique(DATE_PART_FUNC_NAME, - std::vector{STRING, TIMESTAMP}, INT64, + std::vector{LogicalTypeID::STRING, LogicalTypeID::TIMESTAMP}, + LogicalTypeID::INT64, BinaryExecFunction)); result.push_back(make_unique(DATE_PART_FUNC_NAME, - std::vector{STRING, INTERVAL}, INT64, + std::vector{LogicalTypeID::STRING, LogicalTypeID::INTERVAL}, + LogicalTypeID::INT64, BinaryExecFunction)); return result; } @@ -24,21 +27,22 @@ std::vector> DatePartVectorOperation: std::vector> DateTruncVectorOperation::getDefinitions() { std::vector> result; result.push_back(make_unique(DATE_TRUNC_FUNC_NAME, - std::vector{STRING, DATE}, DATE, + std::vector{LogicalTypeID::STRING, LogicalTypeID::DATE}, LogicalTypeID::DATE, BinaryExecFunction)); result.push_back(make_unique(DATE_TRUNC_FUNC_NAME, - std::vector{STRING, TIMESTAMP}, TIMESTAMP, + std::vector{LogicalTypeID::STRING, LogicalTypeID::TIMESTAMP}, + LogicalTypeID::TIMESTAMP, BinaryExecFunction)); return result; } std::vector> DayNameVectorOperation::getDefinitions() { std::vector> result; - result.push_back( - make_unique(DAYNAME_FUNC_NAME, std::vector{DATE}, - STRING, UnaryExecFunction)); result.push_back(make_unique(DAYNAME_FUNC_NAME, - std::vector{TIMESTAMP}, STRING, + std::vector{LogicalTypeID::DATE}, LogicalTypeID::STRING, + UnaryExecFunction)); + result.push_back(make_unique(DAYNAME_FUNC_NAME, + std::vector{LogicalTypeID::TIMESTAMP}, LogicalTypeID::STRING, UnaryExecFunction)); return result; } @@ -46,32 +50,34 @@ std::vector> DayNameVectorOperation:: std::vector> GreatestVectorOperation::getDefinitions() { std::vector> result; result.push_back(make_unique(GREATEST_FUNC_NAME, - std::vector{DATE, DATE}, DATE, + std::vector{LogicalTypeID::DATE, LogicalTypeID::DATE}, LogicalTypeID::DATE, BinaryExecFunction)); result.push_back(make_unique(GREATEST_FUNC_NAME, - std::vector{TIMESTAMP, TIMESTAMP}, TIMESTAMP, + std::vector{LogicalTypeID::TIMESTAMP, LogicalTypeID::TIMESTAMP}, + LogicalTypeID::TIMESTAMP, BinaryExecFunction)); return result; } std::vector> LastDayVectorOperation::getDefinitions() { std::vector> result; - result.push_back( - make_unique(LAST_DAY_FUNC_NAME, std::vector{DATE}, - DATE, UnaryExecFunction)); result.push_back(make_unique(LAST_DAY_FUNC_NAME, - std::vector{TIMESTAMP}, DATE, + std::vector{LogicalTypeID::DATE}, LogicalTypeID::DATE, + UnaryExecFunction)); + result.push_back(make_unique(LAST_DAY_FUNC_NAME, + std::vector{LogicalTypeID::TIMESTAMP}, LogicalTypeID::DATE, UnaryExecFunction)); return result; } std::vector> LeastVectorOperation::getDefinitions() { std::vector> result; - result.push_back( - make_unique(LEAST_FUNC_NAME, std::vector{DATE, DATE}, - DATE, BinaryExecFunction)); result.push_back(make_unique(LEAST_FUNC_NAME, - std::vector{TIMESTAMP, TIMESTAMP}, TIMESTAMP, + std::vector{LogicalTypeID::DATE, LogicalTypeID::DATE}, LogicalTypeID::DATE, + BinaryExecFunction)); + result.push_back(make_unique(LEAST_FUNC_NAME, + std::vector{LogicalTypeID::TIMESTAMP, LogicalTypeID::TIMESTAMP}, + LogicalTypeID::TIMESTAMP, BinaryExecFunction)); return result; } @@ -79,18 +85,20 @@ std::vector> LeastVectorOperation::ge std::vector> MakeDateVectorOperation::getDefinitions() { std::vector> result; result.push_back(make_unique(MAKE_DATE_FUNC_NAME, - std::vector{INT64, INT64, INT64}, DATE, + std::vector{ + LogicalTypeID::INT64, LogicalTypeID::INT64, LogicalTypeID::INT64}, + LogicalTypeID::DATE, TernaryExecFunction)); return result; } std::vector> MonthNameVectorOperation::getDefinitions() { std::vector> result; - result.push_back( - make_unique(MONTHNAME_FUNC_NAME, std::vector{DATE}, - STRING, UnaryExecFunction)); result.push_back(make_unique(MONTHNAME_FUNC_NAME, - std::vector{TIMESTAMP}, STRING, + std::vector{LogicalTypeID::DATE}, LogicalTypeID::STRING, + UnaryExecFunction)); + result.push_back(make_unique(MONTHNAME_FUNC_NAME, + std::vector{LogicalTypeID::TIMESTAMP}, LogicalTypeID::STRING, UnaryExecFunction)); return result; } diff --git a/src/function/vector_hash_operations.cpp b/src/function/vector_hash_operations.cpp index d0ec652232..b355f986d6 100644 --- a/src/function/vector_hash_operations.cpp +++ b/src/function/vector_hash_operations.cpp @@ -1,7 +1,6 @@ #include "function/hash/vector_hash_operations.h" #include "function/binary_operation_executor.h" -#include "function/unary_operation_executor.h" using namespace kuzu::common; @@ -10,52 +9,47 @@ namespace function { void VectorHashOperations::computeHash(ValueVector* operand, ValueVector* result) { result->state = operand->state; - assert(result->dataType.typeID == INT64); - switch (operand->dataType.typeID) { - case INTERNAL_ID: { + assert(result->dataType.getLogicalTypeID() == LogicalTypeID::INT64); + switch (operand->dataType.getPhysicalType()) { + case PhysicalTypeID::INTERNAL_ID: { UnaryHashOperationExecutor::execute(*operand, *result); } break; - case BOOL: { + case PhysicalTypeID::BOOL: { UnaryHashOperationExecutor::execute(*operand, *result); } break; - case INT64: { + case PhysicalTypeID::INT64: { UnaryHashOperationExecutor::execute(*operand, *result); } break; - case INT32: { + case PhysicalTypeID::INT32: { UnaryHashOperationExecutor::execute(*operand, *result); } break; - case INT16: { + case PhysicalTypeID::INT16: { UnaryHashOperationExecutor::execute(*operand, *result); } break; - case DOUBLE: { + case PhysicalTypeID::DOUBLE: { UnaryHashOperationExecutor::execute(*operand, *result); } break; - case FLOAT: { + case PhysicalTypeID::FLOAT: { UnaryHashOperationExecutor::execute(*operand, *result); } break; - case STRING: { + case PhysicalTypeID::STRING: { UnaryHashOperationExecutor::execute(*operand, *result); } break; - case DATE: { - UnaryHashOperationExecutor::execute(*operand, *result); - } break; - case TIMESTAMP: { - UnaryHashOperationExecutor::execute(*operand, *result); - } break; - case INTERVAL: { + case PhysicalTypeID::INTERVAL: { UnaryHashOperationExecutor::execute(*operand, *result); } break; default: { throw RuntimeException( - "Cannot hash data type " + Types::dataTypeToString(operand->dataType.typeID)); + "Cannot hash data type " + + LogicalTypeUtils::dataTypeToString(operand->dataType.getLogicalTypeID())); } } } void VectorHashOperations::combineHash(ValueVector* left, ValueVector* right, ValueVector* result) { - assert(left->dataType.typeID == INT64); - assert(left->dataType.typeID == right->dataType.typeID); - assert(left->dataType.typeID == result->dataType.typeID); + assert(left->dataType.getLogicalTypeID() == LogicalTypeID::INT64); + assert(left->dataType.getLogicalTypeID() == right->dataType.getLogicalTypeID()); + assert(left->dataType.getLogicalTypeID() == result->dataType.getLogicalTypeID()); // TODO(Xiyang/Guodong): we should resolve result state of hash vector at compile time. result->state = !right->state->isFlat() ? right->state : left->state; BinaryOperationExecutor::execute( diff --git a/src/function/vector_list_operation.cpp b/src/function/vector_list_operation.cpp index fbdf6bdb5e..81b375ed3b 100644 --- a/src/function/vector_list_operation.cpp +++ b/src/function/vector_list_operation.cpp @@ -23,15 +23,15 @@ namespace kuzu { namespace function { static std::string getListFunctionIncompatibleChildrenTypeErrorMsg( - const std::string& functionName, const DataType& left, const DataType& right) { + const std::string& functionName, const LogicalType& left, const LogicalType& right) { return std::string("Cannot bind " + functionName + " with parameter type " + - Types::dataTypeToString(left) + " and " + Types::dataTypeToString(right) + - "."); + LogicalTypeUtils::dataTypeToString(left) + " and " + + LogicalTypeUtils::dataTypeToString(right) + "."); } void ListCreationVectorOperation::execFunc( const std::vector>& parameters, ValueVector& result) { - assert(result.dataType.typeID == VAR_LIST); + assert(result.dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST); common::StringVector::resetOverflowBuffer(&result); for (auto selectedPos = 0u; selectedPos < result.state->selVector->selectedSize; ++selectedPos) { @@ -62,20 +62,24 @@ std::unique_ptr ListCreationVectorOperation::bindFunc( // ListCreation requires all parameters to have the same type or be ANY type. The result type of // listCreation can be determined by the first non-ANY type parameter. If all parameters have // dataType ANY, then the resultType will be INT64[] (default type). - auto resultType = DataType{std::make_unique(INT64)}; + auto varListTypeInfo = + std::make_unique(std::make_unique(LogicalTypeID::INT64)); + auto resultType = LogicalType{LogicalTypeID::VAR_LIST, std::move(varListTypeInfo)}; for (auto i = 0u; i < arguments.size(); i++) { - if (arguments[i]->getDataType().typeID != common::ANY) { - resultType = DataType{std::make_unique(arguments[i]->getDataType())}; + if (arguments[i]->getDataType().getLogicalTypeID() != common::LogicalTypeID::ANY) { + varListTypeInfo = std::make_unique( + std::make_unique(arguments[i]->getDataType())); + resultType = LogicalType{common::LogicalTypeID::VAR_LIST, std::move(varListTypeInfo)}; break; } } + auto resultChildType = VarListType::getChildType(&resultType); // Cast parameters with ANY dataType to resultChildType. for (auto i = 0u; i < arguments.size(); i++) { auto parameterType = arguments[i]->getDataType(); - if (parameterType != *resultType.getChildType()) { - if (parameterType.typeID == common::ANY) { - binder::ExpressionBinder::resolveAnyDataType( - *arguments[i], *resultType.getChildType()); + if (parameterType != *resultChildType) { + if (parameterType.getLogicalTypeID() == common::LogicalTypeID::ANY) { + binder::ExpressionBinder::resolveAnyDataType(*arguments[i], *resultChildType); } else { throw BinderException( getListFunctionIncompatibleChildrenTypeErrorMsg(LIST_CREATION_FUNC_NAME, @@ -90,8 +94,8 @@ std::vector> ListCreationVectorOperation::getDefinitions() { std::vector> result; result.push_back(std::make_unique(LIST_CREATION_FUNC_NAME, - std::vector{ANY}, VAR_LIST, execFunc, nullptr, bindFunc, - true /* isVarLength */)); + std::vector{LogicalTypeID::ANY}, LogicalTypeID::VAR_LIST, execFunc, nullptr, + bindFunc, true /* isVarLength */)); return result; } @@ -99,44 +103,45 @@ std::vector> ListLenVectorOperation:: std::vector> result; auto execFunc = UnaryExecFunction; result.push_back(std::make_unique(LIST_LEN_FUNC_NAME, - std::vector{VAR_LIST}, INT64, execFunc, true /* isVarlength*/)); + std::vector{LogicalTypeID::VAR_LIST}, LogicalTypeID::INT64, execFunc, + true /* isVarlength*/)); return result; } std::unique_ptr ListExtractVectorOperation::bindFunc( const binder::expression_vector& arguments, FunctionDefinition* definition) { - auto resultType = *arguments[0]->getDataType().getChildType(); + auto resultType = VarListType::getChildType(&arguments[0]->dataType); auto vectorOperationDefinition = reinterpret_cast(definition); - switch (resultType.typeID) { - case BOOL: { + switch (resultType->getLogicalTypeID()) { + case LogicalTypeID::BOOL: { vectorOperationDefinition->execFunc = BinaryListExecFunction; } break; - case INT64: { + case LogicalTypeID::INT64: { vectorOperationDefinition->execFunc = BinaryListExecFunction; } break; - case DOUBLE: { + case LogicalTypeID::DOUBLE: { vectorOperationDefinition->execFunc = BinaryListExecFunction; } break; - case DATE: { + case LogicalTypeID::DATE: { vectorOperationDefinition->execFunc = BinaryListExecFunction; } break; - case TIMESTAMP: { + case LogicalTypeID::TIMESTAMP: { vectorOperationDefinition->execFunc = BinaryListExecFunction; } break; - case INTERVAL: { + case LogicalTypeID::INTERVAL: { vectorOperationDefinition->execFunc = BinaryListExecFunction; } break; - case STRING: { + case LogicalTypeID::STRING: { vectorOperationDefinition->execFunc = BinaryListExecFunction; } break; - case VAR_LIST: { + case LogicalTypeID::VAR_LIST: { vectorOperationDefinition->execFunc = BinaryListExecFunction; } break; @@ -144,17 +149,18 @@ std::unique_ptr ListExtractVectorOperation::bindFunc( throw common::NotImplementedException("ListExtractVectorOperation::bindFunc"); } } - return std::make_unique(resultType); + return std::make_unique(*resultType); } std::vector> ListExtractVectorOperation::getDefinitions() { std::vector> result; result.push_back(std::make_unique(LIST_EXTRACT_FUNC_NAME, - std::vector{VAR_LIST, INT64}, ANY, nullptr, nullptr, bindFunc, - false /* isVarlength*/)); + std::vector{LogicalTypeID::VAR_LIST, LogicalTypeID::INT64}, + LogicalTypeID::ANY, nullptr, nullptr, bindFunc, false /* isVarlength*/)); result.push_back(std::make_unique(LIST_EXTRACT_FUNC_NAME, - std::vector{STRING, INT64}, STRING, + std::vector{LogicalTypeID::STRING, LogicalTypeID::INT64}, + LogicalTypeID::STRING, BinaryExecFunction, false /* isVarlength */)); return result; @@ -166,8 +172,8 @@ ListConcatVectorOperation::getDefinitions() { auto execFunc = BinaryListExecFunction; result.push_back(std::make_unique(LIST_CONCAT_FUNC_NAME, - std::vector{VAR_LIST, VAR_LIST}, VAR_LIST, execFunc, nullptr, bindFunc, - false /* isVarlength*/)); + std::vector{LogicalTypeID::VAR_LIST, LogicalTypeID::VAR_LIST}, + LogicalTypeID::VAR_LIST, execFunc, nullptr, bindFunc, false /* isVarlength*/)); return result; } @@ -182,42 +188,42 @@ std::unique_ptr ListConcatVectorOperation::bindFunc( std::unique_ptr ListAppendVectorOperation::bindFunc( const binder::expression_vector& arguments, FunctionDefinition* definition) { - if (*arguments[0]->getDataType().getChildType() != arguments[1]->getDataType()) { + if (*VarListType::getChildType(&arguments[0]->dataType) != arguments[1]->getDataType()) { throw BinderException(getListFunctionIncompatibleChildrenTypeErrorMsg( LIST_APPEND_FUNC_NAME, arguments[0]->getDataType(), arguments[1]->getDataType())); } auto resultType = arguments[0]->getDataType(); auto vectorOperationDefinition = reinterpret_cast(definition); - switch (arguments[1]->getDataType().typeID) { - case INT64: { + switch (arguments[1]->getDataType().getLogicalTypeID()) { + case LogicalTypeID::INT64: { vectorOperationDefinition->execFunc = BinaryListExecFunction; } break; - case DOUBLE: { + case LogicalTypeID::DOUBLE: { vectorOperationDefinition->execFunc = BinaryListExecFunction; } break; - case BOOL: { + case LogicalTypeID::BOOL: { vectorOperationDefinition->execFunc = BinaryListExecFunction; } break; - case STRING: { + case LogicalTypeID::STRING: { vectorOperationDefinition->execFunc = BinaryListExecFunction; } break; - case DATE: { + case LogicalTypeID::DATE: { vectorOperationDefinition->execFunc = BinaryListExecFunction; } break; - case TIMESTAMP: { + case LogicalTypeID::TIMESTAMP: { vectorOperationDefinition->execFunc = BinaryListExecFunction; } break; - case INTERVAL: { + case LogicalTypeID::INTERVAL: { vectorOperationDefinition->execFunc = BinaryListExecFunction; } break; - case VAR_LIST: { + case LogicalTypeID::VAR_LIST: { vectorOperationDefinition->execFunc = BinaryListExecFunction; } break; @@ -232,49 +238,49 @@ std::vector> ListAppendVectorOperation::getDefinitions() { std::vector> result; result.push_back(std::make_unique(LIST_APPEND_FUNC_NAME, - std::vector{VAR_LIST, ANY}, VAR_LIST, nullptr, nullptr, bindFunc, - false /* isVarlength*/)); + std::vector{LogicalTypeID::VAR_LIST, LogicalTypeID::ANY}, + LogicalTypeID::VAR_LIST, nullptr, nullptr, bindFunc, false /* isVarlength*/)); return result; } std::unique_ptr ListPrependVectorOperation::bindFunc( const binder::expression_vector& arguments, FunctionDefinition* definition) { - if (arguments[0]->dataType != *arguments[1]->dataType.getChildType()) { + if (arguments[0]->dataType != *VarListType::getChildType(&arguments[1]->dataType)) { throw BinderException(getListFunctionIncompatibleChildrenTypeErrorMsg( LIST_APPEND_FUNC_NAME, arguments[0]->getDataType(), arguments[1]->getDataType())); } auto resultType = arguments[1]->getDataType(); auto vectorOperationDefinition = reinterpret_cast(definition); - switch (arguments[0]->getDataType().getTypeID()) { - case INT64: { + switch (arguments[0]->getDataType().getLogicalTypeID()) { + case LogicalTypeID::INT64: { vectorOperationDefinition->execFunc = BinaryListExecFunction; } break; - case DOUBLE: { + case LogicalTypeID::DOUBLE: { vectorOperationDefinition->execFunc = BinaryListExecFunction; } break; - case BOOL: { + case LogicalTypeID::BOOL: { vectorOperationDefinition->execFunc = BinaryListExecFunction; } break; - case STRING: { + case LogicalTypeID::STRING: { vectorOperationDefinition->execFunc = BinaryListExecFunction; } break; - case DATE: { + case LogicalTypeID::DATE: { vectorOperationDefinition->execFunc = BinaryListExecFunction; } break; - case TIMESTAMP: { + case LogicalTypeID::TIMESTAMP: { vectorOperationDefinition->execFunc = BinaryListExecFunction; } break; - case INTERVAL: { + case LogicalTypeID::INTERVAL: { vectorOperationDefinition->execFunc = BinaryListExecFunction; } break; - case VAR_LIST: { + case LogicalTypeID::VAR_LIST: { vectorOperationDefinition->execFunc = BinaryListExecFunction; } break; @@ -289,32 +295,36 @@ std::vector> ListPrependVectorOperation::getDefinitions() { std::vector> result; result.push_back(std::make_unique(LIST_PREPEND_FUNC_NAME, - std::vector{ANY, VAR_LIST}, VAR_LIST, nullptr, nullptr, bindFunc, - false /* isVarlength */)); + std::vector{LogicalTypeID::ANY, LogicalTypeID::VAR_LIST}, + LogicalTypeID::VAR_LIST, nullptr, nullptr, bindFunc, false /* isVarlength */)); return result; } std::vector> ListPositionVectorOperation::getDefinitions() { return getBinaryListOperationDefinitions( - LIST_POSITION_FUNC_NAME, INT64); + LIST_POSITION_FUNC_NAME, LogicalTypeID::INT64); } std::vector> ListContainsVectorOperation::getDefinitions() { return getBinaryListOperationDefinitions( - LIST_CONTAINS_FUNC_NAME, BOOL); + LIST_CONTAINS_FUNC_NAME, LogicalTypeID::BOOL); } std::vector> ListSliceVectorOperation::getDefinitions() { std::vector> result; result.push_back(std::make_unique(LIST_SLICE_FUNC_NAME, - std::vector{VAR_LIST, INT64, INT64}, VAR_LIST, + std::vector{ + LogicalTypeID::VAR_LIST, LogicalTypeID::INT64, LogicalTypeID::INT64}, + LogicalTypeID::VAR_LIST, TernaryListExecFunction, nullptr, bindFunc, false /* isVarlength*/)); result.push_back(std::make_unique(LIST_SLICE_FUNC_NAME, - std::vector{STRING, INT64, INT64}, STRING, + std::vector{ + LogicalTypeID::STRING, LogicalTypeID::INT64, LogicalTypeID::INT64}, + LogicalTypeID::STRING, TernaryListExecFunction, false /* isVarlength */)); @@ -329,49 +339,50 @@ std::unique_ptr ListSliceVectorOperation::bindFunc( std::vector> ListSortVectorOperation::getDefinitions() { std::vector> result; result.push_back(std::make_unique(LIST_SORT_FUNC_NAME, - std::vector{VAR_LIST}, VAR_LIST, nullptr, nullptr, bindFunc, - false /* isVarlength*/)); + std::vector{LogicalTypeID::VAR_LIST}, LogicalTypeID::VAR_LIST, nullptr, + nullptr, bindFunc, false /* isVarlength*/)); result.push_back(std::make_unique(LIST_SORT_FUNC_NAME, - std::vector{VAR_LIST, STRING}, VAR_LIST, nullptr, nullptr, bindFunc, - false /* isVarlength*/)); + std::vector{LogicalTypeID::VAR_LIST, LogicalTypeID::STRING}, + LogicalTypeID::VAR_LIST, nullptr, nullptr, bindFunc, false /* isVarlength*/)); result.push_back(std::make_unique(LIST_SORT_FUNC_NAME, - std::vector{VAR_LIST, STRING, STRING}, VAR_LIST, nullptr, nullptr, bindFunc, - false /* isVarlength*/)); + std::vector{ + LogicalTypeID::VAR_LIST, LogicalTypeID::STRING, LogicalTypeID::STRING}, + LogicalTypeID::VAR_LIST, nullptr, nullptr, bindFunc, false /* isVarlength*/)); return result; } std::unique_ptr ListSortVectorOperation::bindFunc( const binder::expression_vector& arguments, FunctionDefinition* definition) { auto vectorOperationDefinition = reinterpret_cast(definition); - switch (arguments[0]->dataType.getChildType()->getTypeID()) { - case INT64: { + switch (VarListType::getChildType(&arguments[0]->dataType)->getLogicalTypeID()) { + case LogicalTypeID::INT64: { vectorOperationDefinition->execFunc = getExecFunction(arguments); } break; - case INT32: { + case LogicalTypeID::INT32: { vectorOperationDefinition->execFunc = getExecFunction(arguments); } break; - case INT16: { + case LogicalTypeID::INT16: { vectorOperationDefinition->execFunc = getExecFunction(arguments); } break; - case DOUBLE: { + case LogicalTypeID::DOUBLE: { vectorOperationDefinition->execFunc = getExecFunction(arguments); } break; - case FLOAT: { + case LogicalTypeID::FLOAT: { vectorOperationDefinition->execFunc = getExecFunction(arguments); } break; - case BOOL: { + case LogicalTypeID::BOOL: { vectorOperationDefinition->execFunc = getExecFunction(arguments); } break; - case STRING: { + case LogicalTypeID::STRING: { vectorOperationDefinition->execFunc = getExecFunction(arguments); } break; - case DATE: { + case LogicalTypeID::DATE: { vectorOperationDefinition->execFunc = getExecFunction(arguments); } break; - case TIMESTAMP: { + case LogicalTypeID::TIMESTAMP: { vectorOperationDefinition->execFunc = getExecFunction(arguments); } break; - case INTERVAL: { + case LogicalTypeID::INTERVAL: { vectorOperationDefinition->execFunc = getExecFunction(arguments); } break; default: { @@ -401,46 +412,46 @@ std::vector> ListReverseSortVectorOperation::getDefinitions() { std::vector> result; result.push_back(std::make_unique(LIST_REVERSE_SORT_FUNC_NAME, - std::vector{VAR_LIST}, VAR_LIST, nullptr, nullptr, bindFunc, - false /* isVarlength*/)); + std::vector{LogicalTypeID::VAR_LIST}, LogicalTypeID::VAR_LIST, nullptr, + nullptr, bindFunc, false /* isVarlength*/)); result.push_back(std::make_unique(LIST_REVERSE_SORT_FUNC_NAME, - std::vector{VAR_LIST, STRING}, VAR_LIST, nullptr, nullptr, bindFunc, - false /* isVarlength*/)); + std::vector{LogicalTypeID::VAR_LIST, LogicalTypeID::STRING}, + LogicalTypeID::VAR_LIST, nullptr, nullptr, bindFunc, false /* isVarlength*/)); return result; } std::unique_ptr ListReverseSortVectorOperation::bindFunc( const binder::expression_vector& arguments, FunctionDefinition* definition) { auto vectorOperationDefinition = reinterpret_cast(definition); - switch (arguments[0]->dataType.getChildType()->getTypeID()) { - case INT64: { + switch (VarListType::getChildType(&arguments[0]->dataType)->getLogicalTypeID()) { + case LogicalTypeID::INT64: { vectorOperationDefinition->execFunc = getExecFunction(arguments); } break; - case INT32: { + case LogicalTypeID::INT32: { vectorOperationDefinition->execFunc = getExecFunction(arguments); } break; - case INT16: { + case LogicalTypeID::INT16: { vectorOperationDefinition->execFunc = getExecFunction(arguments); } break; - case DOUBLE: { + case LogicalTypeID::DOUBLE: { vectorOperationDefinition->execFunc = getExecFunction(arguments); } break; - case FLOAT: { + case LogicalTypeID::FLOAT: { vectorOperationDefinition->execFunc = getExecFunction(arguments); } break; - case BOOL: { + case LogicalTypeID::BOOL: { vectorOperationDefinition->execFunc = getExecFunction(arguments); } break; - case STRING: { + case LogicalTypeID::STRING: { vectorOperationDefinition->execFunc = getExecFunction(arguments); } break; - case DATE: { + case LogicalTypeID::DATE: { vectorOperationDefinition->execFunc = getExecFunction(arguments); } break; - case TIMESTAMP: { + case LogicalTypeID::TIMESTAMP: { vectorOperationDefinition->execFunc = getExecFunction(arguments); } break; - case INTERVAL: { + case LogicalTypeID::INTERVAL: { vectorOperationDefinition->execFunc = getExecFunction(arguments); } break; default: { @@ -466,81 +477,93 @@ scalar_exec_func ListReverseSortVectorOperation::getExecFunction( std::vector> ListSumVectorOperation::getDefinitions() { std::vector> result; result.push_back(std::make_unique(LIST_SUM_FUNC_NAME, - std::vector{VAR_LIST}, INT64, nullptr, nullptr, bindFunc, - false /* isVarlength*/)); + std::vector{LogicalTypeID::VAR_LIST}, LogicalTypeID::INT64, nullptr, nullptr, + bindFunc, false /* isVarlength*/)); return result; } std::unique_ptr ListSumVectorOperation::bindFunc( const binder::expression_vector& arguments, FunctionDefinition* definition) { auto vectorOperationDefinition = reinterpret_cast(definition); - auto resultType = *arguments[0]->getDataType().getChildType(); - switch (resultType.getTypeID()) { - case INT64: { + auto resultType = VarListType::getChildType(&arguments[0]->dataType); + switch (resultType->getLogicalTypeID()) { + case LogicalTypeID::INT64: { vectorOperationDefinition->execFunc = UnaryListExecFunction; } break; - case DOUBLE: { + case LogicalTypeID::INT32: { + vectorOperationDefinition->execFunc = + UnaryListExecFunction; + } break; + case LogicalTypeID::INT16: { + vectorOperationDefinition->execFunc = + UnaryListExecFunction; + } break; + case LogicalTypeID::DOUBLE: { vectorOperationDefinition->execFunc = UnaryListExecFunction; } break; + case LogicalTypeID::FLOAT: { + vectorOperationDefinition->execFunc = + UnaryListExecFunction; + } break; default: { throw common::NotImplementedException("ListSumVectorOperation::bindFunc"); } } - return std::make_unique(resultType); + return std::make_unique(*resultType); } std::vector> ListDistinctVectorOperation::getDefinitions() { std::vector> result; result.push_back(std::make_unique(LIST_DISTINCT_FUNC_NAME, - std::vector{VAR_LIST}, VAR_LIST, nullptr, nullptr, bindFunc, - false /* isVarlength*/)); + std::vector{LogicalTypeID::VAR_LIST}, LogicalTypeID::VAR_LIST, nullptr, + nullptr, bindFunc, false /* isVarlength*/)); return result; } std::unique_ptr ListDistinctVectorOperation::bindFunc( const binder::expression_vector& arguments, FunctionDefinition* definition) { auto vectorOperationDefinition = reinterpret_cast(definition); - switch (arguments[0]->dataType.getChildType()->getTypeID()) { - case INT64: { + switch (VarListType::getChildType(&arguments[0]->dataType)->getLogicalTypeID()) { + case LogicalTypeID::INT64: { vectorOperationDefinition->execFunc = UnaryListExecFunction>; } break; - case INT32: { + case LogicalTypeID::INT32: { vectorOperationDefinition->execFunc = UnaryListExecFunction>; } break; - case INT16: { + case LogicalTypeID::INT16: { vectorOperationDefinition->execFunc = UnaryListExecFunction>; } break; - case DOUBLE: { + case LogicalTypeID::DOUBLE: { vectorOperationDefinition->execFunc = UnaryListExecFunction>; } break; - case FLOAT: { + case LogicalTypeID::FLOAT: { vectorOperationDefinition->execFunc = UnaryListExecFunction>; } break; - case BOOL: { + case LogicalTypeID::BOOL: { vectorOperationDefinition->execFunc = UnaryListExecFunction>; } break; - case STRING: { + case LogicalTypeID::STRING: { vectorOperationDefinition->execFunc = UnaryListExecFunction>; } break; - case DATE: { + case LogicalTypeID::DATE: { vectorOperationDefinition->execFunc = UnaryListExecFunction>; } break; - case TIMESTAMP: { + case LogicalTypeID::TIMESTAMP: { vectorOperationDefinition->execFunc = UnaryListExecFunction>; } break; - case INTERVAL: { + case LogicalTypeID::INTERVAL: { vectorOperationDefinition->execFunc = UnaryListExecFunction>; } break; @@ -555,52 +578,52 @@ std::vector> ListUniqueVectorOperation::getDefinitions() { std::vector> result; result.push_back(std::make_unique(LIST_UNIQUE_FUNC_NAME, - std::vector{VAR_LIST}, INT64, nullptr, nullptr, bindFunc, - false /* isVarlength*/)); + std::vector{LogicalTypeID::VAR_LIST}, LogicalTypeID::INT64, nullptr, nullptr, + bindFunc, false /* isVarlength*/)); return result; } std::unique_ptr ListUniqueVectorOperation::bindFunc( const binder::expression_vector& arguments, FunctionDefinition* definition) { auto vectorOperationDefinition = reinterpret_cast(definition); - switch (arguments[0]->dataType.getChildType()->getTypeID()) { - case INT64: { + switch (common::VarListType::getChildType(&arguments[0]->dataType)->getLogicalTypeID()) { + case LogicalTypeID::INT64: { vectorOperationDefinition->execFunc = UnaryListExecFunction>; } break; - case INT32: { + case LogicalTypeID::INT32: { vectorOperationDefinition->execFunc = UnaryListExecFunction>; } break; - case INT16: { + case LogicalTypeID::INT16: { vectorOperationDefinition->execFunc = UnaryListExecFunction>; } break; - case DOUBLE: { + case LogicalTypeID::DOUBLE: { vectorOperationDefinition->execFunc = UnaryListExecFunction>; } break; - case FLOAT: { + case LogicalTypeID::FLOAT: { vectorOperationDefinition->execFunc = UnaryListExecFunction>; } break; - case BOOL: { + case LogicalTypeID::BOOL: { vectorOperationDefinition->execFunc = UnaryListExecFunction>; } break; - case STRING: { + case LogicalTypeID::STRING: { vectorOperationDefinition->execFunc = UnaryListExecFunction>; } break; - case DATE: { + case LogicalTypeID::DATE: { vectorOperationDefinition->execFunc = UnaryListExecFunction>; } break; - case TIMESTAMP: { + case LogicalTypeID::TIMESTAMP: { vectorOperationDefinition->execFunc = UnaryListExecFunction>; } break; - case INTERVAL: { + case LogicalTypeID::INTERVAL: { vectorOperationDefinition->execFunc = UnaryListExecFunction>; } break; @@ -608,64 +631,64 @@ std::unique_ptr ListUniqueVectorOperation::bindFunc( throw common::NotImplementedException("ListUniqueVectorOperation::bindFunc"); } } - return std::make_unique(DataType(INT64)); + return std::make_unique(LogicalType(LogicalTypeID::INT64)); } std::vector> ListAnyValueVectorOperation::getDefinitions() { std::vector> result; result.push_back(std::make_unique(LIST_ANY_VALUE_FUNC_NAME, - std::vector{VAR_LIST}, ANY, nullptr, nullptr, bindFunc, - false /* isVarlength*/)); + std::vector{LogicalTypeID::VAR_LIST}, LogicalTypeID::ANY, nullptr, nullptr, + bindFunc, false /* isVarlength*/)); return result; } std::unique_ptr ListAnyValueVectorOperation::bindFunc( const binder::expression_vector& arguments, FunctionDefinition* definition) { auto vectorOperationDefinition = reinterpret_cast(definition); - auto resultType = *arguments[0]->getDataType().getChildType(); - switch (resultType.typeID) { - case INT64: { + auto resultType = VarListType::getChildType(&arguments[0]->dataType); + switch (resultType->getLogicalTypeID()) { + case LogicalTypeID::INT64: { vectorOperationDefinition->execFunc = UnaryListExecFunction; } break; - case INT32: { + case LogicalTypeID::INT32: { vectorOperationDefinition->execFunc = UnaryListExecFunction; } break; - case INT16: { + case LogicalTypeID::INT16: { vectorOperationDefinition->execFunc = UnaryListExecFunction; } break; - case DOUBLE: { + case LogicalTypeID::DOUBLE: { vectorOperationDefinition->execFunc = UnaryListExecFunction; } break; - case FLOAT: { + case LogicalTypeID::FLOAT: { vectorOperationDefinition->execFunc = UnaryListExecFunction; } break; - case BOOL: { + case LogicalTypeID::BOOL: { vectorOperationDefinition->execFunc = UnaryListExecFunction; } break; - case STRING: { + case LogicalTypeID::STRING: { vectorOperationDefinition->execFunc = UnaryListExecFunction; } break; - case DATE: { + case LogicalTypeID::DATE: { vectorOperationDefinition->execFunc = UnaryListExecFunction; } break; - case TIMESTAMP: { + case LogicalTypeID::TIMESTAMP: { vectorOperationDefinition->execFunc = UnaryListExecFunction; } break; - case INTERVAL: { + case LogicalTypeID::INTERVAL: { vectorOperationDefinition->execFunc = UnaryListExecFunction; } break; - case VAR_LIST: { + case LogicalTypeID::VAR_LIST: { vectorOperationDefinition->execFunc = UnaryListExecFunction; } break; @@ -673,7 +696,7 @@ std::unique_ptr ListAnyValueVectorOperation::bindFunc( throw common::NotImplementedException("ListAnyValueVectorOperation::bindFunc"); } } - return std::make_unique(resultType); + return std::make_unique(*resultType); } } // namespace function diff --git a/src/function/vector_string_operations.cpp b/src/function/vector_string_operations.cpp index a25f529e39..b39fcbf05f 100644 --- a/src/function/vector_string_operations.cpp +++ b/src/function/vector_string_operations.cpp @@ -27,7 +27,8 @@ std::vector> ArrayExtractVectorOperation::getDefinitions() { std::vector> definitions; definitions.emplace_back(make_unique(ARRAY_EXTRACT_FUNC_NAME, - std::vector{STRING, INT64}, STRING, + std::vector{LogicalTypeID::STRING, LogicalTypeID::INT64}, + LogicalTypeID::STRING, BinaryExecFunction, false /* isVarLength */)); return definitions; @@ -36,7 +37,8 @@ ArrayExtractVectorOperation::getDefinitions() { std::vector> ConcatVectorOperation::getDefinitions() { std::vector> definitions; definitions.emplace_back(make_unique(CONCAT_FUNC_NAME, - std::vector{STRING, STRING}, STRING, + std::vector{LogicalTypeID::STRING, LogicalTypeID::STRING}, + LogicalTypeID::STRING, BinaryStringExecFunction, false /* isVarLength */)); return definitions; @@ -45,7 +47,8 @@ std::vector> ConcatVectorOperation::g std::vector> ContainsVectorOperation::getDefinitions() { std::vector> definitions; definitions.emplace_back(make_unique(CONTAINS_FUNC_NAME, - std::vector{STRING, STRING}, BOOL, + std::vector{LogicalTypeID::STRING, LogicalTypeID::STRING}, + LogicalTypeID::BOOL, BinaryExecFunction, BinarySelectFunction, false /* isVarLength */)); @@ -55,7 +58,8 @@ std::vector> ContainsVectorOperation: std::vector> EndsWithVectorOperation::getDefinitions() { std::vector> definitions; definitions.emplace_back(make_unique(ENDS_WITH_FUNC_NAME, - std::vector{STRING, STRING}, BOOL, + std::vector{LogicalTypeID::STRING, LogicalTypeID::STRING}, + LogicalTypeID::BOOL, BinaryExecFunction, BinarySelectFunction, false /* isVarLength */)); @@ -65,7 +69,8 @@ std::vector> EndsWithVectorOperation: std::vector> LeftVectorOperation::getDefinitions() { std::vector> definitions; definitions.emplace_back(make_unique(LEFT_FUNC_NAME, - std::vector{STRING, INT64}, STRING, + std::vector{LogicalTypeID::STRING, LogicalTypeID::INT64}, + LogicalTypeID::STRING, BinaryStringExecFunction, false /* isVarLength */)); return definitions; @@ -74,7 +79,7 @@ std::vector> LeftVectorOperation::get std::vector> LengthVectorOperation::getDefinitions() { std::vector> definitions; definitions.emplace_back(make_unique(LENGTH_FUNC_NAME, - std::vector{STRING}, INT64, + std::vector{LogicalTypeID::STRING}, LogicalTypeID::INT64, UnaryExecFunction, false /* isVarLength */)); return definitions; } @@ -82,7 +87,9 @@ std::vector> LengthVectorOperation::g std::vector> LpadVectorOperation::getDefinitions() { std::vector> definitions; definitions.emplace_back(make_unique(LPAD_FUNC_NAME, - std::vector{STRING, INT64, STRING}, STRING, + std::vector{ + LogicalTypeID::STRING, LogicalTypeID::INT64, LogicalTypeID::STRING}, + LogicalTypeID::STRING, TernaryStringExecFunction, false /* isVarLength */)); return definitions; @@ -91,7 +98,8 @@ std::vector> LpadVectorOperation::get std::vector> RepeatVectorOperation::getDefinitions() { std::vector> definitions; definitions.emplace_back(make_unique(REPEAT_FUNC_NAME, - std::vector{STRING, INT64}, STRING, + std::vector{LogicalTypeID::STRING, LogicalTypeID::INT64}, + LogicalTypeID::STRING, BinaryStringExecFunction, false /* isVarLength */)); return definitions; @@ -100,7 +108,8 @@ std::vector> RepeatVectorOperation::g std::vector> RightVectorOperation::getDefinitions() { std::vector> definitions; definitions.emplace_back(make_unique(RIGHT_FUNC_NAME, - std::vector{STRING, INT64}, STRING, + std::vector{LogicalTypeID::STRING, LogicalTypeID::INT64}, + LogicalTypeID::STRING, BinaryStringExecFunction, false /* isVarLength */)); return definitions; @@ -109,7 +118,9 @@ std::vector> RightVectorOperation::ge std::vector> RpadVectorOperation::getDefinitions() { std::vector> definitions; definitions.emplace_back(make_unique(RPAD_FUNC_NAME, - std::vector{STRING, INT64, STRING}, STRING, + std::vector{ + LogicalTypeID::STRING, LogicalTypeID::INT64, LogicalTypeID::STRING}, + LogicalTypeID::STRING, TernaryStringExecFunction, false /* isVarLength */)); return definitions; @@ -119,7 +130,8 @@ std::vector> StartsWithVectorOperation::getDefinitions() { std::vector> definitions; definitions.emplace_back(make_unique(STARTS_WITH_FUNC_NAME, - std::vector{STRING, STRING}, BOOL, + std::vector{LogicalTypeID::STRING, LogicalTypeID::STRING}, + LogicalTypeID::BOOL, BinaryExecFunction, BinarySelectFunction, false /* isVarLength */)); @@ -129,7 +141,9 @@ StartsWithVectorOperation::getDefinitions() { std::vector> SubStrVectorOperation::getDefinitions() { std::vector> definitions; definitions.emplace_back(make_unique(SUBSTRING_FUNC_NAME, - std::vector{STRING, INT64, INT64}, STRING, + std::vector{ + LogicalTypeID::STRING, LogicalTypeID::INT64, LogicalTypeID::INT64}, + LogicalTypeID::STRING, TernaryStringExecFunction, false /* isVarLength */)); return definitions; @@ -139,7 +153,8 @@ std::vector> RegexpFullMatchVectorOperation::getDefinitions() { std::vector> definitions; definitions.emplace_back(make_unique(REGEXP_FULL_MATCH_FUNC_NAME, - std::vector{STRING, STRING}, BOOL, + std::vector{LogicalTypeID::STRING, LogicalTypeID::STRING}, + LogicalTypeID::BOOL, BinaryExecFunction, BinarySelectFunction, false /* isVarLength */)); @@ -149,7 +164,8 @@ RegexpFullMatchVectorOperation::getDefinitions() { std::vector> RegexpMatchesOperation::getDefinitions() { std::vector> definitions; definitions.emplace_back(make_unique(REGEXP_MATCHES_FUNC_NAME, - std::vector{STRING, STRING}, BOOL, + std::vector{LogicalTypeID::STRING, LogicalTypeID::STRING}, + LogicalTypeID::BOOL, BinaryExecFunction, BinarySelectFunction, false /* isVarLength */)); @@ -161,7 +177,9 @@ std::vector> RegexpReplaceOperation:: // Todo: Implement a function with modifiers // regexp_replace(string, regex, replacement, modifiers) definitions.emplace_back(make_unique(REGEXP_REPLACE_FUNC_NAME, - std::vector{STRING, STRING, STRING}, STRING, + std::vector{ + LogicalTypeID::STRING, LogicalTypeID::STRING, LogicalTypeID::STRING}, + LogicalTypeID::STRING, TernaryStringExecFunction, false /* isVarLength */)); @@ -171,11 +189,14 @@ std::vector> RegexpReplaceOperation:: std::vector> RegexpExtractOperation::getDefinitions() { std::vector> definitions; definitions.emplace_back(make_unique(REGEXP_EXTRACT_FUNC_NAME, - std::vector{STRING, STRING}, STRING, + std::vector{LogicalTypeID::STRING, LogicalTypeID::STRING}, + LogicalTypeID::STRING, BinaryStringExecFunction, false /* isVarLength */)); definitions.emplace_back(make_unique(REGEXP_EXTRACT_FUNC_NAME, - std::vector{STRING, STRING, INT64}, STRING, + std::vector{ + LogicalTypeID::STRING, LogicalTypeID::STRING, LogicalTypeID::INT64}, + LogicalTypeID::STRING, TernaryStringExecFunction, false /* isVarLength */)); @@ -186,12 +207,15 @@ std::vector> RegexpExtractAllOperation::getDefinitions() { std::vector> definitions; definitions.emplace_back(make_unique(REGEXP_EXTRACT_FUNC_NAME, - std::vector{STRING, STRING}, VAR_LIST, + std::vector{LogicalTypeID::STRING, LogicalTypeID::STRING}, + LogicalTypeID::VAR_LIST, BinaryStringExecFunction, nullptr, bindFunc, false /* isVarLength */)); definitions.emplace_back(make_unique(REGEXP_EXTRACT_FUNC_NAME, - std::vector{STRING, STRING, INT64}, VAR_LIST, + std::vector{ + LogicalTypeID::STRING, LogicalTypeID::STRING, LogicalTypeID::INT64}, + LogicalTypeID::VAR_LIST, TernaryStringExecFunction, nullptr, bindFunc, false /* isVarLength */)); @@ -200,7 +224,8 @@ RegexpExtractAllOperation::getDefinitions() { std::unique_ptr RegexpExtractAllOperation::bindFunc( const binder::expression_vector& arguments, FunctionDefinition* definition) { - return std::make_unique(DataType(std::make_unique(STRING))); + return std::make_unique(LogicalType(LogicalTypeID::VAR_LIST, + std::make_unique(std::make_unique(LogicalTypeID::STRING)))); } } // namespace function diff --git a/src/function/vector_struct_operations.cpp b/src/function/vector_struct_operations.cpp index 34da1bc347..add8543b09 100644 --- a/src/function/vector_struct_operations.cpp +++ b/src/function/vector_struct_operations.cpp @@ -11,8 +11,8 @@ std::vector> StructPackVectorOperations::getDefinitions() { std::vector> definitions; definitions.push_back(make_unique(common::STRUCT_PACK_FUNC_NAME, - std::vector{common::ANY}, common::STRUCT, execFunc, nullptr, bindFunc, - true /* isVarLength */)); + std::vector{common::LogicalTypeID::ANY}, + common::LogicalTypeID::STRUCT, execFunc, nullptr, bindFunc, true /* isVarLength */)); return definitions; } @@ -20,14 +20,15 @@ std::unique_ptr StructPackVectorOperations::bindFunc( const binder::expression_vector& arguments, kuzu::function::FunctionDefinition* definition) { std::vector> fields; for (auto& argument : arguments) { - if (argument->getDataType().typeID == common::ANY) { + if (argument->getDataType().getLogicalTypeID() == common::LogicalTypeID::ANY) { binder::ExpressionBinder::resolveAnyDataType( - *argument, common::DataType{common::INT64}); + *argument, common::LogicalType{common::LogicalTypeID::INT64}); } fields.emplace_back(std::make_unique( argument->getAlias(), argument->getDataType().copy())); } - auto resultType = common::DataType(std::move(fields)); + auto resultType = common::LogicalType( + common::LogicalTypeID::STRUCT, std::make_unique(std::move(fields))); return std::make_unique(resultType); } @@ -81,28 +82,28 @@ std::vector> StructExtractVectorOperations::getDefinitions() { std::vector> definitions; definitions.push_back(make_unique(common::STRUCT_EXTRACT_FUNC_NAME, - std::vector{common::STRUCT, common::STRING}, common::ANY, execFunc, - nullptr, bindFunc, false /* isVarLength */)); + std::vector{ + common::LogicalTypeID::STRUCT, common::LogicalTypeID::STRING}, + common::LogicalTypeID::ANY, execFunc, nullptr, bindFunc, false /* isVarLength */)); return definitions; } std::unique_ptr StructExtractVectorOperations::bindFunc( const binder::expression_vector& arguments, kuzu::function::FunctionDefinition* definition) { auto structType = arguments[0]->getDataType(); - auto typeInfo = reinterpret_cast(structType.getExtraTypeInfo()); if (arguments[1]->expressionType != common::LITERAL) { throw common::BinderException("Key name for struct_extract must be STRING literal."); } auto key = ((binder::LiteralExpression&)*arguments[1]).getValue()->getValue(); common::StringUtils::toUpper(key); - assert(definition->returnTypeID == common::ANY); - auto childIdx = typeInfo->getStructFieldIdx(key); + assert(definition->returnTypeID == common::LogicalTypeID::ANY); + auto childIdx = common::StructType::getStructFieldIdx(&structType, key); if (childIdx == common::INVALID_STRUCT_FIELD_IDX) { throw common::BinderException( common::StringUtils::string_format("Invalid struct field name: {}.", key)); } return std::make_unique( - *typeInfo->getChildrenTypes()[childIdx], childIdx); + *(common::StructType::getStructFieldTypes(&structType))[childIdx], childIdx); } } // namespace function diff --git a/src/function/vector_timestamp_operations.cpp b/src/function/vector_timestamp_operations.cpp index a6d8458bc9..d6bfab1df4 100644 --- a/src/function/vector_timestamp_operations.cpp +++ b/src/function/vector_timestamp_operations.cpp @@ -10,16 +10,16 @@ namespace function { std::vector> CenturyVectorOperation::getDefinitions() { std::vector> result; result.push_back(make_unique(CENTURY_FUNC_NAME, - std::vector{TIMESTAMP}, INT64, + std::vector{LogicalTypeID::TIMESTAMP}, LogicalTypeID::INT64, UnaryExecFunction)); return result; } std::vector> EpochMsVectorOperation::getDefinitions() { std::vector> result; - result.push_back( - make_unique(EPOCH_MS_FUNC_NAME, std::vector{INT64}, - TIMESTAMP, UnaryExecFunction)); + result.push_back(make_unique(EPOCH_MS_FUNC_NAME, + std::vector{LogicalTypeID::INT64}, LogicalTypeID::TIMESTAMP, + UnaryExecFunction)); return result; } @@ -27,7 +27,7 @@ std::vector> ToTimestampVectorOperation::getDefinitions() { std::vector> result; result.push_back(make_unique(TO_TIMESTAMP_FUNC_NAME, - std::vector{INT64}, TIMESTAMP, + std::vector{LogicalTypeID::INT64}, LogicalTypeID::TIMESTAMP, UnaryExecFunction)); return result; } diff --git a/src/include/binder/binder.h b/src/include/binder/binder.h index 5fefaae58f..3a0b9e52a4 100644 --- a/src/include/binder/binder.h +++ b/src/include/binder/binder.h @@ -36,7 +36,7 @@ class Binder { common::table_id_t bindNodeTableID(const std::string& tableName) const; std::shared_ptr createVariable( - const std::string& name, const common::DataType& dataType); + const std::string& name, const common::LogicalType& dataType); /*** bind DDL ***/ std::unique_ptr bindCreateNodeTableClause(const parser::Statement& statement); @@ -53,7 +53,7 @@ class Binder { std::vector> propertyNameDataTypes); common::property_id_t bindPropertyName( catalog::NodeTableSchema::TableSchema* tableSchema, const std::string& propertyName); - common::DataType bindDataType(const std::string& dataType); + common::LogicalType bindDataType(const std::string& dataType); /*** bind copy csv ***/ std::unique_ptr bindCopyClause(const parser::Statement& statement); @@ -144,14 +144,14 @@ class Binder { std::shared_ptr createQueryNode(const parser::NodePattern& nodePattern); inline std::vector bindNodeTableIDs( const std::vector& tableNames) { - return bindTableIDs(tableNames, common::NODE); + return bindTableIDs(tableNames, common::LogicalTypeID::NODE); } inline std::vector bindRelTableIDs( const std::vector& tableNames) { - return bindTableIDs(tableNames, common::REL); + return bindTableIDs(tableNames, common::LogicalTypeID::REL); } std::vector bindTableIDs( - const std::vector& tableNames, common::DataTypeID nodeOrRelType); + const std::vector& tableNames, common::LogicalTypeID nodeOrRelType); /*** validations ***/ // E.g. Optional MATCH (a) RETURN a.age diff --git a/src/include/binder/ddl/bound_add_property.h b/src/include/binder/ddl/bound_add_property.h index 81a07c2737..8faa5a2627 100644 --- a/src/include/binder/ddl/bound_add_property.h +++ b/src/include/binder/ddl/bound_add_property.h @@ -8,7 +8,8 @@ namespace binder { class BoundAddProperty : public BoundDDL { public: explicit BoundAddProperty(common::table_id_t tableID, std::string propertyName, - common::DataType dataType, std::shared_ptr defaultValue, std::string tableName) + common::LogicalType dataType, std::shared_ptr defaultValue, + std::string tableName) : BoundDDL{common::StatementType::ADD_PROPERTY, std::move(tableName)}, tableID{tableID}, propertyName{std::move(propertyName)}, dataType{std::move(dataType)}, defaultValue{std::move(defaultValue)} {} @@ -17,14 +18,14 @@ class BoundAddProperty : public BoundDDL { inline std::string getPropertyName() const { return propertyName; } - inline common::DataType getDataType() const { return dataType; } + inline common::LogicalType getDataType() const { return dataType; } inline std::shared_ptr getDefaultValue() const { return defaultValue; } private: common::table_id_t tableID; std::string propertyName; - common::DataType dataType; + common::LogicalType dataType; std::shared_ptr defaultValue; }; diff --git a/src/include/binder/expression/case_expression.h b/src/include/binder/expression/case_expression.h index 87b60ce229..ff92e86f21 100644 --- a/src/include/binder/expression/case_expression.h +++ b/src/include/binder/expression/case_expression.h @@ -16,7 +16,7 @@ struct CaseAlternative { class CaseExpression : public Expression { public: - CaseExpression(common::DataType dataType, std::shared_ptr elseExpression, + CaseExpression(common::LogicalType dataType, std::shared_ptr elseExpression, const std::string& name) : Expression{common::CASE_ELSE, std::move(dataType), name}, elseExpression{std::move( elseExpression)} {} diff --git a/src/include/binder/expression/existential_subquery_expression.h b/src/include/binder/expression/existential_subquery_expression.h index a5ede853b7..ea6e823e82 100644 --- a/src/include/binder/expression/existential_subquery_expression.h +++ b/src/include/binder/expression/existential_subquery_expression.h @@ -10,7 +10,7 @@ class ExistentialSubqueryExpression : public Expression { public: ExistentialSubqueryExpression(std::unique_ptr queryGraphCollection, std::string uniqueName, std::string rawName) - : Expression{common::EXISTENTIAL_SUBQUERY, common::DataType(common::BOOL), + : Expression{common::EXISTENTIAL_SUBQUERY, common::LogicalType(common::LogicalTypeID::BOOL), std::move(uniqueName)}, queryGraphCollection{std::move(queryGraphCollection)}, rawName{std::move(rawName)} {} diff --git a/src/include/binder/expression/expression.h b/src/include/binder/expression/expression.h index 21b8ddf2fa..31c86d1151 100644 --- a/src/include/binder/expression/expression.h +++ b/src/include/binder/expression/expression.h @@ -27,27 +27,27 @@ using expression_map = class Expression : public std::enable_shared_from_this { public: - Expression(common::ExpressionType expressionType, common::DataType dataType, + Expression(common::ExpressionType expressionType, common::LogicalType dataType, expression_vector children, std::string uniqueName) : expressionType{expressionType}, dataType{std::move(dataType)}, uniqueName{std::move(uniqueName)}, children{std::move(children)} {} // Create binary expression. - Expression(common::ExpressionType expressionType, common::DataType dataType, + Expression(common::ExpressionType expressionType, common::LogicalType dataType, const std::shared_ptr& left, const std::shared_ptr& right, std::string uniqueName) : Expression{expressionType, std::move(dataType), expression_vector{left, right}, std::move(uniqueName)} {} // Create unary expression. - Expression(common::ExpressionType expressionType, common::DataType dataType, + Expression(common::ExpressionType expressionType, common::LogicalType dataType, const std::shared_ptr& child, std::string uniqueName) : Expression{expressionType, std::move(dataType), expression_vector{child}, std::move(uniqueName)} {} // Create leaf expression Expression( - common::ExpressionType expressionType, common::DataType dataType, std::string uniqueName) + common::ExpressionType expressionType, common::LogicalType dataType, std::string uniqueName) : Expression{ expressionType, std::move(dataType), expression_vector{}, std::move(uniqueName)} {} @@ -61,7 +61,7 @@ class Expression : public std::enable_shared_from_this { return uniqueName; } - inline common::DataType getDataType() const { return dataType; } + inline common::LogicalType getDataType() const { return dataType; } inline bool hasAlias() const { return !alias.empty(); } @@ -106,7 +106,7 @@ class Expression : public std::enable_shared_from_this { public: common::ExpressionType expressionType; - common::DataType dataType; + common::LogicalType dataType; protected: // Name that serves as the unique identifier. @@ -131,7 +131,7 @@ struct ExpressionEquality { class ExpressionUtil { public: static bool allExpressionsHaveDataType( - expression_vector& expressions, common::DataTypeID dataTypeID); + expression_vector& expressions, common::LogicalTypeID dataTypeID); static uint32_t find(Expression* target, expression_vector expressions); diff --git a/src/include/binder/expression/literal_expression.h b/src/include/binder/expression/literal_expression.h index b228d62b21..7fad0d3124 100644 --- a/src/include/binder/expression/literal_expression.h +++ b/src/include/binder/expression/literal_expression.h @@ -13,8 +13,8 @@ class LiteralExpression : public Expression { inline bool isNull() const { return value->isNull(); } - inline void setDataType(const common::DataType& targetType) { - assert(dataType.typeID == common::ANY && isNull()); + inline void setDataType(const common::LogicalType& targetType) { + assert(dataType.getLogicalTypeID() == common::LogicalTypeID::ANY && isNull()); dataType = targetType; value->setDataType(targetType); } diff --git a/src/include/binder/expression/node_expression.h b/src/include/binder/expression/node_expression.h index 4456997630..9338502414 100644 --- a/src/include/binder/expression/node_expression.h +++ b/src/include/binder/expression/node_expression.h @@ -10,8 +10,8 @@ class NodeExpression : public NodeOrRelExpression { public: NodeExpression( std::string uniqueName, std::string variableName, std::vector tableIDs) - : NodeOrRelExpression{common::DataType(common::NODE), std::move(uniqueName), - std::move(variableName), std::move(tableIDs)} {} + : NodeOrRelExpression{common::LogicalType(common::LogicalTypeID::NODE), + std::move(uniqueName), std::move(variableName), std::move(tableIDs)} {} inline void setInternalIDProperty(std::unique_ptr expression) { internalIDExpression = std::move(expression); diff --git a/src/include/binder/expression/node_rel_expression.h b/src/include/binder/expression/node_rel_expression.h index 902fc1dfd9..2b2c7f3c28 100644 --- a/src/include/binder/expression/node_rel_expression.h +++ b/src/include/binder/expression/node_rel_expression.h @@ -9,8 +9,8 @@ namespace binder { class NodeOrRelExpression : public Expression { public: - NodeOrRelExpression(common::DataType dataType, std::string uniqueName, std::string variableName, - std::vector tableIDs) + NodeOrRelExpression(common::LogicalType dataType, std::string uniqueName, + std::string variableName, std::vector tableIDs) : Expression{common::VARIABLE, std::move(dataType), std::move(uniqueName)}, variableName(std::move(variableName)), tableIDs{std::move(tableIDs)} {} virtual ~NodeOrRelExpression() override = default; diff --git a/src/include/binder/expression/parameter_expression.h b/src/include/binder/expression/parameter_expression.h index 606a9f902d..6ad577d6aa 100644 --- a/src/include/binder/expression/parameter_expression.h +++ b/src/include/binder/expression/parameter_expression.h @@ -10,12 +10,12 @@ class ParameterExpression : public Expression { public: explicit ParameterExpression( const std::string& parameterName, std::shared_ptr value) - : Expression{common::PARAMETER, common::DataType(common::ANY), + : Expression{common::PARAMETER, common::LogicalType(common::LogicalTypeID::ANY), createUniqueName(parameterName)}, parameterName(parameterName), value{std::move(value)} {} - inline void setDataType(const common::DataType& targetType) { - assert(dataType.typeID == common::ANY); + inline void setDataType(const common::LogicalType& targetType) { + assert(dataType.getLogicalTypeID() == common::LogicalTypeID::ANY); dataType = targetType; value->setDataType(targetType); } diff --git a/src/include/binder/expression/property_expression.h b/src/include/binder/expression/property_expression.h index 71640e16b7..17ad802ca9 100644 --- a/src/include/binder/expression/property_expression.h +++ b/src/include/binder/expression/property_expression.h @@ -8,7 +8,7 @@ namespace binder { class PropertyExpression : public Expression { public: - PropertyExpression(common::DataType dataType, const std::string& propertyName, + PropertyExpression(common::LogicalType dataType, const std::string& propertyName, const Expression& nodeOrRel, std::unordered_map propertyIDPerTable, bool isPrimaryKey_) diff --git a/src/include/binder/expression/rel_expression.h b/src/include/binder/expression/rel_expression.h index fd25fe78b2..eef1b3c29a 100644 --- a/src/include/binder/expression/rel_expression.h +++ b/src/include/binder/expression/rel_expression.h @@ -9,11 +9,11 @@ namespace binder { class RelExpression : public NodeOrRelExpression { public: - RelExpression(common::DataType dataType, std::string uniqueName, std::string variableName, + RelExpression(common::LogicalType dataType, std::string uniqueName, std::string variableName, std::vector tableIDs, std::shared_ptr srcNode, std::shared_ptr dstNode, bool directed, common::QueryRelType relType, uint64_t lowerBound, uint64_t upperBound) - : NodeOrRelExpression{dataType, std::move(uniqueName), std::move(variableName), + : NodeOrRelExpression{std::move(dataType), std::move(uniqueName), std::move(variableName), std::move(tableIDs)}, srcNode{std::move(srcNode)}, dstNode{std::move(dstNode)}, directed{directed}, relType{relType}, lowerBound{lowerBound}, upperBound{upperBound} {} diff --git a/src/include/binder/expression/variable_expression.h b/src/include/binder/expression/variable_expression.h index d6627ffa6d..c1d89bf3aa 100644 --- a/src/include/binder/expression/variable_expression.h +++ b/src/include/binder/expression/variable_expression.h @@ -7,7 +7,8 @@ namespace binder { class VariableExpression : public Expression { public: - VariableExpression(common::DataType dataType, std::string uniqueName, std::string variableName) + VariableExpression( + common::LogicalType dataType, std::string uniqueName, std::string variableName) : Expression{common::VARIABLE, dataType, std::move(uniqueName)}, variableName{std::move( variableName)} {} diff --git a/src/include/binder/expression_binder.h b/src/include/binder/expression_binder.h index e2a3322179..ebfd813a01 100644 --- a/src/include/binder/expression_binder.h +++ b/src/include/binder/expression_binder.h @@ -19,7 +19,7 @@ class ExpressionBinder { std::shared_ptr bindExpression(const parser::ParsedExpression& parsedExpression); - static void resolveAnyDataType(Expression& expression, const common::DataType& targetType); + static void resolveAnyDataType(Expression& expression, const common::LogicalType& targetType); private: std::shared_ptr bindBooleanExpression( @@ -89,18 +89,19 @@ class ExpressionBinder { // not specify its child type. // For the rest, i.e. set clause binding, we cast with data type. For example, a.list = $1. static std::shared_ptr implicitCastIfNecessary( - const std::shared_ptr& expression, const common::DataType& targetType); + const std::shared_ptr& expression, const common::LogicalType& targetType); static std::shared_ptr implicitCastIfNecessary( - const std::shared_ptr& expression, common::DataTypeID targetTypeID); + const std::shared_ptr& expression, common::LogicalTypeID targetTypeID); static std::shared_ptr implicitCast( - const std::shared_ptr& expression, const common::DataType& targetType); + const std::shared_ptr& expression, const common::LogicalType& targetType); /****** validation *****/ - static void validateExpectedDataType(const Expression& expression, common::DataTypeID target) { - validateExpectedDataType(expression, std::unordered_set{target}); + static void validateExpectedDataType( + const Expression& expression, common::LogicalTypeID target) { + validateExpectedDataType(expression, std::unordered_set{target}); } static void validateExpectedDataType( - const Expression& expression, const std::unordered_set& targets); + const Expression& expression, const std::unordered_set& targets); // E.g. SUM(SUM(a.age)) is not allowed static void validateAggregationExpressionIsNotNested(const Expression& expression); diff --git a/src/include/c_api/kuzu.h b/src/include/c_api/kuzu.h index 051d5e6d49..207895bfb1 100644 --- a/src/include/c_api/kuzu.h +++ b/src/include/c_api/kuzu.h @@ -172,7 +172,6 @@ KUZU_C_API kuzu_data_type* kuzu_data_type_clone(kuzu_data_type* data_type); KUZU_C_API void kuzu_data_type_destroy(kuzu_data_type* data_type); KUZU_C_API bool kuzu_data_type_equals(kuzu_data_type* data_type1, kuzu_data_type* data_type2); KUZU_C_API kuzu_data_type_id kuzu_data_type_get_id(kuzu_data_type* data_type); -KUZU_C_API kuzu_data_type* kuzu_data_type_get_child_type(kuzu_data_type* data_type); KUZU_C_API uint64_t kuzu_data_type_get_fixed_num_elements_in_list(kuzu_data_type* data_type); // Value diff --git a/src/include/catalog/catalog.h b/src/include/catalog/catalog.h index a582f5a34a..de0c205c87 100644 --- a/src/include/catalog/catalog.h +++ b/src/include/catalog/catalog.h @@ -204,7 +204,7 @@ class Catalog { void renameTable(common::table_id_t tableID, std::string newName); void addProperty( - common::table_id_t tableID, std::string propertyName, common::DataType dataType); + common::table_id_t tableID, std::string propertyName, common::LogicalType dataType); void dropProperty(common::table_id_t tableID, common::property_id_t propertyID); diff --git a/src/include/catalog/catalog_structs.h b/src/include/catalog/catalog_structs.h index 9f5e75bc5a..d3d5d10976 100644 --- a/src/include/catalog/catalog_structs.h +++ b/src/include/catalog/catalog_structs.h @@ -21,17 +21,17 @@ struct Property { static constexpr std::string_view REL_TO_PROPERTY_NAME = "_TO_"; // This constructor is needed for ser/deser functions - Property() : Property{"", common::DataType{}} {}; - Property(std::string name, common::DataType dataType) + Property() : Property{"", common::LogicalType{}} {}; + Property(std::string name, common::LogicalType dataType) : Property{std::move(name), std::move(dataType), common::INVALID_PROPERTY_ID, common::INVALID_TABLE_ID} {} - Property(std::string name, common::DataType dataType, common::property_id_t propertyID, + Property(std::string name, common::LogicalType dataType, common::property_id_t propertyID, common::table_id_t tableID) : name{std::move(name)}, dataType{std::move(dataType)}, propertyID{propertyID}, tableID{tableID} {} std::string name; - common::DataType dataType; + common::LogicalType dataType; common::property_id_t propertyID; common::table_id_t tableID; }; @@ -65,7 +65,7 @@ struct TableSchema { [&propertyName](const Property& property) { return property.name == propertyName; }); } - inline void addProperty(std::string propertyName, common::DataType dataType) { + inline void addProperty(std::string propertyName, common::LogicalType dataType) { properties.emplace_back( std::move(propertyName), std::move(dataType), increaseNextPropertyID(), tableID); } diff --git a/src/include/common/arrow/arrow_row_batch.h b/src/include/common/arrow/arrow_row_batch.h index ea39418670..78c0f908d8 100644 --- a/src/include/common/arrow/arrow_row_batch.h +++ b/src/include/common/arrow/arrow_row_batch.h @@ -69,15 +69,15 @@ class ArrowRowBatch { ArrowVector* vector, const main::DataTypeInfo& typeInfo, Value* value, std::int64_t pos); static void copyNullValue(ArrowVector* vector, Value* value, std::int64_t pos); - template + template static void templateInitializeVector( ArrowVector* vector, const main::DataTypeInfo& typeInfo, std::int64_t capacity); - template + template static void templateCopyNonNullValue( ArrowVector* vector, const main::DataTypeInfo& typeInfo, Value* value, std::int64_t pos); - template + template static void templateCopyNullValue(ArrowVector* vector, std::int64_t pos); - template + template static ArrowArray* templateCreateArray(ArrowVector& vector, const main::DataTypeInfo& typeInfo); ArrowArray toArray(); diff --git a/src/include/common/in_mem_overflow_buffer_utils.h b/src/include/common/in_mem_overflow_buffer_utils.h index 5e7f27b70e..d618682f08 100644 --- a/src/include/common/in_mem_overflow_buffer_utils.h +++ b/src/include/common/in_mem_overflow_buffer_utils.h @@ -22,39 +22,7 @@ class InMemOverflowBufferUtils { const char* src, uint64_t len, ku_string_t& dest, InMemOverflowBuffer& inMemOverflowBuffer); static void copyString( const ku_string_t& src, ku_string_t& dest, InMemOverflowBuffer& inMemOverflowBuffer); - - static void copyListRecursiveIfNested(const ku_list_t& src, ku_list_t& dst, - const DataType& dataType, InMemOverflowBuffer& inMemOverflowBuffer, - uint32_t srcStartIdx = 0, uint32_t srcEndIdx = UINT32_MAX); - - template - static inline void setListElement(ku_list_t& result, uint64_t elementPos, T& element, - const DataType& dataType, InMemOverflowBuffer& inMemOverflowBuffer) { - reinterpret_cast(result.overflowPtr)[elementPos] = element; - } - - static inline void allocateSpaceForList( - ku_list_t& list, uint64_t numBytes, InMemOverflowBuffer& buffer) { - list.overflowPtr = reinterpret_cast(buffer.allocateSpace(numBytes)); - } }; -template<> -inline void InMemOverflowBufferUtils::setListElement(ku_list_t& result, uint64_t elementPos, - ku_string_t& element, const DataType& dataType, InMemOverflowBuffer& inMemOverflowBuffer) { - ku_string_t elementToAppend; - InMemOverflowBufferUtils::copyString(element, elementToAppend, inMemOverflowBuffer); - reinterpret_cast(result.overflowPtr)[elementPos] = elementToAppend; -} - -template<> -inline void InMemOverflowBufferUtils::setListElement(ku_list_t& result, uint64_t elementPos, - ku_list_t& element, const DataType& dataType, InMemOverflowBuffer& inMemOverflowBuffer) { - ku_list_t elementToAppend; - InMemOverflowBufferUtils::copyListRecursiveIfNested( - element, elementToAppend, *dataType.getChildType(), inMemOverflowBuffer); - reinterpret_cast(result.overflowPtr)[elementPos] = elementToAppend; -} - } // namespace common } // namespace kuzu diff --git a/src/include/common/ser_deser.h b/src/include/common/ser_deser.h index accb1bdfd5..d99fe18394 100644 --- a/src/include/common/ser_deser.h +++ b/src/include/common/ser_deser.h @@ -101,12 +101,12 @@ class SerDeser { }; template<> -uint64_t SerDeser::serializeValue(const DataType& value, FileInfo* fileInfo, uint64_t offset); +uint64_t SerDeser::serializeValue(const LogicalType& value, FileInfo* fileInfo, uint64_t offset); template<> uint64_t SerDeser::serializeValue(const std::string& value, FileInfo* fileInfo, uint64_t offset); template<> -uint64_t SerDeser::deserializeValue(DataType& value, FileInfo* fileInfo, uint64_t offset); +uint64_t SerDeser::deserializeValue(LogicalType& value, FileInfo* fileInfo, uint64_t offset); template<> uint64_t SerDeser::deserializeValue(std::string& value, FileInfo* fileInfo, uint64_t offset); diff --git a/src/include/common/string_utils.h b/src/include/common/string_utils.h index d6e75d4bcc..e5bf9d0e5a 100644 --- a/src/include/common/string_utils.h +++ b/src/include/common/string_utils.h @@ -19,7 +19,8 @@ class StringUtils { return fmt::format(fmt::runtime(format), args...); } - static std::vector split(const std::string& input, const std::string& delimiter); + static std::vector split( + const std::string& input, const std::string& delimiter, bool ignoreEmptyStringParts = true); static void toUpper(std::string& input) { std::transform(input.begin(), input.end(), input.begin(), ::toupper); diff --git a/src/include/common/type_utils.h b/src/include/common/type_utils.h index 929c45f7dd..44ed4622e3 100644 --- a/src/include/common/type_utils.h +++ b/src/include/common/type_utils.h @@ -27,7 +27,7 @@ class TypeUtils { static inline std::string toString(const interval_t& val) { return Interval::toString(val); } static inline std::string toString(const ku_string_t& val) { return val.getAsString(); } static inline std::string toString(const std::string& val) { return val; } - static std::string toString(const ku_list_t& val, const DataType& dataType); + static std::string toString(const ku_list_t& val, const LogicalType& dataType); static std::string toString(const list_entry_t& val, void* valVector); static inline void encodeOverflowPtr( @@ -63,9 +63,9 @@ class TypeUtils { private: static std::string listValueToString( - const DataType& dataType, uint8_t* listValues, uint64_t pos); + const LogicalType& dataType, uint8_t* listValues, uint64_t pos); - static std::string prefixConversionExceptionMessage(const char* data, DataTypeID dataTypeID); + static std::string prefixConversionExceptionMessage(const char* data, LogicalTypeID dataTypeID); }; template<> diff --git a/src/include/common/types/ku_list.h b/src/include/common/types/ku_list.h index 74f4448182..d608c78140 100644 --- a/src/include/common/types/ku_list.h +++ b/src/include/common/types/ku_list.h @@ -11,12 +11,12 @@ struct ku_list_t { ku_list_t() : size{0}, overflowPtr{0} {} ku_list_t(uint64_t size, uint64_t overflowPtr) : size{size}, overflowPtr{overflowPtr} {} - void set(const uint8_t* values, const DataType& dataType) const; + void set(const uint8_t* values, const LogicalType& dataType) const; private: friend class InMemOverflowBufferUtils; - void set(const std::vector& parameters, DataTypeID childTypeId); + void set(const std::vector& parameters, LogicalTypeID childTypeId); public: uint64_t size; diff --git a/src/include/common/types/types.h b/src/include/common/types/types.h index 6279679e66..a30cf9231e 100644 --- a/src/include/common/types/types.h +++ b/src/include/common/types/types.h @@ -52,16 +52,7 @@ struct list_entry_t { list_entry_t(common::offset_t offset, uint64_t size) : offset{offset}, size{size} {} }; -KUZU_API enum DataTypeID : uint8_t { - // NOTE: Not all data types can be used in processor. For example, ANY should be resolved during - // query compilation. Similarly logical data types should also only be used in compilation. - // Some use cases are as follows. - // - differentiate whether is a variable refers to node table or rel table - // - bind (static evaluate) functions work on node/rel table. - // E.g. ID(a "datatype:NODE") -> node ID property "datatype:NODE_ID" - - // logical types - +KUZU_API enum class LogicalTypeID : uint8_t { ANY = 0, NODE = 10, REL = 11, @@ -69,8 +60,6 @@ KUZU_API enum DataTypeID : uint8_t { // incremented by 1 starting from 0. SERIAL = 12, - // physical types - // fixed size types BOOL = 22, INT64 = 23, @@ -91,7 +80,30 @@ KUZU_API enum DataTypeID : uint8_t { STRUCT = 53, }; -class DataType; +enum class PhysicalTypeID : uint8_t { + // Fixed size types. + ANY = 0, + BOOL = 1, + INT64 = 2, + INT32 = 3, + INT16 = 4, + DOUBLE = 5, + FLOAT = 6, + INTERVAL = 7, + INTERNAL_ID = 9, + + // Variable size types. + STRING = 20, + FIXED_LIST = 21, + VAR_LIST = 22, + STRUCT = 23, +}; + +struct PhysicalTypeUtils { + static std::string physicalTypeToString(PhysicalTypeID physicalType); +}; + +class LogicalType; class ExtraTypeInfo { public: @@ -104,14 +116,14 @@ class VarListTypeInfo : public ExtraTypeInfo { public: VarListTypeInfo() = default; - explicit VarListTypeInfo(std::unique_ptr childType) + explicit VarListTypeInfo(std::unique_ptr childType) : childType{std::move(childType)} {} - inline DataType* getChildType() const { return childType.get(); } + inline LogicalType* getChildType() const { return childType.get(); } bool operator==(const VarListTypeInfo& other) const; std::unique_ptr copy() const override; protected: - std::unique_ptr childType; + std::unique_ptr childType; }; class FixedListTypeInfo : public VarListTypeInfo { @@ -119,9 +131,10 @@ class FixedListTypeInfo : public VarListTypeInfo { public: FixedListTypeInfo() = default; - explicit FixedListTypeInfo(std::unique_ptr childType, uint64_t fixedNumElementsInList) + explicit FixedListTypeInfo( + std::unique_ptr childType, uint64_t fixedNumElementsInList) : VarListTypeInfo{std::move(childType)}, fixedNumElementsInList{fixedNumElementsInList} {} - inline uint64_t getFixedNumElementsInList() const { return fixedNumElementsInList; } + inline uint64_t getNumElementsInList() const { return fixedNumElementsInList; } bool operator==(const FixedListTypeInfo& other) const; std::unique_ptr copy() const override; @@ -133,21 +146,21 @@ class StructField { friend class SerDeser; public: - StructField() : type{std::make_unique()} {} - StructField(std::string name, std::unique_ptr type) + StructField() : type{std::make_unique()} {} + StructField(std::string name, std::unique_ptr type) : name{std::move(name)}, type{std::move(type)} { // Note: struct field name is case-insensitive. StringUtils::toUpper(this->name); } inline bool operator!=(const StructField& other) const { return !(*this == other); } inline std::string getName() const { return name; } - inline DataType* getType() const { return type.get(); } + inline LogicalType* getType() const { return type.get(); } bool operator==(const StructField& other) const; std::unique_ptr copy() const; private: std::string name; - std::unique_ptr type; + std::unique_ptr type; }; class StructTypeInfo : public ExtraTypeInfo { @@ -158,7 +171,7 @@ class StructTypeInfo : public ExtraTypeInfo { explicit StructTypeInfo(std::vector> fields); struct_field_idx_t getStructFieldIdx(std::string fieldName) const; - std::vector getChildrenTypes() const; + std::vector getChildrenTypes() const; std::vector getChildrenNames() const; std::vector getStructFields() const; @@ -171,60 +184,117 @@ class StructTypeInfo : public ExtraTypeInfo { std::unordered_map fieldNameToIdxMap; }; -class DataType { +class LogicalType { + friend class SerDeser; + friend class LogicalTypeUtils; + friend class StructType; + friend class VarListType; + friend class FixedListType; + public: - KUZU_API DataType() : typeID{ANY}, extraTypeInfo{nullptr} {}; - KUZU_API explicit DataType(DataTypeID typeID) : typeID{typeID}, extraTypeInfo{nullptr} {}; - KUZU_API DataType(std::unique_ptr childType) - : typeID{VAR_LIST}, extraTypeInfo{std::make_unique(std::move(childType))} { - } - KUZU_API DataType(std::unique_ptr childType, uint64_t fixedNumElementsInList) - : typeID{FIXED_LIST}, extraTypeInfo{std::make_unique( - std::move(childType), fixedNumElementsInList)} {} - KUZU_API DataType(std::vector> childrenTypes) - : typeID{STRUCT}, extraTypeInfo{ - std::make_unique(std::move(childrenTypes))} {}; - KUZU_API DataType(const DataType& other); - KUZU_API DataType(DataType&& other) noexcept; + KUZU_API LogicalType() : typeID{LogicalTypeID::ANY}, extraTypeInfo{nullptr} {}; + KUZU_API explicit LogicalType(LogicalTypeID typeID) : typeID{typeID}, extraTypeInfo{nullptr} { + setPhysicalType(); + }; + KUZU_API LogicalType(LogicalTypeID typeID, std::unique_ptr extraTypeInfo) + : typeID{typeID}, extraTypeInfo{std::move(extraTypeInfo)} { + setPhysicalType(); + }; + KUZU_API LogicalType(const LogicalType& other); + KUZU_API LogicalType(LogicalType&& other) noexcept; - static std::vector getNumericalTypeIDs(); - static std::vector getAllValidComparableTypes(); - static std::vector getAllValidTypeIDs(); + static std::vector getNumericalLogicalTypeIDs(); + static std::vector getAllValidComparableLogicalTypes(); + static std::vector getAllValidLogicTypeIDs(); - KUZU_API DataType& operator=(const DataType& other); + KUZU_API LogicalType& operator=(const LogicalType& other); - KUZU_API bool operator==(const DataType& other) const; + KUZU_API bool operator==(const LogicalType& other) const; - KUZU_API bool operator!=(const DataType& other) const; + KUZU_API bool operator!=(const LogicalType& other) const; - KUZU_API DataType& operator=(DataType&& other) noexcept; + KUZU_API LogicalType& operator=(LogicalType&& other) noexcept; - KUZU_API DataTypeID getTypeID() const; + KUZU_API inline LogicalTypeID getLogicalTypeID() const { return typeID; } - DataType* getChildType() const; + inline PhysicalTypeID getPhysicalType() const { return physicalType; } - std::unique_ptr copy(); + std::unique_ptr copy(); - ExtraTypeInfo* getExtraTypeInfo() const; +private: + void setPhysicalType(); -public: - DataTypeID typeID; +private: + LogicalTypeID typeID; + PhysicalTypeID physicalType; std::unique_ptr extraTypeInfo; }; -class Types { +struct VarListType { + static inline LogicalType* getChildType(const LogicalType* type) { + assert(type->getLogicalTypeID() == LogicalTypeID::VAR_LIST); + auto varListTypeInfo = reinterpret_cast(type->extraTypeInfo.get()); + return varListTypeInfo->getChildType(); + } +}; + +struct FixedListType { + static inline LogicalType* getChildType(const LogicalType* type) { + assert(type->getLogicalTypeID() == LogicalTypeID::FIXED_LIST); + auto fixedListTypeInfo = reinterpret_cast(type->extraTypeInfo.get()); + return fixedListTypeInfo->getChildType(); + } + + static inline uint64_t getNumElementsInList(const LogicalType* type) { + assert(type->getLogicalTypeID() == LogicalTypeID::FIXED_LIST); + auto fixedListTypeInfo = reinterpret_cast(type->extraTypeInfo.get()); + return fixedListTypeInfo->getNumElementsInList(); + } +}; + +struct StructType { + static inline std::vector getStructFieldTypes(const LogicalType* type) { + assert(type->getLogicalTypeID() == LogicalTypeID::STRUCT); + auto structTypeInfo = reinterpret_cast(type->extraTypeInfo.get()); + return structTypeInfo->getChildrenTypes(); + } + + static inline std::vector getStructFieldNames(const LogicalType* type) { + assert(type->getLogicalTypeID() == LogicalTypeID::STRUCT); + auto structTypeInfo = reinterpret_cast(type->extraTypeInfo.get()); + return structTypeInfo->getChildrenNames(); + } + + static inline uint64_t getNumFields(const LogicalType* type) { + assert(type->getLogicalTypeID() == LogicalTypeID::STRUCT); + return getStructFieldTypes(type).size(); + } + + static inline std::vector getStructFields(const LogicalType* type) { + assert(type->getLogicalTypeID() == LogicalTypeID::STRUCT); + auto structTypeInfo = reinterpret_cast(type->extraTypeInfo.get()); + return structTypeInfo->getStructFields(); + } + + static inline struct_field_idx_t getStructFieldIdx(const LogicalType* type, std::string& key) { + assert(type->getLogicalTypeID() == LogicalTypeID::STRUCT); + auto structTypeInfo = reinterpret_cast(type->extraTypeInfo.get()); + return structTypeInfo->getStructFieldIdx(key); + } +}; + +class LogicalTypeUtils { public: - KUZU_API static std::string dataTypeToString(const DataType& dataType); - KUZU_API static std::string dataTypeToString(DataTypeID dataTypeID); - static std::string dataTypesToString(const std::vector& dataTypes); - static std::string dataTypesToString(const std::vector& dataTypeIDs); - KUZU_API static DataType dataTypeFromString(const std::string& dataTypeString); - static uint32_t getDataTypeSize(DataTypeID dataTypeID); - static uint32_t getDataTypeSize(const DataType& dataType); - static bool isNumerical(const DataType& dataType); + KUZU_API static std::string dataTypeToString(const LogicalType& dataType); + KUZU_API static std::string dataTypeToString(LogicalTypeID dataTypeID); + static std::string dataTypesToString(const std::vector& dataTypes); + static std::string dataTypesToString(const std::vector& dataTypeIDs); + KUZU_API static LogicalType dataTypeFromString(const std::string& dataTypeString); + static uint32_t getFixedTypeSize(kuzu::common::PhysicalTypeID physicalType); + static bool isNumerical(const LogicalType& dataType); private: - static DataTypeID dataTypeIDFromString(const std::string& dataTypeIDString); + static LogicalTypeID dataTypeIDFromString(const std::string& dataTypeIDString); }; // RelDataDirection diff --git a/src/include/common/types/value.h b/src/include/common/types/value.h index 2ded05a759..0eb9a4bcfb 100644 --- a/src/include/common/types/value.h +++ b/src/include/common/types/value.h @@ -21,12 +21,12 @@ class Value { * @param dataType the type of the NULL value. * @return a NULL value of the given type. */ - KUZU_API static Value createNullValue(DataType dataType); + KUZU_API static Value createNullValue(LogicalType dataType); /** * @param dataType the type of the non-NULL value. * @return a default non-NULL value of the given type. */ - KUZU_API static Value createDefaultValue(const DataType& dataType); + KUZU_API static Value createDefaultValue(const LogicalType& dataType); /** * @param val_ the boolean value to set. * @return a Value with BOOL type and val_ value. @@ -86,7 +86,7 @@ class Value { * @param vals the list value to set. * @return a Value with dataType type and vals value. */ - KUZU_API explicit Value(DataType dataType, std::vector> vals); + KUZU_API explicit Value(LogicalType dataType, std::vector> vals); /** * @param val_ the string value to set. * @return a Value with STRING type and val_ value. @@ -106,7 +106,7 @@ class Value { * @param val_ the value to set. * @return a Value with dataType type and val_ value. */ - KUZU_API explicit Value(DataType dataType, const uint8_t* val_); + KUZU_API explicit Value(LogicalType dataType, const uint8_t* val_); /** * @param other the value to copy from. * @return a Value with the same value as other. @@ -116,11 +116,11 @@ class Value { * @brief Sets the data type of the Value. * @param dataType_ the data type to set to. */ - KUZU_API void setDataType(const DataType& dataType_); + KUZU_API void setDataType(const LogicalType& dataType_); /** * @return the dataType of the value. */ - KUZU_API DataType getDataType() const; + KUZU_API LogicalType getDataType() const; /** * @brief Sets the null flag of the Value. * @param flag null value flag to set. @@ -182,7 +182,7 @@ class Value { private: Value(); - explicit Value(DataType dataType); + explicit Value(LogicalType dataType); template static inline void putValuesIntoVector(std::vector>& fixedListResultVal, @@ -198,7 +198,7 @@ class Value { std::vector> convertKUStructToVector(const uint8_t* kuStruct) const; public: - DataType dataType; + LogicalType dataType; bool isNull_; union Val { @@ -350,7 +350,7 @@ class RelVal { */ KUZU_API template<> inline bool Value::getValue() const { - assert(dataType.getTypeID() == BOOL); + assert(dataType.getLogicalTypeID() == LogicalTypeID::BOOL); return val.booleanVal; } @@ -359,7 +359,7 @@ inline bool Value::getValue() const { */ KUZU_API template<> inline int16_t Value::getValue() const { - assert(dataType.getTypeID() == INT16); + assert(dataType.getLogicalTypeID() == LogicalTypeID::INT16); return val.int16Val; } @@ -368,7 +368,7 @@ inline int16_t Value::getValue() const { */ KUZU_API template<> inline int32_t Value::getValue() const { - assert(dataType.getTypeID() == INT32); + assert(dataType.getLogicalTypeID() == LogicalTypeID::INT32); return val.int32Val; } @@ -377,7 +377,7 @@ inline int32_t Value::getValue() const { */ KUZU_API template<> inline int64_t Value::getValue() const { - assert(dataType.getTypeID() == INT64); + assert(dataType.getLogicalTypeID() == LogicalTypeID::INT64); return val.int64Val; } @@ -386,7 +386,7 @@ inline int64_t Value::getValue() const { */ KUZU_API template<> inline float Value::getValue() const { - assert(dataType.getTypeID() == FLOAT); + assert(dataType.getLogicalTypeID() == LogicalTypeID::FLOAT); return val.floatVal; } @@ -395,7 +395,7 @@ inline float Value::getValue() const { */ KUZU_API template<> inline double Value::getValue() const { - assert(dataType.getTypeID() == DOUBLE); + assert(dataType.getLogicalTypeID() == LogicalTypeID::DOUBLE); return val.doubleVal; } @@ -404,7 +404,7 @@ inline double Value::getValue() const { */ KUZU_API template<> inline date_t Value::getValue() const { - assert(dataType.getTypeID() == DATE); + assert(dataType.getLogicalTypeID() == LogicalTypeID::DATE); return val.dateVal; } @@ -413,7 +413,7 @@ inline date_t Value::getValue() const { */ KUZU_API template<> inline timestamp_t Value::getValue() const { - assert(dataType.getTypeID() == TIMESTAMP); + assert(dataType.getLogicalTypeID() == LogicalTypeID::TIMESTAMP); return val.timestampVal; } @@ -422,7 +422,7 @@ inline timestamp_t Value::getValue() const { */ KUZU_API template<> inline interval_t Value::getValue() const { - assert(dataType.getTypeID() == INTERVAL); + assert(dataType.getLogicalTypeID() == LogicalTypeID::INTERVAL); return val.intervalVal; } @@ -431,7 +431,7 @@ inline interval_t Value::getValue() const { */ KUZU_API template<> inline internalID_t Value::getValue() const { - assert(dataType.getTypeID() == INTERNAL_ID); + assert(dataType.getLogicalTypeID() == LogicalTypeID::INTERNAL_ID); return val.internalIDVal; } @@ -440,7 +440,7 @@ inline internalID_t Value::getValue() const { */ KUZU_API template<> inline std::string Value::getValue() const { - assert(dataType.getTypeID() == STRING); + assert(dataType.getLogicalTypeID() == LogicalTypeID::STRING); return strVal; } @@ -449,7 +449,7 @@ inline std::string Value::getValue() const { */ KUZU_API template<> inline NodeVal Value::getValue() const { - assert(dataType.getTypeID() == NODE); + assert(dataType.getLogicalTypeID() == LogicalTypeID::NODE); return *nodeVal; } @@ -458,7 +458,7 @@ inline NodeVal Value::getValue() const { */ KUZU_API template<> inline RelVal Value::getValue() const { - assert(dataType.getTypeID() == REL); + assert(dataType.getLogicalTypeID() == LogicalTypeID::REL); return *relVal; } @@ -467,7 +467,7 @@ inline RelVal Value::getValue() const { */ KUZU_API template<> inline bool& Value::getValueReference() { - assert(dataType.getTypeID() == BOOL); + assert(dataType.getLogicalTypeID() == LogicalTypeID::BOOL); return val.booleanVal; } @@ -476,7 +476,7 @@ inline bool& Value::getValueReference() { */ KUZU_API template<> inline int16_t& Value::getValueReference() { - assert(dataType.getTypeID() == INT16); + assert(dataType.getLogicalTypeID() == LogicalTypeID::INT16); return val.int16Val; } @@ -485,7 +485,7 @@ inline int16_t& Value::getValueReference() { */ KUZU_API template<> inline int32_t& Value::getValueReference() { - assert(dataType.getTypeID() == INT32); + assert(dataType.getLogicalTypeID() == LogicalTypeID::INT32); return val.int32Val; } @@ -494,7 +494,7 @@ inline int32_t& Value::getValueReference() { */ KUZU_API template<> inline int64_t& Value::getValueReference() { - assert(dataType.getTypeID() == INT64); + assert(dataType.getLogicalTypeID() == LogicalTypeID::INT64); return val.int64Val; } @@ -503,7 +503,7 @@ inline int64_t& Value::getValueReference() { */ KUZU_API template<> inline float_t& Value::getValueReference() { - assert(dataType.getTypeID() == FLOAT); + assert(dataType.getLogicalTypeID() == LogicalTypeID::FLOAT); return val.floatVal; } @@ -512,7 +512,7 @@ inline float_t& Value::getValueReference() { */ KUZU_API template<> inline double_t& Value::getValueReference() { - assert(dataType.getTypeID() == DOUBLE); + assert(dataType.getLogicalTypeID() == LogicalTypeID::DOUBLE); return val.doubleVal; } @@ -521,7 +521,7 @@ inline double_t& Value::getValueReference() { */ KUZU_API template<> inline date_t& Value::getValueReference() { - assert(dataType.getTypeID() == DATE); + assert(dataType.getLogicalTypeID() == LogicalTypeID::DATE); return val.dateVal; } @@ -530,7 +530,7 @@ inline date_t& Value::getValueReference() { */ KUZU_API template<> inline timestamp_t& Value::getValueReference() { - assert(dataType.getTypeID() == TIMESTAMP); + assert(dataType.getLogicalTypeID() == LogicalTypeID::TIMESTAMP); return val.timestampVal; } @@ -539,7 +539,7 @@ inline timestamp_t& Value::getValueReference() { */ KUZU_API template<> inline interval_t& Value::getValueReference() { - assert(dataType.getTypeID() == INTERVAL); + assert(dataType.getLogicalTypeID() == LogicalTypeID::INTERVAL); return val.intervalVal; } @@ -548,7 +548,7 @@ inline interval_t& Value::getValueReference() { */ KUZU_API template<> inline nodeID_t& Value::getValueReference() { - assert(dataType.getTypeID() == INTERNAL_ID); + assert(dataType.getLogicalTypeID() == LogicalTypeID::INTERNAL_ID); return val.internalIDVal; } @@ -557,7 +557,7 @@ inline nodeID_t& Value::getValueReference() { */ KUZU_API template<> inline std::string& Value::getValueReference() { - assert(dataType.getTypeID() == STRING); + assert(dataType.getLogicalTypeID() == LogicalTypeID::STRING); return strVal; } @@ -566,7 +566,7 @@ inline std::string& Value::getValueReference() { */ KUZU_API template<> inline NodeVal& Value::getValueReference() { - assert(dataType.getTypeID() == NODE); + assert(dataType.getLogicalTypeID() == LogicalTypeID::NODE); return *nodeVal; } @@ -575,7 +575,7 @@ inline NodeVal& Value::getValueReference() { */ KUZU_API template<> inline RelVal& Value::getValueReference() { - assert(dataType.getTypeID() == REL); + assert(dataType.getLogicalTypeID() == LogicalTypeID::REL); return *relVal; } diff --git a/src/include/common/vector/auxiliary_buffer.h b/src/include/common/vector/auxiliary_buffer.h index c2b3d94270..0b89447321 100644 --- a/src/include/common/vector/auxiliary_buffer.h +++ b/src/include/common/vector/auxiliary_buffer.h @@ -29,7 +29,7 @@ class StringAuxiliaryBuffer : public AuxiliaryBuffer { class StructAuxiliaryBuffer : public AuxiliaryBuffer { public: - StructAuxiliaryBuffer(const DataType& type, storage::MemoryManager* memoryManager); + StructAuxiliaryBuffer(const LogicalType& type, storage::MemoryManager* memoryManager); inline void referenceChildVector( vector_idx_t idx, std::shared_ptr vectorToReference) { @@ -52,7 +52,7 @@ class StructAuxiliaryBuffer : public AuxiliaryBuffer { // contiguous subsequence of elements in this vector. class ListAuxiliaryBuffer : public AuxiliaryBuffer { public: - ListAuxiliaryBuffer(const DataType& dataVectorType, storage::MemoryManager* memoryManager); + ListAuxiliaryBuffer(const LogicalType& dataVectorType, storage::MemoryManager* memoryManager); inline ValueVector* getDataVector() const { return dataVector.get(); } @@ -69,7 +69,7 @@ class ListAuxiliaryBuffer : public AuxiliaryBuffer { class AuxiliaryBufferFactory { public: static std::unique_ptr getAuxiliaryBuffer( - DataType& type, storage::MemoryManager* memoryManager); + LogicalType& type, storage::MemoryManager* memoryManager); }; } // namespace common diff --git a/src/include/common/vector/value_vector.h b/src/include/common/vector/value_vector.h index 28ef3e0a00..682096e1fb 100644 --- a/src/include/common/vector/value_vector.h +++ b/src/include/common/vector/value_vector.h @@ -20,10 +20,10 @@ class ValueVector { friend class StringVector; public: - explicit ValueVector(DataType dataType, storage::MemoryManager* memoryManager = nullptr); - explicit ValueVector(DataTypeID dataTypeID, storage::MemoryManager* memoryManager = nullptr) - : ValueVector(DataType(dataTypeID), memoryManager) { - assert(dataTypeID != VAR_LIST); + explicit ValueVector(LogicalType dataType, storage::MemoryManager* memoryManager = nullptr); + explicit ValueVector(LogicalTypeID dataTypeID, storage::MemoryManager* memoryManager = nullptr) + : ValueVector(LogicalType(dataTypeID), memoryManager) { + assert(dataTypeID != LogicalTypeID::VAR_LIST); } ~ValueVector() = default; @@ -57,7 +57,7 @@ class ValueVector { inline uint8_t* getData() const { return valueBuffer.get(); } inline offset_t readNodeOffset(uint32_t pos) const { - assert(dataType.typeID == INTERNAL_ID); + assert(dataType.getLogicalTypeID() == LogicalTypeID::INTERNAL_ID); return getValue(pos).offset; } @@ -65,11 +65,11 @@ class ValueVector { inline bool isSequential() const { return _isSequential; } private: - void setNumBytesPerValue(); + uint32_t getDataTypeSize(const LogicalType& type); void initializeValueBuffer(); public: - DataType dataType; + LogicalType dataType; std::shared_ptr state; private: @@ -83,14 +83,14 @@ class ValueVector { class StringVector { public: static inline InMemOverflowBuffer* getInMemOverflowBuffer(ValueVector* vector) { - return vector->dataType.typeID == STRING ? + return vector->dataType.getLogicalTypeID() == LogicalTypeID::STRING ? reinterpret_cast(vector->auxiliaryBuffer.get()) ->getOverflowBuffer() : nullptr; } static inline void resetOverflowBuffer(ValueVector* vector) { - if (vector->dataType.typeID == STRING) { + if (vector->dataType.getLogicalTypeID() == LogicalTypeID::STRING) { reinterpret_cast(vector->auxiliaryBuffer.get()) ->resetOverflowBuffer(); } @@ -106,28 +106,28 @@ class StringVector { class ListVector { public: static inline ValueVector* getDataVector(const ValueVector* vector) { - assert(vector->dataType.typeID == VAR_LIST); + assert(vector->dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST); return reinterpret_cast(vector->auxiliaryBuffer.get()) ->getDataVector(); } static inline uint8_t* getListValues(const ValueVector* vector, const list_entry_t& listEntry) { - assert(vector->dataType.typeID == VAR_LIST); + assert(vector->dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST); auto dataVector = getDataVector(vector); return dataVector->getData() + dataVector->getNumBytesPerValue() * listEntry.offset; } static inline uint8_t* getListValuesWithOffset(const ValueVector* vector, const list_entry_t& listEntry, common::offset_t elementOffsetInList) { - assert(vector->dataType.typeID == VAR_LIST); + assert(vector->dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST); return getListValues(vector, listEntry) + elementOffsetInList * getDataVector(vector)->getNumBytesPerValue(); } static inline list_entry_t addList(ValueVector* vector, uint64_t listSize) { - assert(vector->dataType.typeID == VAR_LIST); + assert(vector->dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST); return reinterpret_cast(vector->auxiliaryBuffer.get()) ->addList(listSize); } static inline void resetListAuxiliaryBuffer(ValueVector* vector) { - assert(vector->dataType.typeID == VAR_LIST); + assert(vector->dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST); reinterpret_cast(vector->auxiliaryBuffer.get())->resetSize(); } }; @@ -153,9 +153,9 @@ class StructVector { } static inline void initializeEntries(ValueVector* vector) { - std::iota((struct_entry_t*)vector->getData(), - (struct_entry_t*)(vector->getData() + - vector->getNumBytesPerValue() * DEFAULT_VECTOR_CAPACITY), + std::iota(reinterpret_cast(vector->getData()), + reinterpret_cast( + vector->getData() + vector->getNumBytesPerValue() * DEFAULT_VECTOR_CAPACITY), 0); } }; diff --git a/src/include/common/vector/value_vector_utils.h b/src/include/common/vector/value_vector_utils.h index e5d986a3ed..c179097cab 100644 --- a/src/include/common/vector/value_vector_utils.h +++ b/src/include/common/vector/value_vector_utils.h @@ -17,7 +17,7 @@ class ValueVectorUtils { const uint8_t* srcValue, const common::ValueVector& srcVector); private: - static void copyNonNullDataWithSameType(const DataType& dataType, const uint8_t* srcData, + static void copyNonNullDataWithSameType(const LogicalType& dataType, const uint8_t* srcData, uint8_t* dstData, InMemOverflowBuffer& inMemOverflowBuffer); }; diff --git a/src/include/function/aggregate/aggregate_function.h b/src/include/function/aggregate/aggregate_function.h index 0b422cc9fd..009b754322 100644 --- a/src/include/function/aggregate/aggregate_function.h +++ b/src/include/function/aggregate/aggregate_function.h @@ -14,15 +14,16 @@ class AggregateFunction; struct AggregateFunctionDefinition : public FunctionDefinition { - AggregateFunctionDefinition(std::string name, std::vector parameterTypeIDs, - common::DataTypeID returnTypeID, std::unique_ptr aggregateFunction, - bool isDistinct) + AggregateFunctionDefinition(std::string name, + std::vector parameterTypeIDs, common::LogicalTypeID returnTypeID, + std::unique_ptr aggregateFunction, bool isDistinct) : FunctionDefinition{std::move(name), std::move(parameterTypeIDs), returnTypeID}, aggregateFunction{std::move(aggregateFunction)}, isDistinct{isDistinct} {} - AggregateFunctionDefinition(std::string name, std::vector parameterTypeIDs, - common::DataTypeID returnTypeID, std::unique_ptr aggregateFunction, - bool isDistinct, scalar_bind_func bindFunc) + AggregateFunctionDefinition(std::string name, + std::vector parameterTypeIDs, common::LogicalTypeID returnTypeID, + std::unique_ptr aggregateFunction, bool isDistinct, + scalar_bind_func bindFunc) : FunctionDefinition{std::move(name), std::move(parameterTypeIDs), returnTypeID, std::move(bindFunc)}, aggregateFunction{std::move(aggregateFunction)}, isDistinct{isDistinct} {} @@ -54,7 +55,7 @@ class AggregateFunction { AggregateFunction(aggr_initialize_function_t initializeFunc, aggr_update_all_function_t updateAllFunc, aggr_update_pos_function_t updatePosFunc, aggr_combine_function_t combineFunc, aggr_finalize_function_t finalizeFunc, - common::DataType inputDataType, bool isDistinct = false) + common::LogicalType inputDataType, bool isDistinct = false) : initializeFunc{std::move(initializeFunc)}, updateAllFunc{std::move(updateAllFunc)}, updatePosFunc{std::move(updatePosFunc)}, combineFunc{std::move(combineFunc)}, finalizeFunc{std::move(finalizeFunc)}, inputDataType{std::move(inputDataType)}, @@ -89,9 +90,11 @@ class AggregateFunction { inline void finalizeState(uint8_t* state) { return finalizeFunc(state); } - inline common::DataType getInputDataType() const { return inputDataType; } + inline common::LogicalType getInputDataType() const { return inputDataType; } - inline void setInputDataType(common::DataType dataType) { inputDataType = std::move(dataType); } + inline void setInputDataType(common::LogicalType dataType) { + inputDataType = std::move(dataType); + } inline bool isFunctionDistinct() const { return isDistinct; } @@ -107,7 +110,7 @@ class AggregateFunction { aggr_combine_function_t combineFunc; aggr_finalize_function_t finalizeFunc; - common::DataType inputDataType; + common::LogicalType inputDataType; bool isDistinct; std::unique_ptr initialNullAggregateState; @@ -118,22 +121,22 @@ class AggregateFunctionUtil { public: static std::unique_ptr getCountStarFunction(); static std::unique_ptr getCountFunction( - const common::DataType& inputType, bool isDistinct); + const common::LogicalType& inputType, bool isDistinct); static std::unique_ptr getAvgFunction( - const common::DataType& inputType, bool isDistinct); + const common::LogicalType& inputType, bool isDistinct); static std::unique_ptr getSumFunction( - const common::DataType& inputType, bool isDistinct); + const common::LogicalType& inputType, bool isDistinct); static std::unique_ptr getMinFunction( - const common::DataType& inputType, bool isDistinct); + const common::LogicalType& inputType, bool isDistinct); static std::unique_ptr getMaxFunction( - const common::DataType& inputType, bool isDistinct); + const common::LogicalType& inputType, bool isDistinct); static std::unique_ptr getCollectFunction( - const common::DataType& inputType, bool isDistinct); + const common::LogicalType& inputType, bool isDistinct); private: template static std::unique_ptr getMinMaxFunction( - const common::DataType& inputType, bool isDistinct); + const common::LogicalType& inputType, bool isDistinct); }; } // namespace function diff --git a/src/include/function/aggregate/built_in_aggregate_functions.h b/src/include/function/aggregate/built_in_aggregate_functions.h index 04dfe9f990..e3ae46cda4 100644 --- a/src/include/function/aggregate/built_in_aggregate_functions.h +++ b/src/include/function/aggregate/built_in_aggregate_functions.h @@ -14,18 +14,18 @@ class BuiltInAggregateFunctions { return aggregateFunctions.contains(name); } - AggregateFunctionDefinition* matchFunction( - const std::string& name, const std::vector& inputTypes, bool isDistinct); + AggregateFunctionDefinition* matchFunction(const std::string& name, + const std::vector& inputTypes, bool isDistinct); std::vector getFunctionNames(); private: - uint32_t getFunctionCost(const std::vector& inputTypes, bool isDistinct, + uint32_t getFunctionCost(const std::vector& inputTypes, bool isDistinct, AggregateFunctionDefinition* function); void validateNonEmptyCandidateFunctions( std::vector& candidateFunctions, const std::string& name, - const std::vector& inputTypes, bool isDistinct); + const std::vector& inputTypes, bool isDistinct); void registerAggregateFunctions(); void registerCountStar(); diff --git a/src/include/function/aggregate/collect.h b/src/include/function/aggregate/collect.h index 1b6ee8bec0..4561346918 100644 --- a/src/include/function/aggregate/collect.h +++ b/src/include/function/aggregate/collect.h @@ -58,12 +58,12 @@ struct CollectFunction { } static void initCollectStateIfNecessary( - CollectState* state, storage::MemoryManager* memoryManager, common::DataType& dataType) { + CollectState* state, storage::MemoryManager* memoryManager, common::LogicalType& dataType) { if (state->factorizedTable == nullptr) { auto tableSchema = std::make_unique(); tableSchema->appendColumn( std::make_unique(false /* isUnflat */, - 0 /* dataChunkPos */, common::Types::getDataTypeSize(dataType))); + 0 /* dataChunkPos */, storage::StorageUtils::getDataTypeSize(dataType))); state->factorizedTable = std::make_unique(memoryManager, std::move(tableSchema)); } @@ -102,8 +102,10 @@ struct CollectFunction { assert(arguments.size() == 1); auto aggFuncDefinition = reinterpret_cast(definition); aggFuncDefinition->aggregateFunction->setInputDataType(arguments[0]->dataType); + auto varListTypeInfo = std::make_unique( + std::make_unique(arguments[0]->dataType)); auto returnType = - common::DataType(std::make_unique(arguments[0]->dataType)); + common::LogicalType(common::LogicalTypeID::VAR_LIST, std::move(varListTypeInfo)); return std::make_unique(returnType); } }; diff --git a/src/include/function/arithmetic/vector_arithmetic_operations.h b/src/include/function/arithmetic/vector_arithmetic_operations.h index 408242d739..9e2d6c2c3a 100644 --- a/src/include/function/arithmetic/vector_arithmetic_operations.h +++ b/src/include/function/arithmetic/vector_arithmetic_operations.h @@ -11,87 +11,89 @@ class VectorArithmeticOperations : public VectorOperations { public: template static std::unique_ptr getUnaryDefinition( - std::string name, common::DataTypeID operandTypeID) { + std::string name, common::LogicalTypeID operandTypeID) { return std::make_unique(std::move(name), - std::vector{operandTypeID}, operandTypeID, + std::vector{operandTypeID}, operandTypeID, getUnaryExecFunc(operandTypeID)); } template static std::unique_ptr getUnaryDefinition( - std::string name, common::DataTypeID operandTypeID, common::DataTypeID resultTypeID) { + std::string name, common::LogicalTypeID operandTypeID, common::LogicalTypeID resultTypeID) { return std::make_unique(std::move(name), - std::vector{operandTypeID}, resultTypeID, + std::vector{operandTypeID}, resultTypeID, VectorArithmeticOperations::UnaryExecFunction); } template static inline std::unique_ptr getBinaryDefinition( - std::string name, common::DataTypeID operandTypeID) { + std::string name, common::LogicalTypeID operandTypeID) { return std::make_unique(std::move(name), - std::vector{operandTypeID, operandTypeID}, operandTypeID, + std::vector{operandTypeID, operandTypeID}, operandTypeID, getBinaryExecFunc(operandTypeID)); } template static inline std::unique_ptr getBinaryDefinition( - std::string name, common::DataTypeID operandTypeID, common::DataTypeID resultTypeID) { + std::string name, common::LogicalTypeID operandTypeID, common::LogicalTypeID resultTypeID) { return std::make_unique(std::move(name), - std::vector{operandTypeID, operandTypeID}, resultTypeID, + std::vector{operandTypeID, operandTypeID}, resultTypeID, VectorArithmeticOperations::BinaryExecFunction); } private: template - static scalar_exec_func getUnaryExecFunc(common::DataTypeID operandTypeID) { + static scalar_exec_func getUnaryExecFunc(common::LogicalTypeID operandTypeID) { switch (operandTypeID) { - case common::INT64: { + case common::LogicalTypeID::INT64: { return VectorArithmeticOperations::UnaryExecFunction; } - case common::INT32: { + case common::LogicalTypeID::INT32: { return VectorArithmeticOperations::UnaryExecFunction; } - case common::INT16: { + case common::LogicalTypeID::INT16: { return VectorArithmeticOperations::UnaryExecFunction; } - case common::DOUBLE: { + case common::LogicalTypeID::DOUBLE: { return VectorArithmeticOperations::UnaryExecFunction; } - case common::FLOAT: { + case common::LogicalTypeID::FLOAT: { return VectorArithmeticOperations::UnaryExecFunction; ; } default: - throw common::RuntimeException("Invalid input data types(" + - common::Types::dataTypeToString(operandTypeID) + - ") for getUnaryExecFunc."); + throw common::RuntimeException( + "Invalid input data types(" + + common::LogicalTypeUtils::dataTypeToString(operandTypeID) + + ") for getUnaryExecFunc."); } } template - static scalar_exec_func getBinaryExecFunc(common::DataTypeID operandTypeID) { + static scalar_exec_func getBinaryExecFunc(common::LogicalTypeID operandTypeID) { switch (operandTypeID) { - case common::INT64: { + case common::LogicalTypeID::INT64: { return VectorArithmeticOperations::BinaryExecFunction; } - case common::INT32: { + case common::LogicalTypeID::INT32: { return VectorArithmeticOperations::BinaryExecFunction; } - case common::INT16: { + case common::LogicalTypeID::INT16: { return VectorArithmeticOperations::BinaryExecFunction; } - case common::DOUBLE: { + case common::LogicalTypeID::DOUBLE: { return VectorArithmeticOperations::BinaryExecFunction; } - case common::FLOAT: { + case common::LogicalTypeID::FLOAT: { return VectorArithmeticOperations::BinaryExecFunction; } default: - throw common::RuntimeException("Invalid input data types(" + - common::Types::dataTypeToString(operandTypeID) + - ") for getUnaryExecFunc."); + throw common::RuntimeException( + "Invalid input data types(" + + common::LogicalTypeUtils::dataTypeToString(operandTypeID) + + ") for getUnaryExecFunc."); } } }; diff --git a/src/include/function/boolean/boolean_operation_executor.h b/src/include/function/boolean/boolean_operation_executor.h index 4f5b3d6b06..93af2eaaa0 100644 --- a/src/include/function/boolean/boolean_operation_executor.h +++ b/src/include/function/boolean/boolean_operation_executor.h @@ -130,8 +130,9 @@ struct BinaryBooleanOperationExecutor { template static void execute( common::ValueVector& left, common::ValueVector& right, common::ValueVector& result) { - assert(left.dataType.typeID == common::BOOL && right.dataType.typeID == common::BOOL && - result.dataType.typeID == common::BOOL); + assert(left.dataType.getLogicalTypeID() == common::LogicalTypeID::BOOL && + right.dataType.getLogicalTypeID() == common::LogicalTypeID::BOOL && + result.dataType.getLogicalTypeID() == common::LogicalTypeID::BOOL); if (left.state->isFlat() && right.state->isFlat()) { executeBothFlat(left, right, result); } else if (left.state->isFlat() && !right.state->isFlat()) { @@ -232,7 +233,8 @@ struct BinaryBooleanOperationExecutor { template static bool select( common::ValueVector& left, common::ValueVector& right, common::SelectionVector& selVector) { - assert(left.dataType.typeID == common::BOOL && right.dataType.typeID == common::BOOL); + assert(left.dataType.getLogicalTypeID() == common::LogicalTypeID::BOOL && + right.dataType.getLogicalTypeID() == common::LogicalTypeID::BOOL); if (left.state->isFlat() && right.state->isFlat()) { return selectBothFlat(left, right); } else if (left.state->isFlat() && !right.state->isFlat()) { diff --git a/src/include/function/built_in_vector_operations.h b/src/include/function/built_in_vector_operations.h index 3a31f42dd2..951e741a36 100644 --- a/src/include/function/built_in_vector_operations.h +++ b/src/include/function/built_in_vector_operations.h @@ -22,40 +22,41 @@ class BuiltInVectorOperations { const std::string& functionName, const binder::expression_vector& children); VectorOperationDefinition* matchFunction( - const std::string& name, const std::vector& inputTypes); + const std::string& name, const std::vector& inputTypes); std::vector getFunctionNames(); - static uint32_t getCastCost(common::DataTypeID inputTypeID, common::DataTypeID targetTypeID); + static uint32_t getCastCost( + common::LogicalTypeID inputTypeID, common::LogicalTypeID targetTypeID); static uint32_t getCastCost( - const common::DataType& inputType, const common::DataType& targetType); + const common::LogicalType& inputType, const common::LogicalType& targetType); private: - static uint32_t getTargetTypeCost(common::DataTypeID typeID); + static uint32_t getTargetTypeCost(common::LogicalTypeID typeID); - static uint32_t castInt64(common::DataTypeID targetTypeID); + static uint32_t castInt64(common::LogicalTypeID targetTypeID); - static uint32_t castInt32(common::DataTypeID targetTypeID); + static uint32_t castInt32(common::LogicalTypeID targetTypeID); - static uint32_t castInt16(common::DataTypeID targetTypeID); + static uint32_t castInt16(common::LogicalTypeID targetTypeID); - static uint32_t castDouble(common::DataTypeID targetTypeID); + static uint32_t castDouble(common::LogicalTypeID targetTypeID); - static uint32_t castFloat(common::DataTypeID targetTypeID); + static uint32_t castFloat(common::LogicalTypeID targetTypeID); VectorOperationDefinition* getBestMatch(std::vector& functions); - uint32_t getFunctionCost(const std::vector& inputTypes, + uint32_t getFunctionCost(const std::vector& inputTypes, VectorOperationDefinition* function, bool isOverload); - uint32_t matchParameters(const std::vector& inputTypes, - const std::vector& targetTypeIDs, bool isOverload); - uint32_t matchVarLengthParameters(const std::vector& inputTypes, - common::DataTypeID targetTypeID, bool isOverload); + uint32_t matchParameters(const std::vector& inputTypes, + const std::vector& targetTypeIDs, bool isOverload); + uint32_t matchVarLengthParameters(const std::vector& inputTypes, + common::LogicalTypeID targetTypeID, bool isOverload); void validateNonEmptyCandidateFunctions( std::vector& candidateFunctions, const std::string& name, - const std::vector& inputTypes); + const std::vector& inputTypes); void registerVectorOperations(); void registerComparisonOperations(); diff --git a/src/include/function/cast/vector_cast_operations.h b/src/include/function/cast/vector_cast_operations.h index 6409614d7c..6f4495cc32 100644 --- a/src/include/function/cast/vector_cast_operations.h +++ b/src/include/function/cast/vector_cast_operations.h @@ -15,17 +15,18 @@ class VectorCastOperations : public VectorOperations { public: // This function is only used by expression binder when implicit cast is needed. // The expression binder should consider reusing the existing matchFunction() API. - static bool hasImplicitCast(const common::DataType& srcType, const common::DataType& dstType); - static std::string bindImplicitCastFuncName(const common::DataType& dstType); + static bool hasImplicitCast( + const common::LogicalType& srcType, const common::LogicalType& dstType); + static std::string bindImplicitCastFuncName(const common::LogicalType& dstType); static scalar_exec_func bindImplicitCastFunc( - common::DataTypeID sourceTypeID, common::DataTypeID targetTypeID); + common::LogicalTypeID sourceTypeID, common::LogicalTypeID targetTypeID); template inline static std::unique_ptr bindVectorOperation( - const std::string& funcName, common::DataTypeID sourceTypeID, - common::DataTypeID targetTypeID) { + const std::string& funcName, common::LogicalTypeID sourceTypeID, + common::LogicalTypeID targetTypeID) { return std::make_unique(funcName, - std::vector{sourceTypeID}, targetTypeID, + std::vector{sourceTypeID}, targetTypeID, VectorOperations::UnaryExecFunction); } @@ -39,44 +40,44 @@ class VectorCastOperations : public VectorOperations { private: template - static scalar_exec_func bindImplicitNumericalCastFunc(common::DataTypeID srcTypeID) { + static scalar_exec_func bindImplicitNumericalCastFunc(common::LogicalTypeID srcTypeID) { switch (srcTypeID) { - case common::INT16: + case common::LogicalTypeID::INT16: return VectorOperations::UnaryExecFunction; - case common::INT32: + case common::LogicalTypeID::INT32: return VectorOperations::UnaryExecFunction; - case common::INT64: + case common::LogicalTypeID::INT64: return VectorOperations::UnaryExecFunction; - case common::FLOAT: + case common::LogicalTypeID::FLOAT: return VectorOperations::UnaryExecFunction; - case common::DOUBLE: + case common::LogicalTypeID::DOUBLE: return VectorOperations::UnaryExecFunction; default: - throw common::NotImplementedException("Unimplemented casting operation from " + - common::Types::dataTypeToString(srcTypeID) + - " to numeric."); + throw common::NotImplementedException( + "Unimplemented casting operation from " + + common::LogicalTypeUtils::dataTypeToString(srcTypeID) + " to numeric."); } } template - static scalar_exec_func bindImplicitStringCastFunc(common::DataTypeID srcTypeID) { + static scalar_exec_func bindImplicitStringCastFunc(common::LogicalTypeID srcTypeID) { switch (srcTypeID) { - case common::INT64: + case common::LogicalTypeID::INT64: return UnaryCastExecFunction; - case common::DOUBLE: + case common::LogicalTypeID::DOUBLE: return UnaryCastExecFunction; - case common::DATE: + case common::LogicalTypeID::DATE: return UnaryCastExecFunction; - case common::TIMESTAMP: + case common::LogicalTypeID::TIMESTAMP: return UnaryCastExecFunction; - case common::INTERVAL: + case common::LogicalTypeID::INTERVAL: return UnaryCastExecFunction; - case common::VAR_LIST: + case common::LogicalTypeID::VAR_LIST: return UnaryCastExecFunction; default: - throw common::NotImplementedException("Unimplemented casting operation from " + - common::Types::dataTypeToString(srcTypeID) + - " to string."); + throw common::NotImplementedException( + "Unimplemented casting operation from " + + common::LogicalTypeUtils::dataTypeToString(srcTypeID) + " to string."); } } }; diff --git a/src/include/function/comparison/vector_comparison_operations.h b/src/include/function/comparison/vector_comparison_operations.h index adf270603e..cb1b12b959 100644 --- a/src/include/function/comparison/vector_comparison_operations.h +++ b/src/include/function/comparison/vector_comparison_operations.h @@ -14,166 +14,178 @@ class VectorComparisonOperations : public VectorOperations { static std::vector> getDefinitions( const std::string& name) { std::vector> definitions; - for (auto& numericTypeID : common::DataType::getNumericalTypeIDs()) { + for (auto& numericTypeID : common::LogicalType::getNumericalLogicalTypeIDs()) { definitions.push_back(getDefinition(name, numericTypeID, numericTypeID)); } - for (auto& typeID : std::vector{common::BOOL, common::STRING, - common::INTERNAL_ID, common::DATE, common::TIMESTAMP, common::INTERVAL}) { + for (auto& typeID : std::vector{common::LogicalTypeID::BOOL, + common::LogicalTypeID::STRING, common::LogicalTypeID::INTERNAL_ID, + common::LogicalTypeID::DATE, common::LogicalTypeID::TIMESTAMP, + common::LogicalTypeID::INTERVAL}) { definitions.push_back(getDefinition(name, typeID, typeID)); } - definitions.push_back(getDefinition(name, common::DATE, common::TIMESTAMP)); - definitions.push_back(getDefinition(name, common::TIMESTAMP, common::DATE)); + definitions.push_back(getDefinition( + name, common::LogicalTypeID::DATE, common::LogicalTypeID::TIMESTAMP)); + definitions.push_back(getDefinition( + name, common::LogicalTypeID::TIMESTAMP, common::LogicalTypeID::DATE)); return definitions; } private: template - static inline std::unique_ptr getDefinition( - const std::string& name, common::DataTypeID leftTypeID, common::DataTypeID rightTypeID) { + static inline std::unique_ptr getDefinition(const std::string& name, + common::LogicalTypeID leftTypeID, common::LogicalTypeID rightTypeID) { auto execFunc = getExecFunc(leftTypeID, rightTypeID); auto selectFunc = getSelectFunc(leftTypeID, rightTypeID); return std::make_unique(name, - std::vector{leftTypeID, rightTypeID}, common::BOOL, execFunc, - selectFunc); + std::vector{leftTypeID, rightTypeID}, + common::LogicalTypeID::BOOL, execFunc, selectFunc); } template static scalar_exec_func getExecFunc( - common::DataTypeID leftTypeID, common::DataTypeID rightTypeID) { + common::LogicalTypeID leftTypeID, common::LogicalTypeID rightTypeID) { switch (leftTypeID) { - case common::INT64: { + case common::LogicalTypeID::INT64: { return BinaryExecFunction; } - case common::INT32: { + case common::LogicalTypeID::INT32: { return BinaryExecFunction; } - case common::INT16: { + case common::LogicalTypeID::INT16: { return BinaryExecFunction; } - case common::DOUBLE: { + case common::LogicalTypeID::DOUBLE: { return BinaryExecFunction; } - case common::FLOAT: { + case common::LogicalTypeID::FLOAT: { return BinaryExecFunction; } - case common::BOOL: { - assert(rightTypeID == common::BOOL); + case common::LogicalTypeID::BOOL: { + assert(rightTypeID == common::LogicalTypeID::BOOL); return BinaryExecFunction; } - case common::STRING: { - assert(rightTypeID == common::STRING); + case common::LogicalTypeID::STRING: { + assert(rightTypeID == common::LogicalTypeID::STRING); return BinaryExecFunction; } - case common::INTERNAL_ID: { - assert(rightTypeID == common::INTERNAL_ID); + case common::LogicalTypeID::INTERNAL_ID: { + assert(rightTypeID == common::LogicalTypeID::INTERNAL_ID); return BinaryExecFunction; } - case common::DATE: { + case common::LogicalTypeID::DATE: { switch (rightTypeID) { - case common::DATE: { + case common::LogicalTypeID::DATE: { return BinaryExecFunction; } - case common::TIMESTAMP: { + case common::LogicalTypeID::TIMESTAMP: { return BinaryExecFunction; } default: throw common::RuntimeException( - "Invalid input data types(" + common::Types::dataTypeToString(leftTypeID) + - "," + common::Types::dataTypeToString(rightTypeID) + ") for getExecFunc."); + "Invalid input data types(" + + common::LogicalTypeUtils::dataTypeToString(leftTypeID) + "," + + common::LogicalTypeUtils::dataTypeToString(rightTypeID) + ") for getExecFunc."); } } - case common::TIMESTAMP: { + case common::LogicalTypeID::TIMESTAMP: { switch (rightTypeID) { - case common::DATE: { + case common::LogicalTypeID::DATE: { return BinaryExecFunction; } - case common::TIMESTAMP: { + case common::LogicalTypeID::TIMESTAMP: { return BinaryExecFunction; } default: throw common::RuntimeException( - "Invalid input data types(" + common::Types::dataTypeToString(leftTypeID) + - "," + common::Types::dataTypeToString(rightTypeID) + ") for getExecFunc."); + "Invalid input data types(" + + common::LogicalTypeUtils::dataTypeToString(leftTypeID) + "," + + common::LogicalTypeUtils::dataTypeToString(rightTypeID) + ") for getExecFunc."); } } - case common::INTERVAL: { - assert(rightTypeID == common::INTERVAL); + case common::LogicalTypeID::INTERVAL: { + assert(rightTypeID == common::LogicalTypeID::INTERVAL); return BinaryExecFunction; } default: throw common::RuntimeException( - "Invalid input data types(" + common::Types::dataTypeToString(leftTypeID) + "," + - common::Types::dataTypeToString(rightTypeID) + ") for getExecFunc."); + "Invalid input data types(" + + common::LogicalTypeUtils::dataTypeToString(leftTypeID) + "," + + common::LogicalTypeUtils::dataTypeToString(rightTypeID) + ") for getExecFunc."); } } template static scalar_select_func getSelectFunc( - common::DataTypeID leftTypeID, common::DataTypeID rightTypeID) { + common::LogicalTypeID leftTypeID, common::LogicalTypeID rightTypeID) { switch (leftTypeID) { - case common::INT64: { + case common::LogicalTypeID::INT64: { return BinarySelectFunction; } - case common::INT32: { + case common::LogicalTypeID::INT32: { return BinarySelectFunction; } - case common::INT16: { + case common::LogicalTypeID::INT16: { return BinarySelectFunction; } - case common::DOUBLE: { + case common::LogicalTypeID::DOUBLE: { return BinarySelectFunction; } - case common::FLOAT: { + case common::LogicalTypeID::FLOAT: { return BinarySelectFunction; } - case common::BOOL: { - assert(rightTypeID == common::BOOL); + case common::LogicalTypeID::BOOL: { + assert(rightTypeID == common::LogicalTypeID::BOOL); return BinarySelectFunction; } - case common::STRING: { - assert(rightTypeID == common::STRING); + case common::LogicalTypeID::STRING: { + assert(rightTypeID == common::LogicalTypeID::STRING); return BinarySelectFunction; } - case common::INTERNAL_ID: { - assert(rightTypeID == common::INTERNAL_ID); + case common::LogicalTypeID::INTERNAL_ID: { + assert(rightTypeID == common::LogicalTypeID::INTERNAL_ID); return BinarySelectFunction; } - case common::DATE: { + case common::LogicalTypeID::DATE: { switch (rightTypeID) { - case common::DATE: { + case common::LogicalTypeID::DATE: { return BinarySelectFunction; } - case common::TIMESTAMP: { + case common::LogicalTypeID::TIMESTAMP: { return BinarySelectFunction; } default: throw common::RuntimeException( - "Invalid input data types(" + common::Types::dataTypeToString(leftTypeID) + - "," + common::Types::dataTypeToString(rightTypeID) + ") for getSelectFunc."); + "Invalid input data types(" + + common::LogicalTypeUtils::dataTypeToString(leftTypeID) + "," + + common::LogicalTypeUtils::dataTypeToString(rightTypeID) + + ") for getSelectFunc."); } } - case common::TIMESTAMP: { + case common::LogicalTypeID::TIMESTAMP: { switch (rightTypeID) { - case common::DATE: { + case common::LogicalTypeID::DATE: { return BinarySelectFunction; } - case common::TIMESTAMP: { + case common::LogicalTypeID::TIMESTAMP: { return BinarySelectFunction; } default: throw common::RuntimeException( - "Invalid input data types(" + common::Types::dataTypeToString(leftTypeID) + - "," + common::Types::dataTypeToString(rightTypeID) + ") for getSelectFunc."); + "Invalid input data types(" + + common::LogicalTypeUtils::dataTypeToString(leftTypeID) + "," + + common::LogicalTypeUtils::dataTypeToString(rightTypeID) + + ") for getSelectFunc."); } } - case common::INTERVAL: { - assert(rightTypeID == common::INTERVAL); + case common::LogicalTypeID::INTERVAL: { + assert(rightTypeID == common::LogicalTypeID::INTERVAL); return BinarySelectFunction; } default: throw common::RuntimeException( - "Invalid input data types(" + common::Types::dataTypeToString(leftTypeID) + "," + - common::Types::dataTypeToString(rightTypeID) + ") for getSelectFunc."); + "Invalid input data types(" + + common::LogicalTypeUtils::dataTypeToString(leftTypeID) + "," + + common::LogicalTypeUtils::dataTypeToString(rightTypeID) + ") for getSelectFunc."); } } }; diff --git a/src/include/function/function_definition.h b/src/include/function/function_definition.h index 95aeb6217c..0f0491dbd3 100644 --- a/src/include/function/function_definition.h +++ b/src/include/function/function_definition.h @@ -7,9 +7,9 @@ namespace kuzu { namespace function { struct FunctionBindData { - common::DataType resultType; + common::LogicalType resultType; - explicit FunctionBindData(common::DataType dataType) : resultType{std::move(dataType)} {} + explicit FunctionBindData(common::LogicalType dataType) : resultType{std::move(dataType)} {} virtual ~FunctionBindData() = default; }; @@ -20,24 +20,24 @@ using scalar_bind_func = std::function( const binder::expression_vector&, FunctionDefinition* definition)>; struct FunctionDefinition { - FunctionDefinition(std::string name, std::vector parameterTypeIDs, - common::DataTypeID returnTypeID) + FunctionDefinition(std::string name, std::vector parameterTypeIDs, + common::LogicalTypeID returnTypeID) : name{std::move(name)}, parameterTypeIDs{std::move(parameterTypeIDs)}, returnTypeID{ returnTypeID} {} - FunctionDefinition(std::string name, std::vector parameterTypeIDs, - common::DataTypeID returnTypeID, scalar_bind_func bindFunc) + FunctionDefinition(std::string name, std::vector parameterTypeIDs, + common::LogicalTypeID returnTypeID, scalar_bind_func bindFunc) : name{std::move(name)}, parameterTypeIDs{std::move(parameterTypeIDs)}, returnTypeID{returnTypeID}, bindFunc{std::move(bindFunc)} {} inline std::string signatureToString() const { - std::string result = common::Types::dataTypesToString(parameterTypeIDs); - result += " -> " + common::Types::dataTypeToString(returnTypeID); + std::string result = common::LogicalTypeUtils::dataTypesToString(parameterTypeIDs); + result += " -> " + common::LogicalTypeUtils::dataTypeToString(returnTypeID); return result; } std::string name; - std::vector parameterTypeIDs; - common::DataTypeID returnTypeID; + std::vector parameterTypeIDs; + common::LogicalTypeID returnTypeID; // This function is used to bind parameter/return types for functions with nested dataType. scalar_bind_func bindFunc; }; diff --git a/src/include/function/hash/hash_operations.h b/src/include/function/hash/hash_operations.h index d7d8104f9b..6c24135d2e 100644 --- a/src/include/function/hash/hash_operations.h +++ b/src/include/function/hash/hash_operations.h @@ -99,16 +99,6 @@ inline void Hash::operation(const common::ku_string_t& key, common::hash_t& resu result = std::hash()(key.getAsString()); } -template<> -inline void Hash::operation(const common::date_t& key, common::hash_t& result) { - result = murmurhash64(key.days); -} - -template<> -inline void Hash::operation(const common::timestamp_t& key, common::hash_t& result) { - result = murmurhash64(key.value); -} - template<> inline void Hash::operation(const common::interval_t& key, common::hash_t& result) { result = combineHashScalar(murmurhash64(key.months), diff --git a/src/include/function/interval/vector_interval_operations.h b/src/include/function/interval/vector_interval_operations.h index d2399cbcbe..524c0561e1 100644 --- a/src/include/function/interval/vector_interval_operations.h +++ b/src/include/function/interval/vector_interval_operations.h @@ -13,7 +13,8 @@ class VectorIntervalOperations : public VectorOperations { getUnaryIntervalFunctionDefintion(std::string funcName) { std::vector> result; result.push_back(std::make_unique(funcName, - std::vector{common::INT64}, common::INTERVAL, + std::vector{common::LogicalTypeID::INT64}, + common::LogicalTypeID::INTERVAL, UnaryExecFunction)); return result; } diff --git a/src/include/function/list/operations/list_position_operation.h b/src/include/function/list/operations/list_position_operation.h index 85ae0eb594..055f483579 100644 --- a/src/include/function/list/operations/list_position_operation.h +++ b/src/include/function/list/operations/list_position_operation.h @@ -16,7 +16,7 @@ struct ListPosition { static inline void operation(common::list_entry_t& list, T& element, int64_t& result, common::ValueVector& listVector, common::ValueVector& elementVector, common::ValueVector& resultVector) { - if (*listVector.dataType.getChildType() != elementVector.dataType) { + if (*common::VarListType::getChildType(&listVector.dataType) != elementVector.dataType) { result = 0; return; } @@ -37,7 +37,7 @@ template<> void ListPosition::operation(common::list_entry_t& list, common::list_entry_t& element, int64_t& result, common::ValueVector& listVector, common::ValueVector& elementVector, common::ValueVector& resultVector) { - if (*listVector.dataType.getChildType() != elementVector.dataType) { + if (*common::VarListType::getChildType(&listVector.dataType) != elementVector.dataType) { result = 0; return; } diff --git a/src/include/function/list/vector_list_operations.h b/src/include/function/list/vector_list_operations.h index f6fddca479..f4f2fd1ee4 100644 --- a/src/include/function/list/vector_list_operations.h +++ b/src/include/function/list/vector_list_operations.h @@ -27,42 +27,44 @@ struct VectorListOperations : public VectorOperations { template static std::vector> - getBinaryListOperationDefinitions(std::string funcName, common::DataTypeID resultTypeID) { + getBinaryListOperationDefinitions(std::string funcName, common::LogicalTypeID resultTypeID) { std::vector> result; scalar_exec_func execFunc; - for (auto& rightTypeID : std::vector{common::BOOL, common::INT64, - common::DOUBLE, common::STRING, common::DATE, common::TIMESTAMP, common::INTERVAL, - common::VAR_LIST}) { + for (auto& rightTypeID : std::vector{common::LogicalTypeID::BOOL, + common::LogicalTypeID::INT64, common::LogicalTypeID::DOUBLE, + common::LogicalTypeID::STRING, common::LogicalTypeID::DATE, + common::LogicalTypeID::TIMESTAMP, common::LogicalTypeID::INTERVAL, + common::LogicalTypeID::VAR_LIST}) { switch (rightTypeID) { - case common::BOOL: { + case common::LogicalTypeID::BOOL: { execFunc = BinaryListExecFunction; } break; - case common::INT64: { + case common::LogicalTypeID::INT64: { execFunc = BinaryListExecFunction; } break; - case common::DOUBLE: { + case common::LogicalTypeID::DOUBLE: { execFunc = BinaryListExecFunction; } break; - case common::STRING: { + case common::LogicalTypeID::STRING: { execFunc = BinaryListExecFunction; } break; - case common::DATE: { + case common::LogicalTypeID::DATE: { execFunc = BinaryListExecFunction; } break; - case common::TIMESTAMP: { + case common::LogicalTypeID::TIMESTAMP: { execFunc = BinaryListExecFunction; } break; - case common::INTERVAL: { + case common::LogicalTypeID::INTERVAL: { execFunc = BinaryListExecFunction; } break; - case common::VAR_LIST: { + case common::LogicalTypeID::VAR_LIST: { execFunc = BinaryListExecFunction; } break; @@ -71,8 +73,8 @@ struct VectorListOperations : public VectorOperations { } } result.push_back(make_unique(funcName, - std::vector{common::VAR_LIST, rightTypeID}, resultTypeID, - execFunc, nullptr, false /* isVarlength*/)); + std::vector{common::LogicalTypeID::VAR_LIST, rightTypeID}, + resultTypeID, execFunc, nullptr, false /* isVarlength*/)); } return result; } diff --git a/src/include/function/null/null_operation_executor.h b/src/include/function/null/null_operation_executor.h index 7c73905026..06e5584853 100644 --- a/src/include/function/null/null_operation_executor.h +++ b/src/include/function/null/null_operation_executor.h @@ -9,7 +9,7 @@ struct NullOperationExecutor { template static void execute(common::ValueVector& operand, common::ValueVector& result) { - assert(result.dataType.typeID == common::BOOL); + assert(result.dataType.getLogicalTypeID() == common::LogicalTypeID::BOOL); auto resultValues = (uint8_t*)result.getData(); if (operand.state->isFlat()) { auto pos = operand.state->selVector->selectedPositions[0]; diff --git a/src/include/function/schema/vector_offset_operations.h b/src/include/function/schema/vector_offset_operations.h index 9f3ec423da..fff01d7e2b 100644 --- a/src/include/function/schema/vector_offset_operations.h +++ b/src/include/function/schema/vector_offset_operations.h @@ -17,8 +17,8 @@ struct OffsetVectorOperation : public VectorOperations { static std::vector> getDefinitions() { std::vector> definitions; definitions.push_back(make_unique(common::OFFSET_FUNC_NAME, - std::vector{common::INTERNAL_ID}, common::INT64, - OffsetVectorOperation::execFunction)); + std::vector{common::LogicalTypeID::INTERNAL_ID}, + common::LogicalTypeID::INT64, OffsetVectorOperation::execFunction)); return definitions; } }; diff --git a/src/include/function/string/vector_string_operations.h b/src/include/function/string/vector_string_operations.h index 4e28a860c4..a31a6f1555 100644 --- a/src/include/function/string/vector_string_operations.h +++ b/src/include/function/string/vector_string_operations.h @@ -44,7 +44,8 @@ struct VectorStringOperations : public VectorOperations { getUnaryStrFunctionDefintion(std::string funcName) { std::vector> definitions; definitions.emplace_back(std::make_unique(funcName, - std::vector{common::STRING}, common::STRING, + std::vector{common::LogicalTypeID::STRING}, + common::LogicalTypeID::STRING, UnaryStringExecFunction, false /* isVarLength */)); return definitions; diff --git a/src/include/function/struct/vector_struct_operations.h b/src/include/function/struct/vector_struct_operations.h index 53a1d3aeb3..97cef9bc84 100644 --- a/src/include/function/struct/vector_struct_operations.h +++ b/src/include/function/struct/vector_struct_operations.h @@ -19,7 +19,7 @@ struct StructPackVectorOperations : public VectorOperations { struct StructExtractBindData : public FunctionBindData { common::vector_idx_t childIdx; - StructExtractBindData(common::DataType dataType, common::vector_idx_t childIdx) + StructExtractBindData(common::LogicalType dataType, common::vector_idx_t childIdx) : FunctionBindData{std::move(dataType)}, childIdx{childIdx} {} }; diff --git a/src/include/function/vector_operations.h b/src/include/function/vector_operations.h index 563e38785e..6a2782375d 100644 --- a/src/include/function/vector_operations.h +++ b/src/include/function/vector_operations.h @@ -20,21 +20,21 @@ using scalar_select_func = std::function parameterTypeIDs, - common::DataTypeID returnTypeID, scalar_exec_func execFunc, bool isVarLength = false) + VectorOperationDefinition(std::string name, std::vector parameterTypeIDs, + common::LogicalTypeID returnTypeID, scalar_exec_func execFunc, bool isVarLength = false) : VectorOperationDefinition{std::move(name), std::move(parameterTypeIDs), returnTypeID, std::move(execFunc), nullptr, isVarLength} {} - VectorOperationDefinition(std::string name, std::vector parameterTypeIDs, - common::DataTypeID returnTypeID, scalar_exec_func execFunc, scalar_select_func selectFunc, - bool isVarLength = false) + VectorOperationDefinition(std::string name, std::vector parameterTypeIDs, + common::LogicalTypeID returnTypeID, scalar_exec_func execFunc, + scalar_select_func selectFunc, bool isVarLength = false) : FunctionDefinition{std::move(name), std::move(parameterTypeIDs), returnTypeID}, execFunc{std::move(execFunc)}, selectFunc(std::move(selectFunc)), isVarLength{isVarLength} {} - VectorOperationDefinition(std::string name, std::vector parameterTypeIDs, - common::DataTypeID returnTypeID, scalar_exec_func execFunc, scalar_select_func selectFunc, - scalar_bind_func bindFunc, bool isVarLength = false) + VectorOperationDefinition(std::string name, std::vector parameterTypeIDs, + common::LogicalTypeID returnTypeID, scalar_exec_func execFunc, + scalar_select_func selectFunc, scalar_bind_func bindFunc, bool isVarLength = false) : FunctionDefinition{std::move(name), std::move(parameterTypeIDs), returnTypeID, std::move(bindFunc)}, execFunc{std::move(execFunc)}, diff --git a/src/include/main/query_result.h b/src/include/main/query_result.h index 668a8e0104..8164cd76a1 100644 --- a/src/include/main/query_result.h +++ b/src/include/main/query_result.h @@ -11,15 +11,15 @@ namespace main { struct DataTypeInfo { public: - DataTypeInfo(common::DataTypeID typeID, std::string name) + DataTypeInfo(common::LogicalTypeID typeID, std::string name) : typeID{typeID}, name{std::move(name)} {} - common::DataTypeID typeID; + common::LogicalTypeID typeID; std::string name; std::vector> childrenTypesInfo; static std::unique_ptr getInfoForDataType( - const common::DataType& type, const std::string& name); + const common::LogicalType& type, const std::string& name); }; /** @@ -58,7 +58,7 @@ class QueryResult { /** * @return dataType of each column in query result. */ - KUZU_API std::vector getColumnDataTypes(); + KUZU_API std::vector getColumnDataTypes(); /** * @return num of tuples in query result. */ @@ -110,7 +110,7 @@ class QueryResult { // header information std::vector columnNames; - std::vector columnDataTypes; + std::vector columnDataTypes; // data std::shared_ptr factorizedTable; std::unique_ptr iterator; diff --git a/src/include/planner/logical_plan/logical_operator/logical_add_property.h b/src/include/planner/logical_plan/logical_operator/logical_add_property.h index 1c98e744ec..2aae996602 100644 --- a/src/include/planner/logical_plan/logical_operator/logical_add_property.h +++ b/src/include/planner/logical_plan/logical_operator/logical_add_property.h @@ -8,7 +8,7 @@ namespace planner { class LogicalAddProperty : public LogicalDDL { public: explicit LogicalAddProperty(common::table_id_t tableID, std::string propertyName, - common::DataType dataType, std::shared_ptr defaultValue, + common::LogicalType dataType, std::shared_ptr defaultValue, std::string tableName, std::shared_ptr outputExpression) : LogicalDDL{LogicalOperatorType::ADD_PROPERTY, std::move(tableName), std::move(outputExpression)}, @@ -19,7 +19,7 @@ class LogicalAddProperty : public LogicalDDL { inline std::string getPropertyName() const { return propertyName; } - inline common::DataType getDataType() const { return dataType; } + inline common::LogicalType getDataType() const { return dataType; } inline std::shared_ptr getDefaultValue() const { return defaultValue; } @@ -31,7 +31,7 @@ class LogicalAddProperty : public LogicalDDL { private: common::table_id_t tableID; std::string propertyName; - common::DataType dataType; + common::LogicalType dataType; std::shared_ptr defaultValue; }; diff --git a/src/include/processor/operator/aggregate/aggregate_hash_table.h b/src/include/processor/operator/aggregate/aggregate_hash_table.h index f4cf5b8482..af53394533 100644 --- a/src/include/processor/operator/aggregate/aggregate_hash_table.h +++ b/src/include/processor/operator/aggregate/aggregate_hash_table.h @@ -42,15 +42,15 @@ class AggregateHashTable : public BaseHashTable { public: // Used by distinct aggregate hash table only. AggregateHashTable(storage::MemoryManager& memoryManager, - const std::vector& keysDataTypes, + const std::vector& keysDataTypes, const std::vector>& aggregateFunctions, uint64_t numEntriesToAllocate) - : AggregateHashTable(memoryManager, keysDataTypes, std::vector(), + : AggregateHashTable(memoryManager, keysDataTypes, std::vector(), aggregateFunctions, numEntriesToAllocate) {} AggregateHashTable(storage::MemoryManager& memoryManager, - std::vector keysDataTypes, - std::vector payloadsDataTypes, + std::vector keysDataTypes, + std::vector payloadsDataTypes, const std::vector>& aggregateFunctions, uint64_t numEntriesToAllocate); @@ -181,7 +181,7 @@ class AggregateHashTable : public BaseHashTable { template static bool compareEntryWithKeys(const uint8_t* keyValue, const uint8_t* entry); - static compare_function_t getCompareEntryWithKeysFunc(common::DataTypeID typeId); + static compare_function_t getCompareEntryWithKeysFunc(common::LogicalTypeID typeId); void updateNullAggVectorState( const std::vector& groupByFlatHashKeyVectors, @@ -218,8 +218,8 @@ class AggregateHashTable : public BaseHashTable { common::ValueVector* aggVector, uint64_t multiplicity, uint32_t aggStateOffset); private: - std::vector keyDataTypes; - std::vector dependentKeyDataTypes; + std::vector keyDataTypes; + std::vector dependentKeyDataTypes; std::vector> aggregateFunctions; //! special handling of distinct aggregate @@ -249,7 +249,7 @@ class AggregateHashTableUtils { public: static std::vector> createDistinctHashTables( storage::MemoryManager& memoryManager, - const std::vector& groupByKeyDataTypes, + const std::vector& groupByKeyDataTypes, const std::vector>& aggregateFunctions); }; diff --git a/src/include/processor/operator/ddl/add_node_property.h b/src/include/processor/operator/ddl/add_node_property.h index b829edb7fc..a4a59eba2b 100644 --- a/src/include/processor/operator/ddl/add_node_property.h +++ b/src/include/processor/operator/ddl/add_node_property.h @@ -8,7 +8,7 @@ namespace processor { class AddNodeProperty : public AddProperty { public: AddNodeProperty(catalog::Catalog* catalog, common::table_id_t tableID, std::string propertyName, - common::DataType dataType, + common::LogicalType dataType, std::unique_ptr expressionEvaluator, storage::StorageManager& storageManager, const DataPos& outputPos, uint32_t id, const std::string& paramsString) diff --git a/src/include/processor/operator/ddl/add_property.h b/src/include/processor/operator/ddl/add_property.h index 78786253d8..c77caab6ee 100644 --- a/src/include/processor/operator/ddl/add_property.h +++ b/src/include/processor/operator/ddl/add_property.h @@ -10,7 +10,7 @@ namespace processor { class AddProperty : public DDL { public: AddProperty(catalog::Catalog* catalog, common::table_id_t tableID, std::string propertyName, - common::DataType dataType, + common::LogicalType dataType, std::unique_ptr expressionEvaluator, storage::StorageManager& storageManager, const DataPos& outputPos, uint32_t id, const std::string& paramsString) @@ -38,7 +38,7 @@ class AddProperty : public DDL { protected: common::table_id_t tableID; std::string propertyName; - common::DataType dataType; + common::LogicalType dataType; std::unique_ptr expressionEvaluator; storage::StorageManager& storageManager; }; diff --git a/src/include/processor/operator/ddl/add_rel_property.h b/src/include/processor/operator/ddl/add_rel_property.h index 9b8fae62f6..692196a3e7 100644 --- a/src/include/processor/operator/ddl/add_rel_property.h +++ b/src/include/processor/operator/ddl/add_rel_property.h @@ -10,7 +10,7 @@ class AddRelProperty; class AddRelProperty : public AddProperty { public: AddRelProperty(catalog::Catalog* catalog, common::table_id_t tableID, std::string propertyName, - common::DataType dataType, + common::LogicalType dataType, std::unique_ptr expressionEvaluator, storage::StorageManager& storageManager, const DataPos& outputPos, uint32_t id, const std::string& paramsString) diff --git a/src/include/processor/operator/hash_join/hash_join_build.h b/src/include/processor/operator/hash_join/hash_join_build.h index 7c0860b7e8..24e4cc2103 100644 --- a/src/include/processor/operator/hash_join/hash_join_build.h +++ b/src/include/processor/operator/hash_join/hash_join_build.h @@ -36,8 +36,8 @@ class HashJoinSharedState { struct BuildDataInfo { public: - BuildDataInfo(std::vector> keysPosAndType, - std::vector> payloadsPosAndType, + BuildDataInfo(std::vector> keysPosAndType, + std::vector> payloadsPosAndType, std::vector isPayloadsFlat, std::vector isPayloadsInKeyChunk) : keysPosAndType{std::move(keysPosAndType)}, payloadsPosAndType{std::move( payloadsPosAndType)}, @@ -51,8 +51,8 @@ struct BuildDataInfo { inline uint32_t getNumKeys() const { return keysPosAndType.size(); } public: - std::vector> keysPosAndType; - std::vector> payloadsPosAndType; + std::vector> keysPosAndType; + std::vector> payloadsPosAndType; std::vector isPayloadsFlat; std::vector isPayloadsInKeyChunk; }; diff --git a/src/include/processor/operator/order_by/key_block_merger.h b/src/include/processor/operator/order_by/key_block_merger.h index f6196e42ae..e6b6787f26 100644 --- a/src/include/processor/operator/order_by/key_block_merger.h +++ b/src/include/processor/operator/order_by/key_block_merger.h @@ -17,7 +17,8 @@ struct StrKeyColInfo { isAscOrder{isAscOrder} {} inline uint32_t getEncodingSize() const { - return OrderByKeyEncoder::getEncodingSize(common::DataType(common::STRING)); + return OrderByKeyEncoder::getEncodingSize( + common::LogicalType(common::LogicalTypeID::STRING)); } uint32_t colOffsetInFT; diff --git a/src/include/processor/operator/order_by/order_by.h b/src/include/processor/operator/order_by/order_by.h index 0081cfee9a..5017b68dcf 100644 --- a/src/include/processor/operator/order_by/order_by.h +++ b/src/include/processor/operator/order_by/order_by.h @@ -73,8 +73,8 @@ class SharedFactorizedTablesAndSortedKeyBlocks { struct OrderByDataInfo { public: - OrderByDataInfo(std::vector> keysPosAndType, - std::vector> payloadsPosAndType, + OrderByDataInfo(std::vector> keysPosAndType, + std::vector> payloadsPosAndType, std::vector isPayloadFlat, std::vector isAscOrder, bool mayContainUnflatKey) : keysPosAndType{std::move(keysPosAndType)}, payloadsPosAndType{std::move( payloadsPosAndType)}, @@ -86,8 +86,8 @@ struct OrderByDataInfo { other.isAscOrder, other.mayContainUnflatKey} {} public: - std::vector> keysPosAndType; - std::vector> payloadsPosAndType; + std::vector> keysPosAndType; + std::vector> payloadsPosAndType; std::vector isPayloadFlat; std::vector isAscOrder; // TODO(Ziyi): We should figure out unflat keys in a more general way. diff --git a/src/include/processor/operator/order_by/order_by_key_encoder.h b/src/include/processor/operator/order_by/order_by_key_encoder.h index 05907e6f6e..3918fe398d 100644 --- a/src/include/processor/operator/order_by/order_by_key_encoder.h +++ b/src/include/processor/operator/order_by/order_by_key_encoder.h @@ -84,7 +84,7 @@ class OrderByKeyEncoder { static uint32_t getNumBytesPerTuple( const std::vector>& keyVectors); - static uint32_t getEncodingSize(const common::DataType& dataType); + static uint32_t getEncodingSize(const common::LogicalType& dataType); void encodeKeys(); @@ -101,8 +101,8 @@ class OrderByKeyEncoder { assert(false); } - void flipBytesIfNecessary( - uint32_t keyColIdx, uint8_t* tuplePtr, uint32_t numEntriesToEncode, common::DataType& type); + void flipBytesIfNecessary(uint32_t keyColIdx, uint8_t* tuplePtr, uint32_t numEntriesToEncode, + common::LogicalType& type); void encodeFlatVector(common::ValueVector* vector, uint8_t* tuplePtr, uint32_t keyColIdx); @@ -116,7 +116,7 @@ class OrderByKeyEncoder { void allocateMemoryIfFull(); - static encode_function_t getEncodingFunction(common::DataTypeID typeId); + static encode_function_t getEncodingFunction(common::PhysicalTypeID physicalType); private: storage::MemoryManager* memoryManager; diff --git a/src/include/processor/operator/result_collector.h b/src/include/processor/operator/result_collector.h index 274ada99c7..0676996300 100644 --- a/src/include/processor/operator/result_collector.h +++ b/src/include/processor/operator/result_collector.h @@ -44,7 +44,7 @@ class FTableSharedState { class ResultCollector : public Sink { public: ResultCollector(std::unique_ptr resultSetDescriptor, - std::vector> payloadsPosAndType, + std::vector> payloadsPosAndType, std::vector isPayloadFlat, std::shared_ptr sharedState, std::unique_ptr child, uint32_t id, const std::string& paramsString) : Sink{std::move(resultSetDescriptor), PhysicalOperatorType::RESULT_COLLECTOR, @@ -72,7 +72,7 @@ class ResultCollector : public Sink { std::unique_ptr populateTableSchema(); private: - std::vector> payloadsPosAndType; + std::vector> payloadsPosAndType; std::vector isPayloadFlat; std::vector vectorsToCollect; std::shared_ptr sharedState; diff --git a/src/include/processor/operator/unwind.h b/src/include/processor/operator/unwind.h index ce52b0d2aa..abd49a3307 100644 --- a/src/include/processor/operator/unwind.h +++ b/src/include/processor/operator/unwind.h @@ -11,7 +11,7 @@ namespace processor { class Unwind : public PhysicalOperator { public: - Unwind(common::DataType outDataType, DataPos outDataPos, + Unwind(common::LogicalType outDataType, DataPos outDataPos, std::unique_ptr expressionEvaluator, std::unique_ptr child, uint32_t id, const std::string& paramsString) : PhysicalOperator{PhysicalOperatorType::UNWIND, std::move(child), id, paramsString}, @@ -31,7 +31,7 @@ class Unwind : public PhysicalOperator { bool hasMoreToRead() const; void copyTuplesToOutVector(uint64_t startPos, uint64_t endPos) const; - common::DataType outDataType; + common::LogicalType outDataType; DataPos outDataPos; std::unique_ptr expressionEvaluator; diff --git a/src/include/processor/result/factorized_table.h b/src/include/processor/result/factorized_table.h index 3dde3ccbbe..f318f47c96 100644 --- a/src/include/processor/result/factorized_table.h +++ b/src/include/processor/result/factorized_table.h @@ -255,9 +255,10 @@ class FactorizedTable { // inside overflowFileOfInMemList. void copyToInMemList(ft_col_idx_t colIdx, std::vector& tupleIdxesToRead, uint8_t* data, common::NullMask* nullMask, uint64_t startElemPosInList, - storage::DiskOverflowFile* overflowFileOfInMemList, const common::DataType& type) const; + storage::DiskOverflowFile* overflowFileOfInMemList, const common::LogicalType& type) const; void clear(); int64_t findValueInFlatColumn(ft_col_idx_t colIdx, int64_t value) const; + static uint32_t getDataTypeSize(const common::LogicalType& type); private: void setOverflowColNull(uint8_t* nullBuffer, ft_col_idx_t colIdx, ft_tuple_idx_t tupleIdx); @@ -310,7 +311,7 @@ class FactorizedTable { readFlatColToFlatVector(tuplesToRead, colIdx, vector) : readFlatColToUnflatVector(tuplesToRead, colIdx, vector, numTuplesToRead); } - static void copyOverflowIfNecessary(uint8_t* dst, uint8_t* src, const common::DataType& type, + static void copyOverflowIfNecessary(uint8_t* dst, uint8_t* src, const common::LogicalType& type, storage::DiskOverflowFile* diskOverflowFile); private: diff --git a/src/include/storage/copier/npy_reader.h b/src/include/storage/copier/npy_reader.h index 56d1b94d30..75009fcd52 100644 --- a/src/include/storage/copier/npy_reader.h +++ b/src/include/storage/copier/npy_reader.h @@ -26,11 +26,12 @@ class NpyReader { inline size_t getNumRows() const { return shape[0]; } // Used in tests only. - inline common::DataTypeID getType() const { return type; } + inline common::LogicalTypeID getType() const { return type; } inline std::vector const& getShape() const { return shape; } inline size_t getNumDimensions() const { return shape.size(); } - void validate(common::DataType& type_, common::offset_t numRows, const std::string& tableName); + void validate( + common::LogicalType& type_, common::offset_t numRows, const std::string& tableName); private: void parseHeader(); @@ -43,7 +44,7 @@ class NpyReader { void* mmapRegion; size_t dataOffset; std::vector shape; - common::DataTypeID type; + common::LogicalTypeID type; }; } // namespace storage diff --git a/src/include/storage/copier/rel_copy_executor.h b/src/include/storage/copier/rel_copy_executor.h index aa8a8b8abd..7e4bfa8c81 100644 --- a/src/include/storage/copier/rel_copy_executor.h +++ b/src/include/storage/copier/rel_copy_executor.h @@ -70,7 +70,7 @@ class RelCopyExecutor : public TableCopyExecutor { template static void inferTableIDsAndOffsets(const std::vector>& batchColumns, - std::vector& nodeIDs, std::vector& nodeIDTypes, + std::vector& nodeIDs, std::vector& nodeIDTypes, const std::map& pkIndexes, transaction::Transaction* transaction, int64_t blockOffset, int64_t& colIndex); @@ -92,7 +92,7 @@ class RelCopyExecutor : public TableCopyExecutor { InMemOverflowFile* unorderedOverflowFile, InMemOverflowFile* orderedOverflowFile); static void copyListOverflowFromUnorderedToOrderedPages(common::ku_list_t* kuList, - const common::DataType& dataType, PageByteCursor& unorderedOverflowCursor, + const common::LogicalType& dataType, PageByteCursor& unorderedOverflowCursor, PageByteCursor& orderedOverflowCursor, InMemOverflowFile* unorderedOverflowFile, InMemOverflowFile* orderedOverflowFile); @@ -107,11 +107,11 @@ class RelCopyExecutor : public TableCopyExecutor { RelCopyExecutor* copier, const std::vector>& batchColumns, const std::string& filePath); - static void sortOverflowValuesOfPropertyColumnTask(const common::DataType& dataType, + static void sortOverflowValuesOfPropertyColumnTask(const common::LogicalType& dataType, common::offset_t offsetStart, common::offset_t offsetEnd, InMemColumnChunk* propertyColumn, InMemOverflowFile* unorderedInMemOverflowFile, InMemOverflowFile* orderedInMemOverflowFile); - static void sortOverflowValuesOfPropertyListsTask(const common::DataType& dataType, + static void sortOverflowValuesOfPropertyListsTask(const common::LogicalType& dataType, common::offset_t offsetStart, common::offset_t offsetEnd, InMemAdjLists* adjLists, InMemLists* propertyLists, InMemOverflowFile* unorderedInMemOverflowFile, InMemOverflowFile* orderedInMemOverflowFile); diff --git a/src/include/storage/copier/table_copy_executor.h b/src/include/storage/copier/table_copy_executor.h index e281fbfa02..e5b3366b0f 100644 --- a/src/include/storage/copier/table_copy_executor.h +++ b/src/include/storage/copier/table_copy_executor.h @@ -43,10 +43,10 @@ class TableCopyExecutor { static void throwCopyExceptionIfNotOK(const arrow::Status& status); static std::unique_ptr getArrowVarList(const std::string& l, int64_t from, - int64_t to, const common::DataType& dataType, + int64_t to, const common::LogicalType& dataType, const common::CopyDescription& copyDescription); static std::unique_ptr getArrowFixedList(const std::string& l, int64_t from, - int64_t to, const common::DataType& dataType, + int64_t to, const common::LogicalType& dataType, const common::CopyDescription& copyDescription); static std::shared_ptr createCSVReader(const std::string& filePath, common::CSVReaderConfig* csvReaderConfig, catalog::TableSchema* tableSchema); @@ -77,11 +77,11 @@ class TableCopyExecutor { tablesStatistics->setNumTuplesForTable(tableSchema->tableID, numRows); } - static std::shared_ptr toArrowDataType(const common::DataType& dataType); + static std::shared_ptr toArrowDataType(const common::LogicalType& dataType); private: static std::unique_ptr convertStringToValue(std::string element, - const common::DataType& type, const common::CopyDescription& copyDescription); + const common::LogicalType& type, const common::CopyDescription& copyDescription); protected: std::shared_ptr logger; diff --git a/src/include/storage/in_mem_storage_structure/in_mem_column.h b/src/include/storage/in_mem_storage_structure/in_mem_column.h index 4b0f0b7ce1..bba388af62 100644 --- a/src/include/storage/in_mem_storage_structure/in_mem_column.h +++ b/src/include/storage/in_mem_storage_structure/in_mem_column.h @@ -6,16 +6,17 @@ namespace kuzu { namespace storage { // TODO(GUODONG): Currently, we have both InMemNodeColumn and InMemColumn. This is a temporary // solution for now to allow gradual refactorings. Eventually, we should only have InMemColumn. + class InMemColumn { public: - InMemColumn(std::string filePath, common::DataType dataType, bool requireNullBits = true); + InMemColumn(std::string filePath, common::LogicalType dataType, bool requireNullBits = true); // Encode and flush null bits. void saveToFile(); void flushChunk(InMemColumnChunk* chunk); - inline common::DataType getDataType() { return dataType; } + inline common::LogicalType getDataType() { return dataType; } inline InMemOverflowFile* getInMemOverflowFile() { return inMemOverflowFile.get(); } inline uint16_t getNumBytesForValue() const { return numBytesForValue; } @@ -23,7 +24,7 @@ class InMemColumn { std::string filePath; uint16_t numBytesForValue; std::unique_ptr fileHandle; - common::DataType dataType; + common::LogicalType dataType; std::unique_ptr nullColumn; std::unique_ptr inMemOverflowFile; std::vector> childColumns; diff --git a/src/include/storage/in_mem_storage_structure/in_mem_column_chunk.h b/src/include/storage/in_mem_storage_structure/in_mem_column_chunk.h index bf067cd390..f8fc16b3b1 100644 --- a/src/include/storage/in_mem_storage_structure/in_mem_column_chunk.h +++ b/src/include/storage/in_mem_storage_structure/in_mem_column_chunk.h @@ -13,10 +13,10 @@ namespace storage { class InMemColumnChunk { public: - InMemColumnChunk(common::DataType dataType, common::offset_t startNodeOffset, + InMemColumnChunk(common::LogicalType dataType, common::offset_t startNodeOffset, common::offset_t endNodeOffset, bool requireNullBits = true); - inline common::DataType getDataType() const { return dataType; } + inline common::LogicalType getDataType() const { return dataType; } template inline T getValue(common::offset_t pos) const { @@ -94,10 +94,10 @@ class InMemColumnChunk { ((T*)buffer.get())[pos] = val; } - static uint32_t getDataTypeSizeInColumn(common::DataType& dataType); + static uint32_t getDataTypeSizeInColumn(common::LogicalType& dataType); private: - common::DataType dataType; + common::LogicalType dataType; common::offset_t startNodeOffset; std::uint64_t numBytesPerValue; std::uint64_t numBytes; @@ -113,7 +113,7 @@ class InMemStructColumnChunk { static void setValueToStructColumnField(InMemColumnChunk* chunk, common::offset_t pos, common::field_idx_t structFieldIdx, const std::string& structFieldValue, - const common::DataType& dataType); + const common::LogicalType& dataType); }; template<> diff --git a/src/include/storage/in_mem_storage_structure/in_mem_lists.h b/src/include/storage/in_mem_storage_structure/in_mem_lists.h index 92d5a8aa84..03f36b61f8 100644 --- a/src/include/storage/in_mem_storage_structure/in_mem_lists.h +++ b/src/include/storage/in_mem_storage_structure/in_mem_lists.h @@ -14,7 +14,7 @@ class AdjLists; using fill_in_mem_lists_function_t = std::function; + common::offset_t nodeOffset, uint64_t posInList, const common::LogicalType& dataType)>; class InMemListsUtils { public: @@ -32,7 +32,7 @@ class InMemListsUtils { class InMemLists { public: - InMemLists(std::string fName, common::DataType dataType, uint64_t numBytesForElement, + InMemLists(std::string fName, common::LogicalType dataType, uint64_t numBytesForElement, uint64_t numNodes, std::shared_ptr listHeadersBuilder) : InMemLists{std::move(fName), std::move(dataType), numBytesForElement, numNodes} { this->listHeadersBuilder = std::move(listHeadersBuilder); @@ -57,7 +57,7 @@ class InMemLists { common::offset_t nodeOffset, bool hasNULLBytes); protected: - InMemLists(std::string fName, common::DataType dataType, uint64_t numBytesForElement, + InMemLists(std::string fName, common::LogicalType dataType, uint64_t numBytesForElement, uint64_t numNodes); private: @@ -66,23 +66,23 @@ class InMemLists { static inline void fillInMemListsWithNonOverflowValFunc(InMemLists* inMemLists, uint8_t* defaultVal, PageByteCursor& pageByteCursor, common::offset_t nodeOffset, - uint64_t posInList, const common::DataType& dataType) { + uint64_t posInList, const common::LogicalType& dataType) { inMemLists->setElement(nodeOffset, posInList, defaultVal); } static void fillInMemListsWithStrValFunc(InMemLists* inMemLists, uint8_t* defaultVal, PageByteCursor& pageByteCursor, common::offset_t nodeOffset, uint64_t posInList, - const common::DataType& dataType); + const common::LogicalType& dataType); static void fillInMemListsWithListValFunc(InMemLists* inMemLists, uint8_t* defaultVal, PageByteCursor& pageByteCursor, common::offset_t nodeOffset, uint64_t posInList, - const common::DataType& dataType); - static fill_in_mem_lists_function_t getFillInMemListsFunc(const common::DataType& dataType); + const common::LogicalType& dataType); + static fill_in_mem_lists_function_t getFillInMemListsFunc(const common::LogicalType& dataType); public: std::unique_ptr inMemFile; protected: std::string fName; - common::DataType dataType; + common::LogicalType dataType; uint64_t numBytesForElement; std::unique_ptr listsMetadataBuilder; std::shared_ptr listHeadersBuilder; @@ -92,13 +92,13 @@ class InMemRelIDLists : public InMemLists { public: InMemRelIDLists(std::string fName, uint64_t numNodes, std::shared_ptr listHeadersBuilder) - : InMemLists{std::move(fName), common::DataType{common::INTERNAL_ID}, + : InMemLists{std::move(fName), common::LogicalType{common::LogicalTypeID::INTERNAL_ID}, sizeof(common::offset_t), numNodes, std::move(listHeadersBuilder)} {} }; class InMemListsWithOverflow : public InMemLists { protected: - InMemListsWithOverflow(std::string fName, common::DataType dataType, uint64_t numNodes, + InMemListsWithOverflow(std::string fName, common::LogicalType dataType, uint64_t numNodes, std::shared_ptr listHeadersBuilder); InMemOverflowFile* getInMemOverflowFile() override { return overflowInMemFile.get(); } @@ -111,7 +111,7 @@ class InMemListsWithOverflow : public InMemLists { class InMemAdjLists : public InMemLists { public: InMemAdjLists(std::string fName, uint64_t numNodes) - : InMemLists{std::move(fName), common::DataType(common::INTERNAL_ID), + : InMemLists{std::move(fName), common::LogicalType(common::LogicalTypeID::INTERNAL_ID), sizeof(common::offset_t), numNodes} { listHeadersBuilder = make_shared(this->fName, numNodes); }; @@ -132,13 +132,14 @@ class InMemStringLists : public InMemListsWithOverflow { public: InMemStringLists(std::string fName, uint64_t numNodes, std::shared_ptr listHeadersBuilder) - : InMemListsWithOverflow{std::move(fName), common::DataType(common::STRING), numNodes, + : InMemListsWithOverflow{std::move(fName), + common::LogicalType(common::LogicalTypeID::STRING), numNodes, std::move(listHeadersBuilder)} {}; }; class InMemListLists : public InMemListsWithOverflow { public: - InMemListLists(std::string fName, common::DataType dataType, uint64_t numNodes, + InMemListLists(std::string fName, common::LogicalType dataType, uint64_t numNodes, std::shared_ptr listHeadersBuilder) : InMemListsWithOverflow{ std::move(fName), std::move(dataType), numNodes, std::move(listHeadersBuilder)} {}; @@ -147,7 +148,7 @@ class InMemListLists : public InMemListsWithOverflow { class InMemListsFactory { public: static std::unique_ptr getInMemPropertyLists(const std::string& fName, - const common::DataType& dataType, uint64_t numNodes, + const common::LogicalType& dataType, uint64_t numNodes, std::shared_ptr listHeadersBuilder = nullptr); }; diff --git a/src/include/storage/index/hash_index.h b/src/include/storage/index/hash_index.h index 0b526d90fc..252c62e904 100644 --- a/src/include/storage/index/hash_index.h +++ b/src/include/storage/index/hash_index.h @@ -35,7 +35,7 @@ class TemplatedHashIndexLocalStorage { // and deletions are very small, thus they can be kept in memory. class HashIndexLocalStorage { public: - explicit HashIndexLocalStorage(common::DataType keyDataType) + explicit HashIndexLocalStorage(common::LogicalType keyDataType) : keyDataType{std::move(keyDataType)} {} // Currently, we assume that reads(lookup) and writes(delete/insert) of the local storage will // never happen concurrently. Thus, lookup requires no local storage lock. Writes are @@ -54,7 +54,7 @@ class HashIndexLocalStorage { std::shared_mutex localStorageSharedMutex; private: - common::DataType keyDataType; + common::LogicalType keyDataType; TemplatedHashIndexLocalStorage templatedLocalStorageForInt; TemplatedHashIndexLocalStorage templatedLocalStorageForString; }; @@ -82,7 +82,7 @@ class HashIndex : public BaseHashIndex { public: HashIndex(const StorageStructureIDAndFName& storageStructureIDAndFName, - const common::DataType& keyDataType, BufferManager& bufferManager, WAL* wal); + const common::LogicalType& keyDataType, BufferManager& bufferManager, WAL* wal); public: bool lookupInternal( @@ -148,9 +148,9 @@ class PrimaryKeyIndex { public: PrimaryKeyIndex(const StorageStructureIDAndFName& storageStructureIDAndFName, - const common::DataType& keyDataType, BufferManager& bufferManager, WAL* wal) - : keyDataTypeID{keyDataType.typeID} { - if (keyDataTypeID == common::INT64) { + const common::LogicalType& keyDataType, BufferManager& bufferManager, WAL* wal) + : keyDataTypeID{keyDataType.getLogicalTypeID()} { + if (keyDataTypeID == common::LogicalTypeID::INT64) { hashIndexForInt64 = std::make_unique>( storageStructureIDAndFName, keyDataType, bufferManager, wal); } else { @@ -169,59 +169,62 @@ class PrimaryKeyIndex { // These two lookups are used by InMemRelCSVCopier. inline bool lookup( transaction::Transaction* transaction, int64_t key, common::offset_t& result) { - assert(keyDataTypeID == common::INT64); + assert(keyDataTypeID == common::LogicalTypeID::INT64); return hashIndexForInt64->lookupInternal( transaction, reinterpret_cast(&key), result); } inline bool lookup( transaction::Transaction* transaction, const char* key, common::offset_t& result) { - assert(keyDataTypeID == common::STRING); + assert(keyDataTypeID == common::LogicalTypeID::STRING); return hashIndexForString->lookupInternal( transaction, reinterpret_cast(key), result); } inline void checkpointInMemoryIfNecessary() { - keyDataTypeID == common::INT64 ? hashIndexForInt64->checkpointInMemoryIfNecessary() : - hashIndexForString->checkpointInMemoryIfNecessary(); + keyDataTypeID == common::LogicalTypeID::INT64 ? + hashIndexForInt64->checkpointInMemoryIfNecessary() : + hashIndexForString->checkpointInMemoryIfNecessary(); } inline void rollbackInMemoryIfNecessary() { - keyDataTypeID == common::INT64 ? hashIndexForInt64->rollbackInMemoryIfNecessary() : - hashIndexForString->rollbackInMemoryIfNecessary(); + keyDataTypeID == common::LogicalTypeID::INT64 ? + hashIndexForInt64->rollbackInMemoryIfNecessary() : + hashIndexForString->rollbackInMemoryIfNecessary(); } inline void prepareCommitOrRollbackIfNecessary(bool isCommit) { - return keyDataTypeID == common::INT64 ? + return keyDataTypeID == common::LogicalTypeID::INT64 ? hashIndexForInt64->prepareCommitOrRollbackIfNecessary(isCommit) : hashIndexForString->prepareCommitOrRollbackIfNecessary(isCommit); } inline BMFileHandle* getFileHandle() { - return keyDataTypeID == common::INT64 ? hashIndexForInt64->getFileHandle() : - hashIndexForString->getFileHandle(); + return keyDataTypeID == common::LogicalTypeID::INT64 ? hashIndexForInt64->getFileHandle() : + hashIndexForString->getFileHandle(); } inline DiskOverflowFile* getDiskOverflowFile() { - return keyDataTypeID == common::STRING ? hashIndexForString->diskOverflowFile.get() : - nullptr; + return keyDataTypeID == common::LogicalTypeID::STRING ? + hashIndexForString->diskOverflowFile.get() : + nullptr; } private: inline void deleteKey(int64_t key) { - assert(keyDataTypeID == common::INT64); + assert(keyDataTypeID == common::LogicalTypeID::INT64); hashIndexForInt64->deleteInternal(reinterpret_cast(&key)); } inline void deleteKey(const char* key) { - assert(keyDataTypeID == common::STRING); + assert(keyDataTypeID == common::LogicalTypeID::STRING); hashIndexForString->deleteInternal(reinterpret_cast(key)); } inline bool insert(int64_t key, common::offset_t value) { - assert(keyDataTypeID == common::INT64); + assert(keyDataTypeID == common::LogicalTypeID::INT64); return hashIndexForInt64->insertInternal(reinterpret_cast(&key), value); } inline bool insert(const char* key, common::offset_t value) { - assert(keyDataTypeID == common::STRING); + assert(keyDataTypeID == common::LogicalTypeID::STRING); return hashIndexForString->insertInternal(reinterpret_cast(key), value); } private: - common::DataTypeID keyDataTypeID; + common::LogicalTypeID keyDataTypeID; std::unique_ptr> hashIndexForInt64; std::unique_ptr> hashIndexForString; }; diff --git a/src/include/storage/index/hash_index_builder.h b/src/include/storage/index/hash_index_builder.h index 06903ea44a..4f471f2037 100644 --- a/src/include/storage/index/hash_index_builder.h +++ b/src/include/storage/index/hash_index_builder.h @@ -50,8 +50,8 @@ struct SlotInfo { class BaseHashIndex { public: - explicit BaseHashIndex(const common::DataType& keyDataType) { - keyHashFunc = HashIndexUtils::initializeHashFunc(keyDataType.typeID); + explicit BaseHashIndex(const common::LogicalType& keyDataType) { + keyHashFunc = HashIndexUtils::initializeHashFunc(keyDataType.getLogicalTypeID()); } virtual ~BaseHashIndex() = default; @@ -74,7 +74,7 @@ template class HashIndexBuilder : public BaseHashIndex { public: - HashIndexBuilder(const std::string& fName, const common::DataType& keyDataType); + HashIndexBuilder(const std::string& fName, const common::LogicalType& keyDataType); public: // Reserves space for at least the specified number of elements. @@ -133,32 +133,33 @@ class HashIndexBuilder : public BaseHashIndex { class PrimaryKeyIndexBuilder { public: - PrimaryKeyIndexBuilder(const std::string& fName, const common::DataType& keyDataType) - : keyDataTypeID{keyDataType.typeID} { + PrimaryKeyIndexBuilder(const std::string& fName, const common::LogicalType& keyDataType) + : keyDataTypeID{keyDataType.getLogicalTypeID()} { switch (keyDataTypeID) { - case common::INT64: { + case common::LogicalTypeID::INT64: { hashIndexBuilderForInt64 = std::make_unique>(fName, keyDataType); } break; - case common::STRING: { + case common::LogicalTypeID::STRING: { hashIndexBuilderForString = std::make_unique>(fName, keyDataType); } break; default: { - throw common::Exception( - "Unsupported data type for primary key index: " + std::to_string(keyDataTypeID)); + throw common::Exception("Unsupported data type for primary key index: " + + common::LogicalTypeUtils::dataTypeToString(keyDataTypeID)); } } } inline void bulkReserve(uint32_t numEntries) { - keyDataTypeID == common::INT64 ? hashIndexBuilderForInt64->bulkReserve(numEntries) : - hashIndexBuilderForString->bulkReserve(numEntries); + keyDataTypeID == common::LogicalTypeID::INT64 ? + hashIndexBuilderForInt64->bulkReserve(numEntries) : + hashIndexBuilderForString->bulkReserve(numEntries); } // Note: append assumes that bulkRserve has been called before it and the index has reserved // enough space already. inline void append(int64_t key, common::offset_t value) { - auto retVal = keyDataTypeID == common::INT64 ? + auto retVal = keyDataTypeID == common::LogicalTypeID::INT64 ? hashIndexBuilderForInt64->append(key, value) : hashIndexBuilderForString->append(key, value); if (!retVal) { @@ -167,7 +168,7 @@ class PrimaryKeyIndexBuilder { } } inline void append(const char* key, common::offset_t value) { - auto retVal = keyDataTypeID == common::INT64 ? + auto retVal = keyDataTypeID == common::LogicalTypeID::INT64 ? hashIndexBuilderForInt64->append(key, value) : hashIndexBuilderForString->append(key, value); if (!retVal) { @@ -176,18 +177,19 @@ class PrimaryKeyIndexBuilder { } } inline bool lookup(int64_t key, common::offset_t& result) { - return keyDataTypeID == common::INT64 ? hashIndexBuilderForInt64->lookup(key, result) : - hashIndexBuilderForString->lookup(key, result); + return keyDataTypeID == common::LogicalTypeID::INT64 ? + hashIndexBuilderForInt64->lookup(key, result) : + hashIndexBuilderForString->lookup(key, result); } // Non-thread safe. This should only be called in the copyCSV and never be called in parallel. inline void flush() { - keyDataTypeID == common::INT64 ? hashIndexBuilderForInt64->flush() : - hashIndexBuilderForString->flush(); + keyDataTypeID == common::LogicalTypeID::INT64 ? hashIndexBuilderForInt64->flush() : + hashIndexBuilderForString->flush(); } private: - common::DataTypeID keyDataTypeID; + common::LogicalTypeID keyDataTypeID; std::unique_ptr> hashIndexBuilderForInt64; std::unique_ptr> hashIndexBuilderForString; }; diff --git a/src/include/storage/index/hash_index_header.h b/src/include/storage/index/hash_index_header.h index 8b6bf0b3e9..d3f168bebb 100644 --- a/src/include/storage/index/hash_index_header.h +++ b/src/include/storage/index/hash_index_header.h @@ -2,21 +2,24 @@ #include "common/types/types.h" #include "hash_index_slot.h" +#include "storage/storage_utils.h" namespace kuzu { namespace storage { class HashIndexHeader { public: - explicit HashIndexHeader(common::DataTypeID keyDataTypeID) + explicit HashIndexHeader(common::LogicalTypeID keyDataTypeID) : currentLevel{1}, levelHashMask{1}, higherLevelHashMask{3}, nextSplitSlotId{0}, - numEntries{0}, numBytesPerKey{common::Types::getDataTypeSize(keyDataTypeID)}, - numBytesPerEntry{ - (uint32_t)(common::Types::getDataTypeSize(keyDataTypeID) + sizeof(common::offset_t))}, + numEntries{0}, numBytesPerKey{storage::StorageUtils::getDataTypeSize( + common::LogicalType{keyDataTypeID})}, + numBytesPerEntry{(uint32_t)( + storage::StorageUtils::getDataTypeSize(common::LogicalType{keyDataTypeID}) + + sizeof(common::offset_t))}, keyDataTypeID{keyDataTypeID} {} // Used for element initialization in disk array only. - HashIndexHeader() : HashIndexHeader(common::STRING) {} + HashIndexHeader() : HashIndexHeader(common::LogicalTypeID::STRING) {} inline void incrementLevel() { currentLevel++; @@ -40,7 +43,7 @@ class HashIndexHeader { uint64_t numEntries; uint32_t numBytesPerKey; uint32_t numBytesPerEntry; - common::DataTypeID keyDataTypeID; + common::LogicalTypeID keyDataTypeID; }; } // namespace storage diff --git a/src/include/storage/index/hash_index_utils.h b/src/include/storage/index/hash_index_utils.h index 13c9852af6..afac444419 100644 --- a/src/include/storage/index/hash_index_utils.h +++ b/src/include/storage/index/hash_index_utils.h @@ -15,8 +15,10 @@ using hash_function_t = std::function; using equals_function_t = std::function; -static const uint32_t NUM_BYTES_FOR_INT64_KEY = common::Types::getDataTypeSize(common::INT64); -static const uint32_t NUM_BYTES_FOR_STRING_KEY = common::Types::getDataTypeSize(common::STRING); +static const uint32_t NUM_BYTES_FOR_INT64_KEY = + storage::StorageUtils::getDataTypeSize(common::LogicalType{common::LogicalTypeID::INT64}); +static const uint32_t NUM_BYTES_FOR_STRING_KEY = + storage::StorageUtils::getDataTypeSize(common::LogicalType{common::LogicalTypeID::STRING}); using in_mem_insert_function_t = std::function; @@ -25,8 +27,8 @@ using in_mem_equals_function_t = class InMemHashIndexUtils { public: - static in_mem_equals_function_t initializeEqualsFunc(common::DataTypeID dataTypeID); - static in_mem_insert_function_t initializeInsertFunc(common::DataTypeID dataTypeID); + static in_mem_equals_function_t initializeEqualsFunc(common::LogicalTypeID dataTypeID); + static in_mem_insert_function_t initializeInsertFunc(common::LogicalTypeID dataTypeID); private: // InsertFunc @@ -64,7 +66,7 @@ class HashIndexUtils { memcpy(entry, &kuString, NUM_BYTES_FOR_STRING_KEY); memcpy(entry + NUM_BYTES_FOR_STRING_KEY, &offset, sizeof(common::offset_t)); } - static insert_function_t initializeInsertFunc(common::DataTypeID dataTypeID); + static insert_function_t initializeInsertFunc(common::LogicalTypeID dataTypeID); // HashFunc inline static common::hash_t hashFuncForInt64(const uint8_t* key) { @@ -77,7 +79,7 @@ class HashIndexUtils { function::operation::Hash::operation(std::string((char*)key), hash); return hash; } - static hash_function_t initializeHashFunc(common::DataTypeID dataTypeID); + static hash_function_t initializeHashFunc(common::LogicalTypeID dataTypeID); // EqualsFunc static bool isStringPrefixAndLenEquals( @@ -88,7 +90,7 @@ class HashIndexUtils { } static bool equalsFuncForString(transaction::TransactionType trxType, const uint8_t* keyToLookup, const uint8_t* keyInEntry, DiskOverflowFile* diskOverflowFile); - static equals_function_t initializeEqualsFunc(common::DataTypeID dataTypeID); + static equals_function_t initializeEqualsFunc(common::LogicalTypeID dataTypeID); }; } // namespace storage } // namespace kuzu diff --git a/src/include/storage/storage_structure/column.h b/src/include/storage/storage_structure/column.h index 309e5e53f9..abc3799ca5 100644 --- a/src/include/storage/storage_structure/column.h +++ b/src/include/storage/storage_structure/column.h @@ -19,15 +19,16 @@ class NullColumn; class Column : public BaseColumnOrList { public: // Currently extended by SERIAL column. - explicit Column(const common::DataType& dataType) : BaseColumnOrList{dataType} {}; + explicit Column(const common::LogicalType& dataType) : BaseColumnOrList{dataType} {}; - Column(const StorageStructureIDAndFName& structureIDAndFName, const common::DataType& dataType, - BufferManager* bufferManager, WAL* wal) - : Column(structureIDAndFName, dataType, common::Types::getDataTypeSize(dataType), + Column(const StorageStructureIDAndFName& structureIDAndFName, + const common::LogicalType& dataType, BufferManager* bufferManager, WAL* wal) + : Column(structureIDAndFName, dataType, storage::StorageUtils::getDataTypeSize(dataType), bufferManager, wal){}; - Column(const StorageStructureIDAndFName& structureIDAndFName, const common::DataType& dataType, - size_t elementSize, BufferManager* bufferManager, WAL* wal, bool requireNullBits = true); + Column(const StorageStructureIDAndFName& structureIDAndFName, + const common::LogicalType& dataType, size_t elementSize, BufferManager* bufferManager, + WAL* wal, bool requireNullBits = true); // Expose for feature store virtual void batchLookup(const common::offset_t* nodeOffsets, size_t size, uint8_t* result); @@ -83,8 +84,8 @@ class NullColumn : public Column { public: NullColumn(const StorageStructureIDAndFName& structureIDAndFName, BufferManager* bufferManager, WAL* wal) - : Column{structureIDAndFName, common::DataType(common::BOOL), sizeof(bool), bufferManager, - wal, false /* requireNullBits */} { + : Column{structureIDAndFName, common::LogicalType(common::LogicalTypeID::BOOL), + sizeof(bool), bufferManager, wal, false /* requireNullBits */} { readDataFunc = NullColumn::readNullsFromPage; } @@ -103,7 +104,7 @@ class NullColumn : public Column { class PropertyColumnWithOverflow : public Column { public: PropertyColumnWithOverflow(const StorageStructureIDAndFName& structureIDAndFNameOfMainColumn, - const common::DataType& dataType, BufferManager* bufferManager, WAL* wal) + const common::LogicalType& dataType, BufferManager* bufferManager, WAL* wal) : Column{structureIDAndFNameOfMainColumn, dataType, bufferManager, wal} { diskOverflowFile = std::make_unique(structureIDAndFNameOfMainColumn, bufferManager, wal); @@ -116,7 +117,7 @@ class PropertyColumnWithOverflow : public Column { class StringPropertyColumn : public PropertyColumnWithOverflow { public: StringPropertyColumn(const StorageStructureIDAndFName& structureIDAndFNameOfMainColumn, - const common::DataType& dataType, BufferManager* bufferManager, WAL* wal) + const common::LogicalType& dataType, BufferManager* bufferManager, WAL* wal) : PropertyColumnWithOverflow{ structureIDAndFNameOfMainColumn, dataType, bufferManager, wal} { writeDataFunc = StringPropertyColumn::writeStringToPage; @@ -149,7 +150,7 @@ class StringPropertyColumn : public PropertyColumnWithOverflow { class ListPropertyColumn : public PropertyColumnWithOverflow { public: ListPropertyColumn(const StorageStructureIDAndFName& structureIDAndFNameOfMainColumn, - const common::DataType& dataType, BufferManager* bufferManager, WAL* wal) + const common::LogicalType& dataType, BufferManager* bufferManager, WAL* wal) : PropertyColumnWithOverflow{ structureIDAndFNameOfMainColumn, dataType, bufferManager, wal} { readDataFunc = ListPropertyColumn::readListsFromPage; @@ -169,7 +170,7 @@ class ListPropertyColumn : public PropertyColumnWithOverflow { class StructPropertyColumn : public Column { public: StructPropertyColumn(const StorageStructureIDAndFName& structureIDAndFName, - const common::DataType& dataType, BufferManager* bufferManager, WAL* wal); + const common::LogicalType& dataType, BufferManager* bufferManager, WAL* wal); void read(transaction::Transaction* transaction, common::ValueVector* nodeIDVector, common::ValueVector* resultVector) final; @@ -182,7 +183,7 @@ class InternalIDColumn : public Column { public: InternalIDColumn(const StorageStructureIDAndFName& structureIDAndFName, BufferManager* bufferManager, WAL* wal) - : Column{structureIDAndFName, common::DataType(common::INTERNAL_ID), + : Column{structureIDAndFName, common::LogicalType(common::LogicalTypeID::INTERNAL_ID), sizeof(common::offset_t), bufferManager, wal, true /* requireNullBits */} { readDataFunc = InternalIDColumn::readInternalIDsFromPage; writeDataFunc = InternalIDColumn::writeInternalIDToPage; @@ -198,7 +199,7 @@ class InternalIDColumn : public Column { class SerialColumn : public Column { public: - SerialColumn() : Column{common::DataType{common::SERIAL}} {} + SerialColumn() : Column{common::LogicalType{common::LogicalTypeID::SERIAL}} {} void read(transaction::Transaction* transaction, common::ValueVector* nodeIDVector, common::ValueVector* resultVector) final; @@ -207,31 +208,31 @@ class SerialColumn : public Column { class ColumnFactory { public: static std::unique_ptr getColumn(const StorageStructureIDAndFName& structureIDAndFName, - const common::DataType& dataType, BufferManager* bufferManager, WAL* wal) { - switch (dataType.typeID) { - case common::INT64: - case common::INT32: - case common::INT16: - case common::DOUBLE: - case common::FLOAT: - case common::BOOL: - case common::DATE: - case common::TIMESTAMP: - case common::INTERVAL: - case common::FIXED_LIST: - return std::make_unique(structureIDAndFName, dataType, bufferManager, wal); - case common::STRING: + const common::LogicalType& logicalType, BufferManager* bufferManager, WAL* wal) { + switch (logicalType.getLogicalTypeID()) { + case common::LogicalTypeID::INT64: + case common::LogicalTypeID::INT32: + case common::LogicalTypeID::INT16: + case common::LogicalTypeID::DOUBLE: + case common::LogicalTypeID::FLOAT: + case common::LogicalTypeID::BOOL: + case common::LogicalTypeID::DATE: + case common::LogicalTypeID::TIMESTAMP: + case common::LogicalTypeID::INTERVAL: + case common::LogicalTypeID::FIXED_LIST: + return std::make_unique(structureIDAndFName, logicalType, bufferManager, wal); + case common::LogicalTypeID::STRING: return std::make_unique( - structureIDAndFName, dataType, bufferManager, wal); - case common::VAR_LIST: + structureIDAndFName, logicalType, bufferManager, wal); + case common::LogicalTypeID::VAR_LIST: return std::make_unique( - structureIDAndFName, dataType, bufferManager, wal); - case common::INTERNAL_ID: + structureIDAndFName, logicalType, bufferManager, wal); + case common::LogicalTypeID::INTERNAL_ID: return std::make_unique(structureIDAndFName, bufferManager, wal); - case common::STRUCT: + case common::LogicalTypeID::STRUCT: return std::make_unique( - structureIDAndFName, dataType, bufferManager, wal); - case common::SERIAL: + structureIDAndFName, logicalType, bufferManager, wal); + case common::LogicalTypeID::SERIAL: return std::make_unique(); default: throw common::StorageException("Invalid type for property column creation."); diff --git a/src/include/storage/storage_structure/disk_overflow_file.h b/src/include/storage/storage_structure/disk_overflow_file.h index bc5dd45db2..d907139e39 100644 --- a/src/include/storage/storage_structure/disk_overflow_file.h +++ b/src/include/storage/storage_structure/disk_overflow_file.h @@ -40,7 +40,8 @@ class DiskOverflowFile : public StorageStructure { inline void scanSingleStringOverflow( transaction::TransactionType trxType, common::ValueVector& vector, uint64_t vectorPos) { - assert(vector.dataType.typeID == common::STRING && !vector.isNull(vectorPos)); + assert(vector.dataType.getLogicalTypeID() == common::LogicalTypeID::STRING && + !vector.isNull(vectorPos)); auto& kuString = ((common::ku_string_t*)vector.getData())[vectorPos]; lookupString(trxType, kuString, *common::StringVector::getInMemOverflowBuffer(&vector)); } @@ -49,13 +50,13 @@ class DiskOverflowFile : public StorageStructure { common::ValueVector* vector, uint64_t pos); std::string readString(transaction::TransactionType trxType, const common::ku_string_t& str); std::vector> readList(transaction::TransactionType trxType, - const common::ku_list_t& listVal, const common::DataType& dataType); + const common::ku_list_t& listVal, const common::LogicalType& dataType); common::ku_string_t writeString(const char* rawString); void writeStringOverflowAndUpdateOverflowPtr( const common::ku_string_t& strToWriteFrom, common::ku_string_t& strToWriteTo); void writeListOverflowAndUpdateOverflowPtr(const common::ku_list_t& listToWriteFrom, - common::ku_list_t& listToWriteTo, const common::DataType& valueType); + common::ku_list_t& listToWriteTo, const common::LogicalType& valueType); inline void resetNextBytePosToWriteTo(uint64_t nextBytePosToWriteTo_) { nextBytePosToWriteTo = nextBytePosToWriteTo_; @@ -79,9 +80,9 @@ class DiskOverflowFile : public StorageStructure { void setStringOverflowWithoutLock( const char* inMemSrcStr, uint64_t len, common::ku_string_t& diskDstString); void setListRecursiveIfNestedWithoutLock(const common::ku_list_t& inMemSrcList, - common::ku_list_t& diskDstList, const common::DataType& dataType); + common::ku_list_t& diskDstList, const common::LogicalType& dataType); void logNewOverflowFileNextBytePosRecordIfNecessaryWithoutLock(); - void readValuesInList(transaction::TransactionType trxType, const common::DataType& dataType, + void readValuesInList(transaction::TransactionType trxType, const common::LogicalType& dataType, std::vector>& retValues, uint32_t numBytesOfSingleValue, uint64_t numValuesInList, PageByteCursor& cursor, uint8_t* frame); void pinOverflowPageCache(BMFileHandle* bmFileHandleToPin, common::page_idx_t pageIdxToPin, diff --git a/src/include/storage/storage_structure/in_mem_file.h b/src/include/storage/storage_structure/in_mem_file.h index 15fef6a671..688fdce199 100644 --- a/src/include/storage/storage_structure/in_mem_file.h +++ b/src/include/storage/storage_structure/in_mem_file.h @@ -72,9 +72,9 @@ class InMemOverflowFile : public InMemFile { PageByteCursor& overflowCursor, uint8_t* srcOverflow, common::ku_string_t* dstKUString); void copyListOverflowFromFile(InMemOverflowFile* srcInMemOverflowFile, const PageByteCursor& srcOverflowCursor, PageByteCursor& dstOverflowCursor, - common::ku_list_t* dstKUList, common::DataType* listChildDataType); + common::ku_list_t* dstKUList, common::LogicalType* listChildDataType); void copyListOverflowToFile(PageByteCursor& pageByteCursor, common::ku_list_t* srcKUList, - common::DataType* childDataType); + common::LogicalType* childDataType); std::string readString(common::ku_string_t* strInInMemOvfFile); @@ -83,12 +83,12 @@ class InMemOverflowFile : public InMemFile { void copyFixedSizedValuesInList(const common::Value& listVal, PageByteCursor& overflowCursor, uint64_t numBytesOfListElement); - template + template void copyVarSizedValuesInList(common::ku_list_t& resultKUList, const common::Value& listVal, PageByteCursor& overflowCursor, uint64_t numBytesOfListElement); void resetElementsOverflowPtrIfNecessary(PageByteCursor& pageByteCursor, - common::DataType* elementType, uint64_t numElementsToReset, uint8_t* elementsToReset); + common::LogicalType* elementType, uint64_t numElementsToReset, uint8_t* elementsToReset); private: // These two fields (currentPageIdxToAppend, currentOffsetInPageToAppend) are used when diff --git a/src/include/storage/storage_structure/lists/lists.h b/src/include/storage/storage_structure/lists/lists.h index 7526a302ae..40be17cea5 100644 --- a/src/include/storage/storage_structure/lists/lists.h +++ b/src/include/storage/storage_structure/lists/lists.h @@ -45,7 +45,7 @@ class Lists : public BaseColumnOrList { public: Lists(const StorageStructureIDAndFName& storageStructureIDAndFName, - const common::DataType& dataType, const size_t& elementSize, + const common::LogicalType& dataType, const size_t& elementSize, std::shared_ptr headers, BufferManager* bufferManager, WAL* wal, ListsUpdatesStore* listsUpdatesStore) : Lists{storageStructureIDAndFName, dataType, elementSize, std::move(headers), @@ -101,7 +101,7 @@ class Lists : public BaseColumnOrList { protected: virtual inline DiskOverflowFile* getDiskOverflowFileIfExists() { return nullptr; } Lists(const StorageStructureIDAndFName& storageStructureIDAndFName, - const common::DataType& dataType, const size_t& elementSize, + const common::LogicalType& dataType, const size_t& elementSize, std::shared_ptr headers, BufferManager* bufferManager, bool hasNULLBytes, WAL* wal, ListsUpdatesStore* listsUpdatesStore) : BaseColumnOrList{storageStructureIDAndFName, dataType, elementSize, bufferManager, @@ -131,10 +131,11 @@ class Lists : public BaseColumnOrList { class PropertyListsWithOverflow : public Lists { public: PropertyListsWithOverflow(const StorageStructureIDAndFName& storageStructureIDAndFName, - const common::DataType& dataType, std::shared_ptr headers, + const common::LogicalType& dataType, std::shared_ptr headers, BufferManager* bufferManager, WAL* wal, ListsUpdatesStore* listsUpdatesStore) - : Lists{storageStructureIDAndFName, dataType, common::Types::getDataTypeSize(dataType), - std::move(headers), bufferManager, wal, listsUpdatesStore}, + : Lists{storageStructureIDAndFName, dataType, + storage::StorageUtils::getDataTypeSize(dataType), std::move(headers), bufferManager, + wal, listsUpdatesStore}, diskOverflowFile{storageStructureIDAndFName, bufferManager, wal} {} private: @@ -150,8 +151,9 @@ class StringPropertyLists : public PropertyListsWithOverflow { StringPropertyLists(const StorageStructureIDAndFName& storageStructureIDAndFName, const std::shared_ptr& headers, BufferManager* bufferManager, WAL* wal, ListsUpdatesStore* listsUpdatesStore) - : PropertyListsWithOverflow{storageStructureIDAndFName, common::DataType{common::STRING}, - headers, bufferManager, wal, listsUpdatesStore} {}; + : PropertyListsWithOverflow{storageStructureIDAndFName, + common::LogicalType{common::LogicalTypeID::STRING}, headers, bufferManager, wal, + listsUpdatesStore} {}; private: void readFromList(common::ValueVector* valueVector, ListHandle& listHandle) override; @@ -161,7 +163,7 @@ class ListPropertyLists : public PropertyListsWithOverflow { public: ListPropertyLists(const StorageStructureIDAndFName& storageStructureIDAndFName, - const common::DataType& dataType, const std::shared_ptr& headers, + const common::LogicalType& dataType, const std::shared_ptr& headers, BufferManager* bufferManager, WAL* wal, ListsUpdatesStore* listsUpdatesStore) : PropertyListsWithOverflow{storageStructureIDAndFName, dataType, headers, bufferManager, wal, listsUpdatesStore} {}; @@ -178,7 +180,7 @@ class AdjLists : public Lists { AdjLists(const StorageStructureIDAndFName& storageStructureIDAndFName, common::table_id_t nbrTableID, BufferManager* bufferManager, WAL* wal, ListsUpdatesStore* listsUpdatesStore) - : Lists{storageStructureIDAndFName, common::DataType(common::INTERNAL_ID), + : Lists{storageStructureIDAndFName, common::LogicalType(common::LogicalTypeID::INTERNAL_ID), sizeof(common::offset_t), std::make_shared(storageStructureIDAndFName, bufferManager, wal), bufferManager, false /* hasNullBytes */, wal, listsUpdatesStore}, @@ -218,7 +220,7 @@ class RelIDList : public Lists { RelIDList(const StorageStructureIDAndFName& storageStructureIDAndFName, std::shared_ptr headers, BufferManager* bufferManager, WAL* wal, ListsUpdatesStore* listsUpdatesStore) - : Lists{storageStructureIDAndFName, common::DataType{common::INTERNAL_ID}, + : Lists{storageStructureIDAndFName, common::LogicalType{common::LogicalTypeID::INTERNAL_ID}, sizeof(common::offset_t), std::move(headers), bufferManager, wal, listsUpdatesStore} { } @@ -245,30 +247,30 @@ class ListsFactory { public: static std::unique_ptr getLists(const StorageStructureIDAndFName& structureIDAndFName, - const common::DataType& dataType, const std::shared_ptr& adjListsHeaders, + const common::LogicalType& dataType, const std::shared_ptr& adjListsHeaders, BufferManager* bufferManager, WAL* wal, ListsUpdatesStore* listsUpdatesStore) { assert(listsUpdatesStore != nullptr); - switch (dataType.typeID) { - case common::INT64: - case common::INT32: - case common::INT16: - case common::DOUBLE: - case common::FLOAT: - case common::BOOL: - case common::DATE: - case common::TIMESTAMP: - case common::INTERVAL: - case common::FIXED_LIST: + switch (dataType.getLogicalTypeID()) { + case common::LogicalTypeID::INT64: + case common::LogicalTypeID::INT32: + case common::LogicalTypeID::INT16: + case common::LogicalTypeID::DOUBLE: + case common::LogicalTypeID::FLOAT: + case common::LogicalTypeID::BOOL: + case common::LogicalTypeID::DATE: + case common::LogicalTypeID::TIMESTAMP: + case common::LogicalTypeID::INTERVAL: + case common::LogicalTypeID::FIXED_LIST: return std::make_unique(structureIDAndFName, dataType, - common::Types::getDataTypeSize(dataType), adjListsHeaders, bufferManager, wal, - listsUpdatesStore); - case common::STRING: + storage::StorageUtils::getDataTypeSize(dataType), adjListsHeaders, bufferManager, + wal, listsUpdatesStore); + case common::LogicalTypeID::STRING: return std::make_unique( structureIDAndFName, adjListsHeaders, bufferManager, wal, listsUpdatesStore); - case common::VAR_LIST: + case common::LogicalTypeID::VAR_LIST: return std::make_unique(structureIDAndFName, dataType, adjListsHeaders, bufferManager, wal, listsUpdatesStore); - case common::INTERNAL_ID: + case common::LogicalTypeID::INTERNAL_ID: return std::make_unique( structureIDAndFName, adjListsHeaders, bufferManager, wal, listsUpdatesStore); default: diff --git a/src/include/storage/storage_structure/lists/lists_update_store.h b/src/include/storage/storage_structure/lists/lists_update_store.h index 2372fda46a..b45c017571 100644 --- a/src/include/storage/storage_structure/lists/lists_update_store.h +++ b/src/include/storage/storage_structure/lists/lists_update_store.h @@ -96,7 +96,7 @@ class ListsUpdatesStore { void readInsertedRelsToList(ListFileID& listFileID, std::vector tupleIdxes, InMemList& inMemList, uint64_t numElementsInPersistentStore, DiskOverflowFile* diskOverflowFile, - common::DataType dataType); + common::LogicalType dataType); // If this is a one-to-one relTable, all properties are stored in columns. // In this case, the listsUpdatesStore should not store the insert rels in FT. @@ -126,7 +126,7 @@ class ListsUpdatesStore { common::ValueVector* propertyVector, list_offset_t startListOffset); void readPropertyUpdateToInMemList(ListFileID& listFileID, processor::ft_tuple_idx_t ftTupleIdx, - InMemList& inMemList, uint64_t posToWriteToInMemList, const common::DataType& dataType, + InMemList& inMemList, uint64_t posToWriteToInMemList, const common::LogicalType& dataType, DiskOverflowFile* overflowFileOfInMemList); void initNewlyAddedNodes(common::nodeID_t& nodeID); diff --git a/src/include/storage/storage_structure/storage_structure.h b/src/include/storage/storage_structure/storage_structure.h index 487a684d2f..4afac65d76 100644 --- a/src/include/storage/storage_structure/storage_structure.h +++ b/src/include/storage/storage_structure/storage_structure.h @@ -71,7 +71,7 @@ class StorageStructure { class BaseColumnOrList : public StorageStructure { public: - explicit BaseColumnOrList(common::DataType dataType) : dataType{std::move(dataType)} {} + explicit BaseColumnOrList(common::LogicalType dataType) : dataType{std::move(dataType)} {} // Maps the position of element in page to its byte offset in page. // TODO(Everyone): we should slowly get rid of this function. @@ -89,7 +89,7 @@ class BaseColumnOrList : public StorageStructure { } BaseColumnOrList(const StorageStructureIDAndFName& storageStructureIDAndFName, - common::DataType dataType, const size_t& elementSize, BufferManager* bufferManager, + common::LogicalType dataType, const size_t& elementSize, BufferManager* bufferManager, bool hasInlineNullBytes, WAL* wal); void readBySequentialCopy(transaction::Transaction* transaction, common::ValueVector* vector, @@ -117,7 +117,7 @@ class BaseColumnOrList : public StorageStructure { uint16_t pagePosOfFirstElement, uint64_t numValuesToRead); public: - common::DataType dataType; + common::LogicalType dataType; size_t elementSize; uint32_t numElementsPerPage; }; diff --git a/src/include/storage/storage_utils.h b/src/include/storage/storage_utils.h index 5b92eadd02..e8602be321 100644 --- a/src/include/storage/storage_utils.h +++ b/src/include/storage/storage_utils.h @@ -269,6 +269,8 @@ class StorageUtils { uint64_t numNodesInTable, const std::string& directory, common::RelDataDirection relDirection); + static uint32_t getDataTypeSize(const common::LogicalType& type); + private: static std::string appendSuffixOrInsertBeforeWALSuffix( std::string fileName, std::string suffix); diff --git a/src/main/connection.cpp b/src/main/connection.cpp index b771cbff1f..a2ae5c300a 100644 --- a/src/main/connection.cpp +++ b/src/main/connection.cpp @@ -236,7 +236,8 @@ std::string Connection::getNodePropertyNames(const std::string& tableName) { auto primaryKeyPropertyID = catalog->getReadOnlyVersion()->getNodeTableSchema(tableID)->getPrimaryKey().propertyID; for (auto& property : catalog->getReadOnlyVersion()->getAllNodeProperties(tableID)) { - result += "\t" + property.name + " " + Types::dataTypeToString(property.dataType); + result += + "\t" + property.name + " " + LogicalTypeUtils::dataTypeToString(property.dataType); result += property.propertyID == primaryKeyPropertyID ? "(PRIMARY KEY)\n" : "\n"; } return result; @@ -262,7 +263,8 @@ std::string Connection::getRelPropertyNames(const std::string& relTableName) { if (catalog::TableSchema::isReservedPropertyName(property.name)) { continue; } - result += "\t" + property.name + " " + Types::dataTypeToString(property.dataType) + "\n"; + result += "\t" + property.name + " " + + LogicalTypeUtils::dataTypeToString(property.dataType) + "\n"; } return result; } @@ -301,8 +303,9 @@ void Connection::bindParametersNoLock(PreparedStatement* preparedStatement, auto expectParam = parameterMap.at(name); if (expectParam->dataType != value->getDataType()) { throw Exception("Parameter " + name + " has data type " + - Types::dataTypeToString(value->getDataType()) + " but expect " + - Types::dataTypeToString(expectParam->dataType) + "."); + LogicalTypeUtils::dataTypeToString(value->getDataType()) + + " but expect " + + LogicalTypeUtils::dataTypeToString(expectParam->dataType) + "."); } parameterMap.at(name)->copyValueFrom(*value); } diff --git a/src/main/query_result.cpp b/src/main/query_result.cpp index c34201c2b9..1a0ec8c509 100644 --- a/src/main/query_result.cpp +++ b/src/main/query_result.cpp @@ -15,18 +15,18 @@ namespace kuzu { namespace main { std::unique_ptr DataTypeInfo::getInfoForDataType( - const DataType& type, const std::string& name) { - auto columnTypeInfo = std::make_unique(type.typeID, name); - switch (type.typeID) { - case common::INTERNAL_ID: { + const LogicalType& type, const std::string& name) { + auto columnTypeInfo = std::make_unique(type.getLogicalTypeID(), name); + switch (type.getLogicalTypeID()) { + case common::LogicalTypeID::INTERNAL_ID: { columnTypeInfo->childrenTypesInfo.push_back( - std::make_unique(common::INT64, "offset")); + std::make_unique(common::LogicalTypeID::INT64, "offset")); columnTypeInfo->childrenTypesInfo.push_back( - std::make_unique(common::INT64, "tableID")); + std::make_unique(common::LogicalTypeID::INT64, "tableID")); } break; - case common::VAR_LIST: { + case common::LogicalTypeID::VAR_LIST: { auto parentTypeInfo = columnTypeInfo.get(); - auto childType = type.getChildType(); + auto childType = VarListType::getChildType(&type); parentTypeInfo->childrenTypesInfo.push_back(getInfoForDataType(*childType, "")); } break; default: { @@ -61,7 +61,7 @@ std::vector QueryResult::getColumnNames() { return columnNames; } -std::vector QueryResult::getColumnDataTypes() { +std::vector QueryResult::getColumnDataTypes() { return columnDataTypes; } @@ -81,22 +81,22 @@ std::vector> QueryResult::getColumnTypesInfo() { std::vector> result; for (auto i = 0u; i < columnDataTypes.size(); i++) { auto columnTypeInfo = DataTypeInfo::getInfoForDataType(columnDataTypes[i], columnNames[i]); - if (columnTypeInfo->typeID == common::NODE) { + if (columnTypeInfo->typeID == common::LogicalTypeID::NODE) { auto value = tuple->getValue(i)->nodeVal.get(); - columnTypeInfo->childrenTypesInfo.push_back( - DataTypeInfo::getInfoForDataType(DataType(common::INTERNAL_ID), "_id")); - columnTypeInfo->childrenTypesInfo.push_back( - DataTypeInfo::getInfoForDataType(DataType(common::STRING), "_label")); + columnTypeInfo->childrenTypesInfo.push_back(DataTypeInfo::getInfoForDataType( + LogicalType(common::LogicalTypeID::INTERNAL_ID), "_id")); + columnTypeInfo->childrenTypesInfo.push_back(DataTypeInfo::getInfoForDataType( + LogicalType(common::LogicalTypeID::STRING), "_label")); for (auto& [name, val] : value->getProperties()) { columnTypeInfo->childrenTypesInfo.push_back( DataTypeInfo::getInfoForDataType(val->dataType, name)); } - } else if (columnTypeInfo->typeID == common::REL) { + } else if (columnTypeInfo->typeID == common::LogicalTypeID::REL) { auto value = tuple->getValue(i)->relVal.get(); - columnTypeInfo->childrenTypesInfo.push_back( - DataTypeInfo::getInfoForDataType(DataType(common::INTERNAL_ID), "_src")); - columnTypeInfo->childrenTypesInfo.push_back( - DataTypeInfo::getInfoForDataType(DataType(common::INTERNAL_ID), "_dst")); + columnTypeInfo->childrenTypesInfo.push_back(DataTypeInfo::getInfoForDataType( + LogicalType(common::LogicalTypeID::INTERNAL_ID), "_src")); + columnTypeInfo->childrenTypesInfo.push_back(DataTypeInfo::getInfoForDataType( + LogicalType(common::LogicalTypeID::INTERNAL_ID), "_dst")); for (auto& [name, val] : value->getProperties()) { columnTypeInfo->childrenTypesInfo.push_back( DataTypeInfo::getInfoForDataType(val->dataType, name)); @@ -122,16 +122,18 @@ void QueryResult::initResultTableAndIterator( columnNames.push_back(columnName); auto expressionsToCollect = expressionToCollectPerColumn[i]; std::unique_ptr value; - if (columnType.typeID == common::NODE) { + if (columnType.getLogicalTypeID() == common::LogicalTypeID::NODE) { // first expression is node ID. - assert(expressionsToCollect[0]->dataType.typeID == common::INTERNAL_ID); - auto nodeIDVal = - std::make_unique(Value::createDefaultValue(DataType(INTERNAL_ID))); + assert(expressionsToCollect[0]->dataType.getLogicalTypeID() == + common::LogicalTypeID::INTERNAL_ID); + auto nodeIDVal = std::make_unique( + Value::createDefaultValue(LogicalType(LogicalTypeID::INTERNAL_ID))); valuesToCollect.push_back(nodeIDVal.get()); // second expression is node label function. - assert(expressionsToCollect[1]->dataType.typeID == common::STRING); - auto labelNameVal = - std::make_unique(Value::createDefaultValue(DataType(STRING))); + assert(expressionsToCollect[1]->dataType.getLogicalTypeID() == + common::LogicalTypeID::STRING); + auto labelNameVal = std::make_unique( + Value::createDefaultValue(LogicalType(LogicalTypeID::STRING))); valuesToCollect.push_back(labelNameVal.get()); auto nodeVal = std::make_unique(std::move(nodeIDVal), std::move(labelNameVal)); for (auto j = 2u; j < expressionsToCollect.size(); ++j) { @@ -143,20 +145,22 @@ void QueryResult::initResultTableAndIterator( nodeVal->addProperty(property->getPropertyName(), std::move(propertyValue)); } value = std::make_unique(std::move(nodeVal)); - } else if (columnType.typeID == common::REL) { + } else if (columnType.getLogicalTypeID() == common::LogicalTypeID::REL) { // first expression is src node ID. - assert(expressionsToCollect[0]->dataType.typeID == common::INTERNAL_ID); - auto srcNodeIDVal = - std::make_unique(Value::createDefaultValue(DataType(INTERNAL_ID))); + assert(expressionsToCollect[0]->dataType.getLogicalTypeID() == + common::LogicalTypeID::INTERNAL_ID); + auto srcNodeIDVal = std::make_unique( + Value::createDefaultValue(LogicalType(LogicalTypeID::INTERNAL_ID))); valuesToCollect.push_back(srcNodeIDVal.get()); // second expression is dst node ID. - assert(expressionsToCollect[1]->dataType.typeID == common::INTERNAL_ID); - auto dstNodeIDVal = - std::make_unique(Value::createDefaultValue(DataType(INTERNAL_ID))); + assert(expressionsToCollect[1]->dataType.getLogicalTypeID() == + common::LogicalTypeID::INTERNAL_ID); + auto dstNodeIDVal = std::make_unique( + Value::createDefaultValue(LogicalType(LogicalTypeID::INTERNAL_ID))); valuesToCollect.push_back(dstNodeIDVal.get()); // third expression is rel label function. - auto labelNameVal = - std::make_unique(Value::createDefaultValue(DataType(STRING))); + auto labelNameVal = std::make_unique( + Value::createDefaultValue(LogicalType(LogicalTypeID::STRING))); valuesToCollect.push_back(labelNameVal.get()); auto relVal = std::make_unique( std::move(srcNodeIDVal), std::move(dstNodeIDVal), std::move(labelNameVal)); @@ -229,7 +233,8 @@ void QueryResult::writeToCSV( for (auto idx = 0ul; idx < nextTuple->len(); idx++) { std::string resultVal = nextTuple->getValue(idx)->toString(); bool isStringList = false; - if (Types::dataTypeToString(nextTuple->getValue(idx)->getDataType()) == "STRING[]") { + if (LogicalTypeUtils::dataTypeToString(nextTuple->getValue(idx)->getDataType()) == + "STRING[]") { isStringList = true; } bool surroundQuotes = false; diff --git a/src/planner/projection_planner.cpp b/src/planner/projection_planner.cpp index 3621b3da8a..9d6c4993a3 100644 --- a/src/planner/projection_planner.cpp +++ b/src/planner/projection_planner.cpp @@ -188,7 +188,8 @@ expression_vector ProjectionPlanner::rewriteExpressionsToProject( const expression_vector& expressionsToProject, const Schema& schema) { expression_vector result; for (auto& expression : expressionsToProject) { - if (expression->dataType.typeID == NODE || expression->dataType.typeID == REL) { + if (expression->dataType.getLogicalTypeID() == LogicalTypeID::NODE || + expression->dataType.getLogicalTypeID() == LogicalTypeID::REL) { for (auto& property : rewriteVariableAsAllPropertiesInScope(*expression, schema)) { result.push_back(property); } diff --git a/src/planner/query_planner.cpp b/src/planner/query_planner.cpp index 137876ccf4..580b0cbca4 100644 --- a/src/planner/query_planner.cpp +++ b/src/planner/query_planner.cpp @@ -164,7 +164,7 @@ static expression_vector getCorrelatedExpressions( static expression_vector getJoinNodeIDs(expression_vector& expressions) { expression_vector joinNodeIDs; for (auto& expression : expressions) { - if (expression->dataType.typeID == INTERNAL_ID) { + if (expression->dataType.getLogicalTypeID() == LogicalTypeID::INTERNAL_ID) { joinNodeIDs.push_back(expression); } } @@ -178,7 +178,8 @@ void QueryPlanner::planOptionalMatch(const QueryGraphCollection& queryGraphColle if (correlatedExpressions.empty()) { throw NotImplementedException("Optional match is disconnected with previous MATCH clause."); } - if (ExpressionUtil::allExpressionsHaveDataType(correlatedExpressions, INTERNAL_ID)) { + if (ExpressionUtil::allExpressionsHaveDataType( + correlatedExpressions, LogicalTypeID::INTERNAL_ID)) { auto joinNodeIDs = getJoinNodeIDs(correlatedExpressions); // When correlated variables are all NODE IDs, the subquery can be un-nested as left join. // Join nodes are scanned twice in both outer and inner. However, we make sure inner table @@ -233,7 +234,8 @@ void QueryPlanner::planExistsSubquery( if (correlatedExpressions.empty()) { throw NotImplementedException("Subquery is disconnected with outer query."); } - if (ExpressionUtil::allExpressionsHaveDataType(correlatedExpressions, INTERNAL_ID)) { + if (ExpressionUtil::allExpressionsHaveDataType( + correlatedExpressions, LogicalTypeID::INTERNAL_ID)) { auto joinNodeIDs = getJoinNodeIDs(correlatedExpressions); // Unnest as mark join. See planOptionalMatch for unnesting logic. auto prevContext = joinOrderEnumerator.enterSubquery(joinNodeIDs); diff --git a/src/processor/mapper/map_expressions_scan.cpp b/src/processor/mapper/map_expressions_scan.cpp index 5e81654171..7b5777d85d 100644 --- a/src/processor/mapper/map_expressions_scan.cpp +++ b/src/processor/mapper/map_expressions_scan.cpp @@ -24,7 +24,7 @@ std::unique_ptr PlanMapper::mapLogicalExpressionsScanToPhysica for (auto& expression : expressions) { tableSchema->appendColumn( std::make_unique(false, 0 /* all expressions are in the same datachunk */, - Types::getDataTypeSize(expression->dataType))); + FactorizedTable::getDataTypeSize(expression->dataType))); auto expressionEvaluator = expressionMapper.mapExpression(expression, *inSchema); // expression can be evaluated statically and does not require an actual resultset to init expressionEvaluator->init(ResultSet(0) /* dummy resultset */, memoryManager); diff --git a/src/processor/mapper/map_hash_join.cpp b/src/processor/mapper/map_hash_join.cpp index 61647d9d89..5ed8544039 100644 --- a/src/processor/mapper/map_hash_join.cpp +++ b/src/processor/mapper/map_hash_join.cpp @@ -11,14 +11,15 @@ namespace processor { BuildDataInfo PlanMapper::generateBuildDataInfo(const Schema& buildSideSchema, const expression_vector& keys, const expression_vector& payloads) { - std::vector> buildKeysPosAndType, buildPayloadsPosAndTypes; + std::vector> buildKeysPosAndType, + buildPayloadsPosAndTypes; std::vector isBuildPayloadsFlat, isBuildPayloadsInKeyChunk; std::vector isBuildDataChunkContainKeys(buildSideSchema.getNumGroups(), false); std::unordered_set joinKeyNames; for (auto& key : keys) { auto buildSideKeyPos = DataPos(buildSideSchema.getExpressionPos(*key)); isBuildDataChunkContainKeys[buildSideKeyPos.dataChunkPos] = true; - buildKeysPosAndType.emplace_back(buildSideKeyPos, common::INTERNAL_ID); + buildKeysPosAndType.emplace_back(buildSideKeyPos, common::LogicalTypeID::INTERNAL_ID); joinKeyNames.insert(key->getUniqueName()); } for (auto& payload : payloads) { diff --git a/src/processor/mapper/map_order_by.cpp b/src/processor/mapper/map_order_by.cpp index 7ca2bebc28..fe58fa4fae 100644 --- a/src/processor/mapper/map_order_by.cpp +++ b/src/processor/mapper/map_order_by.cpp @@ -16,11 +16,11 @@ std::unique_ptr PlanMapper::mapLogicalOrderByToPhysical( auto inSchema = logicalOrderBy.getChild(0)->getSchema(); auto prevOperator = mapLogicalOperatorToPhysical(logicalOrderBy.getChild(0)); auto paramsString = logicalOrderBy.getExpressionsForPrinting(); - std::vector> keysPosAndType; + std::vector> keysPosAndType; for (auto& expression : logicalOrderBy.getExpressionsToOrderBy()) { keysPosAndType.emplace_back(inSchema->getExpressionPos(*expression), expression->dataType); } - std::vector> payloadsPosAndType; + std::vector> payloadsPosAndType; std::vector isPayloadFlat; std::vector outVectorPos; for (auto& expression : logicalOrderBy.getExpressionsToMaterialize()) { diff --git a/src/processor/mapper/map_unwind.cpp b/src/processor/mapper/map_unwind.cpp index 0fb2d08b8e..d785b435da 100644 --- a/src/processor/mapper/map_unwind.cpp +++ b/src/processor/mapper/map_unwind.cpp @@ -16,7 +16,8 @@ std::unique_ptr PlanMapper::mapLogicalUnwindToPhysical( auto prevOperator = mapLogicalOperatorToPhysical(logicalOperator->getChild(0)); auto dataPos = DataPos(outSchema->getExpressionPos(*unwind->getAliasExpression())); auto expressionEvaluator = expressionMapper.mapExpression(unwind->getExpression(), *inSchema); - return std::make_unique(*unwind->getExpression()->getDataType().getChildType(), dataPos, + return std::make_unique( + *common::VarListType::getChildType(&unwind->getExpression()->dataType), dataPos, std::move(expressionEvaluator), std::move(prevOperator), getOperatorID(), unwind->getExpressionsForPrinting()); } diff --git a/src/processor/mapper/plan_mapper.cpp b/src/processor/mapper/plan_mapper.cpp index 6c06988ef0..46def7148c 100644 --- a/src/processor/mapper/plan_mapper.cpp +++ b/src/processor/mapper/plan_mapper.cpp @@ -147,7 +147,7 @@ std::unique_ptr PlanMapper::mapLogicalOperatorToPhysical( std::unique_ptr PlanMapper::appendResultCollector( const binder::expression_vector& expressionsToCollect, const Schema& schema, std::unique_ptr prevOperator) { - std::vector> payloadsPosAndType; + std::vector> payloadsPosAndType; std::vector isPayloadFlat; for (auto& expression : expressionsToCollect) { auto expressionName = expression->getUniqueName(); diff --git a/src/processor/operator/aggregate/aggregate_hash_table.cpp b/src/processor/operator/aggregate/aggregate_hash_table.cpp index 5d62949630..17937ef566 100644 --- a/src/processor/operator/aggregate/aggregate_hash_table.cpp +++ b/src/processor/operator/aggregate/aggregate_hash_table.cpp @@ -14,7 +14,7 @@ namespace kuzu { namespace processor { AggregateHashTable::AggregateHashTable(MemoryManager& memoryManager, - std::vector keyDataTypes, std::vector dependentKeyDataTypes, + std::vector keyDataTypes, std::vector dependentKeyDataTypes, const std::vector>& aggregateFunctions, uint64_t numEntriesToAllocate) : BaseHashTable{memoryManager}, keyDataTypes{std::move(keyDataTypes)}, @@ -53,8 +53,10 @@ bool AggregateHashTable::isAggregateValueDistinctForGroupByKeys( } else { VectorHashOperations::computeHash(groupByFlatKeyVectors[0], hashVector.get()); computeAndCombineVecHash(groupByFlatKeyVectors, 1 /* startVecIdx */); - auto tmpHashResultVector = std::make_unique(INT64, &memoryManager); - auto tmpHashCombineResultVector = std::make_unique(INT64, &memoryManager); + auto tmpHashResultVector = + std::make_unique(LogicalTypeID::INT64, &memoryManager); + auto tmpHashCombineResultVector = + std::make_unique(LogicalTypeID::INT64, &memoryManager); VectorHashOperations::computeHash(aggregateVector, tmpHashResultVector.get()); VectorHashOperations::combineHash( hashVector.get(), tmpHashResultVector.get(), tmpHashCombineResultVector.get()); @@ -139,18 +141,18 @@ void AggregateHashTable::initializeFT( compareFuncs.resize(aggStateColIdxInFT); auto colIdx = 0u; for (auto& dataType : keyDataTypes) { - auto size = Types::getDataTypeSize(dataType); + auto size = FactorizedTable::getDataTypeSize(dataType); tableSchema->appendColumn(std::make_unique(isUnflat, dataChunkPos, size)); - hasStrCol = hasStrCol || dataType.typeID == STRING; - compareFuncs[colIdx] = getCompareEntryWithKeysFunc(dataType.typeID); + hasStrCol = hasStrCol || dataType.getLogicalTypeID() == LogicalTypeID::STRING; + compareFuncs[colIdx] = getCompareEntryWithKeysFunc(dataType.getLogicalTypeID()); numBytesForKeys += size; colIdx++; } for (auto& dataType : dependentKeyDataTypes) { - auto size = Types::getDataTypeSize(dataType); + auto size = FactorizedTable::getDataTypeSize(dataType); tableSchema->appendColumn(std::make_unique(isUnflat, dataChunkPos, size)); - hasStrCol = hasStrCol || dataType.typeID == STRING; - compareFuncs[colIdx] = getCompareEntryWithKeysFunc(dataType.typeID); + hasStrCol = hasStrCol || dataType.getLogicalTypeID() == LogicalTypeID::STRING; + compareFuncs[colIdx] = getCompareEntryWithKeysFunc(dataType.getLogicalTypeID()); numBytesForDependentKeys += size; colIdx++; } @@ -192,7 +194,7 @@ void AggregateHashTable::initializeHashTable(uint64_t numEntriesToAllocate) { void AggregateHashTable::initializeTmpVectors() { hashState = std::make_shared(); hashState->currIdx = 0; - hashVector = std::make_unique(INT64, &memoryManager); + hashVector = std::make_unique(LogicalTypeID::INT64, &memoryManager); hashVector->state = hashState; hashSlotsToUpdateAggState = std::make_unique(DEFAULT_VECTOR_CAPACITY); tmpValueIdxes = std::make_unique(DEFAULT_VECTOR_CAPACITY); @@ -395,8 +397,10 @@ void AggregateHashTable::computeAndCombineVecHash( const std::vector& groupByHashKeyVectors, uint32_t startVecIdx) { for (; startVecIdx < groupByHashKeyVectors.size(); startVecIdx++) { auto keyVector = groupByHashKeyVectors[startVecIdx]; - auto tmpHashResultVector = std::make_unique(INT64, &memoryManager); - auto tmpHashCombineResultVector = std::make_unique(INT64, &memoryManager); + auto tmpHashResultVector = + std::make_unique(LogicalTypeID::INT64, &memoryManager); + auto tmpHashCombineResultVector = + std::make_unique(LogicalTypeID::INT64, &memoryManager); VectorHashOperations::computeHash(keyVector, tmpHashResultVector.get()); VectorHashOperations::combineHash( hashVector.get(), tmpHashResultVector.get(), tmpHashCombineResultVector.get()); @@ -668,43 +672,44 @@ bool AggregateHashTable::compareEntryWithKeys(const uint8_t* keyValue, const uin return result != 0; } -compare_function_t AggregateHashTable::getCompareEntryWithKeysFunc(DataTypeID typeId) { +compare_function_t AggregateHashTable::getCompareEntryWithKeysFunc(LogicalTypeID typeId) { switch (typeId) { - case INTERNAL_ID: { + case LogicalTypeID::INTERNAL_ID: { return compareEntryWithKeys; } - case BOOL: { + case LogicalTypeID::BOOL: { return compareEntryWithKeys; } - case INT64: { + case LogicalTypeID::INT64: { return compareEntryWithKeys; } - case INT32: { + case LogicalTypeID::INT32: { return compareEntryWithKeys; } - case INT16: { + case LogicalTypeID::INT16: { return compareEntryWithKeys; } - case DOUBLE: { + case LogicalTypeID::DOUBLE: { return compareEntryWithKeys; } - case FLOAT: { + case LogicalTypeID::FLOAT: { return compareEntryWithKeys; } - case STRING: { + case LogicalTypeID::STRING: { return compareEntryWithKeys; } - case DATE: { + case LogicalTypeID::DATE: { return compareEntryWithKeys; } - case TIMESTAMP: { + case LogicalTypeID::TIMESTAMP: { return compareEntryWithKeys; } - case INTERVAL: { + case LogicalTypeID::INTERVAL: { return compareEntryWithKeys; } default: { - throw RuntimeException("Cannot compare data type " + Types::dataTypeToString(typeId)); + throw RuntimeException( + "Cannot compare data type " + LogicalTypeUtils::dataTypeToString(typeId)); } } } @@ -881,12 +886,12 @@ void AggregateHashTable::updateBothUnflatDifferentDCAggVectorState( } std::vector> AggregateHashTableUtils::createDistinctHashTables( - MemoryManager& memoryManager, const std::vector& groupByKeyDataTypes, + MemoryManager& memoryManager, const std::vector& groupByKeyDataTypes, const std::vector>& aggregateFunctions) { std::vector> distinctHTs; for (auto& aggregateFunction : aggregateFunctions) { if (aggregateFunction->isFunctionDistinct()) { - std::vector distinctKeysDataTypes(groupByKeyDataTypes.size() + 1); + std::vector distinctKeysDataTypes(groupByKeyDataTypes.size() + 1); for (auto i = 0u; i < groupByKeyDataTypes.size(); i++) { distinctKeysDataTypes[i] = groupByKeyDataTypes[i]; } diff --git a/src/processor/operator/aggregate/hash_aggregate.cpp b/src/processor/operator/aggregate/hash_aggregate.cpp index e6b89ab71d..19d1a734bb 100644 --- a/src/processor/operator/aggregate/hash_aggregate.cpp +++ b/src/processor/operator/aggregate/hash_aggregate.cpp @@ -49,7 +49,7 @@ std::pair HashAggregateSharedState::getNextRangeToRead() { void HashAggregate::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { BaseAggregate::initLocalStateInternal(resultSet, context); - std::vector keyDataTypes; + std::vector keyDataTypes; for (auto& pos : flatKeysPos) { auto vector = resultSet->getValueVector(pos).get(); flatKeyVectors.push_back(vector); @@ -60,7 +60,7 @@ void HashAggregate::initLocalStateInternal(ResultSet* resultSet, ExecutionContex unFlatKeyVectors.push_back(vector); keyDataTypes.push_back(vector->dataType); } - std::vector payloadDataTypes; + std::vector payloadDataTypes; for (auto& pos : dependentKeysPos) { auto vector = resultSet->getValueVector(pos).get(); dependentKeyVectors.push_back(vector); diff --git a/src/processor/operator/aggregate/simple_aggregate.cpp b/src/processor/operator/aggregate/simple_aggregate.cpp index ff358261a9..9fa46c3600 100644 --- a/src/processor/operator/aggregate/simple_aggregate.cpp +++ b/src/processor/operator/aggregate/simple_aggregate.cpp @@ -48,7 +48,7 @@ void SimpleAggregate::initLocalStateInternal(ResultSet* resultSet, ExecutionCont localAggregateStates.push_back(aggregateFunction->createInitialNullAggregateState()); } distinctHashTables = AggregateHashTableUtils::createDistinctHashTables( - *context->memoryManager, std::vector{}, this->aggregateFunctions); + *context->memoryManager, std::vector{}, this->aggregateFunctions); } void SimpleAggregate::executeInternal(ExecutionContext* context) { diff --git a/src/processor/operator/hash_join/hash_join_build.cpp b/src/processor/operator/hash_join/hash_join_build.cpp index 12e0281acb..97650aba24 100644 --- a/src/processor/operator/hash_join/hash_join_build.cpp +++ b/src/processor/operator/hash_join/hash_join_build.cpp @@ -35,25 +35,25 @@ std::unique_ptr HashJoinBuild::populateTableSchema() { std::unique_ptr tableSchema = std::make_unique(); for (auto& [pos, dataType] : buildDataInfo.keysPosAndType) { tableSchema->appendColumn(std::make_unique( - false /* is flat */, pos.dataChunkPos, Types::getDataTypeSize(dataType))); + false /* is flat */, pos.dataChunkPos, FactorizedTable::getDataTypeSize(dataType))); } for (auto i = 0u; i < buildDataInfo.payloadsPosAndType.size(); ++i) { auto [pos, dataType] = buildDataInfo.payloadsPosAndType[i]; if (buildDataInfo.isPayloadsInKeyChunk[i]) { tableSchema->appendColumn(std::make_unique( - false /* is flat */, pos.dataChunkPos, Types::getDataTypeSize(dataType))); + false /* is flat */, pos.dataChunkPos, FactorizedTable::getDataTypeSize(dataType))); } else { auto isVectorFlat = buildDataInfo.isPayloadsFlat[i]; tableSchema->appendColumn( std::make_unique(!isVectorFlat, pos.dataChunkPos, - isVectorFlat ? Types::getDataTypeSize(dataType) : + isVectorFlat ? FactorizedTable::getDataTypeSize(dataType) : (uint32_t)sizeof(overflow_value_t))); } } // The prev pointer column. tableSchema->appendColumn(std::make_unique(false /* is flat */, UINT32_MAX /* For now, we just put UINT32_MAX for prev pointer */, - Types::getDataTypeSize(INT64))); + FactorizedTable::getDataTypeSize(LogicalType{LogicalTypeID::INT64}))); return tableSchema; } diff --git a/src/processor/operator/hash_join/hash_join_probe.cpp b/src/processor/operator/hash_join/hash_join_probe.cpp index 2e9ed7f6bc..c737641753 100644 --- a/src/processor/operator/hash_join/hash_join_probe.cpp +++ b/src/processor/operator/hash_join/hash_join_probe.cpp @@ -26,10 +26,11 @@ void HashJoinProbe::initLocalStateInternal(ResultSet* resultSet, ExecutionContex columnIdxsToReadFrom.resize(probeDataInfo.getNumPayloads()); iota( columnIdxsToReadFrom.begin(), columnIdxsToReadFrom.end(), probeDataInfo.keysDataPos.size()); - hashVector = std::make_unique(common::INT64, context->memoryManager); + hashVector = + std::make_unique(common::LogicalTypeID::INT64, context->memoryManager); if (keyVectors.size() > 1) { - tmpHashVector = - std::make_unique(common::INT64, context->memoryManager); + tmpHashVector = std::make_unique( + common::LogicalTypeID::INT64, context->memoryManager); } } diff --git a/src/processor/operator/order_by/order_by.cpp b/src/processor/operator/order_by/order_by.cpp index fe8cbef240..351f1ecad5 100644 --- a/src/processor/operator/order_by/order_by.cpp +++ b/src/processor/operator/order_by/order_by.cpp @@ -38,7 +38,8 @@ std::unique_ptr OrderBy::populateTableSchema() { auto [dataPos, dataType] = orderByDataInfo.payloadsPosAndType[i]; bool isUnflat = !orderByDataInfo.isPayloadFlat[i] && !orderByDataInfo.mayContainUnflatKey; tableSchema->appendColumn(std::make_unique(isUnflat, dataPos.dataChunkPos, - isUnflat ? (uint32_t)sizeof(overflow_value_t) : Types::getDataTypeSize(dataType))); + isUnflat ? (uint32_t)sizeof(overflow_value_t) : + FactorizedTable::getDataTypeSize(dataType))); } return tableSchema; } @@ -49,7 +50,7 @@ void OrderBy::initGlobalStateInternal(kuzu::processor::ExecutionContext* context auto tableSchema = populateTableSchema(); for (auto i = 0u; i < orderByDataInfo.keysPosAndType.size(); ++i) { auto [dataPos, dataType] = orderByDataInfo.keysPosAndType[i]; - if (STRING == dataType.typeID) { + if (LogicalTypeID::STRING == dataType.getLogicalTypeID()) { // If this is a string column, we need to find the factorizedTable offset for this // column. auto factorizedTableColIdx = 0ul; diff --git a/src/processor/operator/order_by/order_by_key_encoder.cpp b/src/processor/operator/order_by/order_by_key_encoder.cpp index b75453aa28..6863dc5a70 100644 --- a/src/processor/operator/order_by/order_by_key_encoder.cpp +++ b/src/processor/operator/order_by/order_by_key_encoder.cpp @@ -32,7 +32,7 @@ OrderByKeyEncoder::OrderByKeyEncoder(std::vector& orderByVectors, } encodeFunctions.resize(orderByVectors.size()); for (auto i = 0u; i < orderByVectors.size(); i++) { - encodeFunctions[i] = getEncodingFunction(orderByVectors[i]->dataType.typeID); + encodeFunctions[i] = getEncodingFunction(orderByVectors[i]->dataType.getPhysicalType()); } } @@ -67,19 +67,19 @@ uint32_t OrderByKeyEncoder::getNumBytesPerTuple(const std::vector& return result; } -uint32_t OrderByKeyEncoder::getEncodingSize(const DataType& dataType) { +uint32_t OrderByKeyEncoder::getEncodingSize(const LogicalType& dataType) { // Add one more byte for null flag. - switch (dataType.typeID) { - case STRING: + switch (dataType.getPhysicalType()) { + case PhysicalTypeID::STRING: // 1 byte for null flag + 1 byte to indicate long/short string + 12 bytes for string prefix return 2 + ku_string_t::SHORT_STR_LENGTH; default: - return 1 + Types::getDataTypeSize(dataType); + return 1 + storage::StorageUtils::getDataTypeSize(dataType); } } void OrderByKeyEncoder::flipBytesIfNecessary( - uint32_t keyColIdx, uint8_t* tuplePtr, uint32_t numEntriesToEncode, DataType& type) { + uint32_t keyColIdx, uint8_t* tuplePtr, uint32_t numEntriesToEncode, LogicalType& type) { if (!isAscOrder[keyColIdx]) { auto encodingSize = getEncodingSize(type); // If the current column is in desc order, flip all bytes. @@ -197,40 +197,35 @@ void OrderByKeyEncoder::allocateMemoryIfFull() { } } -encode_function_t OrderByKeyEncoder::getEncodingFunction(DataTypeID typeId) { - switch (typeId) { - case BOOL: { +encode_function_t OrderByKeyEncoder::getEncodingFunction(PhysicalTypeID physicalType) { + switch (physicalType) { + case PhysicalTypeID::BOOL: { return encodeTemplate; } - case INT64: { + case PhysicalTypeID::INT64: { return encodeTemplate; } - case INT32: { + case PhysicalTypeID::INT32: { return encodeTemplate; } - case INT16: { + case PhysicalTypeID::INT16: { return encodeTemplate; } - case DOUBLE: { + case PhysicalTypeID::DOUBLE: { return encodeTemplate; } - case FLOAT: { + case PhysicalTypeID::FLOAT: { return encodeTemplate; } - case STRING: { + case PhysicalTypeID::STRING: { return encodeTemplate; } - case DATE: { - return encodeTemplate; - } - case TIMESTAMP: { - return encodeTemplate; - } - case INTERVAL: { + case PhysicalTypeID::INTERVAL: { return encodeTemplate; } default: { - throw RuntimeException("Cannot encode data type " + Types::dataTypeToString(typeId)); + throw RuntimeException("Cannot encode data with physical type: " + + common::PhysicalTypeUtils::physicalTypeToString(physicalType)); } } } diff --git a/src/processor/operator/recursive_extend/recursive_join.cpp b/src/processor/operator/recursive_extend/recursive_join.cpp index 463d577bca..10dcd25df7 100644 --- a/src/processor/operator/recursive_extend/recursive_join.cpp +++ b/src/processor/operator/recursive_extend/recursive_join.cpp @@ -102,9 +102,11 @@ static std::unique_ptr populateResultSetWithTwoDataChunks() { auto resultSet = std::make_unique(2); auto dataChunk0 = std::make_shared(1); dataChunk0->state = common::DataChunkState::getSingleValueDataChunkState(); - dataChunk0->insert(0, std::make_shared(common::INTERNAL_ID, nullptr)); + dataChunk0->insert( + 0, std::make_shared(common::LogicalTypeID::INTERNAL_ID, nullptr)); auto dataChunk1 = std::make_shared(1); - dataChunk1->insert(0, std::make_shared(common::INTERNAL_ID, nullptr)); + dataChunk1->insert( + 0, std::make_shared(common::LogicalTypeID::INTERNAL_ID, nullptr)); resultSet->insert(0, std::move(dataChunk0)); resultSet->insert(1, std::move(dataChunk1)); return resultSet; @@ -115,8 +117,10 @@ static std::unique_ptr populateResultSetWithOneDataChunk() { auto resultSet = std::make_unique(1); auto dataChunk0 = std::make_shared(2); dataChunk0->state = common::DataChunkState::getSingleValueDataChunkState(); - dataChunk0->insert(0, std::make_shared(common::INTERNAL_ID, nullptr)); - dataChunk0->insert(1, std::make_shared(common::INTERNAL_ID, nullptr)); + dataChunk0->insert( + 0, std::make_shared(common::LogicalTypeID::INTERNAL_ID, nullptr)); + dataChunk0->insert( + 1, std::make_shared(common::LogicalTypeID::INTERNAL_ID, nullptr)); resultSet->insert(0, std::move(dataChunk0)); return resultSet; } diff --git a/src/processor/operator/result_collector.cpp b/src/processor/operator/result_collector.cpp index 2fa9731e33..278d4b7280 100644 --- a/src/processor/operator/result_collector.cpp +++ b/src/processor/operator/result_collector.cpp @@ -54,7 +54,7 @@ std::unique_ptr ResultCollector::populateTableSchema() { auto [dataPos, dataType] = payloadsPosAndType[i]; tableSchema->appendColumn( std::make_unique(!isPayloadFlat[i], dataPos.dataChunkPos, - isPayloadFlat[i] ? Types::getDataTypeSize(dataType) : + isPayloadFlat[i] ? FactorizedTable::getDataTypeSize(dataType) : (uint32_t)sizeof(overflow_value_t))); } return tableSchema; diff --git a/src/processor/operator/semi_masker.cpp b/src/processor/operator/semi_masker.cpp index 9f690c28c0..fe7004fcc6 100644 --- a/src/processor/operator/semi_masker.cpp +++ b/src/processor/operator/semi_masker.cpp @@ -18,7 +18,7 @@ void BaseSemiMasker::initGlobalStateInternal(ExecutionContext* context) { void BaseSemiMasker::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) { keyValueVector = resultSet->getValueVector(keyDataPos); - assert(keyValueVector->dataType.typeID == INTERNAL_ID); + assert(keyValueVector->dataType.getLogicalTypeID() == LogicalTypeID::INTERNAL_ID); for (auto& [table, masks] : masksPerTable) { for (auto& maskWithIdx : masks) { maskWithIdx.first->init(transaction); diff --git a/src/processor/operator/update/create.cpp b/src/processor/operator/update/create.cpp index f920797784..feeaae1d49 100644 --- a/src/processor/operator/update/create.cpp +++ b/src/processor/operator/update/create.cpp @@ -60,7 +60,7 @@ bool CreateRel::getNextTuplesInternal(ExecutionContext* context) { // Rel ID is our interval property, so we overwrite relID=$expr with system ID. if (j == createRelInfo->relIDEvaluatorIdx) { auto relIDVector = evaluator->resultVector; - assert(relIDVector->dataType.typeID == INTERNAL_ID && + assert(relIDVector->dataType.getLogicalTypeID() == LogicalTypeID::INTERNAL_ID && relIDVector->state->selVector->selectedPositions[0] == 0); relIDVector->setValue(0, relsStatistics.getNextRelOffset( transaction, createRelInfo->table->getRelTableID())); diff --git a/src/processor/processor.cpp b/src/processor/processor.cpp index dcc5b5458a..ea73ac95b1 100644 --- a/src/processor/processor.cpp +++ b/src/processor/processor.cpp @@ -86,11 +86,12 @@ void QueryProcessor::decomposePlanIntoTasks( std::shared_ptr QueryProcessor::getFactorizedTableForOutputMsg( std::string& outputMsg, MemoryManager* memoryManager) { auto ftTableSchema = std::make_unique(); - ftTableSchema->appendColumn(std::make_unique( - false /* flat */, 0 /* dataChunkPos */, Types::getDataTypeSize(STRING))); + ftTableSchema->appendColumn( + std::make_unique(false /* flat */, 0 /* dataChunkPos */, + FactorizedTable::getDataTypeSize(LogicalType{LogicalTypeID::STRING}))); auto factorizedTable = std::make_shared(memoryManager, std::move(ftTableSchema)); - auto outputMsgVector = std::make_shared(STRING, memoryManager); + auto outputMsgVector = std::make_shared(LogicalTypeID::STRING, memoryManager); auto outputMsgChunk = std::make_shared(1 /* numValueVectors */); outputMsgChunk->insert(0 /* pos */, outputMsgVector); ku_string_t outputKUStr = ku_string_t(); diff --git a/src/processor/result/factorized_table.cpp b/src/processor/result/factorized_table.cpp index fc07d6d988..29057c9363 100644 --- a/src/processor/result/factorized_table.cpp +++ b/src/processor/result/factorized_table.cpp @@ -282,11 +282,12 @@ void FactorizedTable::setNonOverflowColNull(uint8_t* nullBuffer, ft_col_idx_t co void FactorizedTable::copyToInMemList(ft_col_idx_t colIdx, std::vector& tupleIdxesToRead, uint8_t* data, NullMask* nullMask, uint64_t startElemPosInList, DiskOverflowFile* overflowFileOfInMemList, - const DataType& type) const { + const LogicalType& type) const { auto column = tableSchema->getColumn(colIdx); assert(column->isFlat() == true); - auto numBytesPerValue = - type.typeID == INTERNAL_ID ? sizeof(offset_t) : Types::getDataTypeSize(type); + auto numBytesPerValue = type.getLogicalTypeID() == LogicalTypeID::INTERNAL_ID ? + sizeof(offset_t) : + getDataTypeSize(type); auto colOffset = tableSchema->getColOffset(colIdx); auto listToFill = data + startElemPosInList * numBytesPerValue; for (auto i = 0u; i < tupleIdxesToRead.size(); i++) { @@ -330,6 +331,33 @@ int64_t FactorizedTable::findValueInFlatColumn(ft_col_idx_t colIdx, int64_t valu return -1; } +uint32_t FactorizedTable::getDataTypeSize(const common::LogicalType& type) { + switch (type.getLogicalTypeID()) { + case LogicalTypeID::STRING: { + return sizeof(ku_string_t); + } + case LogicalTypeID::FIXED_LIST: { + return getDataTypeSize(*FixedListType::getChildType(&type)) * + FixedListType::getNumElementsInList(&type); + } + case LogicalTypeID::VAR_LIST: { + return sizeof(ku_list_t); + } + case LogicalTypeID::STRUCT: { + uint32_t size = 0; + auto structFieldsTypes = StructType::getStructFieldTypes(&type); + for (auto structFieldType : structFieldsTypes) { + size += getDataTypeSize(*structFieldType); + } + size += NullBuffer::getNumBytesForNullValues(structFieldsTypes.size()); + return size; + } + default: { + return LogicalTypeUtils::getFixedTypeSize(type.getPhysicalType()); + } + } +} + void FactorizedTable::clear() { numTuples = 0; flatTupleBlockCollection = std::make_unique( @@ -350,11 +378,11 @@ uint64_t FactorizedTable::computeNumTuplesToAppend( auto unflatDataChunkPos = -1ul; auto numTuplesToAppend = 1ul; for (auto i = 0u; i < vectorsToAppend.size(); i++) { - // If the caller tries to append an unflat vector to a flat column in the factorizedTable, - // the factorizedTable needs to flatten that vector. + // If the caller tries to append an unflat vector to a flat column in the + // factorizedTable, the factorizedTable needs to flatten that vector. if (tableSchema->getColumn(i)->isFlat() && !vectorsToAppend[i]->state->isFlat()) { - // The caller is not allowed to append multiple unflat columns from different datachunks - // to multiple flat columns in the factorizedTable. + // The caller is not allowed to append multiple unflat columns from different + // datachunks to multiple flat columns in the factorizedTable. if (unflatDataChunkPos != -1 && tableSchema->getColumn(i)->getDataChunkPos() != unflatDataChunkPos) { assert(false); @@ -497,7 +525,7 @@ overflow_value_t FactorizedTable::appendVectorToUnflatTupleBlocks( const ValueVector& vector, ft_col_idx_t colIdx) { assert(!vector.state->isFlat()); auto numFlatTuplesInVector = vector.state->selVector->selectedSize; - auto numBytesPerValue = Types::getDataTypeSize(vector.dataType); + auto numBytesPerValue = getDataTypeSize(vector.dataType); auto numBytesForData = numBytesPerValue * numFlatTuplesInVector; auto overflowBlockBuffer = allocateUnflatTupleBlock( numBytesForData + NullBuffer::getNumBytesForNullValues(numFlatTuplesInVector)); @@ -645,16 +673,16 @@ void FactorizedTable::readFlatColToUnflatVector(uint8_t** tuplesToRead, ft_col_i } void FactorizedTable::copyOverflowIfNecessary( - uint8_t* dst, uint8_t* src, const DataType& type, DiskOverflowFile* diskOverflowFile) { - switch (type.typeID) { - case STRING: { + uint8_t* dst, uint8_t* src, const LogicalType& type, DiskOverflowFile* diskOverflowFile) { + switch (type.getLogicalTypeID()) { + case LogicalTypeID::STRING: { ku_string_t* stringToWriteFrom = (ku_string_t*)src; if (!ku_string_t::isShortString(stringToWriteFrom->len)) { diskOverflowFile->writeStringOverflowAndUpdateOverflowPtr( *stringToWriteFrom, *(ku_string_t*)dst); } } break; - case VAR_LIST: { + case LogicalTypeID::VAR_LIST: { diskOverflowFile->writeListOverflowAndUpdateOverflowPtr( *(ku_list_t*)src, *(ku_list_t*)dst, type); } break; @@ -707,7 +735,8 @@ void FlatTupleIterator::readUnflatColToFlatTuple(ft_col_idx_t colIdx, uint8_t* v auto overflowValue = (overflow_value_t*)(valueBuffer + factorizedTable.getTableSchema()->getColOffset(colIdx)); auto columnInFactorizedTable = factorizedTable.getTableSchema()->getColumn(colIdx); - auto tupleSizeInOverflowBuffer = Types::getDataTypeSize(values[colIdx]->getDataType()); + auto tupleSizeInOverflowBuffer = + FactorizedTable::getDataTypeSize(values[colIdx]->getDataType()); valueBuffer = overflowValue->value + tupleSizeInOverflowBuffer * diff --git a/src/storage/copier/node_copier.cpp b/src/storage/copier/node_copier.cpp index 3263a71d33..5d319178dd 100644 --- a/src/storage/copier/node_copier.cpp +++ b/src/storage/copier/node_copier.cpp @@ -64,7 +64,7 @@ NodeCopier::NodeCopier(const std::string& directory, pkColumnID} { for (auto i = 0u; i < schema->properties.size(); i++) { auto property = schema->properties[i]; - if (property.dataType.typeID == common::SERIAL) { + if (property.dataType.getLogicalTypeID() == common::LogicalTypeID::SERIAL) { // Skip SERIAL, as it is not physically stored. continue; } @@ -100,11 +100,11 @@ void NodeCopier::populatePKIndex(InMemColumnChunk* chunk, InMemOverflowFile* ove } } // No nulls, so we can populate the index with actual values. - switch (chunk->getDataType().typeID) { - case INT64: { + switch (chunk->getDataType().getLogicalTypeID()) { + case LogicalTypeID::INT64: { appendToPKIndex(chunk, startOffset, numValues); } break; - case STRING: { + case LogicalTypeID::STRING: { appendToPKIndex( chunk, startOffset, numValues, overflowFile); } break; diff --git a/src/storage/copier/node_copy_executor.cpp b/src/storage/copier/node_copy_executor.cpp index 5358490a9b..a5c12914fd 100644 --- a/src/storage/copier/node_copy_executor.cpp +++ b/src/storage/copier/node_copy_executor.cpp @@ -30,7 +30,7 @@ void NodeCopyExecutor::populateColumns(processor::ExecutionContext* executionCon logger->info("Populating properties"); auto primaryKey = reinterpret_cast(tableSchema)->getPrimaryKey(); std::unique_ptr pkIndex; - if (primaryKey.dataType.typeID != common::SERIAL) { + if (primaryKey.dataType.getLogicalTypeID() != common::LogicalTypeID::SERIAL) { pkIndex = std::make_unique( StorageUtils::getNodeIndexFName( this->outputDirectory, tableSchema->tableID, common::DBFileType::WAL_VERSION), diff --git a/src/storage/copier/npy_reader.cpp b/src/storage/copier/npy_reader.cpp index 0b79035e40..d67cfbdee8 100644 --- a/src/storage/copier/npy_reader.cpp +++ b/src/storage/copier/npy_reader.cpp @@ -9,6 +9,7 @@ #include "common/string_utils.h" #include "common/utils.h" #include "pyparse.h" +#include "storage/storage_utils.h" using namespace kuzu::common; @@ -47,8 +48,9 @@ uint8_t* NpyReader::getPointerToRow(size_t row) const { if (row >= getNumRows()) { return nullptr; } - return (uint8_t*)((char*)mmapRegion + dataOffset + - row * getNumElementsPerRow() * common::Types::getDataTypeSize(type)); + return ( + uint8_t*)((char*)mmapRegion + dataOffset + + row * getNumElementsPerRow() * StorageUtils::getDataTypeSize(LogicalType{type})); } void NpyReader::parseHeader() { @@ -119,21 +121,21 @@ void NpyReader::parseType(std::string descr) { descr = descr.substr(1); } if (descr == "f8") { - type = DOUBLE; + type = LogicalTypeID::DOUBLE; } else if (descr == "f4") { - type = FLOAT; + type = LogicalTypeID::FLOAT; } else if (descr == "i8") { - type = INT64; + type = LogicalTypeID::INT64; } else if (descr == "i4") { - type = INT32; + type = LogicalTypeID::INT32; } else if (descr == "i2") { - type = INT16; + type = LogicalTypeID::INT16; } else { throw CopyException("Unsupported data type: " + descr); } } -void NpyReader::validate(DataType& type_, offset_t numRows, const std::string& tableName) { +void NpyReader::validate(LogicalType& type_, offset_t numRows, const std::string& tableName) { auto numNodesInFile = getNumRows(); if (numNodesInFile == 0) { throw CopyException( @@ -143,7 +145,7 @@ void NpyReader::validate(DataType& type_, offset_t numRows, const std::string& t throw CopyException("Number of rows in npy files is not equal to each other."); } // TODO(Guodong): Set npy reader data type to FIXED_LIST, so we can simplify checks here. - if (type_.typeID == this->type) { + if (type_.getLogicalTypeID() == this->type) { if (getNumElementsPerRow() != 1) { throw CopyException( StringUtils::string_format("Cannot copy a vector property in npy file {} to a " @@ -151,14 +153,13 @@ void NpyReader::validate(DataType& type_, offset_t numRows, const std::string& t filePath, tableName)); } return; - } else if (type_.typeID == DataTypeID::FIXED_LIST) { - if (this->type != type_.getChildType()->typeID) { + } else if (type_.getLogicalTypeID() == LogicalTypeID::FIXED_LIST) { + if (this->type != FixedListType::getChildType(&type_)->getLogicalTypeID()) { throw CopyException(StringUtils::string_format("The type of npy file {} does not " "match the type defined in table {}.", filePath, tableName)); } - auto fixedListInfo = reinterpret_cast(type_.getExtraTypeInfo()); - if (getNumElementsPerRow() != fixedListInfo->getFixedNumElementsInList()) { + if (getNumElementsPerRow() != FixedListType::getNumElementsInList(&type_)) { throw CopyException( StringUtils::string_format("The shape of {} does not match the length of the " "fixed list property in table " diff --git a/src/storage/copier/rel_copy_executor.cpp b/src/storage/copier/rel_copy_executor.cpp index e132c5104f..bf2d36dd5f 100644 --- a/src/storage/copier/rel_copy_executor.cpp +++ b/src/storage/copier/rel_copy_executor.cpp @@ -50,7 +50,8 @@ void RelCopyExecutor::initializeColumnsAndLists() { } } for (auto& property : tableSchema->properties) { - if (property.dataType.typeID == VAR_LIST || property.dataType.typeID == STRING) { + if (property.dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST || + property.dataType.getLogicalTypeID() == LogicalTypeID::STRING) { overflowFilePerPropertyID[property.propertyID] = std::make_unique(); } } @@ -96,9 +97,9 @@ void RelCopyExecutor::initializeColumns(RelDataDirection relDirection) { adjColumnsPerDirection[relDirection] = std::make_unique( StorageUtils::getAdjColumnFName( outputDirectory, tableSchema->tableID, relDirection, DBFileType::WAL_VERSION), - DataType(INTERNAL_ID)); - adjColumnChunksPerDirection[relDirection] = - std::make_unique(DataType(INTERNAL_ID), 0, numNodes - 1); + LogicalType(LogicalTypeID::INTERNAL_ID)); + adjColumnChunksPerDirection[relDirection] = std::make_unique( + LogicalType(LogicalTypeID::INTERNAL_ID), 0, numNodes - 1); std::unordered_map> propertyColumns; std::unordered_map> propertyColumnChunks; for (auto i = 0u; i < tableSchema->getNumProperties(); ++i) { @@ -172,7 +173,7 @@ void RelCopyExecutor::initListsMetadata() { for (auto& property : tableSchema->properties) { taskScheduler.scheduleTask(CopyTaskFactory::createCopyTask( calculateListsMetadataAndAllocateInMemListPagesTask, numNodes, - Types::getDataTypeSize(property.dataType), listSizes, + storage::StorageUtils::getDataTypeSize(property.dataType), listSizes, adjLists->getListHeadersBuilder().get(), propertyListsPerDirection[relDirection][property.propertyID].get(), true /*hasNULLBytes*/, logger)); @@ -292,7 +293,8 @@ void RelCopyExecutor::sortAndCopyOverflowValues() { auto numBuckets = numNodes / 256; numBuckets += (numNodes % 256 != 0); for (auto& property : tableSchema->properties) { - if (property.dataType.typeID == STRING || property.dataType.typeID == VAR_LIST) { + if (property.dataType.getLogicalTypeID() == LogicalTypeID::STRING || + property.dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST) { offset_t offsetStart = 0, offsetEnd = 0; for (auto bucketIdx = 0u; bucketIdx < numBuckets; bucketIdx++) { offsetStart = offsetEnd; @@ -322,7 +324,8 @@ void RelCopyExecutor::sortAndCopyOverflowValues() { numBuckets += (numNodes % 256 != 0); // TODO(Semih): Schedule one at a time. for (auto& property : tableSchema->properties) { - if (property.dataType.typeID == STRING || property.dataType.typeID == VAR_LIST) { + if (property.dataType.getLogicalTypeID() == LogicalTypeID::STRING || + property.dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST) { offset_t offsetStart = 0, offsetEnd = 0; for (auto bucketIdx = 0u; bucketIdx < numBuckets; bucketIdx++) { offsetStart = offsetEnd; @@ -348,7 +351,7 @@ void RelCopyExecutor::sortAndCopyOverflowValues() { template void RelCopyExecutor::inferTableIDsAndOffsets(const std::vector>& batchColumns, - std::vector& nodeIDs, std::vector& nodeIDTypes, + std::vector& nodeIDs, std::vector& nodeIDTypes, const std::map& pkIndexes, Transaction* transaction, int64_t blockOffset, int64_t& colIndex) { for (auto& relDirection : REL_DIRECTIONS) { @@ -358,19 +361,19 @@ void RelCopyExecutor::inferTableIDsAndOffsets(const std::vectorGetScalar(blockOffset)->get()->ToString(); auto keyStr = keyToken.c_str(); ++colIndex; - switch (nodeIDTypes[relDirection].typeID) { - case SERIAL: { + switch (nodeIDTypes[relDirection].getLogicalTypeID()) { + case LogicalTypeID::SERIAL: { auto key = TypeUtils::convertStringToNumber(keyStr); nodeIDs[relDirection].offset = key; } break; - case INT64: { + case LogicalTypeID::INT64: { auto key = TypeUtils::convertStringToNumber(keyStr); if (!pkIndexes.at(nodeIDs[relDirection].tableID) ->lookup(transaction, key, nodeIDs[relDirection].offset)) { throw CopyException("Cannot find key: " + std::to_string(key) + " in the pkIndex."); } } break; - case STRING: { + case LogicalTypeID::STRING: { if (!pkIndexes.at(nodeIDs[relDirection].tableID) ->lookup(transaction, keyStr, nodeIDs[relDirection].offset)) { throw CopyException("Cannot find key: " + std::string(keyStr) + " in the pkIndex."); @@ -378,7 +381,7 @@ void RelCopyExecutor::inferTableIDsAndOffsets(const std::vectorget()->ToString().substr(0, BufferPoolConstants::PAGE_4KB_SIZE); const char* data = stringToken.c_str(); - switch (properties[propertyID].dataType.typeID) { - case INT64: { + switch (properties[propertyID].dataType.getLogicalTypeID()) { + case LogicalTypeID::INT64: { auto val = TypeUtils::convertStringToNumber(data); putValueIntoColumns(propertyID, directionTablePropertyColumnChunks, nodeIDs, reinterpret_cast(&val)); } break; - case INT32: { + case LogicalTypeID::INT32: { auto val = TypeUtils::convertStringToNumber(data); putValueIntoColumns(propertyID, directionTablePropertyColumnChunks, nodeIDs, reinterpret_cast(&val)); } break; - case INT16: { + case LogicalTypeID::INT16: { auto val = TypeUtils::convertStringToNumber(data); putValueIntoColumns(propertyID, directionTablePropertyColumnChunks, nodeIDs, reinterpret_cast(&val)); } break; - case DOUBLE: { + case LogicalTypeID::DOUBLE: { auto val = TypeUtils::convertStringToNumber(data); putValueIntoColumns(propertyID, directionTablePropertyColumnChunks, nodeIDs, reinterpret_cast(&val)); } break; - case BOOL: { + case LogicalTypeID::FLOAT: { + auto val = TypeUtils::convertStringToNumber(data); + putValueIntoColumns(propertyID, directionTablePropertyColumnChunks, nodeIDs, + reinterpret_cast(&val)); + } break; + case LogicalTypeID::BOOL: { auto val = TypeUtils::convertToBoolean(data); putValueIntoColumns(propertyID, directionTablePropertyColumnChunks, nodeIDs, reinterpret_cast(&val)); } break; - case DATE: { + case LogicalTypeID::DATE: { auto val = Date::FromCString(data, stringToken.length()); putValueIntoColumns(propertyID, directionTablePropertyColumnChunks, nodeIDs, reinterpret_cast(&val)); } break; - case TIMESTAMP: { + case LogicalTypeID::TIMESTAMP: { auto val = Timestamp::FromCString(data, stringToken.length()); putValueIntoColumns(propertyID, directionTablePropertyColumnChunks, nodeIDs, reinterpret_cast(&val)); } break; - case INTERVAL: { + case LogicalTypeID::INTERVAL: { auto val = Interval::FromCString(data, stringToken.length()); putValueIntoColumns(propertyID, directionTablePropertyColumnChunks, nodeIDs, reinterpret_cast(&val)); } break; - case STRING: { + case LogicalTypeID::STRING: { auto kuStr = inMemOverflowFilePerPropertyID[propertyID]->copyString( data, strlen(data), inMemOverflowFileCursors[propertyID]); putValueIntoColumns(propertyID, directionTablePropertyColumnChunks, nodeIDs, reinterpret_cast(&kuStr)); } break; - case VAR_LIST: { + case LogicalTypeID::VAR_LIST: { auto varListVal = getArrowVarList(stringToken, 1, stringToken.length() - 2, properties[propertyID].dataType, copier->copyDescription); auto kuList = inMemOverflowFilePerPropertyID[propertyID]->copyList( @@ -460,22 +468,17 @@ void RelCopyExecutor::putPropsOfLineIntoColumns(RelCopyExecutor* copier, putValueIntoColumns(propertyID, directionTablePropertyColumnChunks, nodeIDs, reinterpret_cast(&kuList)); } break; - case FIXED_LIST: { + case LogicalTypeID::FIXED_LIST: { auto fixedListVal = getArrowFixedList(stringToken, 1, stringToken.length() - 2, properties[propertyID].dataType, copier->copyDescription); putValueIntoColumns( propertyID, directionTablePropertyColumnChunks, nodeIDs, fixedListVal.get()); } break; - case FLOAT: { - auto val = TypeUtils::convertStringToNumber(data); - putValueIntoColumns(propertyID, directionTablePropertyColumnChunks, nodeIDs, - reinterpret_cast(&val)); - } break; default: - throw NotImplementedException( - "Not supported data type " + - Types::dataTypeToString(properties[propertyID].dataType.typeID) + - " for RelCopyExecutor::putPropsOfLineIntoColumns."); + throw NotImplementedException("Not supported data type " + + LogicalTypeUtils::dataTypeToString( + properties[propertyID].dataType.getLogicalTypeID()) + + " for RelCopyExecutor::putPropsOfLineIntoColumns."); } ++colIndex; } @@ -504,54 +507,54 @@ void RelCopyExecutor::putPropsOfLineIntoLists(RelCopyExecutor* copier, auto stringToken = currentToken->get()->ToString().substr(0, BufferPoolConstants::PAGE_4KB_SIZE); const char* data = stringToken.c_str(); - switch (properties[propertyID].dataType.typeID) { - case INT64: { + switch (properties[propertyID].dataType.getLogicalTypeID()) { + case LogicalTypeID::INT64: { auto val = TypeUtils::convertStringToNumber(data); putValueIntoLists(propertyID, directionTablePropertyLists, directionTableAdjLists, nodeIDs, reversePos, reinterpret_cast(&val)); } break; - case INT32: { + case LogicalTypeID::INT32: { auto val = TypeUtils::convertStringToNumber(data); putValueIntoLists(propertyID, directionTablePropertyLists, directionTableAdjLists, nodeIDs, reversePos, reinterpret_cast(&val)); } break; - case INT16: { + case LogicalTypeID::INT16: { auto val = TypeUtils::convertStringToNumber(data); putValueIntoLists(propertyID, directionTablePropertyLists, directionTableAdjLists, nodeIDs, reversePos, reinterpret_cast(&val)); } break; - case DOUBLE: { + case LogicalTypeID::DOUBLE: { auto val = TypeUtils::convertStringToNumber(data); putValueIntoLists(propertyID, directionTablePropertyLists, directionTableAdjLists, nodeIDs, reversePos, reinterpret_cast(&val)); } break; - case BOOL: { + case LogicalTypeID::BOOL: { auto val = TypeUtils::convertToBoolean(data); putValueIntoLists(propertyID, directionTablePropertyLists, directionTableAdjLists, nodeIDs, reversePos, reinterpret_cast(&val)); } break; - case DATE: { + case LogicalTypeID::DATE: { auto val = Date::FromCString(data, stringToken.length()); putValueIntoLists(propertyID, directionTablePropertyLists, directionTableAdjLists, nodeIDs, reversePos, reinterpret_cast(&val)); } break; - case TIMESTAMP: { + case LogicalTypeID::TIMESTAMP: { auto val = Timestamp::FromCString(data, stringToken.length()); putValueIntoLists(propertyID, directionTablePropertyLists, directionTableAdjLists, nodeIDs, reversePos, reinterpret_cast(&val)); } break; - case INTERVAL: { + case LogicalTypeID::INTERVAL: { auto val = Interval::FromCString(data, stringToken.length()); putValueIntoLists(propertyID, directionTablePropertyLists, directionTableAdjLists, nodeIDs, reversePos, reinterpret_cast(&val)); } break; - case STRING: { + case LogicalTypeID::STRING: { auto kuStr = inMemOverflowFilesPerProperty[propertyID]->copyString( data, strlen(data), inMemOverflowFileCursors[propertyID]); putValueIntoLists(propertyID, directionTablePropertyLists, directionTableAdjLists, nodeIDs, reversePos, reinterpret_cast(&kuStr)); } break; - case VAR_LIST: { + case LogicalTypeID::VAR_LIST: { auto varListVal = getArrowVarList(stringToken, 1, stringToken.length() - 2, properties[propertyID].dataType, copyDescription); auto kuList = inMemOverflowFilesPerProperty[propertyID]->copyList( @@ -559,22 +562,22 @@ void RelCopyExecutor::putPropsOfLineIntoLists(RelCopyExecutor* copier, putValueIntoLists(propertyID, directionTablePropertyLists, directionTableAdjLists, nodeIDs, reversePos, reinterpret_cast(&kuList)); } break; - case FIXED_LIST: { + case LogicalTypeID::FIXED_LIST: { auto fixedListVal = getArrowFixedList(stringToken, 1, stringToken.length() - 2, properties[propertyID].dataType, copyDescription); putValueIntoLists(propertyID, directionTablePropertyLists, directionTableAdjLists, nodeIDs, reversePos, fixedListVal.get()); } break; - case FLOAT: { + case LogicalTypeID::FLOAT: { auto val = TypeUtils::convertStringToNumber(data); putValueIntoLists(propertyID, directionTablePropertyLists, directionTableAdjLists, nodeIDs, reversePos, reinterpret_cast(&val)); } break; default: - throw NotImplementedException( - "Not supported data type " + - Types::dataTypeToString(properties[propertyID].dataType.typeID) + - " for RelCopyExecutor::putPropsOfLineIntoLists."); + throw NotImplementedException("Not supported data type " + + LogicalTypeUtils::dataTypeToString( + properties[propertyID].dataType.getLogicalTypeID()) + + " for RelCopyExecutor::putPropsOfLineIntoLists."); } ++colIndex; } @@ -594,13 +597,13 @@ void RelCopyExecutor::copyStringOverflowFromUnorderedToOrderedPages(ku_string_t* } void RelCopyExecutor::copyListOverflowFromUnorderedToOrderedPages(ku_list_t* kuList, - const DataType& dataType, PageByteCursor& unorderedOverflowCursor, + const LogicalType& dataType, PageByteCursor& unorderedOverflowCursor, PageByteCursor& orderedOverflowCursor, InMemOverflowFile* unorderedOverflowFile, InMemOverflowFile* orderedOverflowFile) { TypeUtils::decodeOverflowPtr( kuList->overflowPtr, unorderedOverflowCursor.pageIdx, unorderedOverflowCursor.offsetInPage); orderedOverflowFile->copyListOverflowFromFile(unorderedOverflowFile, unorderedOverflowCursor, - orderedOverflowCursor, kuList, dataType.getChildType()); + orderedOverflowCursor, kuList, VarListType::getChildType(&dataType)); } template @@ -610,7 +613,7 @@ void RelCopyExecutor::populateAdjColumnsAndCountRelsInAdjListsTask(uint64_t bloc copier->logger->debug("Start: path=`{0}` blkIdx={1}", filePath, blockIdx); std::vector requireToReadTableLabels{true, true}; std::vector nodeIDs{2}; - std::vector nodePKTypes{2}; + std::vector nodePKTypes{2}; auto relTableSchema = reinterpret_cast(copier->tableSchema); for (auto& relDirection : REL_DIRECTIONS) { auto boundTableID = relTableSchema->getBoundTableID(relDirection); @@ -665,7 +668,7 @@ void RelCopyExecutor::populateListsTask(uint64_t blockId, uint64_t blockStartRel const std::string& filePath) { copier->logger->trace("Start: path=`{0}` blkIdx={1}", filePath, blockId); std::vector nodeIDs(2); - std::vector nodePKTypes(2); + std::vector nodePKTypes(2); std::vector reversePos(2); auto relTableSchema = reinterpret_cast(copier->tableSchema); for (auto relDirection : REL_DIRECTIONS) { @@ -706,16 +709,16 @@ void RelCopyExecutor::populateListsTask(uint64_t blockId, uint64_t blockStartRel copier->logger->trace("End: path=`{0}` blkIdx={1}", filePath, blockId); } -void RelCopyExecutor::sortOverflowValuesOfPropertyColumnTask(const DataType& dataType, +void RelCopyExecutor::sortOverflowValuesOfPropertyColumnTask(const LogicalType& dataType, offset_t offsetStart, offset_t offsetEnd, InMemColumnChunk* propertyColumnChunk, InMemOverflowFile* unorderedInMemOverflowFile, InMemOverflowFile* orderedInMemOverflowFile) { PageByteCursor unorderedOverflowCursor, orderedOverflowCursor; for (; offsetStart < offsetEnd; offsetStart++) { - if (dataType.typeID == STRING) { + if (dataType.getLogicalTypeID() == LogicalTypeID::STRING) { auto kuStr = propertyColumnChunk->getValue(offsetStart); copyStringOverflowFromUnorderedToOrderedPages(&kuStr, unorderedOverflowCursor, orderedOverflowCursor, unorderedInMemOverflowFile, orderedInMemOverflowFile); - } else if (dataType.typeID == VAR_LIST) { + } else if (dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST) { auto kuList = propertyColumnChunk->getValue(offsetStart); copyListOverflowFromUnorderedToOrderedPages(&kuList, dataType, unorderedOverflowCursor, orderedOverflowCursor, unorderedInMemOverflowFile, orderedInMemOverflowFile); @@ -725,7 +728,7 @@ void RelCopyExecutor::sortOverflowValuesOfPropertyColumnTask(const DataType& dat } } -void RelCopyExecutor::sortOverflowValuesOfPropertyListsTask(const DataType& dataType, +void RelCopyExecutor::sortOverflowValuesOfPropertyListsTask(const LogicalType& dataType, offset_t offsetStart, offset_t offsetEnd, InMemAdjLists* adjLists, InMemLists* propertyLists, InMemOverflowFile* unorderedInMemOverflowFile, InMemOverflowFile* orderedInMemOverflowFile) { PageByteCursor unorderedOverflowCursor, orderedOverflowCursor; @@ -733,14 +736,15 @@ void RelCopyExecutor::sortOverflowValuesOfPropertyListsTask(const DataType& data for (; offsetStart < offsetEnd; offsetStart++) { csr_offset_t listsLen = adjLists->getListSize(offsetStart); for (auto pos = listsLen; pos > 0; pos--) { - propertyListCursor = propertyLists->calcPageElementCursor( - pos, Types::getDataTypeSize(dataType), offsetStart, true /*hasNULLBytes*/); - if (dataType.typeID == STRING) { + propertyListCursor = propertyLists->calcPageElementCursor(pos, + storage::StorageUtils::getDataTypeSize(dataType), offsetStart, + true /*hasNULLBytes*/); + if (dataType.getLogicalTypeID() == LogicalTypeID::STRING) { auto kuStr = reinterpret_cast(propertyLists->getMemPtrToLoc( propertyListCursor.pageIdx, propertyListCursor.elemPosInPage)); copyStringOverflowFromUnorderedToOrderedPages(kuStr, unorderedOverflowCursor, orderedOverflowCursor, unorderedInMemOverflowFile, orderedInMemOverflowFile); - } else if (dataType.typeID == VAR_LIST) { + } else if (dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST) { auto kuList = reinterpret_cast(propertyLists->getMemPtrToLoc( propertyListCursor.pageIdx, propertyListCursor.elemPosInPage)); copyListOverflowFromUnorderedToOrderedPages(kuList, dataType, diff --git a/src/storage/copier/table_copy_executor.cpp b/src/storage/copier/table_copy_executor.cpp index c4984c7cf8..5c770cb116 100644 --- a/src/storage/copier/table_copy_executor.cpp +++ b/src/storage/copier/table_copy_executor.cpp @@ -65,7 +65,7 @@ void TableCopyExecutor::countNumLinesCSV(const std::vector& filePat auto startNodeOffset = numRows; while (true) { throwCopyExceptionIfNotOK(csvStreamingReader->ReadNext(&currBatch)); - if (currBatch == NULL) { + if (currBatch == nullptr) { break; } ++numBlocks; @@ -120,7 +120,8 @@ void TableCopyExecutor::countNumLinesNpy(const std::vector& filePat } static bool skipCopyForProperty(const Property& property) { - return TableSchema::isReservedPropertyName(property.name) || property.dataType.typeID == SERIAL; + return TableSchema::isReservedPropertyName(property.name) || + property.dataType.getLogicalTypeID() == LogicalTypeID::SERIAL; } std::shared_ptr TableCopyExecutor::createCSVReader( @@ -205,79 +206,83 @@ std::vector> TableCopyExecutor::getListElementPos( } std::unique_ptr TableCopyExecutor::getArrowVarList(const std::string& l, int64_t from, - int64_t to, const DataType& dataType, const CopyDescription& copyDescription) { - assert(dataType.typeID == common::VAR_LIST || dataType.typeID == common::FIXED_LIST); + int64_t to, const LogicalType& dataType, const CopyDescription& copyDescription) { + assert(dataType.getLogicalTypeID() == common::LogicalTypeID::VAR_LIST); auto split = getListElementPos(l, from, to, copyDescription); std::vector> values; - auto childDataType = *dataType.getChildType(); + auto childDataType = VarListType::getChildType(&dataType); for (auto pair : split) { std::string element = l.substr(pair.first, pair.second); if (element.empty()) { continue; } - auto value = convertStringToValue(element, *dataType.getChildType(), copyDescription); + auto value = convertStringToValue(element, *childDataType, copyDescription); values.push_back(std::move(value)); } - auto numBytesOfOverflow = values.size() * Types::getDataTypeSize(childDataType.typeID); + auto numBytesOfOverflow = + values.size() * storage::StorageUtils::getDataTypeSize(*childDataType); if (numBytesOfOverflow >= BufferPoolConstants::PAGE_4KB_SIZE) { throw CopyException(StringUtils::string_format( "Maximum num bytes of a LIST is {}. Input list's num bytes is {}.", BufferPoolConstants::PAGE_4KB_SIZE, numBytesOfOverflow)); } return make_unique( - DataType(std::make_unique(childDataType)), std::move(values)); + LogicalType(common::LogicalTypeID::VAR_LIST, + std::make_unique(std::make_unique(*childDataType))), + std::move(values)); } std::unique_ptr TableCopyExecutor::getArrowFixedList(const std::string& l, int64_t from, - int64_t to, const DataType& dataType, const CopyDescription& copyDescription) { - assert(dataType.typeID == common::FIXED_LIST); + int64_t to, const LogicalType& dataType, const CopyDescription& copyDescription) { + assert(dataType.getLogicalTypeID() == common::LogicalTypeID::FIXED_LIST); auto split = getListElementPos(l, from, to, copyDescription); - auto listVal = std::make_unique(Types::getDataTypeSize(dataType)); - auto childDataType = *dataType.getChildType(); + auto listVal = std::make_unique(storage::StorageUtils::getDataTypeSize(dataType)); + auto childDataType = FixedListType::getChildType(&dataType); uint64_t numElementsRead = 0; for (auto pair : split) { std::string element = l.substr(pair.first, pair.second); if (element.empty()) { continue; } - switch (childDataType.typeID) { - case INT64: { + switch (childDataType->getLogicalTypeID()) { + case LogicalTypeID::INT64: { auto val = TypeUtils::convertStringToNumber(element.c_str()); memcpy(listVal.get() + numElementsRead * sizeof(int64_t), &val, sizeof(int64_t)); numElementsRead++; } break; - case INT32: { + case LogicalTypeID::INT32: { auto val = TypeUtils::convertStringToNumber(element.c_str()); memcpy(listVal.get() + numElementsRead * sizeof(int32_t), &val, sizeof(int32_t)); numElementsRead++; } break; - case INT16: { + case LogicalTypeID::INT16: { auto val = TypeUtils::convertStringToNumber(element.c_str()); memcpy(listVal.get() + numElementsRead * sizeof(int16_t), &val, sizeof(int16_t)); numElementsRead++; } break; - case DOUBLE: { + case LogicalTypeID::DOUBLE: { auto val = TypeUtils::convertStringToNumber(element.c_str()); memcpy(listVal.get() + numElementsRead * sizeof(double_t), &val, sizeof(double_t)); numElementsRead++; } break; - case FLOAT: { + case LogicalTypeID::FLOAT: { auto val = TypeUtils::convertStringToNumber(element.c_str()); memcpy(listVal.get() + numElementsRead * sizeof(float_t), &val, sizeof(float_t)); numElementsRead++; } break; default: { - throw CopyException("Unsupported data type " + - Types::dataTypeToString(dataType.getChildType()->typeID) + - " inside FIXED_LIST"); + throw CopyException( + "Unsupported data type " + + LogicalTypeUtils::dataTypeToString(*VarListType::getChildType(&dataType)) + + " inside FIXED_LIST"); } } } - auto extraTypeInfo = reinterpret_cast(dataType.getExtraTypeInfo()); - if (numElementsRead != extraTypeInfo->getFixedNumElementsInList()) { + auto numElementsInList = FixedListType::getNumElementsInList(&dataType); + if (numElementsRead != numElementsInList) { throw CopyException(StringUtils::string_format( "Each fixed list should have fixed number of elements. Expected: {}, Actual: {}.", - extraTypeInfo->getFixedNumElementsInList(), numElementsRead)); + numElementsInList, numElementsRead)); } return listVal; } @@ -289,87 +294,87 @@ void TableCopyExecutor::throwCopyExceptionIfNotOK(const arrow::Status& status) { } std::shared_ptr TableCopyExecutor::toArrowDataType( - const common::DataType& dataType) { - switch (dataType.typeID) { - case common::BOOL: { + const common::LogicalType& dataType) { + switch (dataType.getLogicalTypeID()) { + case common::LogicalTypeID::BOOL: { return arrow::boolean(); } - case common::INT64: { + case common::LogicalTypeID::INT64: { return arrow::int64(); } - case common::INT32: { + case common::LogicalTypeID::INT32: { return arrow::int32(); } - case common::INT16: { + case common::LogicalTypeID::INT16: { return arrow::int16(); } - case common::DOUBLE: { + case common::LogicalTypeID::DOUBLE: { return arrow::float64(); } - case common::FLOAT: { + case common::LogicalTypeID::FLOAT: { return arrow::float32(); } - case common::TIMESTAMP: - case common::DATE: - case common::INTERVAL: - case common::FIXED_LIST: - case common::VAR_LIST: - case common::STRING: - case common::STRUCT: { + case common::LogicalTypeID::TIMESTAMP: + case common::LogicalTypeID::DATE: + case common::LogicalTypeID::INTERVAL: + case common::LogicalTypeID::FIXED_LIST: + case common::LogicalTypeID::VAR_LIST: + case common::LogicalTypeID::STRING: + case common::LogicalTypeID::STRUCT: { return arrow::utf8(); } default: { - throw CopyException( - "Unsupported data type for CSV " + Types::dataTypeToString(dataType.typeID)); + throw CopyException("Unsupported data type for CSV " + + LogicalTypeUtils::dataTypeToString(dataType.getLogicalTypeID())); } } } std::unique_ptr TableCopyExecutor::convertStringToValue( - std::string element, const DataType& type, const CopyDescription& copyDescription) { + std::string element, const LogicalType& type, const CopyDescription& copyDescription) { std::unique_ptr value; - switch (type.typeID) { - case INT64: { + switch (type.getLogicalTypeID()) { + case LogicalTypeID::INT64: { value = std::make_unique(TypeUtils::convertStringToNumber(element.c_str())); } break; - case INT32: { + case LogicalTypeID::INT32: { value = std::make_unique(TypeUtils::convertStringToNumber(element.c_str())); } break; - case INT16: { + case LogicalTypeID::INT16: { value = std::make_unique(TypeUtils::convertStringToNumber(element.c_str())); } break; - case FLOAT: { + case LogicalTypeID::FLOAT: { value = std::make_unique(TypeUtils::convertStringToNumber(element.c_str())); } break; - case DOUBLE: { + case LogicalTypeID::DOUBLE: { value = std::make_unique(TypeUtils::convertStringToNumber(element.c_str())); } break; - case BOOL: { + case LogicalTypeID::BOOL: { transform(element.begin(), element.end(), element.begin(), ::tolower); std::istringstream is(element); bool b; is >> std::boolalpha >> b; value = std::make_unique(b); } break; - case STRING: { + case LogicalTypeID::STRING: { value = make_unique(element); } break; - case DATE: { + case LogicalTypeID::DATE: { value = std::make_unique(Date::FromCString(element.c_str(), element.length())); } break; - case TIMESTAMP: { + case LogicalTypeID::TIMESTAMP: { value = std::make_unique(Timestamp::FromCString(element.c_str(), element.length())); } break; - case INTERVAL: { + case LogicalTypeID::INTERVAL: { value = std::make_unique(Interval::FromCString(element.c_str(), element.length())); } break; - case VAR_LIST: { + case LogicalTypeID::VAR_LIST: { value = getArrowVarList(element, 1, element.length() - 2, type, copyDescription); } break; default: throw CopyException( - "Unsupported data type " + Types::dataTypeToString(type.typeID) + " inside LIST"); + "Unsupported data type " + LogicalTypeUtils::dataTypeToString(type) + " inside LIST"); } return value; } diff --git a/src/storage/in_mem_storage_structure/in_mem_column.cpp b/src/storage/in_mem_storage_structure/in_mem_column.cpp index 3e70897cac..9e4cb5c27e 100644 --- a/src/storage/in_mem_storage_structure/in_mem_column.cpp +++ b/src/storage/in_mem_storage_structure/in_mem_column.cpp @@ -9,16 +9,15 @@ using namespace kuzu::common; namespace kuzu { namespace storage { -InMemColumn::InMemColumn(std::string filePath, DataType dataType, bool requireNullBits) - : filePath{std::move(filePath)}, - numBytesForValue{(uint16_t)common::Types::getDataTypeSize(dataType)}, dataType{std::move( - dataType)} { +InMemColumn::InMemColumn(std::string filePath, LogicalType dataType, bool requireNullBits) + : filePath{std::move(filePath)}, numBytesForValue{( + uint16_t)storage::StorageUtils::getDataTypeSize(dataType)}, + dataType{std::move(dataType)} { // TODO(Guodong): Separate this as a function. - switch (this->dataType.typeID) { - case STRUCT: { - auto structTypeInfo = reinterpret_cast(this->dataType.getExtraTypeInfo()); - auto childTypes = structTypeInfo->getChildrenTypes(); - auto childNames = structTypeInfo->getChildrenNames(); + switch (this->dataType.getLogicalTypeID()) { + case LogicalTypeID::STRUCT: { + auto childTypes = common::StructType::getStructFieldTypes(&this->dataType); + auto childNames = common::StructType::getStructFieldNames(&this->dataType); childColumns.resize(childTypes.size()); for (auto i = 0u; i < childTypes.size(); i++) { childColumns[i] = std::make_unique( @@ -26,8 +25,8 @@ InMemColumn::InMemColumn(std::string filePath, DataType dataType, bool requireNu true /* hasNull */); } } break; - case STRING: - case VAR_LIST: { + case LogicalTypeID::STRING: + case LogicalTypeID::VAR_LIST: { inMemOverflowFile = std::make_unique(StorageUtils::getOverflowFileName(this->filePath)); fileHandle = std::make_unique( @@ -41,7 +40,7 @@ InMemColumn::InMemColumn(std::string filePath, DataType dataType, bool requireNu if (requireNullBits) { nullColumn = std::make_unique(StorageUtils::getPropertyNullFName(this->filePath), - DataType(BOOL), false /* hasNull */); + LogicalType(LogicalTypeID::BOOL), false /* hasNull */); } } diff --git a/src/storage/in_mem_storage_structure/in_mem_column_chunk.cpp b/src/storage/in_mem_storage_structure/in_mem_column_chunk.cpp index 8eb1e1438a..3839178d7e 100644 --- a/src/storage/in_mem_storage_structure/in_mem_column_chunk.cpp +++ b/src/storage/in_mem_storage_structure/in_mem_column_chunk.cpp @@ -10,20 +10,19 @@ namespace kuzu { namespace storage { InMemColumnChunk::InMemColumnChunk( - DataType dataType, offset_t startNodeOffset, offset_t endNodeOffset, bool requireNullBits) + LogicalType dataType, offset_t startNodeOffset, offset_t endNodeOffset, bool requireNullBits) : dataType{std::move(dataType)}, startNodeOffset{startNodeOffset} { numBytesPerValue = getDataTypeSizeInColumn(this->dataType); numBytes = numBytesPerValue * (endNodeOffset - startNodeOffset + 1); buffer = std::make_unique(numBytes); if (requireNullBits) { nullChunk = std::make_unique( - DataType{BOOL}, startNodeOffset, endNodeOffset, false /* hasNull */); + LogicalType{LogicalTypeID::BOOL}, startNodeOffset, endNodeOffset, false /* hasNull */); memset(nullChunk->getData(), UINT8_MAX, nullChunk->getNumBytes()); } // TODO(Guodong): Consider shifting to a hierarchy structure for STRING/LIST/STRUCT. - if (this->dataType.typeID == STRUCT) { - auto structDataInfo = reinterpret_cast(this->dataType.extraTypeInfo.get()); - auto childTypes = structDataInfo->getChildrenTypes(); + if (this->dataType.getLogicalTypeID() == LogicalTypeID::STRUCT) { + auto childTypes = common::StructType::getStructFieldTypes(&this->dataType); childChunks.resize(childTypes.size()); for (auto i = 0u; i < childTypes.size(); i++) { childChunks[i] = @@ -44,16 +43,16 @@ void InMemColumnChunk::flush(FileInfo* walFileInfo) { } } -uint32_t InMemColumnChunk::getDataTypeSizeInColumn(common::DataType& dataType) { - switch (dataType.typeID) { - case STRUCT: { +uint32_t InMemColumnChunk::getDataTypeSizeInColumn(common::LogicalType& dataType) { + switch (dataType.getLogicalTypeID()) { + case LogicalTypeID::STRUCT: { return 0; } - case INTERNAL_ID: { + case LogicalTypeID::INTERNAL_ID: { return sizeof(offset_t); } default: { - return Types::getDataTypeSize(dataType); + return storage::StorageUtils::getDataTypeSize(dataType); } } } @@ -83,29 +82,29 @@ template<> void InMemColumnChunk::templateCopyValuesToPage(arrow::Array& array, InMemOverflowFile* overflowFile, PageByteCursor& overflowCursor, CopyDescription& copyDesc) { - switch (dataType.typeID) { - case DATE: { + switch (dataType.getLogicalTypeID()) { + case LogicalTypeID::DATE: { templateCopyValuesAsStringToPage(array); } break; - case TIMESTAMP: { + case LogicalTypeID::TIMESTAMP: { templateCopyValuesAsStringToPage(array); } break; - case INTERVAL: { + case LogicalTypeID::INTERVAL: { templateCopyValuesAsStringToPage(array); } break; - case FIXED_LIST: { + case LogicalTypeID::FIXED_LIST: { // Fixed list is a fixed-sized blob. templateCopyValuesAsStringToPage(array, copyDesc); } break; - case VAR_LIST: { + case LogicalTypeID::VAR_LIST: { templateCopyValuesAsStringToPage(array, overflowFile, overflowCursor, copyDesc); } break; - case STRING: { + case LogicalTypeID::STRING: { templateCopyValuesAsStringToPage( array, overflowFile, overflowCursor); } break; - case STRUCT: { + case LogicalTypeID::STRUCT: { templateCopyValuesAsStringToPage(array, overflowFile, overflowCursor, copyDesc); } break; @@ -144,8 +143,8 @@ void InMemColumnChunk::setValueFromString( auto fixedListVal = TableCopyExecutor::getArrowFixedList(value, 1, length - 2, dataType, copyDescription); // TODO(Guodong): Keep value size as a class field. - memcpy(buffer.get() + pos * Types::getDataTypeSize(dataType), fixedListVal.get(), - Types::getDataTypeSize(dataType)); + memcpy(buffer.get() + pos * storage::StorageUtils::getDataTypeSize(dataType), + fixedListVal.get(), storage::StorageUtils::getDataTypeSize(dataType)); } // Var list @@ -190,9 +189,8 @@ void InMemColumnChunk::setValueFromString(const char* value, uint64_t length, uint64_t pos, InMemOverflowFile* overflowFile, PageByteCursor& overflowCursor, CopyDescription& copyDescription) { - auto structTypeInfo = reinterpret_cast(dataType.getExtraTypeInfo()); - auto structFieldTypes = structTypeInfo->getChildrenTypes(); - auto structFieldNames = structTypeInfo->getChildrenNames(); + auto structFieldTypes = StructType::getStructFieldTypes(&dataType); + auto structFieldNames = StructType::getStructFieldNames(&dataType); std::regex whiteSpacePattern{"\\s"}; // Removes the leading and trailing '{', '}'; auto structString = std::string(value, length).substr(1, length - 2); @@ -223,45 +221,45 @@ field_idx_t InMemStructColumnChunk::getStructFieldIdx( } void InMemStructColumnChunk::setValueToStructColumnField(InMemColumnChunk* chunk, offset_t pos, - field_idx_t structFieldIdx, const std::string& structFieldValue, const DataType& dataType) { - switch (dataType.typeID) { - case INT64: { + field_idx_t structFieldIdx, const std::string& structFieldValue, const LogicalType& dataType) { + switch (dataType.getLogicalTypeID()) { + case LogicalTypeID::INT64: { chunk->setValueFromString( structFieldValue.c_str(), structFieldValue.length(), pos); } break; - case INT32: { + case LogicalTypeID::INT32: { chunk->setValueFromString( structFieldValue.c_str(), structFieldValue.length(), pos); } break; - case INT16: { + case LogicalTypeID::INT16: { chunk->setValueFromString( structFieldValue.c_str(), structFieldValue.length(), pos); } break; - case DOUBLE: { + case LogicalTypeID::DOUBLE: { chunk->setValueFromString( structFieldValue.c_str(), structFieldValue.length(), pos); } break; - case FLOAT: { + case LogicalTypeID::FLOAT: { chunk->setValueFromString( structFieldValue.c_str(), structFieldValue.length(), pos); } break; - case BOOL: { + case LogicalTypeID::BOOL: { chunk->setValueFromString(structFieldValue.c_str(), structFieldValue.length(), pos); } break; - case DATE: { + case LogicalTypeID::DATE: { chunk->setValueFromString(structFieldValue.c_str(), structFieldValue.length(), pos); } break; - case TIMESTAMP: { + case LogicalTypeID::TIMESTAMP: { chunk->setValueFromString( structFieldValue.c_str(), structFieldValue.length(), pos); } break; - case INTERVAL: { + case LogicalTypeID::INTERVAL: { chunk->setValueFromString( structFieldValue.c_str(), structFieldValue.length(), pos); } break; default: { throw NotImplementedException{StringUtils::string_format( - "Unsupported data type: {}.", Types::dataTypeToString(dataType))}; + "Unsupported data type: {}.", LogicalTypeUtils::dataTypeToString(dataType))}; } } } diff --git a/src/storage/in_mem_storage_structure/in_mem_lists.cpp b/src/storage/in_mem_storage_structure/in_mem_lists.cpp index 3c285fc8ad..c9a5b7ba92 100644 --- a/src/storage/in_mem_storage_structure/in_mem_lists.cpp +++ b/src/storage/in_mem_storage_structure/in_mem_lists.cpp @@ -19,7 +19,7 @@ PageElementCursor InMemLists::calcPageElementCursor( } InMemLists::InMemLists( - std::string fName, DataType dataType, uint64_t numBytesForElement, uint64_t numNodes) + std::string fName, LogicalType dataType, uint64_t numBytesForElement, uint64_t numNodes) : fName{std::move(fName)}, dataType{std::move(dataType)}, numBytesForElement{ numBytesForElement} { listsMetadataBuilder = make_unique(this->fName); @@ -28,8 +28,8 @@ InMemLists::InMemLists( numChunks++; } listsMetadataBuilder->initChunkPageLists(numChunks); - inMemFile = make_unique( - this->fName, numBytesForElement, this->dataType.typeID != INTERNAL_ID); + inMemFile = make_unique(this->fName, numBytesForElement, + this->dataType.getLogicalTypeID() != LogicalTypeID::INTERNAL_ID); } void InMemLists::fillWithDefaultVal( @@ -102,7 +102,7 @@ void InMemLists::calculatePagesForList(uint64_t& numPages, uint64_t& offsetInPag void InMemLists::fillInMemListsWithStrValFunc(InMemLists* inMemLists, uint8_t* defaultVal, PageByteCursor& pageByteCursor, offset_t nodeOffset, uint64_t posInList, - const DataType& dataType) { + const LogicalType& dataType) { auto strVal = *(ku_string_t*)defaultVal; inMemLists->getInMemOverflowFile()->copyStringOverflow( pageByteCursor, reinterpret_cast(strVal.overflowPtr), &strVal); @@ -111,28 +111,28 @@ void InMemLists::fillInMemListsWithStrValFunc(InMemLists* inMemLists, uint8_t* d void InMemLists::fillInMemListsWithListValFunc(InMemLists* inMemLists, uint8_t* defaultVal, PageByteCursor& pageByteCursor, offset_t nodeOffset, uint64_t posInList, - const DataType& dataType) { + const LogicalType& dataType) { auto listVal = *reinterpret_cast(defaultVal); inMemLists->getInMemOverflowFile()->copyListOverflowToFile( - pageByteCursor, &listVal, dataType.getChildType()); + pageByteCursor, &listVal, VarListType::getChildType(&dataType)); inMemLists->setElement(nodeOffset, posInList, reinterpret_cast(&listVal)); } -fill_in_mem_lists_function_t InMemLists::getFillInMemListsFunc(const DataType& dataType) { - switch (dataType.typeID) { - case INT64: - case DOUBLE: - case BOOL: - case DATE: - case TIMESTAMP: - case INTERVAL: - case FIXED_LIST: { +fill_in_mem_lists_function_t InMemLists::getFillInMemListsFunc(const LogicalType& dataType) { + switch (dataType.getLogicalTypeID()) { + case LogicalTypeID::INT64: + case LogicalTypeID::DOUBLE: + case LogicalTypeID::BOOL: + case LogicalTypeID::DATE: + case LogicalTypeID::TIMESTAMP: + case LogicalTypeID::INTERVAL: + case LogicalTypeID::FIXED_LIST: { return fillInMemListsWithNonOverflowValFunc; } - case STRING: { + case LogicalTypeID::STRING: { return fillInMemListsWithStrValFunc; } - case VAR_LIST: { + case LogicalTypeID::VAR_LIST: { return fillInMemListsWithListValFunc; } default: { @@ -146,11 +146,12 @@ void InMemAdjLists::saveToFile() { InMemLists::saveToFile(); } -InMemListsWithOverflow::InMemListsWithOverflow(std::string fName, DataType dataType, +InMemListsWithOverflow::InMemListsWithOverflow(std::string fName, LogicalType dataType, uint64_t numNodes, std::shared_ptr listHeadersBuilder) - : InMemLists{ - std::move(fName), std::move(dataType), Types::getDataTypeSize(dataType), numNodes} { - assert(this->dataType.typeID == STRING || this->dataType.typeID == VAR_LIST); + : InMemLists{std::move(fName), std::move(dataType), + storage::StorageUtils::getDataTypeSize(dataType), numNodes} { + assert(this->dataType.getLogicalTypeID() == LogicalTypeID::STRING || + this->dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST); overflowInMemFile = make_unique(StorageUtils::getOverflowFileName(this->fName)); this->listHeadersBuilder = std::move(listHeadersBuilder); @@ -162,27 +163,28 @@ void InMemListsWithOverflow::saveToFile() { } std::unique_ptr InMemListsFactory::getInMemPropertyLists(const std::string& fName, - const DataType& dataType, uint64_t numNodes, + const LogicalType& dataType, uint64_t numNodes, std::shared_ptr listHeadersBuilder) { - switch (dataType.typeID) { - case INT64: - case INT32: - case INT16: - case DOUBLE: - case FLOAT: - case BOOL: - case DATE: - case TIMESTAMP: - case INTERVAL: - case FIXED_LIST: - return make_unique(fName, dataType, Types::getDataTypeSize(dataType), numNodes, + switch (dataType.getLogicalTypeID()) { + case LogicalTypeID::INT64: + case LogicalTypeID::INT32: + case LogicalTypeID::INT16: + case LogicalTypeID::DOUBLE: + case LogicalTypeID::FLOAT: + case LogicalTypeID::BOOL: + case LogicalTypeID::DATE: + case LogicalTypeID::TIMESTAMP: + case LogicalTypeID::INTERVAL: + case LogicalTypeID::FIXED_LIST: + return make_unique(fName, dataType, + storage::StorageUtils::getDataTypeSize(dataType), numNodes, std::move(listHeadersBuilder)); - case STRING: + case LogicalTypeID::STRING: return make_unique(fName, numNodes, std::move(listHeadersBuilder)); - case VAR_LIST: + case LogicalTypeID::VAR_LIST: return make_unique( fName, dataType, numNodes, std::move(listHeadersBuilder)); - case INTERNAL_ID: + case LogicalTypeID::INTERNAL_ID: return make_unique(fName, numNodes, std::move(listHeadersBuilder)); default: throw CopyException("Invalid type for property list creation."); diff --git a/src/storage/index/hash_index.cpp b/src/storage/index/hash_index.cpp index 33beeb4729..6d7e9ba975 100644 --- a/src/storage/index/hash_index.cpp +++ b/src/storage/index/hash_index.cpp @@ -48,11 +48,11 @@ template class TemplatedHashIndexLocalStorage; HashIndexLocalLookupState HashIndexLocalStorage::lookup(const uint8_t* key, offset_t& result) { std::shared_lock sLck{localStorageSharedMutex}; - if (keyDataType.typeID == INT64) { + if (keyDataType.getLogicalTypeID() == LogicalTypeID::INT64) { auto keyVal = *(int64_t*)key; return templatedLocalStorageForInt.lookup(keyVal, result); } else { - assert(keyDataType.typeID == STRING); + assert(keyDataType.getLogicalTypeID() == LogicalTypeID::STRING); auto keyVal = std::string((char*)key); return templatedLocalStorageForString.lookup(keyVal, result); } @@ -60,11 +60,11 @@ HashIndexLocalLookupState HashIndexLocalStorage::lookup(const uint8_t* key, offs void HashIndexLocalStorage::deleteKey(const uint8_t* key) { std::unique_lock xLck{localStorageSharedMutex}; - if (keyDataType.typeID == INT64) { + if (keyDataType.getLogicalTypeID() == LogicalTypeID::INT64) { auto keyVal = *(int64_t*)key; templatedLocalStorageForInt.deleteKey(keyVal); } else { - assert(keyDataType.typeID == STRING); + assert(keyDataType.getLogicalTypeID() == LogicalTypeID::STRING); auto keyVal = std::string((char*)key); templatedLocalStorageForString.deleteKey(keyVal); } @@ -72,11 +72,11 @@ void HashIndexLocalStorage::deleteKey(const uint8_t* key) { bool HashIndexLocalStorage::insert(const uint8_t* key, offset_t value) { std::unique_lock xLck{localStorageSharedMutex}; - if (keyDataType.typeID == INT64) { + if (keyDataType.getLogicalTypeID() == LogicalTypeID::INT64) { auto keyVal = *(int64_t*)key; return templatedLocalStorageForInt.insert(keyVal, value); } else { - assert(keyDataType.typeID == STRING); + assert(keyDataType.getLogicalTypeID() == LogicalTypeID::STRING); auto keyVal = std::string((char*)key); return templatedLocalStorageForString.insert(keyVal, value); } @@ -84,7 +84,7 @@ bool HashIndexLocalStorage::insert(const uint8_t* key, offset_t value) { void HashIndexLocalStorage::applyLocalChanges(const std::function& deleteOp, const std::function& insertOp) { - if (keyDataType.typeID == INT64) { + if (keyDataType.getLogicalTypeID() == LogicalTypeID::INT64) { for (auto& key : templatedLocalStorageForInt.localDeletions) { deleteOp((uint8_t*)&key); } @@ -92,7 +92,7 @@ void HashIndexLocalStorage::applyLocalChanges(const std::function HashIndex::HashIndex(const StorageStructureIDAndFName& storageStructureIDAndFName, - const DataType& keyDataType, BufferManager& bufferManager, WAL* wal) + const LogicalType& keyDataType, BufferManager& bufferManager, WAL* wal) : BaseHashIndex{keyDataType}, storageStructureIDAndFName{storageStructureIDAndFName}, bm{bufferManager}, wal{wal} { fileHandle = bufferManager.getBMFileHandle(storageStructureIDAndFName.fName, @@ -133,7 +133,7 @@ HashIndex::HashIndex(const StorageStructureIDAndFName& storageStructureIDAndF // Read indexHeader from the headerArray, which contains only one element. indexHeader = std::make_unique( headerArray->get(INDEX_HEADER_IDX_IN_ARRAY, TransactionType::READ_ONLY)); - assert(indexHeader->keyDataTypeID == keyDataType.typeID); + assert(indexHeader->keyDataTypeID == keyDataType.getLogicalTypeID()); pSlots = std::make_unique>>(*fileHandle, storageStructureIDAndFName.storageStructureID, P_SLOTS_HEADER_PAGE_IDX, &bm, wal); oSlots = std::make_unique>>(*fileHandle, @@ -143,7 +143,7 @@ HashIndex::HashIndex(const StorageStructureIDAndFName& storageStructureIDAndF keyInsertFunc = HashIndexUtils::initializeInsertFunc(indexHeader->keyDataTypeID); keyEqualsFunc = HashIndexUtils::initializeEqualsFunc(indexHeader->keyDataTypeID); localStorage = std::make_unique(keyDataType); - if (keyDataType.typeID == STRING) { + if (keyDataType.getLogicalTypeID() == LogicalTypeID::STRING) { diskOverflowFile = std::make_unique(storageStructureIDAndFName, &bm, wal); } } @@ -308,7 +308,7 @@ void HashIndex::rehashSlots(HashIndexHeader& header) { } auto key = slot.entries[entryPos].data; hash_t hash; - if (header.keyDataTypeID == STRING) { + if (header.keyDataTypeID == LogicalTypeID::STRING) { auto str = diskOverflowFile->readString(TransactionType::WRITE, *(ku_string_t*)key); hash = keyHashFunc((const uint8_t*)str.c_str()); } else { @@ -440,7 +440,7 @@ template class HashIndex; bool PrimaryKeyIndex::lookup( Transaction* trx, ValueVector* keyVector, uint64_t vectorPos, offset_t& result) { assert(!keyVector->isNull(vectorPos)); - if (keyDataTypeID == INT64) { + if (keyDataTypeID == LogicalTypeID::INT64) { auto key = keyVector->getValue(vectorPos); return hashIndexForInt64->lookupInternal( trx, reinterpret_cast(&key), result); @@ -453,7 +453,7 @@ bool PrimaryKeyIndex::lookup( void PrimaryKeyIndex::deleteKey(ValueVector* keyVector, uint64_t vectorPos) { assert(!keyVector->isNull(vectorPos)); - if (keyDataTypeID == INT64) { + if (keyDataTypeID == LogicalTypeID::INT64) { auto key = keyVector->getValue(vectorPos); hashIndexForInt64->deleteInternal(reinterpret_cast(&key)); } else { @@ -464,7 +464,7 @@ void PrimaryKeyIndex::deleteKey(ValueVector* keyVector, uint64_t vectorPos) { bool PrimaryKeyIndex::insert(ValueVector* keyVector, uint64_t vectorPos, offset_t value) { assert(!keyVector->isNull(vectorPos)); - if (keyDataTypeID == INT64) { + if (keyDataTypeID == LogicalTypeID::INT64) { auto key = keyVector->getValue(vectorPos); return hashIndexForInt64->insertInternal(reinterpret_cast(&key), value); } else { diff --git a/src/storage/index/hash_index_builder.cpp b/src/storage/index/hash_index_builder.cpp index 3bcd919bd7..bc6619c0c3 100644 --- a/src/storage/index/hash_index_builder.cpp +++ b/src/storage/index/hash_index_builder.cpp @@ -16,11 +16,11 @@ slot_id_t BaseHashIndex::getPrimarySlotIdForKey( } template -HashIndexBuilder::HashIndexBuilder(const std::string& fName, const DataType& keyDataType) +HashIndexBuilder::HashIndexBuilder(const std::string& fName, const LogicalType& keyDataType) : BaseHashIndex{keyDataType}, numEntries{0} { fileHandle = std::make_unique(fName, FileHandle::O_PERSISTENT_FILE_CREATE_NOT_EXISTS); - indexHeader = std::make_unique(keyDataType.typeID); + indexHeader = std::make_unique(keyDataType.getLogicalTypeID()); fileHandle->addNewPage(); // INDEX_HEADER_ARRAY_HEADER_PAGE fileHandle->addNewPage(); // P_SLOTS_HEADER_PAGE fileHandle->addNewPage(); // O_SLOTS_HEADER_PAGE @@ -32,7 +32,7 @@ HashIndexBuilder::HashIndexBuilder(const std::string& fName, const DataType& oSlots = std::make_unique>>( *fileHandle, O_SLOTS_HEADER_PAGE_IDX, 1 /* numElements */); allocatePSlots(2); - if (keyDataType.typeID == STRING) { + if (keyDataType.getLogicalTypeID() == LogicalTypeID::STRING) { inMemOverflowFile = std::make_unique(StorageUtils::getOverflowFileName(fName)); } @@ -178,7 +178,7 @@ void HashIndexBuilder::flush() { headerArray->saveToDisk(); pSlots->saveToDisk(); oSlots->saveToDisk(); - if (indexHeader->keyDataTypeID == STRING) { + if (indexHeader->keyDataTypeID == LogicalTypeID::STRING) { inMemOverflowFile->flush(); } } diff --git a/src/storage/index/hash_index_utils.cpp b/src/storage/index/hash_index_utils.cpp index a1df924ad6..680ba532db 100644 --- a/src/storage/index/hash_index_utils.cpp +++ b/src/storage/index/hash_index_utils.cpp @@ -6,12 +6,12 @@ using namespace kuzu::transaction; namespace kuzu { namespace storage { -in_mem_insert_function_t InMemHashIndexUtils::initializeInsertFunc(DataTypeID dataTypeID) { +in_mem_insert_function_t InMemHashIndexUtils::initializeInsertFunc(LogicalTypeID dataTypeID) { switch (dataTypeID) { - case INT64: { + case LogicalTypeID::INT64: { return insertFuncForInt64; } - case STRING: { + case LogicalTypeID::STRING: { return insertFuncForString; } default: { @@ -46,12 +46,12 @@ bool InMemHashIndexUtils::equalsFuncForString(const uint8_t* keyToLookup, const } } -in_mem_equals_function_t InMemHashIndexUtils::initializeEqualsFunc(DataTypeID dataTypeID) { +in_mem_equals_function_t InMemHashIndexUtils::initializeEqualsFunc(LogicalTypeID dataTypeID) { switch (dataTypeID) { - case INT64: { + case LogicalTypeID::INT64: { return equalsFuncForInt64; } - case STRING: { + case LogicalTypeID::STRING: { return equalsFuncForString; } default: { @@ -61,30 +61,32 @@ in_mem_equals_function_t InMemHashIndexUtils::initializeEqualsFunc(DataTypeID da } } -insert_function_t HashIndexUtils::initializeInsertFunc(DataTypeID dataTypeID) { +insert_function_t HashIndexUtils::initializeInsertFunc(LogicalTypeID dataTypeID) { switch (dataTypeID) { - case INT64: { + case LogicalTypeID::INT64: { return insertFuncForInt64; } - case STRING: { + case LogicalTypeID::STRING: { return insertFuncForString; } default: { - throw StorageException("Type " + Types::dataTypeToString(dataTypeID) + " not supported."); + throw StorageException( + "Type " + LogicalTypeUtils::dataTypeToString(dataTypeID) + " not supported."); } } } -hash_function_t HashIndexUtils::initializeHashFunc(DataTypeID dataTypeID) { +hash_function_t HashIndexUtils::initializeHashFunc(LogicalTypeID dataTypeID) { switch (dataTypeID) { - case INT64: { + case LogicalTypeID::INT64: { return hashFuncForInt64; } - case STRING: { + case LogicalTypeID::STRING: { return hashFuncForString; } default: { - throw StorageException("Type " + Types::dataTypeToString(dataTypeID) + " not supported."); + throw StorageException( + "Type " + LogicalTypeUtils::dataTypeToString(dataTypeID) + " not supported."); } } } @@ -107,12 +109,12 @@ bool HashIndexUtils::equalsFuncForString(TransactionType trxType, const uint8_t* return false; } -equals_function_t HashIndexUtils::initializeEqualsFunc(DataTypeID dataTypeID) { +equals_function_t HashIndexUtils::initializeEqualsFunc(LogicalTypeID dataTypeID) { switch (dataTypeID) { - case INT64: { + case LogicalTypeID::INT64: { return equalsFuncForInt64; } - case STRING: { + case LogicalTypeID::STRING: { return equalsFuncForString; } default: { diff --git a/src/storage/storage_structure/column.cpp b/src/storage/storage_structure/column.cpp index 8eb840a06f..fa6063e033 100644 --- a/src/storage/storage_structure/column.cpp +++ b/src/storage/storage_structure/column.cpp @@ -10,7 +10,7 @@ namespace kuzu { namespace storage { Column::Column(const kuzu::storage::StorageStructureIDAndFName& structureIDAndFName, - const common::DataType& dataType, size_t elementSize, + const common::LogicalType& dataType, size_t elementSize, kuzu::storage::BufferManager* bufferManager, kuzu::storage::WAL* wal, bool requireNullBits) : BaseColumnOrList{ structureIDAndFName, dataType, elementSize, bufferManager, false /*hasNULLBytes*/, wal} { @@ -300,10 +300,9 @@ void ListPropertyColumn::writeListToPage(uint8_t* frame, uint16_t posInFrame, } StructPropertyColumn::StructPropertyColumn(const StorageStructureIDAndFName& structureIDAndFName, - const common::DataType& dataType, BufferManager* bufferManager, WAL* wal) + const common::LogicalType& dataType, BufferManager* bufferManager, WAL* wal) : Column{dataType} { - auto structFields = - reinterpret_cast(dataType.getExtraTypeInfo())->getStructFields(); + auto structFields = common::StructType::getStructFields(&dataType); for (auto structField : structFields) { auto fieldStructureIDAndFName = structureIDAndFName; fieldStructureIDAndFName.fName = StorageUtils::appendStructFieldName( diff --git a/src/storage/storage_structure/disk_overflow_file.cpp b/src/storage/storage_structure/disk_overflow_file.cpp index b8afa21647..24094ea4d2 100644 --- a/src/storage/storage_structure/disk_overflow_file.cpp +++ b/src/storage/storage_structure/disk_overflow_file.cpp @@ -84,7 +84,8 @@ void DiskOverflowFile::readListToVector( *fileHandle, cursor.pageIdx, *wal, trxType); auto listEntry = common::ListVector::addList(vector, kuList.size); vector->setValue(pos, listEntry); - if (vector->dataType.getChildType()->typeID == common::VAR_LIST) { + if (VarListType::getChildType(&vector->dataType)->getLogicalTypeID() == + common::LogicalTypeID::VAR_LIST) { bufferManager->optimisticRead(*fileHandleToPin, pageIdxToPin, [&](uint8_t* frame) { for (auto i = 0u; i < kuList.size; i++) { readListToVector(trxType, ((ku_list_t*)(frame + cursor.offsetInPage))[i], @@ -97,7 +98,7 @@ void DiskOverflowFile::readListToVector( memcpy(bufferToCopy, frame + cursor.offsetInPage, dataVector->getNumBytesPerValue() * kuList.size); }); - if (dataVector->dataType.typeID == STRING) { + if (dataVector->dataType.getLogicalTypeID() == LogicalTypeID::STRING) { auto kuStrings = (ku_string_t*)bufferToCopy; OverflowPageCache overflowPageCache; for (auto i = 0u; i < kuList.size; i++) { @@ -127,13 +128,14 @@ std::string DiskOverflowFile::readString(TransactionType trxType, const ku_strin } std::vector> DiskOverflowFile::readList( - TransactionType trxType, const ku_list_t& listVal, const DataType& dataType) { + TransactionType trxType, const ku_list_t& listVal, const LogicalType& dataType) { PageByteCursor cursor; TypeUtils::decodeOverflowPtr(listVal.overflowPtr, cursor.pageIdx, cursor.offsetInPage); auto [fileHandleToPin, pageIdxToPin] = StorageStructureUtils::getFileHandleAndPhysicalPageIdxToPin( *fileHandle, cursor.pageIdx, *wal, trxType); - auto numBytesOfSingleValue = Types::getDataTypeSize(*dataType.getChildType()); + auto numBytesOfSingleValue = + storage::StorageUtils::getDataTypeSize(*VarListType::getChildType(&dataType)); auto numValuesInList = listVal.size; std::vector> retValues; bufferManager->optimisticRead(*fileHandleToPin, pageIdxToPin, [&](uint8_t* frame) -> void { @@ -144,26 +146,26 @@ std::vector> DiskOverflowFile::readList( } void DiskOverflowFile::readValuesInList(transaction::TransactionType trxType, - const common::DataType& dataType, std::vector>& retValues, + const common::LogicalType& dataType, std::vector>& retValues, uint32_t numBytesOfSingleValue, uint64_t numValuesInList, PageByteCursor& cursor, uint8_t* frame) { - if (dataType.getChildType()->typeID == STRING) { + auto childType = VarListType::getChildType(&dataType); + if (childType->getLogicalTypeID() == LogicalTypeID::STRING) { for (auto i = 0u; i < numValuesInList; i++) { auto kuListVal = *(ku_string_t*)(frame + cursor.offsetInPage); retValues.push_back(make_unique(readString(trxType, kuListVal))); cursor.offsetInPage += numBytesOfSingleValue; } - } else if (dataType.getChildType()->typeID == VAR_LIST) { + } else if (childType->getLogicalTypeID() == LogicalTypeID::VAR_LIST) { for (auto i = 0u; i < numValuesInList; i++) { auto kuListVal = *(ku_list_t*)(frame + cursor.offsetInPage); - retValues.push_back(make_unique( - *dataType.getChildType(), readList(trxType, kuListVal, *dataType.getChildType()))); + retValues.push_back( + make_unique(*childType, readList(trxType, kuListVal, *childType))); cursor.offsetInPage += numBytesOfSingleValue; } } else { for (auto i = 0u; i < numValuesInList; i++) { - retValues.push_back( - std::make_unique(*dataType.getChildType(), frame + cursor.offsetInPage)); + retValues.push_back(std::make_unique(*childType, frame + cursor.offsetInPage)); cursor.offsetInPage += numBytesOfSingleValue; } } @@ -229,8 +231,9 @@ void DiskOverflowFile::writeStringOverflowAndUpdateOverflowPtr( } void DiskOverflowFile::setListRecursiveIfNestedWithoutLock( - const ku_list_t& inMemSrcList, ku_list_t& diskDstList, const DataType& dataType) { - auto elementSize = Types::getDataTypeSize(*dataType.getChildType()); + const ku_list_t& inMemSrcList, ku_list_t& diskDstList, const LogicalType& dataType) { + auto childType = VarListType::getChildType(&dataType); + auto elementSize = storage::StorageUtils::getDataTypeSize(*childType); if (inMemSrcList.size * elementSize > BufferPoolConstants::PAGE_4KB_SIZE) { throw RuntimeException(StringUtils::string_format( "Maximum num bytes of a LIST is %d. Input list's num bytes is %d.", @@ -252,7 +255,7 @@ void DiskOverflowFile::setListRecursiveIfNestedWithoutLock( updatedPageInfoAndWALPageFrame.originalPageIdx, updatedPageInfoAndWALPageFrame.posInPage); StorageStructureUtils::unpinWALPageAndReleaseOriginalPageLock( updatedPageInfoAndWALPageFrame, *fileHandle, *bufferManager, *wal); - if (dataType.getChildType()->typeID == STRING) { + if (childType->getLogicalTypeID() == LogicalTypeID::STRING) { // Copy overflow for string elements in the list. auto dstListElements = reinterpret_cast( updatedPageInfoAndWALPageFrame.frame + updatedPageInfoAndWALPageFrame.posInPage); @@ -261,19 +264,19 @@ void DiskOverflowFile::setListRecursiveIfNestedWithoutLock( setStringOverflowWithoutLock( (const char*)kuString.overflowPtr, kuString.len, dstListElements[i]); } - } else if (dataType.getChildType()->typeID == VAR_LIST) { + } else if (childType->getLogicalTypeID() == LogicalTypeID::VAR_LIST) { // Recursively copy overflow for list elements in the list. auto dstListElements = reinterpret_cast( updatedPageInfoAndWALPageFrame.frame + updatedPageInfoAndWALPageFrame.posInPage); for (auto i = 0u; i < diskDstList.size; i++) { - setListRecursiveIfNestedWithoutLock((reinterpret_cast(listValues))[i], - dstListElements[i], *dataType.getChildType()); + setListRecursiveIfNestedWithoutLock( + (reinterpret_cast(listValues))[i], dstListElements[i], *childType); } } } void DiskOverflowFile::writeListOverflowAndUpdateOverflowPtr( - const ku_list_t& listToWriteFrom, ku_list_t& listToWriteTo, const DataType& valueType) { + const ku_list_t& listToWriteFrom, ku_list_t& listToWriteTo, const LogicalType& valueType) { lock_t lck{mtx}; logNewOverflowFileNextBytePosRecordIfNecessaryWithoutLock(); setListRecursiveIfNestedWithoutLock(listToWriteFrom, listToWriteTo, valueType); diff --git a/src/storage/storage_structure/in_mem_file.cpp b/src/storage/storage_structure/in_mem_file.cpp index 94e081f3a9..9944f37d33 100644 --- a/src/storage/storage_structure/in_mem_file.cpp +++ b/src/storage/storage_structure/in_mem_file.cpp @@ -92,17 +92,17 @@ void InMemOverflowFile::copyFixedSizedValuesInList( } } -template +template void InMemOverflowFile::copyVarSizedValuesInList(ku_list_t& resultKUList, const Value& listVal, PageByteCursor& overflowCursor, uint64_t numBytesOfListElement) { auto overflowPageIdx = overflowCursor.pageIdx; auto overflowPageOffset = overflowCursor.offsetInPage; // Reserve space for ku_list or ku_string objects. overflowCursor.offsetInPage += (resultKUList.size * numBytesOfListElement); - if constexpr (DT == STRING) { + if constexpr (DT == LogicalTypeID::STRING) { std::vector kuStrings(listVal.nestedTypeVal.size()); for (auto i = 0u; i < listVal.nestedTypeVal.size(); i++) { - assert(listVal.nestedTypeVal[i]->dataType.typeID == STRING); + assert(listVal.nestedTypeVal[i]->dataType.getLogicalTypeID() == LogicalTypeID::STRING); auto strVal = listVal.nestedTypeVal[i]->strVal; kuStrings[i] = copyString(strVal.c_str(), strVal.length(), overflowCursor); } @@ -113,10 +113,11 @@ void InMemOverflowFile::copyVarSizedValuesInList(ku_list_t& resultKUList, const numBytesOfListElement); } } else { - assert(DT == VAR_LIST); + assert(DT == LogicalTypeID::VAR_LIST); std::vector kuLists(listVal.nestedTypeVal.size()); for (auto i = 0u; i < listVal.nestedTypeVal.size(); i++) { - assert(listVal.nestedTypeVal[i]->dataType.typeID == VAR_LIST); + assert( + listVal.nestedTypeVal[i]->dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST); kuLists[i] = copyList(*listVal.nestedTypeVal[i], overflowCursor); } std::shared_lock lck(lock); @@ -129,9 +130,10 @@ void InMemOverflowFile::copyVarSizedValuesInList(ku_list_t& resultKUList, const } ku_list_t InMemOverflowFile::copyList(const Value& listValue, PageByteCursor& overflowCursor) { - assert(listValue.dataType.typeID == VAR_LIST); + assert(listValue.dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST); ku_list_t resultKUList; - auto numBytesOfListElement = Types::getDataTypeSize(*listValue.dataType.getChildType()); + auto numBytesOfListElement = storage::StorageUtils::getDataTypeSize( + *common::VarListType::getChildType(&listValue.dataType)); resultKUList.size = listValue.nestedTypeVal.size(); // Allocate a new page if necessary. if (overflowCursor.offsetInPage + (resultKUList.size * numBytesOfListElement) >= @@ -142,21 +144,21 @@ ku_list_t InMemOverflowFile::copyList(const Value& listValue, PageByteCursor& ov } TypeUtils::encodeOverflowPtr( resultKUList.overflowPtr, overflowCursor.pageIdx, overflowCursor.offsetInPage); - switch (listValue.dataType.getChildType()->typeID) { - case INT64: - case DOUBLE: - case BOOL: - case DATE: - case TIMESTAMP: - case INTERVAL: { + switch (VarListType::getChildType(&listValue.dataType)->getLogicalTypeID()) { + case LogicalTypeID::INT64: + case LogicalTypeID::DOUBLE: + case LogicalTypeID::BOOL: + case LogicalTypeID::DATE: + case LogicalTypeID::TIMESTAMP: + case LogicalTypeID::INTERVAL: { copyFixedSizedValuesInList(listValue, overflowCursor, numBytesOfListElement); } break; - case STRING: { - copyVarSizedValuesInList( + case LogicalTypeID::STRING: { + copyVarSizedValuesInList( resultKUList, listValue, overflowCursor, numBytesOfListElement); } break; - case VAR_LIST: { - copyVarSizedValuesInList( + case LogicalTypeID::VAR_LIST: { + copyVarSizedValuesInList( resultKUList, listValue, overflowCursor, numBytesOfListElement); } break; default: { @@ -184,8 +186,8 @@ void InMemOverflowFile::copyStringOverflow( void InMemOverflowFile::copyListOverflowFromFile(InMemOverflowFile* srcInMemOverflowFile, const PageByteCursor& srcOverflowCursor, PageByteCursor& dstOverflowCursor, - ku_list_t* dstKUList, DataType* listChildDataType) { - auto numBytesOfListElement = Types::getDataTypeSize(*listChildDataType); + ku_list_t* dstKUList, LogicalType* listChildDataType) { + auto numBytesOfListElement = storage::StorageUtils::getDataTypeSize(*listChildDataType); // Allocate a new page if necessary. if (dstOverflowCursor.offsetInPage + (dstKUList->size * numBytesOfListElement) >= BufferPoolConstants::PAGE_4KB_SIZE || @@ -199,16 +201,16 @@ void InMemOverflowFile::copyListOverflowFromFile(InMemOverflowFile* srcInMemOver srcOverflowCursor.offsetInPage; auto offsetToCopyInto = dstOverflowCursor.offsetInPage; dstOverflowCursor.offsetInPage += dstKUList->size * numBytesOfListElement; - if (listChildDataType->typeID == VAR_LIST) { + if (listChildDataType->getLogicalTypeID() == LogicalTypeID::VAR_LIST) { auto elementsInList = (ku_list_t*)dataToCopyFrom; for (auto i = 0u; i < dstKUList->size; i++) { PageByteCursor elementCursor; TypeUtils::decodeOverflowPtr( elementsInList[i].overflowPtr, elementCursor.pageIdx, elementCursor.offsetInPage); copyListOverflowFromFile(srcInMemOverflowFile, elementCursor, dstOverflowCursor, - &elementsInList[i], listChildDataType->getChildType()); + &elementsInList[i], VarListType::getChildType(listChildDataType)); } - } else if (listChildDataType->typeID == STRING) { + } else if (listChildDataType->getLogicalTypeID() == LogicalTypeID::STRING) { auto elementsInList = (ku_string_t*)dataToCopyFrom; for (auto i = 0u; i < dstKUList->size; i++) { if (elementsInList[i].len > ku_string_t::SHORT_STR_LENGTH) { @@ -228,8 +230,8 @@ void InMemOverflowFile::copyListOverflowFromFile(InMemOverflowFile* srcInMemOver } void InMemOverflowFile::copyListOverflowToFile( - PageByteCursor& pageByteCursor, ku_list_t* srcKUList, DataType* childDataType) { - auto numBytesOfListElement = Types::getDataTypeSize(*childDataType); + PageByteCursor& pageByteCursor, ku_list_t* srcKUList, LogicalType* childDataType) { + auto numBytesOfListElement = storage::StorageUtils::getDataTypeSize(*childDataType); // Allocate a new page if necessary. if (pageByteCursor.offsetInPage + (srcKUList->size * numBytesOfListElement) >= BufferPoolConstants::PAGE_4KB_SIZE || @@ -272,14 +274,15 @@ std::string InMemOverflowFile::readString(ku_string_t* strInInMemOvfFile) { } void InMemOverflowFile::resetElementsOverflowPtrIfNecessary(PageByteCursor& pageByteCursor, - DataType* elementType, uint64_t numElementsToReset, uint8_t* elementsToReset) { - if (elementType->typeID == VAR_LIST) { + LogicalType* elementType, uint64_t numElementsToReset, uint8_t* elementsToReset) { + if (elementType->getLogicalTypeID() == LogicalTypeID::VAR_LIST) { auto kuListPtr = reinterpret_cast(elementsToReset); for (auto i = 0u; i < numElementsToReset; i++) { - copyListOverflowToFile(pageByteCursor, kuListPtr, elementType->getChildType()); + copyListOverflowToFile( + pageByteCursor, kuListPtr, VarListType::getChildType(elementType)); kuListPtr++; } - } else if (elementType->typeID == STRING) { + } else if (elementType->getLogicalTypeID() == LogicalTypeID::STRING) { auto kuStrPtr = reinterpret_cast(elementsToReset); for (auto i = 0u; i < numElementsToReset; i++) { if (kuStrPtr->len > ku_string_t::SHORT_STR_LENGTH) { diff --git a/src/storage/storage_structure/lists/lists_update_store.cpp b/src/storage/storage_structure/lists/lists_update_store.cpp index b193f8a1b4..f78647bc51 100644 --- a/src/storage/storage_structure/lists/lists_update_store.cpp +++ b/src/storage/storage_structure/lists/lists_update_store.cpp @@ -78,7 +78,8 @@ bool ListsUpdatesStore::hasUpdates() const { // Note: This function also resets the overflowptr of each string in inMemList if necessary. void ListsUpdatesStore::readInsertedRelsToList(ListFileID& listFileID, std::vector tupleIdxes, InMemList& inMemList, - uint64_t numElementsInPersistentStore, DiskOverflowFile* diskOverflowFile, DataType dataType) { + uint64_t numElementsInPersistentStore, DiskOverflowFile* diskOverflowFile, + LogicalType dataType) { ftOfInsertedRels->copyToInMemList(getColIdxInFT(listFileID), tupleIdxes, inMemList.getListData(), inMemList.nullMask.get(), numElementsInPersistentStore, diskOverflowFile, dataType); @@ -254,7 +255,7 @@ void ListsUpdatesStore::readUpdatesToPropertyVectorIfExists(ListFileID& listFile void ListsUpdatesStore::readPropertyUpdateToInMemList(ListFileID& listFileID, ft_tuple_idx_t ftTupleIdx, InMemList& inMemList, uint64_t posToWriteToInMemList, - const DataType& dataType, DiskOverflowFile* overflowFileOfInMemList) { + const LogicalType& dataType, DiskOverflowFile* overflowFileOfInMemList) { assert(listFileID.listType == ListType::REL_PROPERTY_LISTS); auto propertyID = listFileID.relPropertyListID.propertyID; auto tupleIdxesToRead = std::vector{ftTupleIdx}; @@ -291,7 +292,7 @@ void ListsUpdatesStore::initInsertedRelsAndListsUpdates() { auto numBytesForProperty = relProperty.propertyID == RelTableSchema::INTERNAL_REL_ID_PROPERTY_ID ? sizeof(offset_t) : - Types::getDataTypeSize(relProperty.dataType); + storage::StorageUtils::getDataTypeSize(relProperty.dataType); propertyIDToColIdxMap.emplace( relProperty.propertyID, factorizedTableSchema->getNumColumns()); factorizedTableSchema->appendColumn(std::make_unique( diff --git a/src/storage/storage_structure/storage_structure.cpp b/src/storage/storage_structure/storage_structure.cpp index 9e840ab316..67b5054ee5 100644 --- a/src/storage/storage_structure/storage_structure.cpp +++ b/src/storage/storage_structure/storage_structure.cpp @@ -36,7 +36,7 @@ WALPageIdxPosInPageAndFrame StorageStructure::createWALVersionOfPageIfNecessaryF } BaseColumnOrList::BaseColumnOrList(const StorageStructureIDAndFName& storageStructureIDAndFName, - DataType dataType, const size_t& elementSize, BufferManager* bufferManager, + LogicalType dataType, const size_t& elementSize, BufferManager* bufferManager, bool hasInlineNullBytes, WAL* wal) : StorageStructure(storageStructureIDAndFName, bufferManager, wal), dataType{std::move(dataType)}, elementSize{elementSize} { diff --git a/src/storage/storage_utils.cpp b/src/storage/storage_utils.cpp index 99e0061e8d..9244935824 100644 --- a/src/storage/storage_utils.cpp +++ b/src/storage/storage_utils.cpp @@ -1,5 +1,6 @@ #include "storage/storage_utils.h" +#include "common/null_buffer.h" #include "common/string_utils.h" #include "storage/in_mem_storage_structure/in_mem_column.h" #include "storage/in_mem_storage_structure/in_mem_lists.h" @@ -253,6 +254,33 @@ void StorageUtils::createFileForRelListsPropertyWithDefaultVal(table_id_t relTab inMemList->saveToFile(); } +uint32_t StorageUtils::getDataTypeSize(const common::LogicalType& type) { + switch (type.getLogicalTypeID()) { + case common::LogicalTypeID::STRING: { + return sizeof(common::ku_string_t); + } + case common::LogicalTypeID::FIXED_LIST: { + return getDataTypeSize(*common::FixedListType::getChildType(&type)) * + common::FixedListType::getNumElementsInList(&type); + } + case common::LogicalTypeID::VAR_LIST: { + return sizeof(common::ku_list_t); + } + case common::LogicalTypeID::STRUCT: { + uint32_t size = 0; + auto structFieldsTypes = common::StructType::getStructFieldTypes(&type); + for (auto structFieldType : structFieldsTypes) { + size += getDataTypeSize(*structFieldType); + } + size += NullBuffer::getNumBytesForNullValues(structFieldsTypes.size()); + return size; + } + default: { + return common::LogicalTypeUtils::getFixedTypeSize(type.getPhysicalType()); + } + } +} + std::string StorageUtils::appendSuffixOrInsertBeforeWALSuffix( std::string fileName, std::string suffix) { auto pos = fileName.find(StorageConstants::WAL_FILE_SUFFIX); diff --git a/src/storage/store/node_table.cpp b/src/storage/store/node_table.cpp index 138db5079a..d9427fa73e 100644 --- a/src/storage/store/node_table.cpp +++ b/src/storage/store/node_table.cpp @@ -19,7 +19,7 @@ void NodeTable::initializeData(NodeTableSchema* nodeTableSchema) { StorageUtils::getNodePropertyColumnStructureIDAndFName(wal->getDirectory(), property), property.dataType, &bufferManager, wal); } - if (nodeTableSchema->getPrimaryKey().dataType.typeID != SERIAL) { + if (nodeTableSchema->getPrimaryKey().dataType.getLogicalTypeID() != LogicalTypeID::SERIAL) { pkIndex = std::make_unique( StorageUtils::getNodeIndexIDAndFName(wal->getDirectory(), tableID), nodeTableSchema->getPrimaryKey().dataType, bufferManager, wal); @@ -47,7 +47,7 @@ offset_t NodeTable::addNodeAndResetProperties(ValueVector* primaryKeyVector) { } // TODO(Guodong): Handle SERIAL. if (!pkIndex->insert(primaryKeyVector, pkValPos, nodeOffset)) { - std::string pkStr = primaryKeyVector->dataType.typeID == INT64 ? + std::string pkStr = primaryKeyVector->dataType.getLogicalTypeID() == LogicalTypeID::INT64 ? std::to_string(primaryKeyVector->getValue(pkValPos)) : primaryKeyVector->getValue(pkValPos).getAsString(); throw RuntimeException(Exception::getExistedPKExceptionMsg(pkStr)); diff --git a/src/storage/store/rel_table.cpp b/src/storage/store/rel_table.cpp index 9876a99d6d..680abfaa6c 100644 --- a/src/storage/store/rel_table.cpp +++ b/src/storage/store/rel_table.cpp @@ -52,7 +52,7 @@ void DirectedRelTableData::initializeData(RelTableSchema* tableSchema, WAL* wal) void DirectedRelTableData::initializeColumns(RelTableSchema* tableSchema, WAL* wal) { adjColumn = ColumnFactory::getColumn(StorageUtils::getAdjColumnStructureIDAndFName( wal->getDirectory(), tableSchema->tableID, direction), - DataType(INTERNAL_ID), &bufferManager, wal); + LogicalType(LogicalTypeID::INTERNAL_ID), &bufferManager, wal); for (auto& property : tableSchema->properties) { propertyColumns[property.propertyID] = ColumnFactory::getColumn( StorageUtils::getRelPropertyColumnStructureIDAndFName( @@ -128,7 +128,7 @@ void DirectedRelTableData::scanLists(transaction::Transaction* transaction, // Fill nbr table IDs for the vector scanned from an adj column. void DirectedRelTableData::fillNbrTableIDs(common::ValueVector* vector) { - assert(vector->dataType.typeID == INTERNAL_ID); + assert(vector->dataType.getLogicalTypeID() == LogicalTypeID::INTERNAL_ID); auto nodeIDs = (internalID_t*)vector->getData(); for (auto i = 0u; i < vector->state->selVector->selectedSize; i++) { auto pos = vector->state->selVector->selectedPositions[i]; diff --git a/src/storage/wal_replayer_utils.cpp b/src/storage/wal_replayer_utils.cpp index 5e2f9c4ec6..e59a109bd0 100644 --- a/src/storage/wal_replayer_utils.cpp +++ b/src/storage/wal_replayer_utils.cpp @@ -39,15 +39,15 @@ void WALReplayerUtils::createEmptyDBFilesForNewRelTable(RelTableSchema* relTable void WALReplayerUtils::createEmptyDBFilesForNewNodeTable( NodeTableSchema* nodeTableSchema, const std::string& directory) { for (auto& property : nodeTableSchema->properties) { - if (property.dataType.typeID == SERIAL) { + if (property.dataType.getLogicalTypeID() == LogicalTypeID::SERIAL) { continue; } auto fName = StorageUtils::getNodePropertyColumnFName( directory, nodeTableSchema->tableID, property.propertyID, DBFileType::ORIGINAL); std::make_unique(fName, property.dataType)->saveToFile(); } - switch (nodeTableSchema->getPrimaryKey().dataType.typeID) { - case INT64: { + switch (nodeTableSchema->getPrimaryKey().dataType.getLogicalTypeID()) { + case LogicalTypeID::INT64: { auto pkIndex = make_unique>( StorageUtils::getNodeIndexFName( directory, nodeTableSchema->tableID, DBFileType::ORIGINAL), @@ -55,7 +55,7 @@ void WALReplayerUtils::createEmptyDBFilesForNewNodeTable( pkIndex->bulkReserve(0 /* numNodes */); pkIndex->flush(); } break; - case STRING: { + case LogicalTypeID::STRING: { auto pkIndex = make_unique>( StorageUtils::getNodeIndexFName( directory, nodeTableSchema->tableID, DBFileType::ORIGINAL), @@ -63,7 +63,7 @@ void WALReplayerUtils::createEmptyDBFilesForNewNodeTable( pkIndex->bulkReserve(0 /* numNodes */); pkIndex->flush(); } break; - case SERIAL: { + case LogicalTypeID::SERIAL: { // DO NOTHING. } break; default: { @@ -130,7 +130,7 @@ void WALReplayerUtils::createEmptyDBFilesForColumns( maxNodeOffsetsPerTable.at(boundTableID) + 1; make_unique(StorageUtils::getAdjColumnFName(directory, relTableSchema->tableID, relDirection, DBFileType::ORIGINAL), - DataType(INTERNAL_ID)) + LogicalType(LogicalTypeID::INTERNAL_ID)) ->saveToFile(); createEmptyDBFilesForRelProperties( relTableSchema, directory, relDirection, numNodes, true /* isForRelPropertyColumn */); @@ -194,10 +194,8 @@ void WALReplayerUtils::fileOperationOnNodeFiles(NodeTableSchema* nodeTableSchema const std::string& directory, std::function columnFileOperation, std::function listFileOperation) { for (auto& property : nodeTableSchema->properties) { - if (property.dataType.typeID == common::STRUCT) { - auto structFields = - reinterpret_cast(property.dataType.getExtraTypeInfo()) - ->getStructFields(); + if (property.dataType.getLogicalTypeID() == common::LogicalTypeID::STRUCT) { + auto structFields = StructType::getStructFields(&property.dataType); auto structColumnFName = StorageUtils::getNodePropertyColumnFName( directory, nodeTableSchema->tableID, property.propertyID, DBFileType::ORIGINAL); for (auto& structField : structFields) { diff --git a/test/c_api/data_type_test.cpp b/test/c_api/data_type_test.cpp index fe2e33f515..a5c3054f61 100644 --- a/test/c_api/data_type_test.cpp +++ b/test/c_api/data_type_test.cpp @@ -12,22 +12,21 @@ class CApiDataTypeTest : public CApiTest { TEST_F(CApiDataTypeTest, Create) { auto dataType = kuzu_data_type_create(kuzu_data_type_id::INT64, nullptr, 0); ASSERT_NE(dataType, nullptr); - auto dataTypeCpp = (DataType*)dataType->_data_type; - ASSERT_EQ(dataTypeCpp->getTypeID(), DataTypeID::INT64); + auto dataTypeCpp = (LogicalType*)dataType->_data_type; + ASSERT_EQ(dataTypeCpp->getLogicalTypeID(), LogicalTypeID::INT64); auto dataType2 = kuzu_data_type_create(kuzu_data_type_id::VAR_LIST, dataType, 0); ASSERT_NE(dataType2, nullptr); - auto dataTypeCpp2 = (DataType*)dataType2->_data_type; - ASSERT_EQ(dataTypeCpp2->getTypeID(), DataTypeID::VAR_LIST); - ASSERT_EQ(dataTypeCpp2->getChildType()->getTypeID(), DataTypeID::INT64); + auto dataTypeCpp2 = (LogicalType*)dataType2->_data_type; + ASSERT_EQ(dataTypeCpp2->getLogicalTypeID(), LogicalTypeID::VAR_LIST); + // ASSERT_EQ(dataTypeCpp2->getChildType()->getLogicalTypeID(), LogicalTypeID::INT64); auto dataType3 = kuzu_data_type_create(kuzu_data_type_id::FIXED_LIST, dataType, 100); ASSERT_NE(dataType3, nullptr); - auto dataTypeCpp3 = (DataType*)dataType3->_data_type; - ASSERT_EQ(dataTypeCpp3->getTypeID(), DataTypeID::FIXED_LIST); - ASSERT_EQ(dataTypeCpp3->getChildType()->getTypeID(), DataTypeID::INT64); - auto extraInfo = (FixedListTypeInfo*)dataTypeCpp3->getExtraTypeInfo(); - ASSERT_EQ(extraInfo->getFixedNumElementsInList(), 100); + auto dataTypeCpp3 = (LogicalType*)dataType3->_data_type; + ASSERT_EQ(dataTypeCpp3->getLogicalTypeID(), LogicalTypeID::FIXED_LIST); + // ASSERT_EQ(dataTypeCpp3->getChildType()->getLogicalTypeID(), LogicalTypeID::INT64); + ASSERT_EQ(FixedListType::getNumElementsInList(dataTypeCpp3), 100); // Since child type is copied, we should be able to destroy the original type without an error. kuzu_data_type_destroy(dataType); @@ -40,24 +39,24 @@ TEST_F(CApiDataTypeTest, Clone) { ASSERT_NE(dataType, nullptr); auto dataTypeClone = kuzu_data_type_clone(dataType); ASSERT_NE(dataTypeClone, nullptr); - auto dataTypeCpp = (DataType*)dataType->_data_type; - auto dataTypeCloneCpp = (DataType*)dataTypeClone->_data_type; + auto dataTypeCpp = (LogicalType*)dataType->_data_type; + auto dataTypeCloneCpp = (LogicalType*)dataTypeClone->_data_type; ASSERT_TRUE(*dataTypeCpp == *dataTypeCloneCpp); auto dataType2 = kuzu_data_type_create(kuzu_data_type_id::VAR_LIST, dataType, 0); ASSERT_NE(dataType2, nullptr); auto dataTypeClone2 = kuzu_data_type_clone(dataType2); ASSERT_NE(dataTypeClone2, nullptr); - auto dataTypeCpp2 = (DataType*)dataType2->_data_type; - auto dataTypeCloneCpp2 = (DataType*)dataTypeClone2->_data_type; + auto dataTypeCpp2 = (LogicalType*)dataType2->_data_type; + auto dataTypeCloneCpp2 = (LogicalType*)dataTypeClone2->_data_type; ASSERT_TRUE(*dataTypeCpp2 == *dataTypeCloneCpp2); auto dataType3 = kuzu_data_type_create(kuzu_data_type_id::FIXED_LIST, dataType, 100); ASSERT_NE(dataType3, nullptr); auto dataTypeClone3 = kuzu_data_type_clone(dataType3); ASSERT_NE(dataTypeClone3, nullptr); - auto dataTypeCpp3 = (DataType*)dataType3->_data_type; - auto dataTypeCloneCpp3 = (DataType*)dataTypeClone3->_data_type; + auto dataTypeCpp3 = (LogicalType*)dataType3->_data_type; + auto dataTypeCloneCpp3 = (LogicalType*)dataTypeClone3->_data_type; ASSERT_TRUE(*dataTypeCpp3 == *dataTypeCloneCpp3); kuzu_data_type_destroy(dataType); @@ -117,30 +116,32 @@ TEST_F(CApiDataTypeTest, GetID) { kuzu_data_type_destroy(dataType3); } -TEST_F(CApiDataTypeTest, GetChildType) { - auto dataType = kuzu_data_type_create(kuzu_data_type_id::INT64, nullptr, 0); - ASSERT_NE(dataType, nullptr); - ASSERT_EQ(kuzu_data_type_get_child_type(dataType), nullptr); - - auto dataType2 = kuzu_data_type_create(kuzu_data_type_id::VAR_LIST, dataType, 0); - ASSERT_NE(dataType2, nullptr); - auto childType2 = kuzu_data_type_get_child_type(dataType2); - ASSERT_NE(childType2, nullptr); - ASSERT_EQ(kuzu_data_type_get_id(childType2), kuzu_data_type_id::INT64); - kuzu_data_type_destroy(childType2); - kuzu_data_type_destroy(dataType2); - - auto dataType3 = kuzu_data_type_create(kuzu_data_type_id::FIXED_LIST, dataType, 100); - ASSERT_NE(dataType3, nullptr); - auto childType3 = kuzu_data_type_get_child_type(dataType3); - kuzu_data_type_destroy(dataType3); - // Destroying dataType3 should not destroy childType3. - ASSERT_NE(childType3, nullptr); - ASSERT_EQ(kuzu_data_type_get_id(childType3), kuzu_data_type_id::INT64); - kuzu_data_type_destroy(childType3); - - kuzu_data_type_destroy(dataType); -} +// TODO(Chang): The getChildType interface has been removed from the C++ DataType class. +// Consider adding the StructType/ListType helper to C binding. +// TEST_F(CApiDataTypeTest, GetChildType) { +// auto dataType = kuzu_data_type_create(kuzu_data_type_id::INT64, nullptr, 0); +// ASSERT_NE(dataType, nullptr); +// ASSERT_EQ(kuzu_data_type_get_child_type(dataType), nullptr); +// +// auto dataType2 = kuzu_data_type_create(kuzu_data_type_id::VAR_LIST, dataType, 0); +// ASSERT_NE(dataType2, nullptr); +// auto childType2 = kuzu_data_type_get_child_type(dataType2); +// ASSERT_NE(childType2, nullptr); +// ASSERT_EQ(kuzu_data_type_get_id(childType2), kuzu_data_type_id::INT64); +// kuzu_data_type_destroy(childType2); +// kuzu_data_type_destroy(dataType2); +// +// auto dataType3 = kuzu_data_type_create(kuzu_data_type_id::FIXED_LIST, dataType, 100); +// ASSERT_NE(dataType3, nullptr); +// auto childType3 = kuzu_data_type_get_child_type(dataType3); +// kuzu_data_type_destroy(dataType3); +// // Destroying dataType3 should not destroy childType3. +// ASSERT_NE(childType3, nullptr); +// ASSERT_EQ(kuzu_data_type_get_id(childType3), kuzu_data_type_id::INT64); +// kuzu_data_type_destroy(childType3); +// +// kuzu_data_type_destroy(dataType); +//} TEST_F(CApiDataTypeTest, GetFixedNumElementsInList) { auto dataType = kuzu_data_type_create(kuzu_data_type_id::INT64, nullptr, 0); diff --git a/test/c_api/flat_tuple_test.cpp b/test/c_api/flat_tuple_test.cpp index 4939841ec2..4f133ced08 100644 --- a/test/c_api/flat_tuple_test.cpp +++ b/test/c_api/flat_tuple_test.cpp @@ -20,21 +20,21 @@ TEST_F(CApiFlatTupleTest, GetValue) { ASSERT_NE(value, nullptr); auto valueCpp = static_cast(value->_value); ASSERT_NE(valueCpp, nullptr); - ASSERT_EQ(valueCpp->getDataType().getTypeID(), DataTypeID::STRING); + ASSERT_EQ(valueCpp->getDataType().getLogicalTypeID(), LogicalTypeID::STRING); ASSERT_EQ(valueCpp->getValue(), "Alice"); kuzu_value_destroy(value); value = kuzu_flat_tuple_get_value(flatTuple, 1); ASSERT_NE(value, nullptr); valueCpp = static_cast(value->_value); ASSERT_NE(valueCpp, nullptr); - ASSERT_EQ(valueCpp->getDataType().getTypeID(), DataTypeID::INT64); + ASSERT_EQ(valueCpp->getDataType().getLogicalTypeID(), LogicalTypeID::INT64); ASSERT_EQ(valueCpp->getValue(), 35); kuzu_value_destroy(value); value = kuzu_flat_tuple_get_value(flatTuple, 2); ASSERT_NE(value, nullptr); valueCpp = static_cast(value->_value); ASSERT_NE(valueCpp, nullptr); - ASSERT_EQ(valueCpp->getDataType().getTypeID(), DataTypeID::FLOAT); + ASSERT_EQ(valueCpp->getDataType().getLogicalTypeID(), LogicalTypeID::FLOAT); ASSERT_FLOAT_EQ(valueCpp->getValue(), 1.731); kuzu_value_destroy(value); value = kuzu_flat_tuple_get_value(flatTuple, 222); diff --git a/test/c_api/query_result_test.cpp b/test/c_api/query_result_test.cpp index 83f705a4ac..eb6c20ce51 100644 --- a/test/c_api/query_result_test.cpp +++ b/test/c_api/query_result_test.cpp @@ -62,16 +62,16 @@ TEST_F(CApiQueryResultTest, GetColumnDataType) { kuzu_connection_query(connection, "MATCH (a:person) RETURN a.fName, a.age, a.height"); ASSERT_TRUE(kuzu_query_result_is_success(result)); auto type = kuzu_query_result_get_column_data_type(result, 0); - auto typeCpp = (DataType*)(type->_data_type); - ASSERT_EQ(typeCpp->getTypeID(), DataTypeID::STRING); + auto typeCpp = (LogicalType*)(type->_data_type); + ASSERT_EQ(typeCpp->getLogicalTypeID(), LogicalTypeID::STRING); kuzu_data_type_destroy(type); type = kuzu_query_result_get_column_data_type(result, 1); - typeCpp = (DataType*)(type->_data_type); - ASSERT_EQ(typeCpp->getTypeID(), DataTypeID::INT64); + typeCpp = (LogicalType*)(type->_data_type); + ASSERT_EQ(typeCpp->getLogicalTypeID(), LogicalTypeID::INT64); kuzu_data_type_destroy(type); type = kuzu_query_result_get_column_data_type(result, 2); - typeCpp = (DataType*)(type->_data_type); - ASSERT_EQ(typeCpp->getTypeID(), DataTypeID::FLOAT); + typeCpp = (LogicalType*)(type->_data_type); + ASSERT_EQ(typeCpp->getLogicalTypeID(), LogicalTypeID::FLOAT); kuzu_data_type_destroy(type); type = kuzu_query_result_get_column_data_type(result, 222); ASSERT_EQ(type, nullptr); diff --git a/test/c_api/value_test.cpp b/test/c_api/value_test.cpp index b327d17c6f..560104d3e6 100644 --- a/test/c_api/value_test.cpp +++ b/test/c_api/value_test.cpp @@ -13,7 +13,7 @@ TEST_F(CApiValueTest, CreateNull) { kuzu_value* value = kuzu_value_create_null(); ASSERT_FALSE(value->_is_owned_by_cpp); auto cppValue = static_cast(value->_value); - ASSERT_EQ(cppValue->getDataType().getTypeID(), DataTypeID::ANY); + ASSERT_EQ(cppValue->getDataType().getLogicalTypeID(), LogicalTypeID::ANY); ASSERT_EQ(cppValue->isNull(), true); kuzu_value_destroy(value); } @@ -24,7 +24,7 @@ TEST_F(CApiValueTest, CreateNullWithDatatype) { ASSERT_FALSE(value->_is_owned_by_cpp); kuzu_data_type_destroy(type); auto cppValue = static_cast(value->_value); - ASSERT_EQ(cppValue->getDataType().getTypeID(), DataTypeID::INT64); + ASSERT_EQ(cppValue->getDataType().getLogicalTypeID(), LogicalTypeID::INT64); ASSERT_EQ(cppValue->isNull(), true); kuzu_value_destroy(value); } @@ -57,7 +57,7 @@ TEST_F(CApiValueTest, CreateDefault) { kuzu_data_type_destroy(type); auto cppValue = static_cast(value->_value); ASSERT_FALSE(kuzu_value_is_null(value)); - ASSERT_EQ(cppValue->getDataType().getTypeID(), DataTypeID::INT64); + ASSERT_EQ(cppValue->getDataType().getLogicalTypeID(), LogicalTypeID::INT64); ASSERT_EQ(cppValue->getValue(), 0); kuzu_value_destroy(value); @@ -67,7 +67,7 @@ TEST_F(CApiValueTest, CreateDefault) { kuzu_data_type_destroy(type); cppValue = static_cast(value->_value); ASSERT_FALSE(kuzu_value_is_null(value)); - ASSERT_EQ(cppValue->getDataType().getTypeID(), DataTypeID::STRING); + ASSERT_EQ(cppValue->getDataType().getLogicalTypeID(), LogicalTypeID::STRING); ASSERT_EQ(cppValue->getValue(), ""); kuzu_value_destroy(value); } @@ -76,14 +76,14 @@ TEST_F(CApiValueTest, CreateBool) { kuzu_value* value = kuzu_value_create_bool(true); ASSERT_FALSE(value->_is_owned_by_cpp); auto cppValue = static_cast(value->_value); - ASSERT_EQ(cppValue->getDataType().getTypeID(), DataTypeID::BOOL); + ASSERT_EQ(cppValue->getDataType().getLogicalTypeID(), LogicalTypeID::BOOL); ASSERT_EQ(cppValue->getValue(), true); kuzu_value_destroy(value); value = kuzu_value_create_bool(false); ASSERT_FALSE(value->_is_owned_by_cpp); cppValue = static_cast(value->_value); - ASSERT_EQ(cppValue->getDataType().getTypeID(), DataTypeID::BOOL); + ASSERT_EQ(cppValue->getDataType().getLogicalTypeID(), LogicalTypeID::BOOL); ASSERT_EQ(cppValue->getValue(), false); kuzu_value_destroy(value); } @@ -92,7 +92,7 @@ TEST_F(CApiValueTest, CreateInt16) { kuzu_value* value = kuzu_value_create_int16(123); ASSERT_FALSE(value->_is_owned_by_cpp); auto cppValue = static_cast(value->_value); - ASSERT_EQ(cppValue->getDataType().getTypeID(), DataTypeID::INT16); + ASSERT_EQ(cppValue->getDataType().getLogicalTypeID(), LogicalTypeID::INT16); ASSERT_EQ(cppValue->getValue(), 123); kuzu_value_destroy(value); } @@ -101,7 +101,7 @@ TEST_F(CApiValueTest, CreateInt32) { kuzu_value* value = kuzu_value_create_int32(123); ASSERT_FALSE(value->_is_owned_by_cpp); auto cppValue = static_cast(value->_value); - ASSERT_EQ(cppValue->getDataType().getTypeID(), DataTypeID::INT32); + ASSERT_EQ(cppValue->getDataType().getLogicalTypeID(), LogicalTypeID::INT32); ASSERT_EQ(cppValue->getValue(), 123); kuzu_value_destroy(value); } @@ -110,7 +110,7 @@ TEST_F(CApiValueTest, CreateInt64) { kuzu_value* value = kuzu_value_create_int64(123); ASSERT_FALSE(value->_is_owned_by_cpp); auto cppValue = static_cast(value->_value); - ASSERT_EQ(cppValue->getDataType().getTypeID(), DataTypeID::INT64); + ASSERT_EQ(cppValue->getDataType().getLogicalTypeID(), LogicalTypeID::INT64); ASSERT_EQ(cppValue->getValue(), 123); kuzu_value_destroy(value); } @@ -119,7 +119,7 @@ TEST_F(CApiValueTest, CreateFloat) { kuzu_value* value = kuzu_value_create_float(123.456); ASSERT_FALSE(value->_is_owned_by_cpp); auto cppValue = static_cast(value->_value); - ASSERT_EQ(cppValue->getDataType().getTypeID(), DataTypeID::FLOAT); + ASSERT_EQ(cppValue->getDataType().getLogicalTypeID(), LogicalTypeID::FLOAT); ASSERT_FLOAT_EQ(cppValue->getValue(), 123.456); kuzu_value_destroy(value); } @@ -128,7 +128,7 @@ TEST_F(CApiValueTest, CreateDouble) { kuzu_value* value = kuzu_value_create_double(123.456); ASSERT_FALSE(value->_is_owned_by_cpp); auto cppValue = static_cast(value->_value); - ASSERT_EQ(cppValue->getDataType().getTypeID(), DataTypeID::DOUBLE); + ASSERT_EQ(cppValue->getDataType().getLogicalTypeID(), LogicalTypeID::DOUBLE); ASSERT_DOUBLE_EQ(cppValue->getValue(), 123.456); kuzu_value_destroy(value); } @@ -138,7 +138,7 @@ TEST_F(CApiValueTest, CreateInternalID) { kuzu_value* value = kuzu_value_create_internal_id(internalID); ASSERT_FALSE(value->_is_owned_by_cpp); auto cppValue = static_cast(value->_value); - ASSERT_EQ(cppValue->getDataType().getTypeID(), DataTypeID::INTERNAL_ID); + ASSERT_EQ(cppValue->getDataType().getLogicalTypeID(), LogicalTypeID::INTERNAL_ID); auto internalIDCpp = cppValue->getValue(); ASSERT_EQ(internalIDCpp.tableID, 1); ASSERT_EQ(internalIDCpp.offset, 123); @@ -151,7 +151,7 @@ TEST_F(CApiValueTest, CreateNodeVal) { auto value = kuzu_value_create_node_val(nodeVal); ASSERT_FALSE(value->_is_owned_by_cpp); auto cppValue = static_cast(value->_value); - ASSERT_EQ(cppValue->getDataType().getTypeID(), DataTypeID::NODE); + ASSERT_EQ(cppValue->getDataType().getLogicalTypeID(), LogicalTypeID::NODE); auto nodeValCpp = cppValue->getValue(); ASSERT_EQ(nodeValCpp.getNodeID().tableID, 1); ASSERT_EQ(nodeValCpp.getNodeID().offset, 123); @@ -168,7 +168,7 @@ TEST_F(CApiValueTest, CreateRelVal) { auto value = kuzu_value_create_rel_val(relVal); ASSERT_FALSE(value->_is_owned_by_cpp); auto cppValue = static_cast(value->_value); - ASSERT_EQ(cppValue->getDataType().getTypeID(), DataTypeID::REL); + ASSERT_EQ(cppValue->getDataType().getLogicalTypeID(), LogicalTypeID::REL); auto relValCpp = cppValue->getValue(); ASSERT_EQ(relValCpp.getSrcNodeID().tableID, 1); ASSERT_EQ(relValCpp.getSrcNodeID().offset, 123); @@ -184,7 +184,7 @@ TEST_F(CApiValueTest, CreateDate) { kuzu_value* value = kuzu_value_create_date(kuzu_date_t{123}); ASSERT_FALSE(value->_is_owned_by_cpp); auto cppValue = static_cast(value->_value); - ASSERT_EQ(cppValue->getDataType().getTypeID(), DataTypeID::DATE); + ASSERT_EQ(cppValue->getDataType().getLogicalTypeID(), LogicalTypeID::DATE); auto cppDate = cppValue->getValue(); ASSERT_EQ(cppDate.days, 123); kuzu_value_destroy(value); @@ -194,7 +194,7 @@ TEST_F(CApiValueTest, CreateTimeStamp) { kuzu_value* value = kuzu_value_create_timestamp(kuzu_timestamp_t{123}); ASSERT_FALSE(value->_is_owned_by_cpp); auto cppValue = static_cast(value->_value); - ASSERT_EQ(cppValue->getDataType().getTypeID(), DataTypeID::TIMESTAMP); + ASSERT_EQ(cppValue->getDataType().getLogicalTypeID(), LogicalTypeID::TIMESTAMP); auto cppTimeStamp = cppValue->getValue(); ASSERT_EQ(cppTimeStamp.value, 123); kuzu_value_destroy(value); @@ -204,7 +204,7 @@ TEST_F(CApiValueTest, CreateInterval) { kuzu_value* value = kuzu_value_create_interval(kuzu_interval_t{12, 3, 300}); ASSERT_FALSE(value->_is_owned_by_cpp); auto cppValue = static_cast(value->_value); - ASSERT_EQ(cppValue->getDataType().getTypeID(), DataTypeID::INTERVAL); + ASSERT_EQ(cppValue->getDataType().getLogicalTypeID(), LogicalTypeID::INTERVAL); auto cppTimeStamp = cppValue->getValue(); ASSERT_EQ(cppTimeStamp.months, 12); ASSERT_EQ(cppTimeStamp.days, 3); @@ -216,7 +216,7 @@ TEST_F(CApiValueTest, CreateString) { kuzu_value* value = kuzu_value_create_string((char*)"abcdefg"); ASSERT_FALSE(value->_is_owned_by_cpp); auto cppValue = static_cast(value->_value); - ASSERT_EQ(cppValue->getDataType().getTypeID(), DataTypeID::STRING); + ASSERT_EQ(cppValue->getDataType().getLogicalTypeID(), LogicalTypeID::STRING); ASSERT_EQ(cppValue->getValue(), "abcdefg"); kuzu_value_destroy(value); } @@ -225,7 +225,7 @@ TEST_F(CApiValueTest, Clone) { kuzu_value* value = kuzu_value_create_string((char*)"abcdefg"); ASSERT_FALSE(value->_is_owned_by_cpp); auto cppValue = static_cast(value->_value); - ASSERT_EQ(cppValue->getDataType().getTypeID(), DataTypeID::STRING); + ASSERT_EQ(cppValue->getDataType().getLogicalTypeID(), LogicalTypeID::STRING); ASSERT_EQ(cppValue->getValue(), "abcdefg"); kuzu_value* clone = kuzu_value_clone(value); @@ -233,7 +233,7 @@ TEST_F(CApiValueTest, Clone) { ASSERT_FALSE(clone->_is_owned_by_cpp); auto cppClone = static_cast(clone->_value); - ASSERT_EQ(cppClone->getDataType().getTypeID(), DataTypeID::STRING); + ASSERT_EQ(cppClone->getDataType().getLogicalTypeID(), LogicalTypeID::STRING); ASSERT_EQ(cppClone->getValue(), "abcdefg"); kuzu_value_destroy(clone); } @@ -247,7 +247,7 @@ TEST_F(CApiValueTest, Copy) { ASSERT_FALSE(kuzu_value_is_null(value)); auto cppValue = static_cast(value->_value); - ASSERT_EQ(cppValue->getDataType().getTypeID(), DataTypeID::STRING); + ASSERT_EQ(cppValue->getDataType().getLogicalTypeID(), LogicalTypeID::STRING); ASSERT_EQ(cppValue->getValue(), "abcdefg"); kuzu_value_destroy(value); } @@ -353,21 +353,21 @@ TEST_F(CApiValueTest, GetStructFieldValue) { auto fieldValue = kuzu_value_get_struct_field_value(value, 0); auto fieldType = kuzu_value_get_data_type(fieldValue); - ASSERT_EQ(kuzu_data_type_get_id(fieldType), DataTypeID::DOUBLE); + ASSERT_EQ(kuzu_data_type_get_id(fieldType), DOUBLE); ASSERT_DOUBLE_EQ(kuzu_value_get_double(fieldValue), 1223); kuzu_data_type_destroy(fieldType); kuzu_value_destroy(fieldValue); fieldValue = kuzu_value_get_struct_field_value(value, 1); fieldType = kuzu_value_get_data_type(fieldValue); - ASSERT_EQ(kuzu_data_type_get_id(fieldType), DataTypeID::INT64); + ASSERT_EQ(kuzu_data_type_get_id(fieldType), INT64); ASSERT_EQ(kuzu_value_get_int64(fieldValue), 10003); kuzu_data_type_destroy(fieldType); kuzu_value_destroy(fieldValue); fieldValue = kuzu_value_get_struct_field_value(value, 2); fieldType = kuzu_value_get_data_type(fieldValue); - ASSERT_EQ(kuzu_data_type_get_id(fieldType), DataTypeID::TIMESTAMP); + ASSERT_EQ(kuzu_data_type_get_id(fieldType), TIMESTAMP); auto timestamp = kuzu_value_get_timestamp(fieldValue); ASSERT_EQ(timestamp.value, 1297442662000000); kuzu_data_type_destroy(fieldType); @@ -375,7 +375,7 @@ TEST_F(CApiValueTest, GetStructFieldValue) { fieldValue = kuzu_value_get_struct_field_value(value, 3); fieldType = kuzu_value_get_data_type(fieldValue); - ASSERT_EQ(kuzu_data_type_get_id(fieldType), DataTypeID::DATE); + ASSERT_EQ(kuzu_data_type_get_id(fieldType), DATE); auto date = kuzu_value_get_date(fieldValue); ASSERT_EQ(date.days, 15758); kuzu_data_type_destroy(fieldType); @@ -395,22 +395,19 @@ TEST_F(CApiValueTest, GetDataType) { auto flatTuple = kuzu_query_result_get_next(result); auto value = kuzu_flat_tuple_get_value(flatTuple, 0); auto dataType = kuzu_value_get_data_type(value); - ASSERT_EQ(kuzu_data_type_get_id(dataType), DataTypeID::STRING); + ASSERT_EQ(kuzu_data_type_get_id(dataType), STRING); kuzu_data_type_destroy(dataType); kuzu_value_destroy(value); value = kuzu_flat_tuple_get_value(flatTuple, 1); dataType = kuzu_value_get_data_type(value); - ASSERT_EQ(kuzu_data_type_get_id(dataType), DataTypeID::BOOL); + ASSERT_EQ(kuzu_data_type_get_id(dataType), BOOL); kuzu_data_type_destroy(dataType); kuzu_value_destroy(value); value = kuzu_flat_tuple_get_value(flatTuple, 2); dataType = kuzu_value_get_data_type(value); - ASSERT_EQ(kuzu_data_type_get_id(dataType), DataTypeID::VAR_LIST); - auto childDataType = kuzu_data_type_get_child_type(dataType); - ASSERT_EQ(kuzu_data_type_get_id(childDataType), DataTypeID::INT64); - kuzu_data_type_destroy(childDataType); + ASSERT_EQ(kuzu_data_type_get_id(dataType), VAR_LIST); kuzu_data_type_destroy(dataType); kuzu_value_destroy(value); diff --git a/test/copy/copy_lists_test.cpp b/test/copy/copy_lists_test.cpp index ac9ce316b4..dabcf364f1 100644 --- a/test/copy/copy_lists_test.cpp +++ b/test/copy/copy_lists_test.cpp @@ -10,7 +10,7 @@ class TinySnbListTest : public DBTest { public: static bool CheckEquals(const std::vector& expected, const Value& listVal) { - if (listVal.dataType.typeID != VAR_LIST) { + if (listVal.dataType.getLogicalTypeID() != LogicalTypeID::VAR_LIST) { return false; } if (expected.size() != listVal.nestedTypeVal.size()) { @@ -68,7 +68,6 @@ TEST_F(TinySnbListTest, RelPropertyColumnWithList) { auto graph = getStorageManager(*database); auto catalog = getCatalog(*database); auto tableID = catalog->getReadOnlyVersion()->getTableID("studyAt"); - auto nodeTablesForAdjColumnAndProperties = catalog->getReadOnlyVersion()->getTableID("person"); auto& property = catalog->getReadOnlyVersion()->getRelProperty(tableID, "places"); auto col = graph->getRelsStore().getRelPropertyColumn( RelDataDirection::FWD, tableID, property.propertyID); diff --git a/test/copy/npy_reader_test.cpp b/test/copy/npy_reader_test.cpp index af30d23041..abc9d12c98 100644 --- a/test/copy/npy_reader_test.cpp +++ b/test/copy/npy_reader_test.cpp @@ -24,7 +24,7 @@ TEST_F(NpyReaderTest, ReadNpyOneDimensionalInt64) { std::make_unique(getInputDirForOneDimensional() + "one_dim_int64.npy"); ASSERT_EQ(npyReader->getNumRows(), 3); ASSERT_EQ(npyReader->getNumElementsPerRow(), 1); - ASSERT_EQ(npyReader->getType(), DataTypeID::INT64); + ASSERT_EQ(npyReader->getType(), LogicalTypeID::INT64); ASSERT_EQ(npyReader->getNumDimensions(), 1); auto shape = npyReader->getShape(); ASSERT_EQ(shape[0], 3); @@ -46,7 +46,7 @@ TEST_F(NpyReaderTest, ReadNpyTwoDimensionalInt64) { std::make_unique(getInputDirForTwoDimensional() + "two_dim_int64.npy"); ASSERT_EQ(npyReader->getNumRows(), 3); ASSERT_EQ(npyReader->getNumElementsPerRow(), 3); - ASSERT_EQ(npyReader->getType(), DataTypeID::INT64); + ASSERT_EQ(npyReader->getType(), LogicalTypeID::INT64); ASSERT_EQ(npyReader->getNumDimensions(), 2); auto shape = npyReader->getShape(); ASSERT_EQ(shape[0], 3); @@ -72,7 +72,7 @@ TEST_F(NpyReaderTest, ReadNpyThreeDimensionalInt64) { std::make_unique(getInputDirForThreeDimensional() + "three_dim_int64.npy"); ASSERT_EQ(npyReader->getNumRows(), 2); ASSERT_EQ(npyReader->getNumElementsPerRow(), 12); - ASSERT_EQ(npyReader->getType(), DataTypeID::INT64); + ASSERT_EQ(npyReader->getType(), LogicalTypeID::INT64); ASSERT_EQ(npyReader->getNumDimensions(), 3); auto shape = npyReader->getShape(); ASSERT_EQ(shape[0], 2); @@ -112,7 +112,7 @@ TEST_F(NpyReaderTest, ReadNpyOneDimensionalInt32) { std::make_unique(getInputDirForOneDimensional() + "one_dim_int32.npy"); ASSERT_EQ(npyReader->getNumRows(), 3); ASSERT_EQ(npyReader->getNumElementsPerRow(), 1); - ASSERT_EQ(npyReader->getType(), DataTypeID::INT32); + ASSERT_EQ(npyReader->getType(), LogicalTypeID::INT32); ASSERT_EQ(npyReader->getNumDimensions(), 1); auto shape = npyReader->getShape(); ASSERT_EQ(shape[0], 3); @@ -134,7 +134,7 @@ TEST_F(NpyReaderTest, ReadNpyTwoDimensionalInt32) { std::make_unique(getInputDirForTwoDimensional() + "two_dim_int32.npy"); ASSERT_EQ(npyReader->getNumRows(), 3); ASSERT_EQ(npyReader->getNumElementsPerRow(), 3); - ASSERT_EQ(npyReader->getType(), DataTypeID::INT32); + ASSERT_EQ(npyReader->getType(), LogicalTypeID::INT32); ASSERT_EQ(npyReader->getNumDimensions(), 2); auto shape = npyReader->getShape(); ASSERT_EQ(shape[0], 3); @@ -160,7 +160,7 @@ TEST_F(NpyReaderTest, ReadNpyOneDimensionalInt16) { std::make_unique(getInputDirForOneDimensional() + "one_dim_int16.npy"); ASSERT_EQ(npyReader->getNumRows(), 3); ASSERT_EQ(npyReader->getNumElementsPerRow(), 1); - ASSERT_EQ(npyReader->getType(), DataTypeID::INT16); + ASSERT_EQ(npyReader->getType(), LogicalTypeID::INT16); ASSERT_EQ(npyReader->getNumDimensions(), 1); auto shape = npyReader->getShape(); ASSERT_EQ(shape[0], 3); @@ -182,7 +182,7 @@ TEST_F(NpyReaderTest, ReadNpyTwoDimensionalInt16) { std::make_unique(getInputDirForTwoDimensional() + "two_dim_int16.npy"); ASSERT_EQ(npyReader->getNumRows(), 3); ASSERT_EQ(npyReader->getNumElementsPerRow(), 3); - ASSERT_EQ(npyReader->getType(), DataTypeID::INT16); + ASSERT_EQ(npyReader->getType(), LogicalTypeID::INT16); ASSERT_EQ(npyReader->getNumDimensions(), 2); auto shape = npyReader->getShape(); ASSERT_EQ(shape[0], 3); @@ -208,7 +208,7 @@ TEST_F(NpyReaderTest, ReadNpyOneDimensionalDouble) { std::make_unique(getInputDirForOneDimensional() + "one_dim_double.npy"); ASSERT_EQ(npyReader->getNumRows(), 3); ASSERT_EQ(npyReader->getNumElementsPerRow(), 1); - ASSERT_EQ(npyReader->getType(), DataTypeID::DOUBLE); + ASSERT_EQ(npyReader->getType(), LogicalTypeID::DOUBLE); ASSERT_EQ(npyReader->getNumDimensions(), 1); auto shape = npyReader->getShape(); ASSERT_EQ(shape[0], 3); @@ -230,7 +230,7 @@ TEST_F(NpyReaderTest, ReadNpyTwoDimensionalDouble) { std::make_unique(getInputDirForTwoDimensional() + "two_dim_double.npy"); ASSERT_EQ(npyReader->getNumRows(), 3); ASSERT_EQ(npyReader->getNumElementsPerRow(), 3); - ASSERT_EQ(npyReader->getType(), DataTypeID::DOUBLE); + ASSERT_EQ(npyReader->getType(), LogicalTypeID::DOUBLE); ASSERT_EQ(npyReader->getNumDimensions(), 2); auto shape = npyReader->getShape(); ASSERT_EQ(shape[0], 3); @@ -256,7 +256,7 @@ TEST_F(NpyReaderTest, ReadNpyOneDimensionalFloat) { std::make_unique(getInputDirForOneDimensional() + "one_dim_float.npy"); ASSERT_EQ(npyReader->getNumRows(), 3); ASSERT_EQ(npyReader->getNumElementsPerRow(), 1); - ASSERT_EQ(npyReader->getType(), DataTypeID::FLOAT); + ASSERT_EQ(npyReader->getType(), LogicalTypeID::FLOAT); ASSERT_EQ(npyReader->getNumDimensions(), 1); auto shape = npyReader->getShape(); ASSERT_EQ(shape[0], 3); @@ -278,7 +278,7 @@ TEST_F(NpyReaderTest, ReadNpyTwoDimensionalFloat) { std::make_unique(getInputDirForTwoDimensional() + "two_dim_float.npy"); ASSERT_EQ(npyReader->getNumRows(), 3); ASSERT_EQ(npyReader->getNumElementsPerRow(), 3); - ASSERT_EQ(npyReader->getType(), DataTypeID::FLOAT); + ASSERT_EQ(npyReader->getType(), LogicalTypeID::FLOAT); ASSERT_EQ(npyReader->getNumDimensions(), 2); auto shape = npyReader->getShape(); ASSERT_EQ(shape[0], 3); diff --git a/test/graph_test/graph_test.cpp b/test/graph_test/graph_test.cpp index 36bb65a2c9..ce2beaebb6 100644 --- a/test/graph_test/graph_test.cpp +++ b/test/graph_test/graph_test.cpp @@ -44,11 +44,12 @@ void BaseGraphTest::validateNodeColumnFilesExistence( for (auto& property : nodeTableSchema->properties) { validateColumnFilesExistence(StorageUtils::getNodePropertyColumnFName(databasePath, nodeTableSchema->tableID, property.propertyID, dbFileType), - existence, containsOverflowFile(property.dataType.typeID)); + existence, containsOverflowFile(property.dataType.getLogicalTypeID())); } validateColumnFilesExistence( StorageUtils::getNodeIndexFName(databasePath, nodeTableSchema->tableID, dbFileType), - existence, containsOverflowFile(nodeTableSchema->getPrimaryKey().dataType.typeID)); + existence, + containsOverflowFile(nodeTableSchema->getPrimaryKey().dataType.getLogicalTypeID())); } void BaseGraphTest::validateRelColumnAndListFilesExistence( @@ -96,7 +97,7 @@ void BaseGraphTest::commitOrRollbackConnectionAndInitDBIfNecessary( void BaseGraphTest::validateRelPropertyFiles(catalog::RelTableSchema* relTableSchema, RelDataDirection relDirection, bool isColumnProperty, DBFileType dbFileType, bool existence) { for (auto& property : relTableSchema->properties) { - auto hasOverflow = containsOverflowFile(property.dataType.typeID); + auto hasOverflow = containsOverflowFile(property.dataType.getLogicalTypeID()); if (isColumnProperty) { validateColumnFilesExistence( StorageUtils::getRelPropertyColumnFName(databasePath, relTableSchema->tableID, diff --git a/test/include/graph_test/graph_test.h b/test/include/graph_test/graph_test.h index a9ad2c9ec5..e8c33567e9 100644 --- a/test/include/graph_test/graph_test.h +++ b/test/include/graph_test/graph_test.h @@ -109,8 +109,8 @@ class BaseGraphTest : public Test { sort(expectedResult.begin(), expectedResult.end()); ASSERT_EQ(actualResult, expectedResult); } - static inline bool containsOverflowFile(common::DataTypeID typeID) { - return typeID == common::STRING || typeID == common::VAR_LIST; + static inline bool containsOverflowFile(common::LogicalTypeID typeID) { + return typeID == common::LogicalTypeID::STRING || typeID == common::LogicalTypeID::VAR_LIST; } void validateColumnFilesExistence(std::string fileName, bool existence, bool hasOverflow); diff --git a/test/processor/order_by/key_block_merger_test.cpp b/test/processor/order_by/key_block_merger_test.cpp index 7b2399f845..190d8eb50c 100644 --- a/test/processor/order_by/key_block_merger_test.cpp +++ b/test/processor/order_by/key_block_merger_test.cpp @@ -53,7 +53,7 @@ class KeyBlockMergerTest : public Test { template OrderByKeyEncoder prepareSingleOrderByColEncoder(const std::vector& sortingData, - const std::vector& nullMasks, DataTypeID dataTypeID, bool isAsc, + const std::vector& nullMasks, LogicalTypeID dataTypeID, bool isAsc, uint16_t factorizedTableIdx, bool hasPayLoadCol, std::vector>& factorizedTables, std::shared_ptr& dataChunk) { @@ -76,11 +76,12 @@ class KeyBlockMergerTest : public Test { std::unique_ptr tableSchema = std::make_unique(); - tableSchema->appendColumn(std::make_unique( - false /* isUnflat */, 0 /* dataChunkPos */, Types::getDataTypeSize(dataTypeID))); + tableSchema->appendColumn(std::make_unique(false /* isUnflat */, + 0 /* dataChunkPos */, FactorizedTable::getDataTypeSize(LogicalType{dataTypeID}))); if (hasPayLoadCol) { - auto payloadValueVector = std::make_shared(STRING, memoryManager.get()); + auto payloadValueVector = + std::make_shared(LogicalTypeID::STRING, memoryManager.get()); for (auto i = 0u; i < dataChunk->state->selVector->selectedSize; i++) { payloadValueVector->setValue(i, std::to_string(i)); } @@ -88,8 +89,8 @@ class KeyBlockMergerTest : public Test { // To test whether the orderByCol -> factorizedTableColIdx works properly, we put the // payload column at index 0, and the orderByCol at index 1. allVectors.insert(allVectors.begin(), payloadValueVector.get()); - tableSchema->appendColumn(std::make_unique( - false, 0 /* dataChunkPos */, Types::getDataTypeSize(dataTypeID))); + tableSchema->appendColumn(std::make_unique(false, 0 /* dataChunkPos */, + FactorizedTable::getDataTypeSize(LogicalType{dataTypeID}))); } auto factorizedTable = @@ -111,8 +112,8 @@ class KeyBlockMergerTest : public Test { const std::vector& leftNullMasks, const std::vector& rightSortingData, const std::vector& rightNullMasks, const std::vector& expectedBlockOffsetOrder, - const std::vector& expectedFactorizedTableIdxOrder, const DataTypeID dataTypeID, - const bool isAsc, bool hasPayLoadCol) { + const std::vector& expectedFactorizedTableIdxOrder, + const LogicalTypeID dataTypeID, const bool isAsc, bool hasPayLoadCol) { std::vector> factorizedTables; auto dataChunk0 = std::make_shared(hasPayLoadCol ? 2 : 1); auto dataChunk1 = std::make_shared(hasPayLoadCol ? 2 : 1); @@ -191,9 +192,12 @@ class KeyBlockMergerTest : public Test { dataChunk->state->initOriginalAndSelectedSize(int64Values.size()); dataChunk->state->currIdx = 0; - auto int64ValueVector = std::make_shared(INT64, memoryManager.get()); - auto doubleValueVector = std::make_shared(DOUBLE, memoryManager.get()); - auto timestampValueVector = std::make_shared(TIMESTAMP, memoryManager.get()); + auto int64ValueVector = + std::make_shared(LogicalTypeID::INT64, memoryManager.get()); + auto doubleValueVector = + std::make_shared(LogicalTypeID::DOUBLE, memoryManager.get()); + auto timestampValueVector = + std::make_shared(LogicalTypeID::TIMESTAMP, memoryManager.get()); dataChunk->insert(0, int64ValueVector); dataChunk->insert(1, doubleValueVector); @@ -230,18 +234,24 @@ class KeyBlockMergerTest : public Test { std::unique_ptr tableSchema = std::make_unique(); - tableSchema->appendColumn(std::make_unique( - false /* isUnflat */, 0 /* dataChunkPos */, Types::getDataTypeSize(INT64))); - tableSchema->appendColumn(std::make_unique( - false /* isUnflat */, 0 /* dataChunkPos */, Types::getDataTypeSize(DOUBLE))); - tableSchema->appendColumn(std::make_unique( - false /* isUnflat */, 0 /* dataChunkPos */, Types::getDataTypeSize(TIMESTAMP))); + tableSchema->appendColumn( + std::make_unique(false /* isUnflat */, 0 /* dataChunkPos */, + FactorizedTable::getDataTypeSize(LogicalType{LogicalTypeID::INT64}))); + tableSchema->appendColumn( + std::make_unique(false /* isUnflat */, 0 /* dataChunkPos */, + FactorizedTable::getDataTypeSize(LogicalType{LogicalTypeID::DOUBLE}))); + tableSchema->appendColumn( + std::make_unique(false /* isUnflat */, 0 /* dataChunkPos */, + FactorizedTable::getDataTypeSize(LogicalType{LogicalTypeID::TIMESTAMP}))); if (hasStrCol) { - tableSchema->appendColumn(std::make_unique( - false /* isUnflat */, 0 /* dataChunkPos */, Types::getDataTypeSize(STRING))); - auto stringValueVector1 = std::make_shared(STRING, memoryManager.get()); - auto stringValueVector2 = std::make_shared(STRING, memoryManager.get()); + tableSchema->appendColumn( + std::make_unique(false /* isUnflat */, 0 /* dataChunkPos */, + FactorizedTable::getDataTypeSize(LogicalType{LogicalTypeID::STRING}))); + auto stringValueVector1 = + std::make_shared(LogicalTypeID::STRING, memoryManager.get()); + auto stringValueVector2 = + std::make_shared(LogicalTypeID::STRING, memoryManager.get()); dataChunk1->insert(3, stringValueVector1); dataChunk2->insert(3, stringValueVector2); @@ -272,8 +282,9 @@ class KeyBlockMergerTest : public Test { if (hasStrCol) { strKeyColsInfo.emplace_back( StrKeyColInfo(tableSchema->getColOffset(3 /* colIdx */) /* colOffsetInFT */, - Types::getDataTypeSize(INT64) + Types::getDataTypeSize(DOUBLE) + - Types::getDataTypeSize(TIMESTAMP) + 3, + FactorizedTable::getDataTypeSize(LogicalType{LogicalTypeID::INT64}) + + FactorizedTable::getDataTypeSize(LogicalType{LogicalTypeID::DOUBLE}) + + FactorizedTable::getDataTypeSize(LogicalType{LogicalTypeID::TIMESTAMP}) + 3, true /* isAscOrder */)); expectedBlockOffsetOrder = {0, 0, 1, 1, 2, 2, 3}; expectedFactorizedTableIdxOrder = {4, 5, 4, 5, 4, 5, 4}; @@ -306,7 +317,8 @@ class KeyBlockMergerTest : public Test { dataChunk->state->currIdx = 0; dataChunk->state->initOriginalAndSelectedSize(strValues[0].size()); for (auto i = 0u; i < strValues.size(); i++) { - auto strValueVector = std::make_shared(STRING, memoryManager.get()); + auto strValueVector = + std::make_shared(LogicalTypeID::STRING, memoryManager.get()); dataChunk->insert(i, strValueVector); for (auto j = 0u; j < strValues[i].size(); j++) { strValueVector->setValue(j, strValues[i][j]); @@ -323,14 +335,16 @@ class KeyBlockMergerTest : public Test { std::unique_ptr tableSchema = std::make_unique(); + auto stringColumnSize = + FactorizedTable::getDataTypeSize(LogicalType{LogicalTypeID::STRING}); tableSchema->appendColumn(std::make_unique( - false /* isUnflat */, 0 /* dataChunkPos */, Types::getDataTypeSize(STRING))); + false /* isUnflat */, 0 /* dataChunkPos */, stringColumnSize)); tableSchema->appendColumn(std::make_unique( - false /* isUnflat */, 0 /* dataChunkPos */, Types::getDataTypeSize(STRING))); + false /* isUnflat */, 0 /* dataChunkPos */, stringColumnSize)); tableSchema->appendColumn(std::make_unique( - false /* isUnflat */, 0 /* dataChunkPos */, Types::getDataTypeSize(STRING))); + false /* isUnflat */, 0 /* dataChunkPos */, stringColumnSize)); tableSchema->appendColumn(std::make_unique( - false /* isUnflat */, 0 /* dataChunkPos */, Types::getDataTypeSize(STRING))); + false /* isUnflat */, 0 /* dataChunkPos */, stringColumnSize)); auto factorizedTable = std::make_unique(memoryManager.get(), std::move(tableSchema)); @@ -363,8 +377,8 @@ TEST_F(KeyBlockMergerTest, singleOrderByColInt64Test) { std::vector expectedFactorizedTableIdxOrder = { 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1}; singleOrderByColMergeTest(leftSortingData, leftNullMasks, rightSortingData, rightNullMasks, - expectedBlockOffsetOrder, expectedFactorizedTableIdxOrder, INT64, true /* isAsc */, - false /* hasPayLoadCol */); + expectedBlockOffsetOrder, expectedFactorizedTableIdxOrder, LogicalTypeID::INT64, + true /* isAsc */, false /* hasPayLoadCol */); } TEST_F(KeyBlockMergerTest, singleOrderByColInt64NoNullTest) { @@ -375,8 +389,8 @@ TEST_F(KeyBlockMergerTest, singleOrderByColInt64NoNullTest) { std::vector expectedBlockOffsetOrder = {0, 0, 1, 1, 2, 3, 2, 4, 3}; std::vector expectedFactorizedTableIdxOrder = {0, 1, 1, 0, 0, 0, 1, 0, 1}; singleOrderByColMergeTest(leftSortingData, leftNullMasks, rightSortingData, rightNullMasks, - expectedBlockOffsetOrder, expectedFactorizedTableIdxOrder, INT64, true /* isAsc */, - false /* hasPayLoadCol */); + expectedBlockOffsetOrder, expectedFactorizedTableIdxOrder, LogicalTypeID::INT64, + true /* isAsc */, false /* hasPayLoadCol */); } TEST_F(KeyBlockMergerTest, singleOrderByColInt64SameValueTest) { @@ -388,8 +402,8 @@ TEST_F(KeyBlockMergerTest, singleOrderByColInt64SameValueTest) { std::vector expectedFactorizedTableIdxOrder = { 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1}; singleOrderByColMergeTest(leftSortingData, leftNullMasks, rightSortingData, rightNullMasks, - expectedBlockOffsetOrder, expectedFactorizedTableIdxOrder, INT64, false /* isAsc */, - false /* hasPayLoadCol */); + expectedBlockOffsetOrder, expectedFactorizedTableIdxOrder, LogicalTypeID::INT64, + false /* isAsc */, false /* hasPayLoadCol */); } TEST_F(KeyBlockMergerTest, singleOrderByColInt64LargeNumTuplesTest) { @@ -415,8 +429,8 @@ TEST_F(KeyBlockMergerTest, singleOrderByColInt64LargeNumTuplesTest) { std::vector leftNullMasks(leftSortingData.size(), false); std::vector rightNullMasks(rightSortingData.size(), false); singleOrderByColMergeTest(leftSortingData, leftNullMasks, rightSortingData, rightNullMasks, - expectedBlockOffsetOrder, expectedFactorizedTableIdxOrder, INT64, true /* isAsc */, - false /* hasPayLoadCol */); + expectedBlockOffsetOrder, expectedFactorizedTableIdxOrder, LogicalTypeID::INT64, + true /* isAsc */, false /* hasPayLoadCol */); } TEST_F(KeyBlockMergerTest, singleOrderByColStringTest) { @@ -430,8 +444,8 @@ TEST_F(KeyBlockMergerTest, singleOrderByColStringTest) { std::vector expectedBlockOffsetOrder = {0, 0, 1, 2, 1, 2, 3, 3, 4, 4, 5, 6}; std::vector expectedFactorizedTableIdxOrder = {0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1}; singleOrderByColMergeTest(leftSortingData, leftNullMasks, rightSortingData, rightNullMasks, - expectedBlockOffsetOrder, expectedFactorizedTableIdxOrder, STRING, false /* isAsc */, - false /* hasPayLoadCol */); + expectedBlockOffsetOrder, expectedFactorizedTableIdxOrder, LogicalTypeID::STRING, + false /* isAsc */, false /* hasPayLoadCol */); } TEST_F(KeyBlockMergerTest, singleOrderByColStringNoNullTest) { @@ -444,8 +458,8 @@ TEST_F(KeyBlockMergerTest, singleOrderByColStringNoNullTest) { std::vector expectedBlockOffsetOrder = {0, 0, 1, 1, 2, 2, 3, 4, 3, 4}; std::vector expectedFactorizedTableIdxOrder = {0, 1, 0, 1, 0, 1, 0, 0, 1, 1}; singleOrderByColMergeTest(leftSortingData, leftNullMasks, rightSortingData, rightNullMasks, - expectedBlockOffsetOrder, expectedFactorizedTableIdxOrder, STRING, true /* isAsc */, - false /* hasPayLoadCol */); + expectedBlockOffsetOrder, expectedFactorizedTableIdxOrder, LogicalTypeID::STRING, + true /* isAsc */, false /* hasPayLoadCol */); } TEST_F(KeyBlockMergerTest, singleOrderByColStringWithPayLoadTest) { @@ -458,8 +472,8 @@ TEST_F(KeyBlockMergerTest, singleOrderByColStringWithPayLoadTest) { std::vector expectedBlockOffsetOrder = {0, 1, 0, 2, 3, 4, 1, 2, 3, 4}; std::vector expectedFactorizedTableIdxOrder = {0, 0, 1, 0, 0, 0, 1, 1, 1, 1}; singleOrderByColMergeTest(leftSortingData, leftNullMasks, rightSortingData, rightNullMasks, - expectedBlockOffsetOrder, expectedFactorizedTableIdxOrder, STRING, true /* isAsc */, - true /* hasPayLoadCol */); + expectedBlockOffsetOrder, expectedFactorizedTableIdxOrder, LogicalTypeID::STRING, + true /* isAsc */, true /* hasPayLoadCol */); } TEST_F(KeyBlockMergerTest, multiple0rderByColNoStrTest) { @@ -501,9 +515,11 @@ TEST_F(KeyBlockMergerTest, multipleStrKeyColsTest) { StrKeyColInfo(factorizedTables[0]->getTableSchema()->getColOffset(0 /* colIdx */), 0 /* colOffsetInEncodedKeyBlock */, true /* isAscOrder */), StrKeyColInfo(factorizedTables[0]->getTableSchema()->getColOffset(1 /* colIdx */), - orderByKeyEncoder1.getEncodingSize(DataType(STRING)), true /* isAscOrder */), + orderByKeyEncoder1.getEncodingSize(LogicalType(LogicalTypeID::STRING)), + true /* isAscOrder */), StrKeyColInfo(factorizedTables[0]->getTableSchema()->getColOffset(3 /* colIdx */), - orderByKeyEncoder1.getEncodingSize(DataType(STRING)) * 2, true /* isAscOrder */)}; + orderByKeyEncoder1.getEncodingSize(LogicalType(LogicalTypeID::STRING)) * 2, + true /* isAscOrder */)}; KeyBlockMerger keyBlockMerger = KeyBlockMerger(factorizedTables, strKeyColsInfo, orderByKeyEncoder1.getNumBytesPerTuple()); diff --git a/test/processor/order_by/order_by_key_encoder_test.cpp b/test/processor/order_by/order_by_key_encoder_test.cpp index 731a058ad9..44bd482d3b 100644 --- a/test/processor/order_by/order_by_key_encoder_test.cpp +++ b/test/processor/order_by/order_by_key_encoder_test.cpp @@ -44,8 +44,8 @@ class OrderByKeyEncoderTest : public Test { // 0xFF(null flag) + 0xFF...FF(padding) // if the col is in desc order, the encoding string is: // 0x00(null flag) + 0x00...00(padding) - inline void checkNullVal(uint8_t*& keyBlockPtr, DataTypeID dataTypeID, bool isAsc) { - for (auto i = 0u; i < OrderByKeyEncoder::getEncodingSize(DataType(dataTypeID)); i++) { + inline void checkNullVal(uint8_t*& keyBlockPtr, LogicalTypeID dataTypeID, bool isAsc) { + for (auto i = 0u; i < OrderByKeyEncoder::getEncodingSize(LogicalType(dataTypeID)); i++) { ASSERT_EQ(*(keyBlockPtr++), isAsc ? 0xFF : 0x00); } } @@ -58,7 +58,7 @@ class OrderByKeyEncoderTest : public Test { std::vector valueVectors; for (auto i = 0u; i < numOfOrderByCols; i++) { std::shared_ptr valueVector = - std::make_shared(INT64, memoryManager.get()); + std::make_shared(LogicalTypeID::INT64, memoryManager.get()); for (auto j = 0u; j < numOfElementsPerCol; j++) { valueVector->setValue(j, (int64_t)5); } @@ -138,7 +138,8 @@ class OrderByKeyEncoderTest : public Test { TEST_F(OrderByKeyEncoderTest, singleOrderByColInt64UnflatTest) { std::shared_ptr dataChunk = std::make_shared(1); dataChunk->state->selVector->selectedSize = 6; - auto int64ValueVector = std::make_shared(INT64, memoryManager.get()); + auto int64ValueVector = + std::make_shared(LogicalTypeID::INT64, memoryManager.get()); int64ValueVector->setValue(0, (int64_t)73); // positive number int64ValueVector->setNull(1, true); int64ValueVector->setValue(2, (int64_t)-132); // negative 1 byte number @@ -163,7 +164,7 @@ TEST_F(OrderByKeyEncoderTest, singleOrderByColInt64UnflatTest) { ASSERT_EQ(*(keyBlockPtr++), 0x49); checkTupleIdxAndFactorizedTableIdx(0, keyBlockPtr); - checkNullVal(keyBlockPtr, INT64, isAscOrder[0]); + checkNullVal(keyBlockPtr, LogicalTypeID::INT64, isAscOrder[0]); checkTupleIdxAndFactorizedTableIdx(1, keyBlockPtr); // Check encoding for: NULL FLAG(0x00) + -132=0x7FFFFFFFFFFFFF7C(big endian). @@ -205,7 +206,7 @@ TEST_F(OrderByKeyEncoderTest, singleOrderByColInt64UnflatWithFilterTest) { // valueVector. std::shared_ptr dataChunk = std::make_shared(1); std::shared_ptr int64ValueVector = - std::make_shared(INT64, memoryManager.get()); + std::make_shared(LogicalTypeID::INT64, memoryManager.get()); int64ValueVector->setValue(0, (int64_t)73); int64ValueVector->setValue(1, (int64_t)-52); int64ValueVector->setValue(2, (int64_t)-132); @@ -247,7 +248,7 @@ TEST_F(OrderByKeyEncoderTest, singleOrderByColBoolUnflatTest) { std::shared_ptr dataChunk = std::make_shared(1); dataChunk->state->selVector->selectedSize = 3; std::shared_ptr boolValueVector = - std::make_shared(BOOL, memoryManager.get()); + std::make_shared(LogicalTypeID::BOOL, memoryManager.get()); boolValueVector->setValue(0, true); boolValueVector->setValue(1, false); boolValueVector->setNull(2, true); @@ -270,7 +271,7 @@ TEST_F(OrderByKeyEncoderTest, singleOrderByColBoolUnflatTest) { ASSERT_EQ(*(keyBlockPtr++), 0xFF); checkTupleIdxAndFactorizedTableIdx(1, keyBlockPtr); - checkNullVal(keyBlockPtr, BOOL, isAscOrder[0]); + checkNullVal(keyBlockPtr, LogicalTypeID::BOOL, isAscOrder[0]); checkTupleIdxAndFactorizedTableIdx(2, keyBlockPtr); } @@ -278,7 +279,7 @@ TEST_F(OrderByKeyEncoderTest, singleOrderByColDateUnflatTest) { std::shared_ptr dataChunk = std::make_shared(1); dataChunk->state->selVector->selectedSize = 3; std::shared_ptr dateValueVector = - std::make_shared(DATE, memoryManager.get()); + std::make_shared(LogicalTypeID::DATE, memoryManager.get()); dateValueVector->setValue( 0, Date::FromCString("2035-07-04", strlen("2035-07-04"))); // date after 1970-01-01 dateValueVector->setNull(1, true); @@ -301,7 +302,7 @@ TEST_F(OrderByKeyEncoderTest, singleOrderByColDateUnflatTest) { ASSERT_EQ(*(keyBlockPtr++), 0x75); checkTupleIdxAndFactorizedTableIdx(0, keyBlockPtr); - checkNullVal(keyBlockPtr, DATE, isAscOrder[0]); + checkNullVal(keyBlockPtr, LogicalTypeID::DATE, isAscOrder[0]); checkTupleIdxAndFactorizedTableIdx(1, keyBlockPtr); // Check encoding for: NULL FLAG(0x00) + "1949-10-01"=0x7FFFE31B(-7397 days in big endian). @@ -317,7 +318,7 @@ TEST_F(OrderByKeyEncoderTest, singleOrderByColTimestampUnflatTest) { std::shared_ptr dataChunk = std::make_shared(1); dataChunk->state->selVector->selectedSize = 3; std::shared_ptr timestampValueVector = - std::make_shared(TIMESTAMP, memoryManager.get()); + std::make_shared(LogicalTypeID::TIMESTAMP, memoryManager.get()); // timestamp before 1970-01-01 timestampValueVector->setValue( 0, Timestamp::FromCString("1962-04-07 11:12:35.123", strlen("1962-04-07 11:12:35.123"))); @@ -347,7 +348,7 @@ TEST_F(OrderByKeyEncoderTest, singleOrderByColTimestampUnflatTest) { ASSERT_EQ(*(keyBlockPtr++), 0x38); checkTupleIdxAndFactorizedTableIdx(0, keyBlockPtr); - checkNullVal(keyBlockPtr, TIMESTAMP, isAscOrder[0]); + checkNullVal(keyBlockPtr, LogicalTypeID::TIMESTAMP, isAscOrder[0]); checkTupleIdxAndFactorizedTableIdx(1, keyBlockPtr); // Check encoding for: NULL FLAG(0x00) + "2035-07-01 11:14:33"=0x800757D5F429B840 @@ -368,7 +369,7 @@ TEST_F(OrderByKeyEncoderTest, singleOrderByColIntervalUnflatTest) { std::shared_ptr dataChunk = std::make_shared(1); dataChunk->state->selVector->selectedSize = 2; std::shared_ptr intervalValueVector = - std::make_shared(INTERVAL, memoryManager.get()); + std::make_shared(LogicalTypeID::INTERVAL, memoryManager.get()); intervalValueVector->setValue( 0, Interval::FromCString("18 hours 55 days 13 years 8 milliseconds 3 months", strlen("18 hours 55 days 13 years 8 milliseconds 3 months"))); @@ -407,7 +408,7 @@ TEST_F(OrderByKeyEncoderTest, singleOrderByColIntervalUnflatTest) { ASSERT_EQ(*(keyBlockPtr++), 0x40); checkTupleIdxAndFactorizedTableIdx(0, keyBlockPtr); - checkNullVal(keyBlockPtr, INTERVAL, isAscOrder[0]); + checkNullVal(keyBlockPtr, LogicalTypeID::INTERVAL, isAscOrder[0]); checkTupleIdxAndFactorizedTableIdx(1, keyBlockPtr); } @@ -415,7 +416,7 @@ TEST_F(OrderByKeyEncoderTest, singleOrderByColStringUnflatTest) { std::shared_ptr dataChunk = std::make_shared(1); dataChunk->state->selVector->selectedSize = 4; std::shared_ptr stringValueVector = - std::make_shared(STRING, memoryManager.get()); + std::make_shared(LogicalTypeID::STRING, memoryManager.get()); stringValueVector->setValue(0, "short str"); // short std::string stringValueVector->setNull(1, true); stringValueVector->setValue( @@ -448,7 +449,7 @@ TEST_F(OrderByKeyEncoderTest, singleOrderByColStringUnflatTest) { checkLongStrFlag(keyBlockPtr, isAscOrder[0], false /* isLongStr */); checkTupleIdxAndFactorizedTableIdx(0, keyBlockPtr); - checkNullVal(keyBlockPtr, STRING, isAscOrder[0]); + checkNullVal(keyBlockPtr, LogicalTypeID::STRING, isAscOrder[0]); checkTupleIdxAndFactorizedTableIdx(1, keyBlockPtr); // Check encoding for: NULL FLAG(0x00) + "commonprefix string1". @@ -490,7 +491,7 @@ TEST_F(OrderByKeyEncoderTest, singleOrderByColDoubleUnflatTest) { std::shared_ptr dataChunk = std::make_shared(1); dataChunk->state->selVector->selectedSize = 6; std::shared_ptr doubleValueVector = - std::make_shared(DOUBLE, memoryManager.get()); + std::make_shared(LogicalTypeID::DOUBLE, memoryManager.get()); doubleValueVector->setValue(0, (double_t)3.452); // small positive number doubleValueVector->setNull(1, true); doubleValueVector->setValue(2, (double_t)-0.00031213); // very small negative number @@ -518,7 +519,7 @@ TEST_F(OrderByKeyEncoderTest, singleOrderByColDoubleUnflatTest) { ASSERT_EQ(*(keyBlockPtr++), 0x04); checkTupleIdxAndFactorizedTableIdx(0, keyBlockPtr); - checkNullVal(keyBlockPtr, INT64, isAscOrder[0]); + checkNullVal(keyBlockPtr, LogicalTypeID::INT64, isAscOrder[0]); checkTupleIdxAndFactorizedTableIdx(1, keyBlockPtr); // Check encoding for: NULL FLAG(0x00) + -0.00031213=0x40CB8B53DB9F4D8D(big endian). @@ -621,11 +622,16 @@ TEST_F(OrderByKeyEncoderTest, singleOrderByColMultiBlockFlatTest) { TEST_F(OrderByKeyEncoderTest, multipleOrderByColSingleBlockTest) { std::vector isAscOrder = {true, false, true, true, true}; - auto intFlatValueVector = std::make_shared(INT64, memoryManager.get()); - auto doubleFlatValueVector = std::make_shared(DOUBLE, memoryManager.get()); - auto stringFlatValueVector = std::make_shared(STRING, memoryManager.get()); - auto timestampFlatValueVector = std::make_shared(TIMESTAMP, memoryManager.get()); - auto dateFlatValueVector = std::make_shared(DATE, memoryManager.get()); + auto intFlatValueVector = + std::make_shared(LogicalTypeID::INT64, memoryManager.get()); + auto doubleFlatValueVector = + std::make_shared(LogicalTypeID::DOUBLE, memoryManager.get()); + auto stringFlatValueVector = + std::make_shared(LogicalTypeID::STRING, memoryManager.get()); + auto timestampFlatValueVector = + std::make_shared(LogicalTypeID::TIMESTAMP, memoryManager.get()); + auto dateFlatValueVector = + std::make_shared(LogicalTypeID::DATE, memoryManager.get()); auto mockDataChunk = std::make_shared(5); mockDataChunk->insert(0, intFlatValueVector); @@ -696,7 +702,7 @@ TEST_F(OrderByKeyEncoderTest, multipleOrderByColSingleBlockTest) { ASSERT_EQ(*(keyBlockPtr++), 0x31); ASSERT_EQ(*(keyBlockPtr++), 0x26); - checkNullVal(keyBlockPtr, STRING, isAscOrder[2]); + checkNullVal(keyBlockPtr, LogicalTypeID::STRING, isAscOrder[2]); // Check encoding for: NULL FLAG(0x00) + "2008-08-08 20:20:20"=0x800453F888DCA900 // (1218226820000000 micros in big endian). @@ -785,7 +791,7 @@ TEST_F(OrderByKeyEncoderTest, multipleOrderByColSingleBlockTest) { ASSERT_EQ(*(keyBlockPtr++), 0xB5); ASSERT_EQ(*(keyBlockPtr++), 0x02); - checkNullVal(keyBlockPtr, DOUBLE, isAscOrder[1]); + checkNullVal(keyBlockPtr, LogicalTypeID::DOUBLE, isAscOrder[1]); // Check encoding for: "short str". checkNonNullFlag(keyBlockPtr, isAscOrder[2]); @@ -803,9 +809,9 @@ TEST_F(OrderByKeyEncoderTest, multipleOrderByColSingleBlockTest) { ASSERT_EQ(*(keyBlockPtr++), '\0'); checkLongStrFlag(keyBlockPtr, isAscOrder[2], false /* isLongStr */); - checkNullVal(keyBlockPtr, TIMESTAMP, isAscOrder[3]); + checkNullVal(keyBlockPtr, LogicalTypeID::TIMESTAMP, isAscOrder[3]); - checkNullVal(keyBlockPtr, DATE, isAscOrder[4]); + checkNullVal(keyBlockPtr, LogicalTypeID::DATE, isAscOrder[4]); checkTupleIdxAndFactorizedTableIdx(2, keyBlockPtr); } diff --git a/test/processor/order_by/radix_sort_test.cpp b/test/processor/order_by/radix_sort_test.cpp index 391adbccaa..51005e6777 100644 --- a/test/processor/order_by/radix_sort_test.cpp +++ b/test/processor/order_by/radix_sort_test.cpp @@ -64,7 +64,7 @@ class RadixSortTest : public Test { template void singleOrderByColTest(const std::vector& sortingData, const std::vector& nullMasks, - const std::vector& expectedBlockOffsetOrder, const DataTypeID dataTypeID, + const std::vector& expectedBlockOffsetOrder, const LogicalTypeID dataTypeID, const bool isAsc, bool hasPayLoadCol) { KU_ASSERT(sortingData.size() == nullMasks.size()); KU_ASSERT(sortingData.size() == expectedBlockOffsetOrder.size()); @@ -87,13 +87,14 @@ class RadixSortTest : public Test { std::unique_ptr tableSchema = std::make_unique(); - tableSchema->appendColumn(std::make_unique( - false /* isUnflat */, 0 /* dataChunkPos */, Types::getDataTypeSize(dataTypeID))); + tableSchema->appendColumn(std::make_unique(false /* isUnflat */, + 0 /* dataChunkPos */, FactorizedTable::getDataTypeSize(LogicalType{dataTypeID}))); std::vector strKeyColsInfo; if (hasPayLoadCol) { // Create a new payloadValueVector for the payload column. - auto payloadValueVector = std::make_shared(STRING, memoryManager.get()); + auto payloadValueVector = + std::make_shared(LogicalTypeID::STRING, memoryManager.get()); for (auto i = 0u; i < dataChunk->state->selVector->selectedSize; i++) { payloadValueVector->setValue(i, std::to_string(i)); } @@ -101,8 +102,8 @@ class RadixSortTest : public Test { // To test whether the orderByCol -> ftIdx works properly, we put the // payload column at index 0, and the orderByCol at index 1. allVectors.insert(allVectors.begin(), payloadValueVector.get()); - tableSchema->appendColumn(std::make_unique( - false /* isUnflat */, 0 /* dataChunkPos */, Types::getDataTypeSize(dataTypeID))); + tableSchema->appendColumn(std::make_unique(false /* isUnflat */, + 0 /* dataChunkPos */, FactorizedTable::getDataTypeSize(LogicalType{dataTypeID}))); strKeyColsInfo.emplace_back( StrKeyColInfo(tableSchema->getColOffset(1) /* colOffsetInFT */, 0 /* colOffsetInEncodedKeyBlock */, isAsc)); @@ -140,7 +141,8 @@ class RadixSortTest : public Test { std::make_unique(); std::vector strKeyColsInfo; for (auto i = 0; i < stringValues.size(); i++) { - auto stringValueVector = std::make_shared(STRING, memoryManager.get()); + auto stringValueVector = + std::make_shared(LogicalTypeID::STRING, memoryManager.get()); tableSchema->appendColumn(std::make_unique( false /* isUnflat */, 0 /* dataChunkPos */, sizeof(ku_string_t))); strKeyColsInfo.push_back(StrKeyColInfo(tableSchema->getColOffset(strKeyColsInfo.size()), @@ -184,7 +186,7 @@ TEST_F(RadixSortTest, singleOrderByColInt64Test) { INT64_MIN, 210042 /* positive 2 bytes number */}; std::vector nullMasks = {false, true, false, false, false, false, false}; std::vector expectedFTBlockOffsetOrder = {5, 3, 2, 0, 6, 4, 1}; - singleOrderByColTest(sortingData, nullMasks, expectedFTBlockOffsetOrder, INT64, + singleOrderByColTest(sortingData, nullMasks, expectedFTBlockOffsetOrder, LogicalTypeID::INT64, true /* isAsc */, false /* hasPayLoadCol */); } @@ -194,7 +196,7 @@ TEST_F(RadixSortTest, singleOrderByColNoNullInt64Test) { -819321 /* negative 2 bytes number */, INT64_MAX, INT64_MIN}; std::vector nullMasks(6, false); std::vector expectedFTBlockOffsetOrder = {4, 1, 0, 2, 3, 5}; - singleOrderByColTest(sortingData, nullMasks, expectedFTBlockOffsetOrder, INT64, + singleOrderByColTest(sortingData, nullMasks, expectedFTBlockOffsetOrder, LogicalTypeID::INT64, false /* isAsc */, false /* hasPayLoadCol */); } @@ -208,7 +210,7 @@ TEST_F(RadixSortTest, singleOrderByColLargeInputInt64Test) { std::vector expectedFTBlockOffsetOrder(240); iota(expectedFTBlockOffsetOrder.begin(), expectedFTBlockOffsetOrder.end(), 0); reverse(expectedFTBlockOffsetOrder.begin(), expectedFTBlockOffsetOrder.end()); - singleOrderByColTest(sortingData, nullMasks, expectedFTBlockOffsetOrder, INT64, + singleOrderByColTest(sortingData, nullMasks, expectedFTBlockOffsetOrder, LogicalTypeID::INT64, true /* isAsc */, false /* hasPayLoadCol */); } @@ -216,7 +218,7 @@ TEST_F(RadixSortTest, singleOrderByColBoolTest) { std::vector sortingData = {true, false, false /* NULL */}; std::vector nullMasks = {false, false, true}; std::vector expectedFTBlockOffsetOrder = {2, 0, 1}; - singleOrderByColTest(sortingData, nullMasks, expectedFTBlockOffsetOrder, BOOL, + singleOrderByColTest(sortingData, nullMasks, expectedFTBlockOffsetOrder, LogicalTypeID::BOOL, false /* isAsc */, false /* hasPayLoadCol */); } @@ -229,8 +231,8 @@ TEST_F(RadixSortTest, singleOrderByColDateTest) { date_t(0) /*NULL*/}; std::vector nullMasks = {false, false, false, false, true}; std::vector expectedFTBlockOffsetOrder = {3, 0, 1, 2, 4}; - singleOrderByColTest(sortingData, nullMasks, expectedFTBlockOffsetOrder, DATE, true /* isAsc */, - false /* hasPayLoadCol */); + singleOrderByColTest(sortingData, nullMasks, expectedFTBlockOffsetOrder, LogicalTypeID::DATE, + true /* isAsc */, false /* hasPayLoadCol */); } TEST_F(RadixSortTest, singleOrderByColTimestampTest) { @@ -247,8 +249,8 @@ TEST_F(RadixSortTest, singleOrderByColTimestampTest) { std::vector nullMasks = {false, false, true, false, false}; std::vector expectedFTBlockOffsetOrder = {2, 3, 1, 0, 4}; - singleOrderByColTest(sortingData, nullMasks, expectedFTBlockOffsetOrder, TIMESTAMP, - false /* isAsc */, false /* hasPayLoadCol */); + singleOrderByColTest(sortingData, nullMasks, expectedFTBlockOffsetOrder, + LogicalTypeID::TIMESTAMP, false /* isAsc */, false /* hasPayLoadCol */); } TEST_F(RadixSortTest, singleOrderByColIntervalTest) { @@ -267,8 +269,8 @@ TEST_F(RadixSortTest, singleOrderByColIntervalTest) { std::vector nullMasks = {true, false, false, false}; std::vector expectedFTBlockOffsetOrder = {0, 3, 2, 1}; - singleOrderByColTest(sortingData, nullMasks, expectedFTBlockOffsetOrder, INTERVAL, - false /* isAsc */, false /* hasPayLoadCol */); + singleOrderByColTest(sortingData, nullMasks, expectedFTBlockOffsetOrder, + LogicalTypeID::INTERVAL, false /* isAsc */, false /* hasPayLoadCol */); } TEST_F(RadixSortTest, singleOrderByColDoubleTest) { @@ -277,7 +279,7 @@ TEST_F(RadixSortTest, singleOrderByColDoubleTest) { -76123 /* large negative number */, 0, 0 /* NULL */}; std::vector nullMasks = {false, false, false, false, false, true}; std::vector expectedFTBlockOffsetOrder = {5, 2, 0, 4, 1, 3}; - singleOrderByColTest(sortingData, nullMasks, expectedFTBlockOffsetOrder, DOUBLE, + singleOrderByColTest(sortingData, nullMasks, expectedFTBlockOffsetOrder, LogicalTypeID::DOUBLE, false /* isAsc */, false /* hasPayLoadCol */); } @@ -289,7 +291,7 @@ TEST_F(RadixSortTest, singleOrderByColStringTest) { "common prefix rank2", "another common prefix1", "another short string", "" /*NULL*/}; std::vector nullMasks = {false, false, false, false, false, false, false, false, true}; std::vector expectedFTBlockOffsetOrder = {0, 6, 2, 7, 3, 5, 4, 1, 8}; - singleOrderByColTest(sortingData, nullMasks, expectedFTBlockOffsetOrder, STRING, + singleOrderByColTest(sortingData, nullMasks, expectedFTBlockOffsetOrder, LogicalTypeID::STRING, true /* isAsc */, false /* hasPayLoadCol */); } @@ -301,7 +303,7 @@ TEST_F(RadixSortTest, singleOrderByColNoNullStringTest) { "common prefix rank2", "other common prefix test3", "another short string"}; std::vector nullMasks(8, false); std::vector expectedFTBlockOffsetOrder = {0, 6, 1, 4, 5, 3, 7, 2}; - singleOrderByColTest(sortingData, nullMasks, expectedFTBlockOffsetOrder, STRING, + singleOrderByColTest(sortingData, nullMasks, expectedFTBlockOffsetOrder, LogicalTypeID::STRING, false /* isAsc */, false /* hasPayLoadCol */); } @@ -312,7 +314,7 @@ TEST_F(RadixSortTest, singleOrderByColAllTiesStringTest) { std::vector sortingData(20, "same string for all tuples"); std::vector nullMasks(20, false); std::vector expectedFTBlockOffsetOrder(20, -1); - singleOrderByColTest(sortingData, nullMasks, expectedFTBlockOffsetOrder, STRING, + singleOrderByColTest(sortingData, nullMasks, expectedFTBlockOffsetOrder, LogicalTypeID::STRING, true /* isAsc */, false /* hasPayLoadCol */); } @@ -324,17 +326,22 @@ TEST_F(RadixSortTest, singleOrderByColWithPayloadTest) { "string column with payload col long long", "very long long long string"}; std::vector nullMasks(5, false); std::vector expectedFTBlockOffsetOrder = {2, 3, 1, 0, 4}; - singleOrderByColTest(sortingData, nullMasks, expectedFTBlockOffsetOrder, STRING, + singleOrderByColTest(sortingData, nullMasks, expectedFTBlockOffsetOrder, LogicalTypeID::STRING, true /* isAsc */, true /* hasPayLoadCol */); } TEST_F(RadixSortTest, multipleOrderByColNoTieTest) { std::vector isAscOrder = {true, false, true, false, false}; - auto intFlatValueVector = std::make_shared(INT64, memoryManager.get()); - auto doubleFlatValueVector = std::make_shared(DOUBLE, memoryManager.get()); - auto stringFlatValueVector = std::make_shared(STRING, memoryManager.get()); - auto timestampFlatValueVector = std::make_shared(TIMESTAMP, memoryManager.get()); - auto dateFlatValueVector = std::make_shared(DATE, memoryManager.get()); + auto intFlatValueVector = + std::make_shared(LogicalTypeID::INT64, memoryManager.get()); + auto doubleFlatValueVector = + std::make_shared(LogicalTypeID::DOUBLE, memoryManager.get()); + auto stringFlatValueVector = + std::make_shared(LogicalTypeID::STRING, memoryManager.get()); + auto timestampFlatValueVector = + std::make_shared(LogicalTypeID::TIMESTAMP, memoryManager.get()); + auto dateFlatValueVector = + std::make_shared(LogicalTypeID::DATE, memoryManager.get()); auto mockDataChunk = std::make_shared(5); mockDataChunk->insert(0, intFlatValueVector); @@ -382,21 +389,24 @@ TEST_F(RadixSortTest, multipleOrderByColNoTieTest) { dateFlatValueVector->setValue(4, Date::FromCString("2000-11-13", strlen("2000-11-13"))); std::unique_ptr tableSchema = std::make_unique(); - tableSchema->appendColumn(std::make_unique( - false /* isUnflat */, 0 /* dataChunkPos */, Types::getDataTypeSize(INT64))); - tableSchema->appendColumn(std::make_unique( - false /* isUnflat */, 0 /* dataChunkPos */, Types::getDataTypeSize(DOUBLE))); - tableSchema->appendColumn(std::make_unique( - false /* isUnflat */, 0 /* dataChunkPos */, Types::getDataTypeSize(STRING))); - tableSchema->appendColumn(std::make_unique( - false /* isUnflat */, 0 /* dataChunkPos */, Types::getDataTypeSize(TIMESTAMP))); - tableSchema->appendColumn(std::make_unique( - false /* isUnflat */, 0 /* dataChunkPos */, Types::getDataTypeSize(DATE))); + tableSchema->appendColumn(std::make_unique(false /* isUnflat */, + 0 /* dataChunkPos */, FactorizedTable::getDataTypeSize(LogicalType{LogicalTypeID::INT64}))); + tableSchema->appendColumn( + std::make_unique(false /* isUnflat */, 0 /* dataChunkPos */, + FactorizedTable::getDataTypeSize(LogicalType{LogicalTypeID::DOUBLE}))); + tableSchema->appendColumn( + std::make_unique(false /* isUnflat */, 0 /* dataChunkPos */, + FactorizedTable::getDataTypeSize(LogicalType{LogicalTypeID::STRING}))); + tableSchema->appendColumn( + std::make_unique(false /* isUnflat */, 0 /* dataChunkPos */, + FactorizedTable::getDataTypeSize(LogicalType{LogicalTypeID::TIMESTAMP}))); + tableSchema->appendColumn(std::make_unique(false /* isUnflat */, + 0 /* dataChunkPos */, FactorizedTable::getDataTypeSize(LogicalType{LogicalTypeID::DATE}))); FactorizedTable factorizedTable(memoryManager.get(), std::move(tableSchema)); std::vector strKeyColsInfo = {StrKeyColInfo(16 /* colOffsetInFT */, - OrderByKeyEncoder::getEncodingSize(DataType(INT64)) + - OrderByKeyEncoder::getEncodingSize(DataType(DOUBLE)), + OrderByKeyEncoder::getEncodingSize(LogicalType(LogicalTypeID::INT64)) + + OrderByKeyEncoder::getEncodingSize(LogicalType(LogicalTypeID::DOUBLE)), true /* isAscOrder */)}; auto orderByKeyEncoder = diff --git a/test/runner/e2e_ddl_test.cpp b/test/runner/e2e_ddl_test.cpp index 1d5bbb97c2..288ba86cf1 100644 --- a/test/runner/e2e_ddl_test.cpp +++ b/test/runner/e2e_ddl_test.cpp @@ -277,7 +277,7 @@ class TinySnbDDLTest : public DBTest { catalog->getReadOnlyVersion()->getNodeProperty(personTableID, "gender"); auto propertyFileName = StorageUtils::getNodePropertyColumnFName( databasePath, personTableID, propertyToDrop.propertyID, DBFileType::ORIGINAL); - bool hasOverflowFile = containsOverflowFile(propertyToDrop.dataType.typeID); + bool hasOverflowFile = containsOverflowFile(propertyToDrop.dataType.getLogicalTypeID()); executeQueryWithoutCommit("ALTER TABLE person DROP gender"); validateColumnFilesExistence(propertyFileName, true /* existence */, hasOverflowFile); ASSERT_TRUE(catalog->getReadOnlyVersion() @@ -314,7 +314,7 @@ class TinySnbDDLTest : public DBTest { studyAtTableID, RelDataDirection::FWD, propertyToDrop.propertyID, DBFileType::ORIGINAL); auto propertyBWDListFileName = StorageUtils::getRelPropertyListsFName(databasePath, studyAtTableID, RelDataDirection::BWD, propertyToDrop.propertyID, DBFileType::ORIGINAL); - bool hasOverflowFile = containsOverflowFile(propertyToDrop.dataType.typeID); + bool hasOverflowFile = containsOverflowFile(propertyToDrop.dataType.getLogicalTypeID()); executeQueryWithoutCommit("ALTER TABLE studyAt DROP places"); validateColumnFilesExistence( propertyFWDColumnFileName, true /* existence */, hasOverflowFile); @@ -389,7 +389,7 @@ class TinySnbDDLTest : public DBTest { auto tableSchema = catalog->getWriteVersion()->getTableSchema(personTableID); auto propertyID = tableSchema->getPropertyID("random"); auto hasOverflow = - containsOverflowFile(tableSchema->getProperty(propertyID).dataType.typeID); + containsOverflowFile(tableSchema->getProperty(propertyID).dataType.getLogicalTypeID()); auto columnOriginalVersionFileName = StorageUtils::getNodePropertyColumnFName( databasePath, personTableID, propertyID, DBFileType::ORIGINAL); auto columnWALVersionFileName = StorageUtils::getNodePropertyColumnFName( @@ -420,7 +420,7 @@ class TinySnbDDLTest : public DBTest { auto tableSchema = catalog->getWriteVersion()->getTableSchema(personTableID); auto propertyID = tableSchema->getPropertyID("random"); auto hasOverflow = - containsOverflowFile(tableSchema->getProperty(propertyID).dataType.typeID); + containsOverflowFile(tableSchema->getProperty(propertyID).dataType.getLogicalTypeID()); auto columnOriginalVersionFileName = StorageUtils::getNodePropertyColumnFName( databasePath, personTableID, propertyID, DBFileType::ORIGINAL); auto columnWALVersionFileName = StorageUtils::getNodePropertyColumnFName( @@ -453,7 +453,7 @@ class TinySnbDDLTest : public DBTest { auto tableSchema = catalog->getWriteVersion()->getTableSchema(studyAtTableID); auto propertyID = tableSchema->getPropertyID("random"); auto hasOverflow = - containsOverflowFile(tableSchema->getProperty(propertyID).dataType.typeID); + containsOverflowFile(tableSchema->getProperty(propertyID).dataType.getLogicalTypeID()); auto fwdColumnOriginalVersionFileName = StorageUtils::getRelPropertyColumnFName( databasePath, studyAtTableID, RelDataDirection::FWD, propertyID, DBFileType::ORIGINAL); auto fwdColumnWALVersionFileName = StorageUtils::getRelPropertyColumnFName(databasePath, @@ -493,8 +493,8 @@ class TinySnbDDLTest : public DBTest { "ALTER TABLE studyAt ADD random {} DEFAULT {}", propertyType, defaultVal)); auto relTableSchema = catalog->getWriteVersion()->getTableSchema(studyAtTableID); auto propertyID = relTableSchema->getPropertyID("random"); - auto hasOverflow = - containsOverflowFile(relTableSchema->getProperty(propertyID).dataType.typeID); + auto hasOverflow = containsOverflowFile( + relTableSchema->getProperty(propertyID).dataType.getLogicalTypeID()); auto fwdColumnOriginalVersionFileName = StorageUtils::getRelPropertyColumnFName( databasePath, studyAtTableID, RelDataDirection::FWD, propertyID, DBFileType::ORIGINAL); auto fwdColumnWALVersionFileName = StorageUtils::getRelPropertyColumnFName(databasePath, diff --git a/test/storage/node_insertion_deletion_test.cpp b/test/storage/node_insertion_deletion_test.cpp index f49f1a2d04..ee64b7560c 100644 --- a/test/storage/node_insertion_deletion_test.cpp +++ b/test/storage/node_insertion_deletion_test.cpp @@ -47,9 +47,11 @@ class NodeInsertionDeletionTests : public DBTest { auto dataChunk = std::make_shared(2); // Flatten the data chunk dataChunk->state->currIdx = 0; - auto nodeIDVector = std::make_shared(INTERNAL_ID, getMemoryManager(*database)); + auto nodeIDVector = + std::make_shared(LogicalTypeID::INTERNAL_ID, getMemoryManager(*database)); dataChunk->insert(0, nodeIDVector); - auto idVector = std::make_shared(INT64, getMemoryManager(*database)); + auto idVector = + std::make_shared(LogicalTypeID::INT64, getMemoryManager(*database)); dataChunk->insert(1, idVector); ((nodeID_t*)nodeIDVector->getData())[0].offset = nodeOffset; idVector->setNull(0, true /* is null */); diff --git a/test/transaction/transaction_test.cpp b/test/transaction/transaction_test.cpp index 2cd8f911d6..91836054e4 100644 --- a/test/transaction/transaction_test.cpp +++ b/test/transaction/transaction_test.cpp @@ -40,16 +40,17 @@ class TransactionTests : public DBTest { .propertyID; dataChunk = std::make_shared(3); - nodeVector = std::make_shared(INTERNAL_ID, getMemoryManager(*database)); + nodeVector = + std::make_shared(LogicalTypeID::INTERNAL_ID, getMemoryManager(*database)); dataChunk->insert(0, nodeVector); ((nodeID_t*)nodeVector->getData())[0].offset = 0; ((nodeID_t*)nodeVector->getData())[1].offset = 1; agePropertyVectorToReadDataInto = - std::make_shared(INT64, getMemoryManager(*database)); + std::make_shared(LogicalTypeID::INT64, getMemoryManager(*database)); dataChunk->insert(1, agePropertyVectorToReadDataInto); eyeSightVectorToReadDataInto = - std::make_shared(DOUBLE, getMemoryManager(*database)); + std::make_shared(LogicalTypeID::DOUBLE, getMemoryManager(*database)); dataChunk->insert(2, eyeSightVectorToReadDataInto); personAgeColumn = getStorageManager(*database)->getNodesStore().getNodePropertyColumn( @@ -94,7 +95,7 @@ class TransactionTests : public DBTest { dataChunk->state->selVector->resetSelectorToValuePosBuffer(); dataChunk->state->selVector->selectedPositions[0] = nodeOffset; auto propertyVectorToWriteDataTo = - std::make_shared(INT64, getMemoryManager(*database)); + std::make_shared(LogicalTypeID::INT64, getMemoryManager(*database)); propertyVectorToWriteDataTo->state = dataChunk->state; if (isNull) { propertyVectorToWriteDataTo->setNull(dataChunk->state->currIdx, true /* is null */); @@ -112,7 +113,7 @@ class TransactionTests : public DBTest { dataChunk->state->selVector->resetSelectorToValuePosBuffer(); dataChunk->state->selVector->selectedPositions[0] = nodeOffset; auto propertyVectorToWriteDataTo = - std::make_shared(DOUBLE, getMemoryManager(*database)); + std::make_shared(LogicalTypeID::DOUBLE, getMemoryManager(*database)); propertyVectorToWriteDataTo->state = dataChunk->state; if (isNull) { propertyVectorToWriteDataTo->setNull(dataChunk->state->currIdx, true /* is null */); diff --git a/tools/python_api/src_cpp/include/py_query_result_converter.h b/tools/python_api/src_cpp/include/py_query_result_converter.h index 2cf60ff01a..8096968366 100644 --- a/tools/python_api/src_cpp/include/py_query_result_converter.h +++ b/tools/python_api/src_cpp/include/py_query_result_converter.h @@ -6,18 +6,18 @@ struct NPArrayWrapper { public: - NPArrayWrapper(const kuzu::common::DataType& type, uint64_t numFlatTuple); + NPArrayWrapper(const kuzu::common::LogicalType& type, uint64_t numFlatTuple); void appendElement(kuzu::common::Value* value); private: - py::dtype convertToArrayType(const kuzu::common::DataType& type); + py::dtype convertToArrayType(const kuzu::common::LogicalType& type); public: py::array data; uint8_t* dataBuffer; py::array mask; - kuzu::common::DataType type; + kuzu::common::LogicalType type; uint64_t numElements; }; diff --git a/tools/python_api/src_cpp/py_query_result.cpp b/tools/python_api/src_cpp/py_query_result.cpp index 24756d84dd..37f5368d33 100644 --- a/tools/python_api/src_cpp/py_query_result.cpp +++ b/tools/python_api/src_cpp/py_query_result.cpp @@ -67,35 +67,35 @@ py::object PyQueryResult::convertValueToPyObject(const Value& value) { return py::none(); } auto dataType = value.getDataType(); - switch (dataType.typeID) { - case BOOL: { + switch (dataType.getLogicalTypeID()) { + case LogicalTypeID::BOOL: { return py::cast(value.getValue()); } - case INT16: { + case LogicalTypeID::INT16: { return py::cast(value.getValue()); } - case INT32: { + case LogicalTypeID::INT32: { return py::cast(value.getValue()); } - case INT64: { + case LogicalTypeID::INT64: { return py::cast(value.getValue()); } - case FLOAT: { + case LogicalTypeID::FLOAT: { return py::cast(value.getValue()); } - case DOUBLE: { + case LogicalTypeID::DOUBLE: { return py::cast(value.getValue()); } - case STRING: { + case LogicalTypeID::STRING: { return py::cast(value.getValue()); } - case DATE: { + case LogicalTypeID::DATE: { auto dateVal = value.getValue(); int32_t year, month, day; Date::Convert(dateVal, year, month, day); return py::cast(PyDate_FromDate(year, month, day)); } - case TIMESTAMP: { + case LogicalTypeID::TIMESTAMP: { auto timestampVal = value.getValue(); int32_t year, month, day, hour, min, sec, micros; date_t date; @@ -106,15 +106,15 @@ py::object PyQueryResult::convertValueToPyObject(const Value& value) { return py::cast( PyDateTime_FromDateAndTime(year, month, day, hour, min, sec, micros)); } - case INTERVAL: { + case LogicalTypeID::INTERVAL: { auto intervalVal = value.getValue(); auto days = Interval::DAYS_PER_MONTH * intervalVal.months + intervalVal.days; return py::cast(py::module::import("datetime") .attr("timedelta")(py::arg("days") = days, py::arg("microseconds") = intervalVal.micros)); } - case VAR_LIST: - case FIXED_LIST: { + case LogicalTypeID::VAR_LIST: + case LogicalTypeID::FIXED_LIST: { auto& listVal = value.getListValReference(); py::list list; for (auto i = 0u; i < listVal.size(); ++i) { @@ -122,9 +122,8 @@ py::object PyQueryResult::convertValueToPyObject(const Value& value) { } return std::move(list); } - case STRUCT: { - auto structTypeInfo = reinterpret_cast(dataType.getExtraTypeInfo()); - auto childrenNames = structTypeInfo->getChildrenNames(); + case LogicalTypeID::STRUCT: { + auto childrenNames = StructType::getStructFieldNames(&dataType); py::dict dict; auto& structVals = value.getListValReference(); for (auto i = 0u; i < structVals.size(); ++i) { @@ -134,25 +133,25 @@ py::object PyQueryResult::convertValueToPyObject(const Value& value) { } return dict; } - case NODE: { + case LogicalTypeID::NODE: { auto nodeVal = value.getValue(); auto dict = PyQueryResult::getPyDictFromProperties(nodeVal.getProperties()); dict["_label"] = py::cast(nodeVal.getLabelName()); dict["_id"] = convertNodeIdToPyDict(nodeVal.getNodeID()); return std::move(dict); } - case REL: { + case LogicalTypeID::REL: { auto relVal = value.getValue(); auto dict = PyQueryResult::getPyDictFromProperties(relVal.getProperties()); dict["_src"] = convertNodeIdToPyDict(relVal.getSrcNodeID()); dict["_dst"] = convertNodeIdToPyDict(relVal.getDstNodeID()); return std::move(dict); } - case INTERNAL_ID: { + case LogicalTypeID::INTERNAL_ID: { return convertNodeIdToPyDict(value.getValue()); } default: - throw NotImplementedException("Unsupported type: " + Types::dataTypeToString(dataType)); + throw NotImplementedException("Unsupported type: " + LogicalTypeUtils::dataTypeToString(dataType)); } } @@ -199,7 +198,7 @@ py::list PyQueryResult::getColumnDataTypes() { auto columnDataTypes = queryResult->getColumnDataTypes(); py::tuple result(columnDataTypes.size()); for (auto i = 0u; i < columnDataTypes.size(); ++i) { - result[i] = py::cast(Types::dataTypeToString(columnDataTypes[i])); + result[i] = py::cast(LogicalTypeUtils::dataTypeToString(columnDataTypes[i])); } return std::move(result); } diff --git a/tools/python_api/src_cpp/py_query_result_converter.cpp b/tools/python_api/src_cpp/py_query_result_converter.cpp index b243326a08..e33684c0a3 100644 --- a/tools/python_api/src_cpp/py_query_result_converter.cpp +++ b/tools/python_api/src_cpp/py_query_result_converter.cpp @@ -2,11 +2,10 @@ #include "common/types/value.h" #include "include/py_query_result.h" -#include "processor/result/flat_tuple.h" using namespace kuzu::common; -NPArrayWrapper::NPArrayWrapper(const DataType& type, uint64_t numFlatTuple) +NPArrayWrapper::NPArrayWrapper(const LogicalType& type, uint64_t numFlatTuple) : type{type}, numElements{0} { data = py::array(convertToArrayType(type), numFlatTuple); dataBuffer = (uint8_t*)data.mutable_data(); @@ -16,35 +15,35 @@ NPArrayWrapper::NPArrayWrapper(const DataType& type, uint64_t numFlatTuple) void NPArrayWrapper::appendElement(Value* value) { ((uint8_t*)mask.mutable_data())[numElements] = value->isNull(); if (!value->isNull()) { - switch (type.typeID) { - case BOOL: { + switch (type.getLogicalTypeID()) { + case LogicalTypeID::BOOL: { ((uint8_t*)dataBuffer)[numElements] = value->getValue(); break; } - case INT64: { + case LogicalTypeID::INT64: { ((int64_t*)dataBuffer)[numElements] = value->getValue(); break; } - case DOUBLE: { + case LogicalTypeID::DOUBLE: { ((double_t*)dataBuffer)[numElements] = value->getValue(); break; } - case DATE: { + case LogicalTypeID::DATE: { ((int64_t*)dataBuffer)[numElements] = Date::getEpochNanoSeconds(value->getValue()); break; } - case TIMESTAMP: { + case LogicalTypeID::TIMESTAMP: { ((int64_t*)dataBuffer)[numElements] = Timestamp::getEpochNanoSeconds(value->getValue()); break; } - case INTERVAL: { + case LogicalTypeID::INTERVAL: { ((int64_t*)dataBuffer)[numElements] = Interval::getNanoseconds(value->getValue()); break; } - case STRING: { + case LogicalTypeID::STRING: { auto val = value->getValue(); auto result = PyUnicode_New(val.length(), 127); auto target_data = PyUnicode_DATA(result); @@ -52,12 +51,12 @@ void NPArrayWrapper::appendElement(Value* value) { ((PyObject**)dataBuffer)[numElements] = result; break; } - case NODE: - case REL: { + case LogicalTypeID::NODE: + case LogicalTypeID::REL: { ((py::dict*)dataBuffer)[numElements] = PyQueryResult::convertValueToPyObject(*value); break; } - case VAR_LIST: { + case LogicalTypeID::VAR_LIST: { ((py::list*)dataBuffer)[numElements] = PyQueryResult::convertValueToPyObject(*value); break; } @@ -69,34 +68,34 @@ void NPArrayWrapper::appendElement(Value* value) { numElements++; } -py::dtype NPArrayWrapper::convertToArrayType(const DataType& type) { +py::dtype NPArrayWrapper::convertToArrayType(const LogicalType& type) { std::string dtype; - switch (type.typeID) { - case INT64: { + switch (type.getLogicalTypeID()) { + case LogicalTypeID::INT64: { dtype = "int64"; break; } - case DOUBLE: { + case LogicalTypeID::DOUBLE: { dtype = "float64"; break; } - case BOOL: { + case LogicalTypeID::BOOL: { dtype = "bool"; break; } - case NODE: - case REL: - case VAR_LIST: - case STRING: { + case LogicalTypeID::NODE: + case LogicalTypeID::REL: + case LogicalTypeID::VAR_LIST: + case LogicalTypeID::STRING: { dtype = "object"; break; } - case DATE: - case TIMESTAMP: { + case LogicalTypeID::DATE: + case LogicalTypeID::TIMESTAMP: { dtype = "datetime64[ns]"; break; } - case INTERVAL: { + case LogicalTypeID::INTERVAL: { dtype = "timedelta64[ns]"; break; }