Skip to content

Commit

Permalink
Refactor list_range and list_sort functions (#3393)
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin committed Apr 28, 2024
1 parent 84032a2 commit c60d6a4
Show file tree
Hide file tree
Showing 17 changed files with 474 additions and 304 deletions.
117 changes: 108 additions & 9 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,101 @@ using kuzu::function::BuiltInFunctionsUtils;
namespace kuzu {
namespace common {

LogicalType* ListType::getChildType(const kuzu::common::LogicalType* type) {
KU_ASSERT(type->getPhysicalType() == PhysicalTypeID::LIST ||
type->getPhysicalType() == PhysicalTypeID::ARRAY);
auto listTypeInfo = type->extraTypeInfo->constPtrCast<ListTypeInfo>();
return listTypeInfo->getChildType();
}

LogicalType* ArrayType::getChildType(const LogicalType* type) {
KU_ASSERT(type->getPhysicalType() == PhysicalTypeID::ARRAY);
auto arrayTypeInfo = type->extraTypeInfo->constPtrCast<ArrayTypeInfo>();
return arrayTypeInfo->getChildType();
}

uint64_t ArrayType::getNumElements(const LogicalType* type) {
KU_ASSERT(type->getPhysicalType() == PhysicalTypeID::ARRAY);
auto arrayTypeInfo = type->extraTypeInfo->constPtrCast<ArrayTypeInfo>();
return arrayTypeInfo->getNumElements();
}

std::vector<LogicalType*> StructType::getFieldTypes(const LogicalType* type) {
KU_ASSERT(type->getPhysicalType() == PhysicalTypeID::STRUCT);
auto structTypeInfo = type->extraTypeInfo->constPtrCast<StructTypeInfo>();
return structTypeInfo->getChildrenTypes();
}

std::vector<std::string> StructType::getFieldNames(const LogicalType* type) {
KU_ASSERT(type->getPhysicalType() == PhysicalTypeID::STRUCT);
auto structTypeInfo = type->extraTypeInfo->constPtrCast<StructTypeInfo>();
return structTypeInfo->getChildrenNames();
}

uint64_t StructType::getNumFields(const LogicalType* type) {
KU_ASSERT(type->getPhysicalType() == PhysicalTypeID::STRUCT);
return getFieldTypes(type).size();
}

std::vector<const StructField*> StructType::getFields(const LogicalType* type) {
KU_ASSERT(type->getPhysicalType() == PhysicalTypeID::STRUCT);
auto structTypeInfo = type->extraTypeInfo->constPtrCast<StructTypeInfo>();
return structTypeInfo->getStructFields();
}

bool StructType::hasField(const LogicalType* type, const std::string& key) {
KU_ASSERT(type->getPhysicalType() == PhysicalTypeID::STRUCT);
auto structTypeInfo = type->extraTypeInfo->constPtrCast<StructTypeInfo>();
return structTypeInfo->hasField(key);
}

const StructField* StructType::getField(const LogicalType* type, struct_field_idx_t idx) {
KU_ASSERT(type->getPhysicalType() == PhysicalTypeID::STRUCT);
auto structTypeInfo = type->extraTypeInfo->constPtrCast<StructTypeInfo>();
return structTypeInfo->getStructField(idx);
}

const StructField* StructType::getField(const LogicalType* type, const std::string& key) {
KU_ASSERT(type->getPhysicalType() == PhysicalTypeID::STRUCT);
auto structTypeInfo = type->extraTypeInfo->constPtrCast<StructTypeInfo>();
return structTypeInfo->getStructField(key);
}

struct_field_idx_t StructType::getFieldIdx(const LogicalType* type, const std::string& key) {
KU_ASSERT(type->getPhysicalType() == PhysicalTypeID::STRUCT);
auto structTypeInfo = type->extraTypeInfo->constPtrCast<StructTypeInfo>();
return structTypeInfo->getStructFieldIdx(key);
}

LogicalType* MapType::getKeyType(const LogicalType* type) {
KU_ASSERT(type->getLogicalTypeID() == LogicalTypeID::MAP);
return StructType::getFieldTypes(ListType::getChildType(type))[0];
}

LogicalType* MapType::getValueType(const LogicalType* type) {
KU_ASSERT(type->getLogicalTypeID() == LogicalTypeID::MAP);
return StructType::getFieldTypes(ListType::getChildType(type))[1];
}

union_field_idx_t UnionType::getInternalFieldIdx(union_field_idx_t idx) {
return idx + 1;
}

std::string UnionType::getFieldName(const LogicalType* type, union_field_idx_t idx) {
KU_ASSERT(type->getLogicalTypeID() == LogicalTypeID::UNION);
return StructType::getFieldNames(type)[getInternalFieldIdx(idx)];
}

LogicalType* UnionType::getFieldType(const LogicalType* type, union_field_idx_t idx) {
KU_ASSERT(type->getLogicalTypeID() == LogicalTypeID::UNION);
return StructType::getFieldTypes(type)[getInternalFieldIdx(idx)];
}

uint64_t UnionType::getNumFields(const LogicalType* type) {
KU_ASSERT(type->getLogicalTypeID() == LogicalTypeID::UNION);
return StructType::getNumFields(type) - 1;
}

std::string PhysicalTypeUtils::physicalTypeToString(PhysicalTypeID physicalType) {
// LCOV_EXCL_START
switch (physicalType) {
Expand Down Expand Up @@ -921,17 +1016,21 @@ std::vector<LogicalTypeID> LogicalTypeUtils::getAllValidComparableLogicalTypes()
LogicalTypeID::UUID, LogicalTypeID::STRING, LogicalTypeID::SERIAL};
}

std::vector<LogicalTypeID> LogicalTypeUtils::getNumericalLogicalTypeIDs() {
return std::vector<LogicalTypeID>{LogicalTypeID::INT64, LogicalTypeID::INT32,
LogicalTypeID::INT16, LogicalTypeID::INT8, LogicalTypeID::UINT64, LogicalTypeID::UINT32,
LogicalTypeID::UINT16, LogicalTypeID::UINT8, LogicalTypeID::INT128, LogicalTypeID::DOUBLE,
LogicalTypeID::FLOAT, LogicalTypeID::SERIAL};
std::vector<LogicalTypeID> LogicalTypeUtils::getIntegerTypeIDs() {
return std::vector<LogicalTypeID>{LogicalTypeID::INT128, LogicalTypeID::INT64,
LogicalTypeID::INT32, LogicalTypeID::INT16, LogicalTypeID::INT8, LogicalTypeID::SERIAL,
LogicalTypeID::UINT64, LogicalTypeID::UINT32, LogicalTypeID::UINT16, LogicalTypeID::UINT8};
}

static std::vector<LogicalTypeID> getFloatingPointTypeIDs() {
return std::vector<LogicalTypeID>{LogicalTypeID::DOUBLE, LogicalTypeID::FLOAT};
}

// TODO(Ziyi): Support int128 and uint types here.
std::vector<LogicalTypeID> LogicalTypeUtils::getIntegerLogicalTypeIDs() {
return std::vector<LogicalTypeID>{LogicalTypeID::INT64, LogicalTypeID::INT32,
LogicalTypeID::INT16, LogicalTypeID::INT8, LogicalTypeID::SERIAL};
std::vector<LogicalTypeID> LogicalTypeUtils::getNumericalLogicalTypeIDs() {
auto integerTypes = getIntegerTypeIDs();
auto floatingPointTypes = getFloatingPointTypeIDs();
integerTypes.insert(integerTypes.end(), floatingPointTypes.begin(), floatingPointTypes.end());
return integerTypes;
}

std::vector<LogicalTypeID> LogicalTypeUtils::getAllValidLogicTypes() {
Expand Down
2 changes: 1 addition & 1 deletion src/function/built_in_function_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ Function* BuiltInFunctionsUtils::matchFunction(const std::string& name,
std::vector<Function*> candidateFunctions;
uint32_t minCost = UINT32_MAX;
for (auto& function : functionSet) {
auto func = reinterpret_cast<Function*>(function.get());
auto func = function.get();
auto cost = getFunctionCost(inputTypes, func, isOverload);
if (cost == UINT32_MAX) {
continue;
Expand Down
14 changes: 11 additions & 3 deletions src/function/list/list_range_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ struct Range {
// - when start = end: there is only one element in result list
// - when end - start are of opposite sign of step, the result will be empty
// - default step = 1
template<typename T>
static void operation(T& end, list_entry_t& result, ValueVector& endVector,
ValueVector& resultVector) {
T step = 1;
T start = 0;
operation(start, end, step, result, endVector, resultVector);
}

template<typename T>
static void operation(T& start, T& end, list_entry_t& result, ValueVector& leftVector,
ValueVector& /*rightVector*/, ValueVector& resultVector) {
Expand Down Expand Up @@ -46,7 +54,7 @@ static scalar_func_exec_t getBinaryExecFunc(const LogicalType& type) {
scalar_func_exec_t execFunc;
TypeUtils::visit(
type.getLogicalTypeID(),
[&execFunc]<NumericTypes T>(T) {
[&execFunc]<IntegerTypes T>(T) {
execFunc = ScalarFunction::BinaryExecListStructFunction<T, T, list_entry_t, Range>;
},
[](auto) { KU_UNREACHABLE; });
Expand All @@ -57,7 +65,7 @@ static scalar_func_exec_t getTernaryExecFunc(const LogicalType& type) {
scalar_func_exec_t execFunc;
TypeUtils::visit(
type.getLogicalTypeID(),
[&execFunc]<NumericTypes T>(T) {
[&execFunc]<IntegerTypes T>(T) {
execFunc = ScalarFunction::TernaryExecListStructFunction<T, T, T, list_entry_t, Range>;
},
[](auto) { KU_UNREACHABLE; });
Expand All @@ -78,7 +86,7 @@ static std::unique_ptr<FunctionBindData> bindFunc(const binder::expression_vecto

function_set ListRangeFunction::getFunctionSet() {
function_set result;
for (auto typeID : LogicalTypeUtils::getIntegerLogicalTypeIDs()) {
for (auto typeID : LogicalTypeUtils::getIntegerTypeIDs()) {
// start, end
result.push_back(
std::make_unique<ScalarFunction>(name, std::vector<LogicalTypeID>{typeID, typeID},
Expand Down
Loading

0 comments on commit c60d6a4

Please sign in to comment.