Skip to content

Commit

Permalink
Push recursive join length into operator
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed May 21, 2023
1 parent deff05a commit 09f0cc7
Show file tree
Hide file tree
Showing 27 changed files with 290 additions and 204 deletions.
18 changes: 0 additions & 18 deletions src/binder/bind/bind_copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,24 +57,6 @@ std::vector<std::string> Binder::bindFilePaths(const std::vector<std::string>& f
return boundFilePaths;
}

std::unordered_map<common::property_id_t, std::string> Binder::bindPropertyToNpyMap(
common::table_id_t tableID, const std::vector<std::string>& filePaths) {
auto catalogContent = catalog.getReadOnlyVersion();
auto tableSchema = catalogContent->getTableSchema(tableID);
if (tableSchema->properties.size() != filePaths.size()) {
throw BinderException(StringUtils::string_format(
"Number of npy files is not equal to number of properties in table {}.",
tableSchema->tableName));
}
std::unordered_map<common::property_id_t, std::string> propertyIDToNpyMap;
for (int i = 0; i < filePaths.size(); i++) {
auto& filePath = filePaths[i];
auto& propertyID = tableSchema->properties[i].propertyID;
propertyIDToNpyMap[propertyID] = filePath;
}
return propertyIDToNpyMap;
}

CSVReaderConfig Binder::bindParsingOptions(
const std::unordered_map<std::string, std::unique_ptr<ParsedExpression>>* parsingOptions) {
CSVReaderConfig csvReaderConfig;
Expand Down
4 changes: 4 additions & 0 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,10 @@ void Binder::bindQueryRel(const RelPattern& relPattern,
getUniqueExpressionName(parsedName), parsedName, tableIDs, srcNode, dstNode,
relPattern.getDirection() != BOTH, relPattern.getRelType(), lowerBound, upperBound);
queryRel->setAlias(parsedName);
if (isVariableLength) {
queryRel->setInternalLengthExpression(
expressionBinder.createInternalLengthExpression(*queryRel));
}
// resolve properties associate with rel table
std::vector<RelTableSchema*> relTableSchemas;
for (auto tableID : tableIDs) {
Expand Down
4 changes: 2 additions & 2 deletions src/binder/bind/bind_projection_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ expression_vector Binder::rewriteNodeExpression(const kuzu::binder::Expression&
expression_vector result;
auto& node = (NodeExpression&)expression;
result.push_back(node.getInternalIDProperty());
result.push_back(expressionBinder.bindNodeLabelFunction(node));
result.push_back(expressionBinder.bindLabelFunction(node));
for (auto& property : node.getPropertyExpressions()) {
result.push_back(property->copy());
}
Expand All @@ -92,7 +92,7 @@ expression_vector Binder::rewriteRelExpression(const Expression& expression) {
auto& rel = (RelExpression&)expression;
result.push_back(rel.getSrcNode()->getInternalIDProperty());
result.push_back(rel.getDstNode()->getInternalIDProperty());
result.push_back(expressionBinder.bindRelLabelFunction(rel));
result.push_back(expressionBinder.bindLabelFunction(rel));
for (auto& property : rel.getPropertyExpressions()) {
result.push_back(property->copy());
}
Expand Down
153 changes: 86 additions & 67 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@ std::shared_ptr<Expression> ExpressionBinder::bindFunctionExpression(
auto& parsedFunctionExpression = (ParsedFunctionExpression&)parsedExpression;
auto functionName = parsedFunctionExpression.getFunctionName();
StringUtils::toUpper(functionName);
// check for special function binding
if (functionName == ID_FUNC_NAME) {
return bindInternalIDExpression(parsedExpression);
} else if (functionName == LABEL_FUNC_NAME) {
return bindLabelFunction(parsedExpression);
auto result = rewriteFunctionExpression(parsedExpression, functionName);
if (result != nullptr) {
return result;
}
auto functionType = binder->catalog.getFunctionType(functionName);
if (functionType == FUNCTION) {
Expand Down Expand Up @@ -123,23 +121,29 @@ std::shared_ptr<Expression> ExpressionBinder::staticEvaluate(
return createLiteralExpression(std::move(value));
}

std::shared_ptr<Expression> ExpressionBinder::bindInternalIDExpression(
const ParsedExpression& parsedExpression) {
auto child = bindExpression(*parsedExpression.getChild(0));
validateExpectedDataType(
*child, std::unordered_set<LogicalTypeID>{LogicalTypeID::NODE, LogicalTypeID::REL});
return bindInternalIDExpression(*child);
}

std::shared_ptr<Expression> ExpressionBinder::bindInternalIDExpression(
const Expression& expression) {
if (expression.dataType.getLogicalTypeID() == LogicalTypeID::NODE) {
auto& node = (NodeExpression&)expression;
return node.getInternalIDProperty();
} else {
assert(expression.dataType.getLogicalTypeID() == LogicalTypeID::REL);
return bindRelPropertyExpression(expression, INTERNAL_ID_SUFFIX);
}
// Function rewriting happens when we need to expose internal property access through function so
// that it becomes read-only or the function involves catalog information. Currently we write
// Before | After
// ID(a) | a._id
// LABEL(a) | LIST_EXTRACT(offset(a), [table names from catalog])
// LENGTH(e) | e._length
std::shared_ptr<Expression> ExpressionBinder::rewriteFunctionExpression(
const parser::ParsedExpression& parsedExpression, const std::string& functionName) {
if (functionName == ID_FUNC_NAME) {
auto child = bindExpression(*parsedExpression.getChild(0));
validateExpectedDataType(
*child, std::unordered_set<LogicalTypeID>{LogicalTypeID::NODE, LogicalTypeID::REL});
return bindInternalIDExpression(*child);
} else if (functionName == LABEL_FUNC_NAME) {
auto child = bindExpression(*parsedExpression.getChild(0));
validateExpectedDataType(
*child, std::unordered_set<LogicalTypeID>{LogicalTypeID::NODE, LogicalTypeID::REL});
return bindLabelFunction(*child);
} else if (functionName == LENGTH_FUNC_NAME) {
auto child = bindExpression(*parsedExpression.getChild(0));
return bindRecursiveJoinLengthFunction(*child);
}
return nullptr;
}

std::unique_ptr<Expression> ExpressionBinder::createInternalNodeIDExpression(
Expand All @@ -149,20 +153,22 @@ std::unique_ptr<Expression> ExpressionBinder::createInternalNodeIDExpression(
for (auto tableID : node.getTableIDs()) {
propertyIDPerTable.insert({tableID, INVALID_PROPERTY_ID});
}
auto result = std::make_unique<PropertyExpression>(LogicalType(LogicalTypeID::INTERNAL_ID),
return std::make_unique<PropertyExpression>(LogicalType(LogicalTypeID::INTERNAL_ID),
INTERNAL_ID_SUFFIX, node, std::move(propertyIDPerTable), false /* isPrimaryKey */);
return result;
}

std::shared_ptr<Expression> ExpressionBinder::bindLabelFunction(
const ParsedExpression& parsedExpression) {
// bind child node
auto child = bindExpression(*parsedExpression.getChild(0));
if (child->dataType.getLogicalTypeID() == common::LogicalTypeID::NODE) {
return bindNodeLabelFunction(*child);
} else {
assert(child->dataType.getLogicalTypeID() == common::LogicalTypeID::REL);
return bindRelLabelFunction(*child);
std::shared_ptr<Expression> ExpressionBinder::bindInternalIDExpression(
const Expression& expression) {
switch (expression.getDataType().getLogicalTypeID()) {
case common::LogicalTypeID::NODE: {
auto& node = (NodeExpression&)expression;
return node.getInternalIDProperty();
}
case common::LogicalTypeID::REL: {
return bindRelPropertyExpression(expression, INTERNAL_ID_SUFFIX);
}
default:
throw NotImplementedException("ExpressionBinder::bindInternalIDExpression");
}
}

Expand All @@ -183,22 +189,41 @@ static std::vector<std::unique_ptr<Value>> populateLabelValues(
return labels;
}

std::shared_ptr<Expression> ExpressionBinder::bindNodeLabelFunction(const Expression& expression) {
std::shared_ptr<Expression> ExpressionBinder::bindLabelFunction(const Expression& expression) {
auto catalogContent = binder->catalog.getReadOnlyVersion();
auto& node = (NodeExpression&)expression;
if (!node.isMultiLabeled()) {
auto labelName = catalogContent->getTableName(node.getSingleTableID());
return createLiteralExpression(std::make_unique<Value>(labelName));
}
auto nodeTableIDs = catalogContent->getNodeTableIDs();
auto varListTypeInfo =
std::make_unique<VarListTypeInfo>(std::make_unique<LogicalType>(LogicalTypeID::STRING));
auto listType =
std::make_unique<LogicalType>(LogicalTypeID::VAR_LIST, std::move(varListTypeInfo));
expression_vector children;
children.push_back(node.getInternalIDProperty());
auto labelsValue =
std::make_unique<Value>(LogicalType(LogicalTypeID::VAR_LIST,
std::make_unique<VarListTypeInfo>(
std::make_unique<LogicalType>(LogicalTypeID::STRING))),
populateLabelValues(nodeTableIDs, *catalogContent));
children.push_back(createLiteralExpression(std::move(labelsValue)));
switch (expression.getDataType().getLogicalTypeID()) {
case common::LogicalTypeID::NODE: {
auto& node = (NodeExpression&)expression;
if (!node.isMultiLabeled()) {
auto labelName = catalogContent->getTableName(node.getSingleTableID());
return createLiteralExpression(std::make_unique<Value>(labelName));
}
auto nodeTableIDs = catalogContent->getNodeTableIDs();
children.push_back(node.getInternalIDProperty());
auto labelsValue =
std::make_unique<Value>(*listType, populateLabelValues(nodeTableIDs, *catalogContent));
children.push_back(createLiteralExpression(std::move(labelsValue)));
} break;
case common::LogicalTypeID::REL: {
auto& rel = (RelExpression&)expression;
if (!rel.isMultiLabeled()) {
auto labelName = catalogContent->getTableName(rel.getSingleTableID());
return createLiteralExpression(std::make_unique<Value>(labelName));
}
auto relTableIDs = catalogContent->getRelTableIDs();
children.push_back(rel.getInternalIDProperty());
auto labelsValue =
std::make_unique<Value>(*listType, populateLabelValues(relTableIDs, *catalogContent));
children.push_back(createLiteralExpression(std::move(labelsValue)));
} break;
default:
throw NotImplementedException("ExpressionBinder::bindLabelFunction");
}
auto execFunc = function::LabelVectorOperation::execFunction;
auto bindData =
std::make_unique<function::FunctionBindData>(LogicalType(LogicalTypeID::STRING));
Expand All @@ -207,28 +232,22 @@ std::shared_ptr<Expression> ExpressionBinder::bindNodeLabelFunction(const Expres
std::move(children), execFunc, nullptr, uniqueExpressionName);
}

std::shared_ptr<Expression> ExpressionBinder::bindRelLabelFunction(const Expression& expression) {
auto catalogContent = binder->catalog.getReadOnlyVersion();
std::unique_ptr<Expression> ExpressionBinder::createInternalLengthExpression(
const Expression& expression) {
auto& rel = (RelExpression&)expression;
if (!rel.isMultiLabeled()) {
auto labelName = catalogContent->getTableName(rel.getSingleTableID());
return createLiteralExpression(std::make_unique<Value>(labelName));
std::unordered_map<table_id_t, property_id_t> propertyIDPerTable;
propertyIDPerTable.insert({rel.getSingleTableID(), INVALID_PROPERTY_ID});
return std::make_unique<PropertyExpression>(LogicalType(common::LogicalTypeID::INT64),
INTERNAL_LENGTH_SUFFIX, rel, std::move(propertyIDPerTable), false /* isPrimaryKey */);
}

std::shared_ptr<Expression> ExpressionBinder::bindRecursiveJoinLengthFunction(
const Expression& expression) {
if (expression.getDataType().getLogicalTypeID() != common::LogicalTypeID::RECURSIVE_REL) {
return nullptr;
}
auto relTableIDs = catalogContent->getRelTableIDs();
expression_vector children;
children.push_back(rel.getInternalIDProperty());
auto labelsValue =
std::make_unique<Value>(LogicalType(LogicalTypeID::VAR_LIST,
std::make_unique<VarListTypeInfo>(
std::make_unique<LogicalType>(LogicalTypeID::STRING))),
populateLabelValues(relTableIDs, *catalogContent));
children.push_back(createLiteralExpression(std::move(labelsValue)));
auto execFunc = function::LabelVectorOperation::execFunction;
auto bindData =
std::make_unique<function::FunctionBindData>(LogicalType(LogicalTypeID::STRING));
auto uniqueExpressionName = ScalarFunctionExpression::getUniqueName(LABEL_FUNC_NAME, children);
return make_shared<ScalarFunctionExpression>(LABEL_FUNC_NAME, FUNCTION, std::move(bindData),
std::move(children), execFunc, nullptr, uniqueExpressionName);
auto& rel = (RelExpression&)expression;
return rel.getInternalLengthExpression();
}

} // namespace binder
Expand Down
3 changes: 1 addition & 2 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,8 +471,7 @@ void LogicalType::setPhysicalType() {
physicalType = PhysicalTypeID::STRUCT;
} break;
default:
throw NotImplementedException{
"Unsupported LogicalType: " + LogicalTypeUtils::dataTypeToString(typeID) + "."};
throw NotImplementedException{"LogicalType::setPhysicalType()."};
}
}

Expand Down
10 changes: 0 additions & 10 deletions src/function/built_in_vector_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ void BuiltInVectorOperations::registerVectorOperations() {
registerStringOperations();
registerCastOperations();
registerListOperations();
registerInternalIDOperation();
registerStructOperation();
// register internal offset operation
vectorOperations.insert({OFFSET_FUNC_NAME, OffsetVectorOperation::getDefinitions()});
Expand Down Expand Up @@ -454,15 +453,6 @@ void BuiltInVectorOperations::registerListOperations() {
{LIST_ANY_VALUE_FUNC_NAME, ListAnyValueVectorOperation::getDefinitions()});
}

void BuiltInVectorOperations::registerInternalIDOperation() {
std::vector<std::unique_ptr<VectorOperationDefinition>> definitions;
definitions.push_back(make_unique<VectorOperationDefinition>(ID_FUNC_NAME,
std::vector<LogicalTypeID>{LogicalTypeID::NODE}, LogicalTypeID::INTERNAL_ID, nullptr));
definitions.push_back(make_unique<VectorOperationDefinition>(ID_FUNC_NAME,
std::vector<LogicalTypeID>{LogicalTypeID::REL}, LogicalTypeID::INTERNAL_ID, nullptr));
vectorOperations.insert({ID_FUNC_NAME, std::move(definitions)});
}

void BuiltInVectorOperations::registerStructOperation() {
vectorOperations.insert({STRUCT_PACK_FUNC_NAME, StructPackVectorOperations::getDefinitions()});
vectorOperations.insert(
Expand Down
3 changes: 0 additions & 3 deletions src/function/vector_list_operation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,6 @@ std::vector<std::unique_ptr<VectorOperationDefinition>> ListLenVectorOperation::
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_LEN_FUNC_NAME,
std::vector<LogicalTypeID>{LogicalTypeID::VAR_LIST}, LogicalTypeID::INT64, execFunc,
true /* isVarlength*/));
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_LEN_FUNC_NAME,
std::vector<LogicalTypeID>{LogicalTypeID::RECURSIVE_REL}, LogicalTypeID::INT64, execFunc,
true /* isVarlength*/));
return result;
}

Expand Down
3 changes: 0 additions & 3 deletions src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,6 @@ class Binder {

std::vector<std::string> bindFilePaths(const std::vector<std::string>& filePaths);

std::unordered_map<common::property_id_t, std::string> bindPropertyToNpyMap(
common::table_id_t tableId, const std::vector<std::string>& filePaths);

common::CSVReaderConfig bindParsingOptions(
const std::unordered_map<std::string, std::unique_ptr<parser::ParsedExpression>>*
parsingOptions);
Expand Down
8 changes: 8 additions & 0 deletions src/include/binder/expression/rel_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ class RelExpression : public NodeOrRelExpression {
inline std::shared_ptr<Expression> getInternalIDProperty() const {
return getPropertyExpression(common::INTERNAL_ID_SUFFIX);
}
inline void setInternalLengthExpression(std::unique_ptr<Expression> expression) {
internalLengthExpression = std::move(expression);
}
inline std::shared_ptr<Expression> getInternalLengthExpression() const {
assert(internalLengthExpression != nullptr);
return internalLengthExpression->copy();
}

private:
std::shared_ptr<NodeExpression> srcNode;
Expand All @@ -47,6 +54,7 @@ class RelExpression : public NodeOrRelExpression {
common::QueryRelType relType;
uint64_t lowerBound;
uint64_t upperBound;
std::unique_ptr<Expression> internalLengthExpression;
};

} // namespace binder
Expand Down
14 changes: 7 additions & 7 deletions src/include/binder/expression_binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,24 +46,24 @@ class ExpressionBinder {

std::shared_ptr<Expression> bindFunctionExpression(
const parser::ParsedExpression& parsedExpression);

std::shared_ptr<Expression> bindScalarFunctionExpression(
const parser::ParsedExpression& parsedExpression, const std::string& functionName);
std::shared_ptr<Expression> bindScalarFunctionExpression(
const expression_vector& children, const std::string& functionName);
std::shared_ptr<Expression> bindAggregateFunctionExpression(
const parser::ParsedExpression& parsedExpression, const std::string& functionName,
bool isDistinct);

std::shared_ptr<Expression> staticEvaluate(
const std::string& functionName, const expression_vector& children);

std::shared_ptr<Expression> bindInternalIDExpression(
const parser::ParsedExpression& parsedExpression);
std::shared_ptr<Expression> bindInternalIDExpression(const Expression& expression);
std::shared_ptr<Expression> rewriteFunctionExpression(
const parser::ParsedExpression& parsedExpression, const std::string& functionName);
std::unique_ptr<Expression> createInternalNodeIDExpression(const Expression& node);
std::shared_ptr<Expression> bindLabelFunction(const parser::ParsedExpression& parsedExpression);
std::shared_ptr<Expression> bindNodeLabelFunction(const Expression& expression);
std::shared_ptr<Expression> bindRelLabelFunction(const Expression& expression);
std::shared_ptr<Expression> bindInternalIDExpression(const Expression& expression);
std::shared_ptr<Expression> bindLabelFunction(const Expression& expression);
std::unique_ptr<Expression> createInternalLengthExpression(const Expression& expression);
std::shared_ptr<Expression> bindRecursiveJoinLengthFunction(const Expression& expression);

std::shared_ptr<Expression> bindParameterExpression(
const parser::ParsedExpression& parsedExpression);
Expand Down
2 changes: 0 additions & 2 deletions src/include/function/built_in_vector_operations.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@ class BuiltInVectorOperations {
void registerStringOperations();
void registerCastOperations();
void registerListOperations();
void registerInternalIDOperation();
void registerInternalOffsetOperation();
void registerStructOperation();

private:
Expand Down
8 changes: 4 additions & 4 deletions src/include/optimizer/projection_push_down_optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ namespace optimizer {

// ProjectionPushDownOptimizer implements the logic to avoid materializing unnecessary properties
// for hash join build.
// Note the optimization is for properties only but not for general expressions. This is because
// it's hard to figure out what expression is in-use, e.g. COUNT(a.age) + 1, it could be either the
// whole expression was evaluated in a WITH clause or only COUNT(a.age) was evaluated or only a.age
// is evaluate. For simplicity, we only consider the push down for property.
// Note the optimization is for properties & variables only but not for general expressions. This is
// because it's hard to figure out what expression is in-use, e.g. COUNT(a.age) + 1, it could be
// either the whole expression was evaluated in a WITH clause or only COUNT(a.age) was evaluated or
// only a.age is evaluate. For simplicity, we only consider the push down for property.
class ProjectionPushDownOptimizer : public LogicalOperatorVisitor {
public:
void rewrite(planner::LogicalPlan* plan);
Expand Down
Loading

0 comments on commit 09f0cc7

Please sign in to comment.