Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor list_range and list_sort functions #3393

Merged
merged 3 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 108 additions & 9 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,101 @@ using kuzu::function::BuiltInFunctionsUtils;
namespace kuzu {
namespace common {

LogicalType* ListType::getChildType(const kuzu::common::LogicalType* type) {
KU_ASSERT(type->getPhysicalType() == PhysicalTypeID::LIST ||
type->getPhysicalType() == PhysicalTypeID::ARRAY);
auto listTypeInfo = type->extraTypeInfo->constPtrCast<ListTypeInfo>();
return listTypeInfo->getChildType();
}

LogicalType* ArrayType::getChildType(const LogicalType* type) {
acquamarin marked this conversation as resolved.
Show resolved Hide resolved
KU_ASSERT(type->getPhysicalType() == PhysicalTypeID::ARRAY);
auto arrayTypeInfo = type->extraTypeInfo->constPtrCast<ArrayTypeInfo>();
return arrayTypeInfo->getChildType();
}

uint64_t ArrayType::getNumElements(const LogicalType* type) {
KU_ASSERT(type->getPhysicalType() == PhysicalTypeID::ARRAY);
auto arrayTypeInfo = type->extraTypeInfo->constPtrCast<ArrayTypeInfo>();
return arrayTypeInfo->getNumElements();
}

std::vector<LogicalType*> StructType::getFieldTypes(const LogicalType* type) {
KU_ASSERT(type->getPhysicalType() == PhysicalTypeID::STRUCT);
auto structTypeInfo = type->extraTypeInfo->constPtrCast<StructTypeInfo>();
return structTypeInfo->getChildrenTypes();
}

std::vector<std::string> StructType::getFieldNames(const LogicalType* type) {
KU_ASSERT(type->getPhysicalType() == PhysicalTypeID::STRUCT);
auto structTypeInfo = type->extraTypeInfo->constPtrCast<StructTypeInfo>();
return structTypeInfo->getChildrenNames();
}

uint64_t StructType::getNumFields(const LogicalType* type) {
KU_ASSERT(type->getPhysicalType() == PhysicalTypeID::STRUCT);
return getFieldTypes(type).size();
}

std::vector<const StructField*> StructType::getFields(const LogicalType* type) {
KU_ASSERT(type->getPhysicalType() == PhysicalTypeID::STRUCT);
auto structTypeInfo = type->extraTypeInfo->constPtrCast<StructTypeInfo>();
return structTypeInfo->getStructFields();
}

bool StructType::hasField(const LogicalType* type, const std::string& key) {
KU_ASSERT(type->getPhysicalType() == PhysicalTypeID::STRUCT);
auto structTypeInfo = type->extraTypeInfo->constPtrCast<StructTypeInfo>();
return structTypeInfo->hasField(key);
}

const StructField* StructType::getField(const LogicalType* type, struct_field_idx_t idx) {
KU_ASSERT(type->getPhysicalType() == PhysicalTypeID::STRUCT);
auto structTypeInfo = type->extraTypeInfo->constPtrCast<StructTypeInfo>();
return structTypeInfo->getStructField(idx);
}

const StructField* StructType::getField(const LogicalType* type, const std::string& key) {
KU_ASSERT(type->getPhysicalType() == PhysicalTypeID::STRUCT);
auto structTypeInfo = type->extraTypeInfo->constPtrCast<StructTypeInfo>();
return structTypeInfo->getStructField(key);
}

struct_field_idx_t StructType::getFieldIdx(const LogicalType* type, const std::string& key) {
KU_ASSERT(type->getPhysicalType() == PhysicalTypeID::STRUCT);
auto structTypeInfo = type->extraTypeInfo->constPtrCast<StructTypeInfo>();
return structTypeInfo->getStructFieldIdx(key);
}

LogicalType* MapType::getKeyType(const LogicalType* type) {
KU_ASSERT(type->getLogicalTypeID() == LogicalTypeID::MAP);
return StructType::getFieldTypes(ListType::getChildType(type))[0];
}

LogicalType* MapType::getValueType(const LogicalType* type) {
KU_ASSERT(type->getLogicalTypeID() == LogicalTypeID::MAP);
return StructType::getFieldTypes(ListType::getChildType(type))[1];
}

union_field_idx_t UnionType::getInternalFieldIdx(union_field_idx_t idx) {
return idx + 1;
}

std::string UnionType::getFieldName(const LogicalType* type, union_field_idx_t idx) {
KU_ASSERT(type->getLogicalTypeID() == LogicalTypeID::UNION);
return StructType::getFieldNames(type)[getInternalFieldIdx(idx)];
}

LogicalType* UnionType::getFieldType(const LogicalType* type, union_field_idx_t idx) {
KU_ASSERT(type->getLogicalTypeID() == LogicalTypeID::UNION);
return StructType::getFieldTypes(type)[getInternalFieldIdx(idx)];
}

uint64_t UnionType::getNumFields(const LogicalType* type) {
KU_ASSERT(type->getLogicalTypeID() == LogicalTypeID::UNION);
return StructType::getNumFields(type) - 1;
}

std::string PhysicalTypeUtils::physicalTypeToString(PhysicalTypeID physicalType) {
// LCOV_EXCL_START
switch (physicalType) {
Expand Down Expand Up @@ -921,17 +1016,21 @@ std::vector<LogicalTypeID> LogicalTypeUtils::getAllValidComparableLogicalTypes()
LogicalTypeID::UUID, LogicalTypeID::STRING, LogicalTypeID::SERIAL};
}

std::vector<LogicalTypeID> LogicalTypeUtils::getNumericalLogicalTypeIDs() {
return std::vector<LogicalTypeID>{LogicalTypeID::INT64, LogicalTypeID::INT32,
LogicalTypeID::INT16, LogicalTypeID::INT8, LogicalTypeID::UINT64, LogicalTypeID::UINT32,
LogicalTypeID::UINT16, LogicalTypeID::UINT8, LogicalTypeID::INT128, LogicalTypeID::DOUBLE,
LogicalTypeID::FLOAT, LogicalTypeID::SERIAL};
std::vector<LogicalTypeID> LogicalTypeUtils::getIntegerTypeIDs() {
return std::vector<LogicalTypeID>{LogicalTypeID::INT128, LogicalTypeID::INT64,
LogicalTypeID::INT32, LogicalTypeID::INT16, LogicalTypeID::INT8, LogicalTypeID::SERIAL,
LogicalTypeID::UINT64, LogicalTypeID::UINT32, LogicalTypeID::UINT16, LogicalTypeID::UINT8};
}

static std::vector<LogicalTypeID> getFloatingPointTypeIDs() {
return std::vector<LogicalTypeID>{LogicalTypeID::DOUBLE, LogicalTypeID::FLOAT};
}

// TODO(Ziyi): Support int128 and uint types here.
std::vector<LogicalTypeID> LogicalTypeUtils::getIntegerLogicalTypeIDs() {
return std::vector<LogicalTypeID>{LogicalTypeID::INT64, LogicalTypeID::INT32,
LogicalTypeID::INT16, LogicalTypeID::INT8, LogicalTypeID::SERIAL};
std::vector<LogicalTypeID> LogicalTypeUtils::getNumericalLogicalTypeIDs() {
auto integerTypes = getIntegerTypeIDs();
auto floatingPointTypes = getFloatingPointTypeIDs();
integerTypes.insert(integerTypes.end(), floatingPointTypes.begin(), floatingPointTypes.end());
return integerTypes;
}

std::vector<LogicalTypeID> LogicalTypeUtils::getAllValidLogicTypes() {
Expand Down
2 changes: 1 addition & 1 deletion src/function/built_in_function_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ Function* BuiltInFunctionsUtils::matchFunction(const std::string& name,
std::vector<Function*> candidateFunctions;
uint32_t minCost = UINT32_MAX;
for (auto& function : functionSet) {
auto func = reinterpret_cast<Function*>(function.get());
auto func = function.get();
auto cost = getFunctionCost(inputTypes, func, isOverload);
if (cost == UINT32_MAX) {
continue;
Expand Down
14 changes: 11 additions & 3 deletions src/function/list/list_range_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ struct Range {
// - when start = end: there is only one element in result list
// - when end - start are of opposite sign of step, the result will be empty
// - default step = 1
template<typename T>
static void operation(T& end, list_entry_t& result, ValueVector& endVector,
ValueVector& resultVector) {
T step = 1;
T start = 0;
operation(start, end, step, result, endVector, resultVector);
}

template<typename T>
static void operation(T& start, T& end, list_entry_t& result, ValueVector& leftVector,
ValueVector& /*rightVector*/, ValueVector& resultVector) {
Expand Down Expand Up @@ -46,7 +54,7 @@ static scalar_func_exec_t getBinaryExecFunc(const LogicalType& type) {
scalar_func_exec_t execFunc;
TypeUtils::visit(
type.getLogicalTypeID(),
[&execFunc]<NumericTypes T>(T) {
[&execFunc]<IntegerTypes T>(T) {
execFunc = ScalarFunction::BinaryExecListStructFunction<T, T, list_entry_t, Range>;
},
[](auto) { KU_UNREACHABLE; });
Expand All @@ -57,7 +65,7 @@ static scalar_func_exec_t getTernaryExecFunc(const LogicalType& type) {
scalar_func_exec_t execFunc;
TypeUtils::visit(
type.getLogicalTypeID(),
[&execFunc]<NumericTypes T>(T) {
[&execFunc]<IntegerTypes T>(T) {
execFunc = ScalarFunction::TernaryExecListStructFunction<T, T, T, list_entry_t, Range>;
},
[](auto) { KU_UNREACHABLE; });
Expand All @@ -78,7 +86,7 @@ static std::unique_ptr<FunctionBindData> bindFunc(const binder::expression_vecto

function_set ListRangeFunction::getFunctionSet() {
function_set result;
for (auto typeID : LogicalTypeUtils::getIntegerLogicalTypeIDs()) {
for (auto typeID : LogicalTypeUtils::getIntegerTypeIDs()) {
// start, end
result.push_back(
std::make_unique<ScalarFunction>(name, std::vector<LogicalTypeID>{typeID, typeID},
Expand Down
Loading
Loading