Skip to content

Commit

Permalink
Add query processing for serial
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin committed Jun 4, 2023
1 parent 2e0f617 commit 86cfd03
Show file tree
Hide file tree
Showing 30 changed files with 855 additions and 287 deletions.
2 changes: 1 addition & 1 deletion src/binder/bind/bind_ddl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ property_id_t Binder::bindPropertyName(TableSchema* tableSchema, const std::stri
LogicalType Binder::bindDataType(const std::string& dataType) {
auto boundType = LogicalTypeUtils::dataTypeFromString(dataType);
if (boundType.getLogicalTypeID() == common::LogicalTypeID::FIXED_LIST) {
auto validNumericTypes = common::LogicalType::getNumericalLogicalTypeIDs();
auto validNumericTypes = common::LogicalTypeUtils::getNumericalLogicalTypeIDs();
auto childType = common::FixedListType::getChildType(&boundType);
auto numElementsInList = common::FixedListType::getNumElementsInList(&boundType);
if (find(validNumericTypes.begin(), validNumericTypes.end(),
Expand Down
71 changes: 40 additions & 31 deletions src/common/type_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,58 +31,67 @@ bool TypeUtils::convertToBoolean(const char* data) {
". Input is not equal to True or False (in a case-insensitive manner)");
}

std::string TypeUtils::listValueToString(
const LogicalType& dataType, uint8_t* listValues, uint64_t pos) {
std::string TypeUtils::castValueToString(
const LogicalType& dataType, uint8_t* value, void* vector) {
auto valueVector = reinterpret_cast<ValueVector*>(vector);
switch (dataType.getLogicalTypeID()) {
case LogicalTypeID::BOOL:
return TypeUtils::toString(((bool*)listValues)[pos]);
return TypeUtils::toString(*reinterpret_cast<bool*>(value));
case LogicalTypeID::INT64:
return TypeUtils::toString(((int64_t*)listValues)[pos]);
return TypeUtils::toString(*reinterpret_cast<int64_t*>(value));
case LogicalTypeID::INT32:
return TypeUtils::toString(*reinterpret_cast<int32_t*>(value));
case LogicalTypeID::INT16:
return TypeUtils::toString(*reinterpret_cast<int16_t*>(value));
case LogicalTypeID::DOUBLE:
return TypeUtils::toString(((double_t*)listValues)[pos]);
return TypeUtils::toString(*reinterpret_cast<double_t*>(value));
case LogicalTypeID::FLOAT:
return TypeUtils::toString(*reinterpret_cast<float_t*>(value));
case LogicalTypeID::DATE:
return TypeUtils::toString(((date_t*)listValues)[pos]);
return TypeUtils::toString(*reinterpret_cast<date_t*>(value));
case LogicalTypeID::TIMESTAMP:
return TypeUtils::toString(((timestamp_t*)listValues)[pos]);
return TypeUtils::toString(*reinterpret_cast<timestamp_t*>(value));
case LogicalTypeID::INTERVAL:
return TypeUtils::toString(((interval_t*)listValues)[pos]);
return TypeUtils::toString(*reinterpret_cast<interval_t*>(value));
case LogicalTypeID::STRING:
return TypeUtils::toString(((ku_string_t*)listValues)[pos]);
return TypeUtils::toString(*reinterpret_cast<ku_string_t*>(value));
case LogicalTypeID::INTERNAL_ID:
return TypeUtils::toString(*reinterpret_cast<internalID_t*>(value));
case LogicalTypeID::VAR_LIST:
return TypeUtils::toString(((ku_list_t*)listValues)[pos], dataType);
return TypeUtils::toString(*reinterpret_cast<list_entry_t*>(value), valueVector);
case LogicalTypeID::STRUCT:
return TypeUtils::toString(*reinterpret_cast<struct_entry_t*>(value), valueVector);
default:
throw RuntimeException("Invalid data type " + LogicalTypeUtils::dataTypeToString(dataType) +
" for TypeUtils::listValueToString.");
" for TypeUtils::castValueToString.");
}
}

std::string TypeUtils::toString(const ku_list_t& val, const LogicalType& dataType) {
std::string TypeUtils::toString(const list_entry_t& val, void* valueVector) {
auto listVector = (common::ValueVector*)valueVector;
std::string result = "[";
auto values = ListVector::getListValues(listVector, val);
auto childType = VarListType::getChildType(&listVector->dataType);
auto dataVector = ListVector::getDataVector(listVector);
for (auto i = 0u; i < val.size; ++i) {
result += listValueToString(
*VarListType::getChildType(&dataType), reinterpret_cast<uint8_t*>(val.overflowPtr), i);
result += (i == val.size - 1 ? "]" : ",");
result += castValueToString(*childType, values, dataVector);
result += (val.size - 1 == i ? "]" : ",");
values += ListVector::getDataVector(listVector)->getNumBytesPerValue();
}
return result;
}

std::string TypeUtils::toString(const list_entry_t& val, void* valVector) {
auto listVector = (common::ValueVector*)valVector;
std::string result = "[";
auto values = ListVector::getListValues(listVector, val);
auto childType = VarListType::getChildType(&listVector->dataType);
for (auto i = 0u; i < val.size - 1; ++i) {
result += (childType->getLogicalTypeID() == LogicalTypeID::VAR_LIST ?
toString(reinterpret_cast<common::list_entry_t*>(values)[i],
ListVector::getDataVector(listVector)) :
listValueToString(*childType, values, i)) +
",";
std::string TypeUtils::toString(const struct_entry_t& val, void* valVector) {
auto structVector = (common::ValueVector*)valVector;
std::string result = "{";
auto fields = StructType::getFields(&structVector->dataType);
for (auto i = 0u; i < fields.size(); ++i) {
auto field = fields[i];
auto fieldVector = StructVector::getChildVector(structVector, i);
auto value = fieldVector->getData() + fieldVector->getNumBytesPerValue() * val.pos;
result += castValueToString(*field->getType(), value, fieldVector.get());
result += (fields.size() - 1 == i ? "}" : ",");
}
result += (childType->getLogicalTypeID() == LogicalTypeID::VAR_LIST ?
toString(reinterpret_cast<common::list_entry_t*>(values)[val.size - 1],
ListVector::getDataVector(listVector)) :
listValueToString(*childType, values, val.size - 1)) +
"]";
return result;
}

Expand Down
44 changes: 24 additions & 20 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,26 +135,6 @@ LogicalType::LogicalType(LogicalType&& other) noexcept
: typeID{other.typeID}, physicalType{other.physicalType}, extraTypeInfo{
std::move(other.extraTypeInfo)} {}

std::vector<LogicalTypeID> LogicalType::getNumericalLogicalTypeIDs() {
return std::vector<LogicalTypeID>{LogicalTypeID::INT64, LogicalTypeID::INT32,
LogicalTypeID::INT16, LogicalTypeID::DOUBLE, LogicalTypeID::FLOAT};
}

std::vector<LogicalTypeID> LogicalType::getAllValidComparableLogicalTypes() {
return std::vector<LogicalTypeID>{LogicalTypeID::BOOL, LogicalTypeID::INT64,
LogicalTypeID::INT32, LogicalTypeID::INT16, LogicalTypeID::DOUBLE, LogicalTypeID::FLOAT,
LogicalTypeID::DATE, LogicalTypeID::TIMESTAMP, LogicalTypeID::INTERVAL,
LogicalTypeID::STRING};
}

std::vector<LogicalTypeID> LogicalType::getAllValidLogicTypeIDs() {
// TODO(Ziyi): Add FIX_LIST type to allValidTypeID when we support functions on VAR_LIST.
return std::vector<LogicalTypeID>{LogicalTypeID::INTERNAL_ID, LogicalTypeID::BOOL,
LogicalTypeID::INT64, LogicalTypeID::INT32, LogicalTypeID::INT16, LogicalTypeID::DOUBLE,
LogicalTypeID::STRING, LogicalTypeID::DATE, LogicalTypeID::TIMESTAMP,
LogicalTypeID::INTERVAL, LogicalTypeID::VAR_LIST, LogicalTypeID::FLOAT};
}

LogicalType& LogicalType::operator=(const LogicalType& other) {
typeID = other.typeID;
physicalType = other.physicalType;
Expand Down Expand Up @@ -472,12 +452,36 @@ bool LogicalTypeUtils::isNumerical(const kuzu::common::LogicalType& dataType) {
case LogicalTypeID::INT16:
case LogicalTypeID::DOUBLE:
case LogicalTypeID::FLOAT:
case LogicalTypeID::SERIAL:
return true;
default:
return false;
}
}

std::vector<LogicalType> LogicalTypeUtils::getAllValidComparableLogicalTypes() {
return std::vector<LogicalType>{LogicalType{LogicalTypeID::BOOL},
LogicalType{LogicalTypeID::INT64}, LogicalType{LogicalTypeID::INT32},
LogicalType{LogicalTypeID::INT16}, LogicalType{LogicalTypeID::DOUBLE},
LogicalType{LogicalTypeID::FLOAT}, LogicalType{LogicalTypeID::DATE},
LogicalType{LogicalTypeID::TIMESTAMP}, LogicalType{LogicalTypeID::INTERVAL},
LogicalType{LogicalTypeID::STRING}, LogicalType{LogicalTypeID::SERIAL}};
}

std::vector<LogicalTypeID> LogicalTypeUtils::getNumericalLogicalTypeIDs() {
return std::vector<LogicalTypeID>{LogicalTypeID::INT64, LogicalTypeID::INT32,
LogicalTypeID::INT16, LogicalTypeID::DOUBLE, LogicalTypeID::FLOAT, LogicalTypeID::SERIAL};
}

std::vector<LogicalTypeID> LogicalTypeUtils::getAllValidLogicTypeIDs() {
// TODO(Ziyi): Add FIX_LIST,STRUCT type to allValidTypeID when we support functions on VAR_LIST.
return std::vector<LogicalTypeID>{LogicalTypeID::INTERNAL_ID, LogicalTypeID::BOOL,
LogicalTypeID::INT64, LogicalTypeID::INT32, LogicalTypeID::INT16, LogicalTypeID::DOUBLE,
LogicalTypeID::STRING, LogicalTypeID::DATE, LogicalTypeID::TIMESTAMP,
LogicalTypeID::INTERVAL, LogicalTypeID::VAR_LIST, LogicalTypeID::FLOAT,
LogicalTypeID::SERIAL};
}

std::vector<std::string> LogicalTypeUtils::parseStructFields(const std::string& structTypeStr) {
std::vector<std::string> structFieldsStr;
auto startPos = 0u;
Expand Down
35 changes: 13 additions & 22 deletions src/function/aggregate_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ std::unique_ptr<AggregateFunction> AggregateFunctionUtil::getCountFunction(
std::unique_ptr<AggregateFunction> AggregateFunctionUtil::getAvgFunction(
const LogicalType& inputType, bool isDistinct) {
switch (inputType.getLogicalTypeID()) {
case LogicalTypeID::SERIAL:
case LogicalTypeID::INT64:
return std::make_unique<AggregateFunction>(AvgFunction<int64_t>::initialize,
AvgFunction<int64_t>::updateAll, AvgFunction<int64_t>::updatePos,
Expand Down Expand Up @@ -59,6 +60,7 @@ std::unique_ptr<AggregateFunction> AggregateFunctionUtil::getAvgFunction(
std::unique_ptr<AggregateFunction> AggregateFunctionUtil::getSumFunction(
const LogicalType& inputType, bool isDistinct) {
switch (inputType.getLogicalTypeID()) {
case LogicalTypeID::SERIAL:
case LogicalTypeID::INT64:
return std::make_unique<AggregateFunction>(SumFunction<int64_t>::initialize,
SumFunction<int64_t>::updateAll, SumFunction<int64_t>::updatePos,
Expand Down Expand Up @@ -105,61 +107,50 @@ std::unique_ptr<AggregateFunction> AggregateFunctionUtil::getCollectFunction(

template<typename FUNC>
std::unique_ptr<AggregateFunction> AggregateFunctionUtil::getMinMaxFunction(
const LogicalType& inputType, bool isDistinct) {
switch (inputType.getLogicalTypeID()) {
case LogicalTypeID::BOOL:
const common::LogicalType& inputType, bool isDistinct) {
switch (inputType.getPhysicalType()) {
case PhysicalTypeID::BOOL:
return std::make_unique<AggregateFunction>(MinMaxFunction<bool>::initialize,
MinMaxFunction<bool>::updateAll<FUNC>, MinMaxFunction<bool>::updatePos<FUNC>,
MinMaxFunction<bool>::combine<FUNC>, MinMaxFunction<bool>::finalize, inputType,
isDistinct);
case LogicalTypeID::INT64:
case PhysicalTypeID::INT64:
return std::make_unique<AggregateFunction>(MinMaxFunction<int64_t>::initialize,
MinMaxFunction<int64_t>::updateAll<FUNC>, MinMaxFunction<int64_t>::updatePos<FUNC>,
MinMaxFunction<int64_t>::combine<FUNC>, MinMaxFunction<int64_t>::finalize, inputType,
isDistinct);
case LogicalTypeID::INT32:
case PhysicalTypeID::INT32:
return std::make_unique<AggregateFunction>(MinMaxFunction<int32_t>::initialize,
MinMaxFunction<int32_t>::updateAll<FUNC>, MinMaxFunction<int32_t>::updatePos<FUNC>,
MinMaxFunction<int32_t>::combine<FUNC>, MinMaxFunction<int32_t>::finalize, inputType,
isDistinct);
case LogicalTypeID::INT16:
case PhysicalTypeID::INT16:
return std::make_unique<AggregateFunction>(MinMaxFunction<int16_t>::initialize,
MinMaxFunction<int16_t>::updateAll<FUNC>, MinMaxFunction<int16_t>::updatePos<FUNC>,
MinMaxFunction<int16_t>::combine<FUNC>, MinMaxFunction<int16_t>::finalize, inputType,
isDistinct);
case LogicalTypeID::DOUBLE:
case PhysicalTypeID::DOUBLE:
return std::make_unique<AggregateFunction>(MinMaxFunction<double_t>::initialize,
MinMaxFunction<double_t>::updateAll<FUNC>, MinMaxFunction<double_t>::updatePos<FUNC>,
MinMaxFunction<double_t>::combine<FUNC>, MinMaxFunction<double_t>::finalize, inputType,
isDistinct);
case LogicalTypeID::FLOAT:
case PhysicalTypeID::FLOAT:
return std::make_unique<AggregateFunction>(MinMaxFunction<float_t>::initialize,
MinMaxFunction<float_t>::updateAll<FUNC>, MinMaxFunction<float_t>::updatePos<FUNC>,
MinMaxFunction<float_t>::combine<FUNC>, MinMaxFunction<float_t>::finalize, inputType,
isDistinct);
case LogicalTypeID::DATE:
return std::make_unique<AggregateFunction>(MinMaxFunction<date_t>::initialize,
MinMaxFunction<date_t>::updateAll<FUNC>, MinMaxFunction<date_t>::updatePos<FUNC>,
MinMaxFunction<date_t>::combine<FUNC>, MinMaxFunction<date_t>::finalize, inputType,
isDistinct);
case LogicalTypeID::TIMESTAMP:
return std::make_unique<AggregateFunction>(MinMaxFunction<timestamp_t>::initialize,
MinMaxFunction<timestamp_t>::updateAll<FUNC>,
MinMaxFunction<timestamp_t>::updatePos<FUNC>,
MinMaxFunction<timestamp_t>::combine<FUNC>, MinMaxFunction<timestamp_t>::finalize,
inputType, isDistinct);
case LogicalTypeID::INTERVAL:
case PhysicalTypeID::INTERVAL:
return std::make_unique<AggregateFunction>(MinMaxFunction<interval_t>::initialize,
MinMaxFunction<interval_t>::updateAll<FUNC>,
MinMaxFunction<interval_t>::updatePos<FUNC>, MinMaxFunction<interval_t>::combine<FUNC>,
MinMaxFunction<interval_t>::finalize, inputType, isDistinct);
case LogicalTypeID::STRING:
case PhysicalTypeID::STRING:
return std::make_unique<AggregateFunction>(MinMaxFunction<ku_string_t>::initialize,
MinMaxFunction<ku_string_t>::updateAll<FUNC>,
MinMaxFunction<ku_string_t>::updatePos<FUNC>,
MinMaxFunction<ku_string_t>::combine<FUNC>, MinMaxFunction<ku_string_t>::finalize,
inputType, isDistinct);
case LogicalTypeID::INTERNAL_ID:
case PhysicalTypeID::INTERNAL_ID:
return std::make_unique<AggregateFunction>(MinMaxFunction<nodeID_t>::initialize,
MinMaxFunction<nodeID_t>::updateAll<FUNC>, MinMaxFunction<nodeID_t>::updatePos<FUNC>,
MinMaxFunction<nodeID_t>::combine<FUNC>, MinMaxFunction<nodeID_t>::finalize, inputType,
Expand Down
20 changes: 9 additions & 11 deletions src/function/built_in_aggregate_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ void BuiltInAggregateFunctions::registerCountStar() {
void BuiltInAggregateFunctions::registerCount() {
std::vector<std::unique_ptr<AggregateFunctionDefinition>> definitions;
LogicalType inputType;
for (auto& typeID : LogicalType::getAllValidLogicTypeIDs()) {
for (auto& typeID : LogicalTypeUtils::getAllValidLogicTypeIDs()) {
if (typeID == LogicalTypeID::VAR_LIST) {
inputType = LogicalType(
typeID, std::make_unique<VarListTypeInfo>(std::make_unique<LogicalType>()));
Expand All @@ -104,7 +104,7 @@ void BuiltInAggregateFunctions::registerCount() {

void BuiltInAggregateFunctions::registerSum() {
std::vector<std::unique_ptr<AggregateFunctionDefinition>> definitions;
for (auto typeID : LogicalType::getNumericalLogicalTypeIDs()) {
for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) {
for (auto isDistinct : std::vector<bool>{true, false}) {
definitions.push_back(std::make_unique<AggregateFunctionDefinition>(SUM_FUNC_NAME,
std::vector<LogicalTypeID>{typeID}, typeID,
Expand All @@ -117,7 +117,7 @@ void BuiltInAggregateFunctions::registerSum() {

void BuiltInAggregateFunctions::registerAvg() {
std::vector<std::unique_ptr<AggregateFunctionDefinition>> definitions;
for (auto typeID : LogicalType::getNumericalLogicalTypeIDs()) {
for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) {
for (auto isDistinct : std::vector<bool>{true, false}) {
definitions.push_back(std::make_unique<AggregateFunctionDefinition>(AVG_FUNC_NAME,
std::vector<LogicalTypeID>{typeID}, LogicalTypeID::DOUBLE,
Expand All @@ -130,25 +130,23 @@ void BuiltInAggregateFunctions::registerAvg() {

void BuiltInAggregateFunctions::registerMin() {
std::vector<std::unique_ptr<AggregateFunctionDefinition>> definitions;
for (auto typeID : LogicalType::getAllValidComparableLogicalTypes()) {
for (auto& type : LogicalTypeUtils::getAllValidComparableLogicalTypes()) {
for (auto isDistinct : std::vector<bool>{true, false}) {
definitions.push_back(std::make_unique<AggregateFunctionDefinition>(MIN_FUNC_NAME,
std::vector<LogicalTypeID>{typeID}, typeID,
AggregateFunctionUtil::getMinFunction(LogicalType(typeID), isDistinct),
isDistinct));
std::vector<LogicalTypeID>{type.getLogicalTypeID()}, type.getLogicalTypeID(),
AggregateFunctionUtil::getMinFunction(type, isDistinct), isDistinct));
}
}
aggregateFunctions.insert({MIN_FUNC_NAME, std::move(definitions)});
}

void BuiltInAggregateFunctions::registerMax() {
std::vector<std::unique_ptr<AggregateFunctionDefinition>> definitions;
for (auto typeID : LogicalType::getAllValidComparableLogicalTypes()) {
for (auto& type : LogicalTypeUtils::getAllValidComparableLogicalTypes()) {
for (auto isDistinct : std::vector<bool>{true, false}) {
definitions.push_back(std::make_unique<AggregateFunctionDefinition>(MAX_FUNC_NAME,
std::vector<LogicalTypeID>{typeID}, typeID,
AggregateFunctionUtil::getMaxFunction(LogicalType(typeID), isDistinct),
isDistinct));
std::vector<LogicalTypeID>{type.getLogicalTypeID()}, type.getLogicalTypeID(),
AggregateFunctionUtil::getMaxFunction(type, isDistinct), isDistinct));
}
}
aggregateFunctions.insert({MAX_FUNC_NAME, std::move(definitions)});
Expand Down
25 changes: 25 additions & 0 deletions src/function/built_in_vector_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ uint32_t BuiltInVectorOperations::getCastCost(
return castDouble(targetTypeID);
case common::LogicalTypeID::FLOAT:
return castFloat(targetTypeID);
case common::LogicalTypeID::DATE:
return castDate(targetTypeID);
case common::LogicalTypeID::SERIAL:
return castSerial(targetTypeID);
default:
return UINT32_MAX;
}
Expand Down Expand Up @@ -133,6 +137,9 @@ uint32_t BuiltInVectorOperations::getTargetTypeCost(common::LogicalTypeID typeID
case common::LogicalTypeID::DOUBLE: {
return 102;
}
case common::LogicalTypeID::TIMESTAMP: {
return 120;
}
default: {
throw InternalException("Unsupported casting operation.");
}
Expand Down Expand Up @@ -188,6 +195,24 @@ uint32_t BuiltInVectorOperations::castFloat(common::LogicalTypeID targetTypeID)
}
}

uint32_t BuiltInVectorOperations::castDate(common::LogicalTypeID targetTypeID) {
switch (targetTypeID) {
case common::LogicalTypeID::TIMESTAMP:
return getTargetTypeCost(targetTypeID);
default:
return UINT32_MAX;
}
}

uint32_t BuiltInVectorOperations::castSerial(common::LogicalTypeID targetTypeID) {
switch (targetTypeID) {
case common::LogicalTypeID::INT64:
return 0;
default:
return castInt64(targetTypeID);
}
}

// When there is multiple candidates functions, e.g. double + int and double + double for input
// "1.5 + parameter", we prefer the one without any implicit casting i.e. double + double.
VectorOperationDefinition* BuiltInVectorOperations::getBestMatch(
Expand Down
Loading

0 comments on commit 86cfd03

Please sign in to comment.