Skip to content

Commit

Permalink
Add node evaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Jul 2, 2023
1 parent 3c20011 commit 32b1e8f
Show file tree
Hide file tree
Showing 36 changed files with 450 additions and 280 deletions.
21 changes: 17 additions & 4 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,9 @@ std::shared_ptr<NodeExpression> Binder::bindQueryNode(
}
} else {
queryNode = createQueryNode(nodePattern);
if (!parsedName.empty()) {
variableScope->addExpression(parsedName, queryNode);
}
}
for (auto& [propertyName, rhs] : nodePattern.getPropertyKeyVals()) {
auto boundLhs = expressionBinder.bindNodePropertyExpression(*queryNode, propertyName);
Expand All @@ -348,14 +351,24 @@ std::shared_ptr<NodeExpression> Binder::createQueryNode(const NodePattern& nodeP

std::shared_ptr<NodeExpression> Binder::createQueryNode(
const std::string& parsedName, const std::vector<common::table_id_t>& tableIDs) {
auto queryNode =
make_shared<NodeExpression>(getUniqueExpressionName(parsedName), parsedName, tableIDs);
auto queryNode = make_shared<NodeExpression>(LogicalType(common::LogicalTypeID::NODE),
getUniqueExpressionName(parsedName), parsedName, tableIDs);
queryNode->setAlias(parsedName);
queryNode->setInternalIDProperty(expressionBinder.createInternalNodeIDExpression(*queryNode));
queryNode->setLabelExpression(expressionBinder.bindLabelFunction(*queryNode));
bindQueryNodeProperties(*queryNode);
if (!parsedName.empty()) {
variableScope->addExpression(parsedName, queryNode);
std::vector<std::unique_ptr<StructField>> nodeFields;
nodeFields.push_back(std::make_unique<StructField>(
InternalKeyword::ID, std::make_unique<LogicalType>(LogicalTypeID::INTERNAL_ID)));
nodeFields.push_back(std::make_unique<StructField>(
InternalKeyword::LABEL, std::make_unique<LogicalType>(LogicalTypeID::STRING)));
for (auto& expression : queryNode->getPropertyExpressions()) {
auto propertyExpression = (PropertyExpression*)expression.get();
nodeFields.push_back(std::make_unique<StructField>(
propertyExpression->getPropertyName(), propertyExpression->getDataType().copy()));
}
common::NodeType::setStructTypeInfo(
queryNode->getDataTypeReference(), std::make_unique<StructTypeInfo>(std::move(nodeFields)));
return queryNode;
}

Expand Down
21 changes: 3 additions & 18 deletions src/binder/bind/bind_projection_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ std::unique_ptr<BoundReturnClause> Binder::bindReturnClause(const ReturnClause&
auto statementResult = std::make_unique<BoundStatementResult>();
for (auto& expression : boundProjectionExpressions) {
auto dataType = expression->getDataType();
if (dataType.getLogicalTypeID() == common::LogicalTypeID::NODE ||
dataType.getLogicalTypeID() == common::LogicalTypeID::REL) {
if (dataType.getLogicalTypeID() == common::LogicalTypeID::REL) {
statementResult->addColumn(expression, rewriteNodeOrRelExpression(*expression));
} else {
statementResult->addColumn(expression, expression_vector{expression});
Expand Down Expand Up @@ -164,23 +163,9 @@ expression_vector Binder::bindProjectionExpressions(
}

expression_vector Binder::rewriteNodeOrRelExpression(const Expression& expression) {
if (expression.dataType.getLogicalTypeID() == common::LogicalTypeID::NODE) {
return rewriteNodeExpression(expression);
} else {
assert(expression.dataType.getLogicalTypeID() == common::LogicalTypeID::REL);
return rewriteRelExpression(expression);
}
}

expression_vector Binder::rewriteNodeExpression(const kuzu::binder::Expression& expression) {
expression_vector result;
auto& node = (NodeExpression&)expression;
result.push_back(node.getInternalIDProperty());
result.push_back(expressionBinder.bindLabelFunction(node));
for (auto& property : node.getPropertyExpressions()) {
result.push_back(property->copy());
}
return result;
assert(expression.dataType.getLogicalTypeID() == common::LogicalTypeID::REL);
return rewriteRelExpression(expression);
}

expression_vector Binder::rewriteRelExpression(const Expression& expression) {
Expand Down
22 changes: 18 additions & 4 deletions src/common/types/value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ Value Value::createDefaultValue(const LogicalType& dataType) {
case LogicalTypeID::FIXED_LIST:
case LogicalTypeID::UNION:
case LogicalTypeID::STRUCT:
case LogicalTypeID::NODE:
return Value(dataType, std::vector<std::unique_ptr<Value>>{});
default:
throw RuntimeException("Data type " + LogicalTypeUtils::dataTypeToString(dataType) +
Expand Down Expand Up @@ -333,8 +334,23 @@ std::string Value::toString() const {
std::string result = "{";
auto fieldNames = StructType::getFieldNames(&dataType);
for (auto i = 0u; i < nestedTypeVal.size(); ++i) {
result += fieldNames[i];
result += ": ";
result += fieldNames[i] + ": ";
result += nestedTypeVal[i]->toString();
if (i != nestedTypeVal.size() - 1) {
result += ", ";
}
}
result += "}";
return result;
}
case LogicalTypeID::NODE: {
std::string result = "{";
auto fieldNames = StructType::getFieldNames(&dataType);
for (auto i = 0u; i < nestedTypeVal.size(); ++i) {
if (nestedTypeVal[i]->isNull_) {
continue;
}
result += fieldNames[i] + ": ";
result += nestedTypeVal[i]->toString();
if (i != nestedTypeVal.size() - 1) {
result += ", ";
Expand All @@ -343,8 +359,6 @@ std::string Value::toString() const {
result += "}";
return result;
}
case LogicalTypeID::NODE:
return nodeVal->toString();
case LogicalTypeID::REL:
return relVal->toString();
default:
Expand Down
1 change: 1 addition & 0 deletions src/expression_evaluator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ add_library(kuzu_expression_evaluator
case_evaluator.cpp
function_evaluator.cpp
literal_evaluator.cpp
node_rel_evaluator.cpp
reference_evaluator.cpp)

set(ALL_OBJECT_FILES
Expand Down
4 changes: 1 addition & 3 deletions src/expression_evaluator/function_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,11 @@ std::unique_ptr<BaseExpressionEvaluator> FunctionExpressionEvaluator::clone() {

void FunctionExpressionEvaluator::resolveResultVector(
const ResultSet& resultSet, MemoryManager* memoryManager) {
for (auto& child : children) {
parameters.push_back(child->resultVector);
}
resultVector = std::make_shared<ValueVector>(expression->dataType, memoryManager);
std::vector<BaseExpressionEvaluator*> inputEvaluators;
inputEvaluators.reserve(children.size());
for (auto& child : children) {
parameters.push_back(child->resultVector);
inputEvaluators.push_back(child.get());
}
resolveResultStateFromChildren(inputEvaluators);
Expand Down
32 changes: 32 additions & 0 deletions src/expression_evaluator/node_rel_evaluator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#include "expression_evaluator/node_rel_evaluator.h"

#include "function/struct/vector_struct_operations.h"

using namespace kuzu::common;
using namespace kuzu::function;

namespace kuzu {
namespace evaluator {

void NodeExpressionEvaluator::evaluate() {
for (auto& child : children) {
child->evaluate();
}
StructPackVectorOperations::execFunc(parameters, *resultVector);
}

void NodeExpressionEvaluator::resolveResultVector(
const processor::ResultSet& resultSet, storage::MemoryManager* memoryManager) {
resultVector = std::make_shared<ValueVector>(node->getDataType(), memoryManager);
std::vector<BaseExpressionEvaluator*> inputEvaluators;
inputEvaluators.reserve(children.size());
for (auto& child : children) {
parameters.push_back(child->resultVector);
inputEvaluators.push_back(child.get());
}
resolveResultStateFromChildren(inputEvaluators);
StructPackVectorOperations::compileFunc(nullptr, parameters, resultVector);
}

} // namespace evaluator
} // namespace kuzu
15 changes: 8 additions & 7 deletions src/function/vector_struct_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ void StructPackVectorOperations::execFunc(
}
// If the parameter's state is inconsistent with the result's state, we need to copy the
// parameter's value to the corresponding child vector.
copyParameterValueToStructFieldVector(
parameter.get(), common::StructVector::getFieldVector(&result, i).get());
copyParameterValueToStructFieldVector(parameter.get(),
common::StructVector::getFieldVector(&result, i).get(), result.state.get());
}
}

Expand All @@ -63,15 +63,16 @@ void StructPackVectorOperations::compileFunc(FunctionBindData* bindData,
}

void StructPackVectorOperations::copyParameterValueToStructFieldVector(
const common::ValueVector* parameter, common::ValueVector* structField) {
const common::ValueVector* parameter, common::ValueVector* structField,
common::DataChunkState* structVectorState) {
// If the parameter is unFlat, then its state must be consistent with the result's state.
// Thus, we don't need to copy values to structFieldVector.
assert(parameter->state->isFlat());
auto srcPos = parameter->state->selVector->selectedPositions[0];
auto srcValue = parameter->getData() + parameter->getNumBytesPerValue() * srcPos;
bool isSrcValueNull = parameter->isNull(srcPos);
if (structField->state->isFlat()) {
auto pos = structField->state->selVector->selectedPositions[0];
if (structVectorState->isFlat()) {
auto pos = structVectorState->selVector->selectedPositions[0];
if (isSrcValueNull) {
structField->setNull(pos, true /* isNull */);
} else {
Expand All @@ -80,8 +81,8 @@ void StructPackVectorOperations::copyParameterValueToStructFieldVector(
srcValue);
}
} else {
for (auto j = 0u; j < structField->state->selVector->selectedSize; j++) {
auto pos = structField->state->selVector->selectedPositions[j];
for (auto j = 0u; j < structVectorState->selVector->selectedSize; j++) {
auto pos = structVectorState->selVector->selectedPositions[j];
if (isSrcValueNull) {
structField->setNull(pos, true /* isNull */);
} else {
Expand Down
1 change: 0 additions & 1 deletion src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ class Binder {
const parser::parsed_expression_vector& parsedExpressions, bool star);
// Rewrite variable "v" as all properties of "v"
expression_vector rewriteNodeOrRelExpression(const Expression& expression);
expression_vector rewriteNodeExpression(const Expression& expression);
expression_vector rewriteRelExpression(const Expression& expression);

expression_vector bindOrderByExpressions(
Expand Down
1 change: 1 addition & 0 deletions src/include/binder/expression/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class Expression : public std::enable_shared_from_this<Expression> {
}

inline common::LogicalType getDataType() const { return dataType; }
inline common::LogicalType& getDataTypeReference() { return dataType; }

inline bool hasAlias() const { return !alias.empty(); }

Expand Down
14 changes: 10 additions & 4 deletions src/include/binder/expression/node_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ namespace binder {

class NodeExpression : public NodeOrRelExpression {
public:
NodeExpression(
std::string uniqueName, std::string variableName, std::vector<common::table_id_t> tableIDs)
: NodeOrRelExpression{common::LogicalType(common::LogicalTypeID::NODE),
std::move(uniqueName), std::move(variableName), std::move(tableIDs)} {}
NodeExpression(common::LogicalType dataType, std::string uniqueName, std::string variableName,
std::vector<common::table_id_t> tableIDs)
: NodeOrRelExpression{std::move(dataType), std::move(uniqueName), std::move(variableName),
std::move(tableIDs)} {}

inline void setInternalIDProperty(std::unique_ptr<Expression> expression) {
internalIDExpression = std::move(expression);
Expand All @@ -24,6 +24,12 @@ class NodeExpression : public NodeOrRelExpression {
return internalIDExpression->getUniqueName();
}

expression_vector getChildren() const override {
auto result = NodeOrRelExpression::getChildren();
result.push_back(internalIDExpression->copy());
return result;
}

private:
std::unique_ptr<Expression> internalIDExpression;
};
Expand Down
15 changes: 15 additions & 0 deletions src/include/binder/expression/node_rel_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,28 @@ class NodeOrRelExpression : public Expression {
return properties;
}

// TODO: move into constructor
inline void setLabelExpression(std::shared_ptr<Expression> expression) {
labelExpression = std::move(expression);
}
inline std::shared_ptr<Expression> getLabelExpression() const { return labelExpression; }

expression_vector getChildren() const override {
expression_vector result;
for (auto& property : properties) {
result.push_back(property->copy());
}
return result;
}

std::string toString() const override { return variableName; }

protected:
std::string variableName;
std::vector<common::table_id_t> tableIDs;
std::unordered_map<std::string, common::vector_idx_t> propertyNameToIdx;
std::vector<std::unique_ptr<Expression>> properties;
std::shared_ptr<Expression> labelExpression;
};

} // namespace binder
Expand Down
1 change: 1 addition & 0 deletions src/include/common/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ constexpr uint64_t DEFAULT_CHECKPOINT_WAIT_TIMEOUT_FOR_TRANSACTIONS_TO_LEAVE_IN_
struct InternalKeyword {
static constexpr char ANONYMOUS[] = "";
static constexpr char ID[] = "_id";
static constexpr char LABEL[] = "_label";
static constexpr char LENGTH[] = "_length";
static constexpr char NODES[] = "_nodes";
static constexpr char RELS[] = "_rels";
Expand Down
12 changes: 12 additions & 0 deletions src/include/common/types/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,10 @@ class LogicalType {

inline PhysicalTypeID getPhysicalType() const { return physicalType; }

inline void setExtraTypeInfo(std::unique_ptr<ExtraTypeInfo> typeInfo) {
extraTypeInfo = std::move(typeInfo);
}

std::unique_ptr<LogicalType> copy();

private:
Expand Down Expand Up @@ -257,6 +261,14 @@ struct FixedListType {
}
};

struct NodeType {
static inline void setStructTypeInfo(
LogicalType& type, std::unique_ptr<ExtraTypeInfo> extraTypeInfo) {
assert(type.getLogicalTypeID() == LogicalTypeID::NODE);
type.setExtraTypeInfo(std::move(extraTypeInfo));
}
};

struct StructType {
static inline std::vector<LogicalType*> getFieldTypes(const LogicalType* type) {
assert(type->getPhysicalType() == PhysicalTypeID::STRUCT);
Expand Down
39 changes: 39 additions & 0 deletions src/include/expression_evaluator/node_rel_evaluator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#pragma once

#include "base_evaluator.h"
#include "binder/expression/node_expression.h"

namespace kuzu {
namespace evaluator {

class NodeExpressionEvaluator : public BaseExpressionEvaluator {
public:
NodeExpressionEvaluator(std::shared_ptr<binder::Expression> node,
std::vector<std::unique_ptr<BaseExpressionEvaluator>> children)
: BaseExpressionEvaluator{std::move(children)}, node{std::move(node)} {}

void evaluate() override;

bool select(common::SelectionVector& selVector) override {
throw common::NotImplementedException("NodeExpressionEvaluator::select");
}

std::unique_ptr<BaseExpressionEvaluator> clone() override {
std::vector<std::unique_ptr<BaseExpressionEvaluator>> clonedChildren;
for (auto& child : children) {
clonedChildren.push_back(child->clone());
}
return make_unique<NodeExpressionEvaluator>(node, std::move(clonedChildren));
}

private:
void resolveResultVector(
const processor::ResultSet& resultSet, storage::MemoryManager* memoryManager) override;

private:
std::shared_ptr<binder::Expression> node;
std::vector<std::shared_ptr<common::ValueVector>> parameters;
};

} // namespace evaluator
} // namespace kuzu
4 changes: 2 additions & 2 deletions src/include/function/struct/vector_struct_operations.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ struct StructPackVectorOperations {
static void compileFunc(FunctionBindData* bindData,
const std::vector<std::shared_ptr<common::ValueVector>>& parameters,
std::shared_ptr<common::ValueVector>& result);
static void copyParameterValueToStructFieldVector(
const common::ValueVector* parameter, common::ValueVector* structField);
static void copyParameterValueToStructFieldVector(const common::ValueVector* parameter,
common::ValueVector* structField, common::DataChunkState* structVectorState);
};

struct StructExtractBindData : public FunctionBindData {
Expand Down
5 changes: 3 additions & 2 deletions src/include/planner/projection_planner.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ class ProjectionPlanner {
explicit ProjectionPlanner(QueryPlanner* queryPlanner) : queryPlanner{queryPlanner} {}

void planProjectionBody(const binder::BoundProjectionBody& projectionBody,
const std::vector<std::unique_ptr<LogicalPlan>>& plans);
const std::vector<std::unique_ptr<LogicalPlan>>& plans, bool isLastProjection);

private:
void planProjectionBody(const binder::BoundProjectionBody& projectionBody, LogicalPlan& plan);
void planProjectionBody(const binder::BoundProjectionBody& projectionBody, LogicalPlan& plan,
bool isLastProjection);
void planAggregate(const binder::expression_vector& expressionsToAggregate,
const binder::expression_vector& expressionsToGroupBy, LogicalPlan& plan);
void planOrderBy(const binder::expression_vector& expressionsToProject,
Expand Down
Loading

0 comments on commit 32b1e8f

Please sign in to comment.