Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Push recursive path length into operator #1555

Merged
merged 1 commit into from
May 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 0 additions & 18 deletions src/binder/bind/bind_copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,24 +57,6 @@ std::vector<std::string> Binder::bindFilePaths(const std::vector<std::string>& f
return boundFilePaths;
}

std::unordered_map<common::property_id_t, std::string> Binder::bindPropertyToNpyMap(
common::table_id_t tableID, const std::vector<std::string>& filePaths) {
auto catalogContent = catalog.getReadOnlyVersion();
auto tableSchema = catalogContent->getTableSchema(tableID);
if (tableSchema->properties.size() != filePaths.size()) {
throw BinderException(StringUtils::string_format(
"Number of npy files is not equal to number of properties in table {}.",
tableSchema->tableName));
}
std::unordered_map<common::property_id_t, std::string> propertyIDToNpyMap;
for (int i = 0; i < filePaths.size(); i++) {
auto& filePath = filePaths[i];
auto& propertyID = tableSchema->properties[i].propertyID;
propertyIDToNpyMap[propertyID] = filePath;
}
return propertyIDToNpyMap;
}

CSVReaderConfig Binder::bindParsingOptions(
const std::unordered_map<std::string, std::unique_ptr<ParsedExpression>>* parsingOptions) {
CSVReaderConfig csvReaderConfig;
Expand Down
4 changes: 4 additions & 0 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,10 @@ void Binder::bindQueryRel(const RelPattern& relPattern,
getUniqueExpressionName(parsedName), parsedName, tableIDs, srcNode, dstNode,
relPattern.getDirection() != BOTH, relPattern.getRelType(), lowerBound, upperBound);
queryRel->setAlias(parsedName);
if (isVariableLength) {
queryRel->setInternalLengthExpression(
expressionBinder.createInternalLengthExpression(*queryRel));
}
// resolve properties associate with rel table
std::vector<RelTableSchema*> relTableSchemas;
for (auto tableID : tableIDs) {
Expand Down
4 changes: 2 additions & 2 deletions src/binder/bind/bind_projection_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ expression_vector Binder::rewriteNodeExpression(const kuzu::binder::Expression&
expression_vector result;
auto& node = (NodeExpression&)expression;
result.push_back(node.getInternalIDProperty());
result.push_back(expressionBinder.bindNodeLabelFunction(node));
result.push_back(expressionBinder.bindLabelFunction(node));
for (auto& property : node.getPropertyExpressions()) {
result.push_back(property->copy());
}
Expand All @@ -92,7 +92,7 @@ expression_vector Binder::rewriteRelExpression(const Expression& expression) {
auto& rel = (RelExpression&)expression;
result.push_back(rel.getSrcNode()->getInternalIDProperty());
result.push_back(rel.getDstNode()->getInternalIDProperty());
result.push_back(expressionBinder.bindRelLabelFunction(rel));
result.push_back(expressionBinder.bindLabelFunction(rel));
for (auto& property : rel.getPropertyExpressions()) {
result.push_back(property->copy());
}
Expand Down
153 changes: 86 additions & 67 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@ std::shared_ptr<Expression> ExpressionBinder::bindFunctionExpression(
auto& parsedFunctionExpression = (ParsedFunctionExpression&)parsedExpression;
auto functionName = parsedFunctionExpression.getFunctionName();
StringUtils::toUpper(functionName);
// check for special function binding
if (functionName == ID_FUNC_NAME) {
return bindInternalIDExpression(parsedExpression);
} else if (functionName == LABEL_FUNC_NAME) {
return bindLabelFunction(parsedExpression);
auto result = rewriteFunctionExpression(parsedExpression, functionName);
andyfengHKU marked this conversation as resolved.
Show resolved Hide resolved
if (result != nullptr) {
return result;
}
auto functionType = binder->catalog.getFunctionType(functionName);
if (functionType == FUNCTION) {
Expand Down Expand Up @@ -123,23 +121,29 @@ std::shared_ptr<Expression> ExpressionBinder::staticEvaluate(
return createLiteralExpression(std::move(value));
}

std::shared_ptr<Expression> ExpressionBinder::bindInternalIDExpression(
const ParsedExpression& parsedExpression) {
auto child = bindExpression(*parsedExpression.getChild(0));
validateExpectedDataType(
*child, std::unordered_set<LogicalTypeID>{LogicalTypeID::NODE, LogicalTypeID::REL});
return bindInternalIDExpression(*child);
}

std::shared_ptr<Expression> ExpressionBinder::bindInternalIDExpression(
const Expression& expression) {
if (expression.dataType.getLogicalTypeID() == LogicalTypeID::NODE) {
auto& node = (NodeExpression&)expression;
return node.getInternalIDProperty();
} else {
assert(expression.dataType.getLogicalTypeID() == LogicalTypeID::REL);
return bindRelPropertyExpression(expression, INTERNAL_ID_SUFFIX);
}
// Function rewriting happens when we need to expose internal property access through function so
// that it becomes read-only or the function involves catalog information. Currently we write
// Before | After
// ID(a) | a._id
// LABEL(a) | LIST_EXTRACT(offset(a), [table names from catalog])
// LENGTH(e) | e._length
std::shared_ptr<Expression> ExpressionBinder::rewriteFunctionExpression(
const parser::ParsedExpression& parsedExpression, const std::string& functionName) {
if (functionName == ID_FUNC_NAME) {
auto child = bindExpression(*parsedExpression.getChild(0));
validateExpectedDataType(
*child, std::unordered_set<LogicalTypeID>{LogicalTypeID::NODE, LogicalTypeID::REL});
return bindInternalIDExpression(*child);
} else if (functionName == LABEL_FUNC_NAME) {
auto child = bindExpression(*parsedExpression.getChild(0));
validateExpectedDataType(
*child, std::unordered_set<LogicalTypeID>{LogicalTypeID::NODE, LogicalTypeID::REL});
return bindLabelFunction(*child);
} else if (functionName == LENGTH_FUNC_NAME) {
auto child = bindExpression(*parsedExpression.getChild(0));
return bindRecursiveJoinLengthFunction(*child);
}
return nullptr;
}

std::unique_ptr<Expression> ExpressionBinder::createInternalNodeIDExpression(
Expand All @@ -149,20 +153,22 @@ std::unique_ptr<Expression> ExpressionBinder::createInternalNodeIDExpression(
for (auto tableID : node.getTableIDs()) {
propertyIDPerTable.insert({tableID, INVALID_PROPERTY_ID});
}
auto result = std::make_unique<PropertyExpression>(LogicalType(LogicalTypeID::INTERNAL_ID),
return std::make_unique<PropertyExpression>(LogicalType(LogicalTypeID::INTERNAL_ID),
INTERNAL_ID_SUFFIX, node, std::move(propertyIDPerTable), false /* isPrimaryKey */);
return result;
}

std::shared_ptr<Expression> ExpressionBinder::bindLabelFunction(
const ParsedExpression& parsedExpression) {
// bind child node
auto child = bindExpression(*parsedExpression.getChild(0));
if (child->dataType.getLogicalTypeID() == common::LogicalTypeID::NODE) {
return bindNodeLabelFunction(*child);
} else {
assert(child->dataType.getLogicalTypeID() == common::LogicalTypeID::REL);
return bindRelLabelFunction(*child);
std::shared_ptr<Expression> ExpressionBinder::bindInternalIDExpression(
const Expression& expression) {
switch (expression.getDataType().getLogicalTypeID()) {
case common::LogicalTypeID::NODE: {
auto& node = (NodeExpression&)expression;
return node.getInternalIDProperty();
}
case common::LogicalTypeID::REL: {
return bindRelPropertyExpression(expression, INTERNAL_ID_SUFFIX);
}
default:
throw NotImplementedException("ExpressionBinder::bindInternalIDExpression");
}
}

Expand All @@ -183,22 +189,41 @@ static std::vector<std::unique_ptr<Value>> populateLabelValues(
return labels;
}

std::shared_ptr<Expression> ExpressionBinder::bindNodeLabelFunction(const Expression& expression) {
std::shared_ptr<Expression> ExpressionBinder::bindLabelFunction(const Expression& expression) {
auto catalogContent = binder->catalog.getReadOnlyVersion();
auto& node = (NodeExpression&)expression;
if (!node.isMultiLabeled()) {
auto labelName = catalogContent->getTableName(node.getSingleTableID());
return createLiteralExpression(std::make_unique<Value>(labelName));
}
auto nodeTableIDs = catalogContent->getNodeTableIDs();
auto varListTypeInfo =
std::make_unique<VarListTypeInfo>(std::make_unique<LogicalType>(LogicalTypeID::STRING));
auto listType =
std::make_unique<LogicalType>(LogicalTypeID::VAR_LIST, std::move(varListTypeInfo));
expression_vector children;
children.push_back(node.getInternalIDProperty());
auto labelsValue =
std::make_unique<Value>(LogicalType(LogicalTypeID::VAR_LIST,
std::make_unique<VarListTypeInfo>(
std::make_unique<LogicalType>(LogicalTypeID::STRING))),
populateLabelValues(nodeTableIDs, *catalogContent));
children.push_back(createLiteralExpression(std::move(labelsValue)));
switch (expression.getDataType().getLogicalTypeID()) {
case common::LogicalTypeID::NODE: {
auto& node = (NodeExpression&)expression;
if (!node.isMultiLabeled()) {
auto labelName = catalogContent->getTableName(node.getSingleTableID());
return createLiteralExpression(std::make_unique<Value>(labelName));
}
auto nodeTableIDs = catalogContent->getNodeTableIDs();
children.push_back(node.getInternalIDProperty());
auto labelsValue =
std::make_unique<Value>(*listType, populateLabelValues(nodeTableIDs, *catalogContent));
children.push_back(createLiteralExpression(std::move(labelsValue)));
} break;
case common::LogicalTypeID::REL: {
auto& rel = (RelExpression&)expression;
if (!rel.isMultiLabeled()) {
auto labelName = catalogContent->getTableName(rel.getSingleTableID());
return createLiteralExpression(std::make_unique<Value>(labelName));
}
auto relTableIDs = catalogContent->getRelTableIDs();
children.push_back(rel.getInternalIDProperty());
auto labelsValue =
std::make_unique<Value>(*listType, populateLabelValues(relTableIDs, *catalogContent));
children.push_back(createLiteralExpression(std::move(labelsValue)));
} break;
default:
throw NotImplementedException("ExpressionBinder::bindLabelFunction");
}
auto execFunc = function::LabelVectorOperation::execFunction;
auto bindData =
std::make_unique<function::FunctionBindData>(LogicalType(LogicalTypeID::STRING));
Expand All @@ -207,28 +232,22 @@ std::shared_ptr<Expression> ExpressionBinder::bindNodeLabelFunction(const Expres
std::move(children), execFunc, nullptr, uniqueExpressionName);
}

std::shared_ptr<Expression> ExpressionBinder::bindRelLabelFunction(const Expression& expression) {
auto catalogContent = binder->catalog.getReadOnlyVersion();
std::unique_ptr<Expression> ExpressionBinder::createInternalLengthExpression(
const Expression& expression) {
auto& rel = (RelExpression&)expression;
if (!rel.isMultiLabeled()) {
auto labelName = catalogContent->getTableName(rel.getSingleTableID());
return createLiteralExpression(std::make_unique<Value>(labelName));
std::unordered_map<table_id_t, property_id_t> propertyIDPerTable;
propertyIDPerTable.insert({rel.getSingleTableID(), INVALID_PROPERTY_ID});
return std::make_unique<PropertyExpression>(LogicalType(common::LogicalTypeID::INT64),
INTERNAL_LENGTH_SUFFIX, rel, std::move(propertyIDPerTable), false /* isPrimaryKey */);
}

std::shared_ptr<Expression> ExpressionBinder::bindRecursiveJoinLengthFunction(
const Expression& expression) {
if (expression.getDataType().getLogicalTypeID() != common::LogicalTypeID::RECURSIVE_REL) {
return nullptr;
}
auto relTableIDs = catalogContent->getRelTableIDs();
expression_vector children;
children.push_back(rel.getInternalIDProperty());
auto labelsValue =
std::make_unique<Value>(LogicalType(LogicalTypeID::VAR_LIST,
std::make_unique<VarListTypeInfo>(
std::make_unique<LogicalType>(LogicalTypeID::STRING))),
populateLabelValues(relTableIDs, *catalogContent));
children.push_back(createLiteralExpression(std::move(labelsValue)));
auto execFunc = function::LabelVectorOperation::execFunction;
auto bindData =
std::make_unique<function::FunctionBindData>(LogicalType(LogicalTypeID::STRING));
auto uniqueExpressionName = ScalarFunctionExpression::getUniqueName(LABEL_FUNC_NAME, children);
return make_shared<ScalarFunctionExpression>(LABEL_FUNC_NAME, FUNCTION, std::move(bindData),
std::move(children), execFunc, nullptr, uniqueExpressionName);
auto& rel = (RelExpression&)expression;
andyfengHKU marked this conversation as resolved.
Show resolved Hide resolved
return rel.getInternalLengthExpression();
}

} // namespace binder
Expand Down
3 changes: 1 addition & 2 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,8 +471,7 @@ void LogicalType::setPhysicalType() {
physicalType = PhysicalTypeID::STRUCT;
} break;
default:
throw NotImplementedException{
"Unsupported LogicalType: " + LogicalTypeUtils::dataTypeToString(typeID) + "."};
throw NotImplementedException{"LogicalType::setPhysicalType()."};
}
}

Expand Down
10 changes: 0 additions & 10 deletions src/function/built_in_vector_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ void BuiltInVectorOperations::registerVectorOperations() {
registerStringOperations();
registerCastOperations();
registerListOperations();
registerInternalIDOperation();
registerStructOperation();
// register internal offset operation
vectorOperations.insert({OFFSET_FUNC_NAME, OffsetVectorOperation::getDefinitions()});
Expand Down Expand Up @@ -454,15 +453,6 @@ void BuiltInVectorOperations::registerListOperations() {
{LIST_ANY_VALUE_FUNC_NAME, ListAnyValueVectorOperation::getDefinitions()});
}

void BuiltInVectorOperations::registerInternalIDOperation() {
std::vector<std::unique_ptr<VectorOperationDefinition>> definitions;
definitions.push_back(make_unique<VectorOperationDefinition>(ID_FUNC_NAME,
std::vector<LogicalTypeID>{LogicalTypeID::NODE}, LogicalTypeID::INTERNAL_ID, nullptr));
definitions.push_back(make_unique<VectorOperationDefinition>(ID_FUNC_NAME,
std::vector<LogicalTypeID>{LogicalTypeID::REL}, LogicalTypeID::INTERNAL_ID, nullptr));
vectorOperations.insert({ID_FUNC_NAME, std::move(definitions)});
}

void BuiltInVectorOperations::registerStructOperation() {
vectorOperations.insert({STRUCT_PACK_FUNC_NAME, StructPackVectorOperations::getDefinitions()});
vectorOperations.insert(
Expand Down
3 changes: 0 additions & 3 deletions src/function/vector_list_operation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,6 @@ std::vector<std::unique_ptr<VectorOperationDefinition>> ListLenVectorOperation::
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_LEN_FUNC_NAME,
std::vector<LogicalTypeID>{LogicalTypeID::VAR_LIST}, LogicalTypeID::INT64, execFunc,
true /* isVarlength*/));
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_LEN_FUNC_NAME,
std::vector<LogicalTypeID>{LogicalTypeID::RECURSIVE_REL}, LogicalTypeID::INT64, execFunc,
true /* isVarlength*/));
return result;
}

Expand Down
3 changes: 0 additions & 3 deletions src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,6 @@ class Binder {

std::vector<std::string> bindFilePaths(const std::vector<std::string>& filePaths);

std::unordered_map<common::property_id_t, std::string> bindPropertyToNpyMap(
common::table_id_t tableId, const std::vector<std::string>& filePaths);

common::CSVReaderConfig bindParsingOptions(
const std::unordered_map<std::string, std::unique_ptr<parser::ParsedExpression>>*
parsingOptions);
Expand Down
8 changes: 8 additions & 0 deletions src/include/binder/expression/rel_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ class RelExpression : public NodeOrRelExpression {
inline std::shared_ptr<Expression> getInternalIDProperty() const {
return getPropertyExpression(common::INTERNAL_ID_SUFFIX);
}
inline void setInternalLengthExpression(std::unique_ptr<Expression> expression) {
internalLengthExpression = std::move(expression);
}
inline std::shared_ptr<Expression> getInternalLengthExpression() const {
assert(internalLengthExpression != nullptr);
return internalLengthExpression->copy();
}

private:
std::shared_ptr<NodeExpression> srcNode;
Expand All @@ -47,6 +54,7 @@ class RelExpression : public NodeOrRelExpression {
common::QueryRelType relType;
uint64_t lowerBound;
uint64_t upperBound;
std::unique_ptr<Expression> internalLengthExpression;
};

} // namespace binder
Expand Down
14 changes: 7 additions & 7 deletions src/include/binder/expression_binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,24 +46,24 @@ class ExpressionBinder {

std::shared_ptr<Expression> bindFunctionExpression(
const parser::ParsedExpression& parsedExpression);

std::shared_ptr<Expression> bindScalarFunctionExpression(
const parser::ParsedExpression& parsedExpression, const std::string& functionName);
std::shared_ptr<Expression> bindScalarFunctionExpression(
const expression_vector& children, const std::string& functionName);
std::shared_ptr<Expression> bindAggregateFunctionExpression(
const parser::ParsedExpression& parsedExpression, const std::string& functionName,
bool isDistinct);

std::shared_ptr<Expression> staticEvaluate(
const std::string& functionName, const expression_vector& children);

std::shared_ptr<Expression> bindInternalIDExpression(
const parser::ParsedExpression& parsedExpression);
std::shared_ptr<Expression> bindInternalIDExpression(const Expression& expression);
std::shared_ptr<Expression> rewriteFunctionExpression(
const parser::ParsedExpression& parsedExpression, const std::string& functionName);
std::unique_ptr<Expression> createInternalNodeIDExpression(const Expression& node);
std::shared_ptr<Expression> bindLabelFunction(const parser::ParsedExpression& parsedExpression);
std::shared_ptr<Expression> bindNodeLabelFunction(const Expression& expression);
std::shared_ptr<Expression> bindRelLabelFunction(const Expression& expression);
std::shared_ptr<Expression> bindInternalIDExpression(const Expression& expression);
std::shared_ptr<Expression> bindLabelFunction(const Expression& expression);
std::unique_ptr<Expression> createInternalLengthExpression(const Expression& expression);
std::shared_ptr<Expression> bindRecursiveJoinLengthFunction(const Expression& expression);

std::shared_ptr<Expression> bindParameterExpression(
const parser::ParsedExpression& parsedExpression);
Expand Down
2 changes: 0 additions & 2 deletions src/include/function/built_in_vector_operations.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@ class BuiltInVectorOperations {
void registerStringOperations();
void registerCastOperations();
void registerListOperations();
void registerInternalIDOperation();
void registerInternalOffsetOperation();
void registerStructOperation();

private:
Expand Down
8 changes: 4 additions & 4 deletions src/include/optimizer/projection_push_down_optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ namespace optimizer {

// ProjectionPushDownOptimizer implements the logic to avoid materializing unnecessary properties
// for hash join build.
// Note the optimization is for properties only but not for general expressions. This is because
// it's hard to figure out what expression is in-use, e.g. COUNT(a.age) + 1, it could be either the
// whole expression was evaluated in a WITH clause or only COUNT(a.age) was evaluated or only a.age
// is evaluate. For simplicity, we only consider the push down for property.
// Note the optimization is for properties & variables only but not for general expressions. This is
// because it's hard to figure out what expression is in-use, e.g. COUNT(a.age) + 1, it could be
// either the whole expression was evaluated in a WITH clause or only COUNT(a.age) was evaluated or
// only a.age is evaluate. For simplicity, we only consider the push down for property.
class ProjectionPushDownOptimizer : public LogicalOperatorVisitor {
public:
void rewrite(planner::LogicalPlan* plan);
Expand Down
Loading