Skip to content

Commit

Permalink
Support read after update (#3126)
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Mar 25, 2024
1 parent a8b15dc commit 4d21128
Show file tree
Hide file tree
Showing 21 changed files with 187 additions and 88 deletions.
1 change: 0 additions & 1 deletion src/binder/bind/bind_query.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ std::unique_ptr<BoundRegularQuery> Binder::bindQuery(const RegularQuery& regular
auto boundRegularQuery = std::make_unique<BoundRegularQuery>(
regularQuery.getIsUnionAll(), normalizedSingleQueries[0].getStatementResult()->copy());
for (auto& normalizedSingleQuery : normalizedSingleQueries) {
validateReadNotFollowUpdate(normalizedSingleQuery);
boundRegularQuery->addSingleQuery(std::move(normalizedSingleQuery));
}
validateIsAllUnionOrUnionAll(*boundRegularQuery);
Expand Down
12 changes: 0 additions & 12 deletions src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,18 +153,6 @@ void Binder::validateOrderByFollowedBySkipOrLimitInWithClause(
}
}

void Binder::validateReadNotFollowUpdate(const NormalizedSingleQuery& singleQuery) {
bool hasSeenUpdateClause = false;
for (auto i = 0u; i < singleQuery.getNumQueryParts(); ++i) {
auto normalizedQueryPart = singleQuery.getQueryPart(i);
if (hasSeenUpdateClause && normalizedQueryPart->hasReadingClause()) {
throw BinderException(
"Read after update is not supported. Try query with multiple statements.");
}
hasSeenUpdateClause |= normalizedQueryPart->hasUpdatingClause();
}
}

void Binder::validateTableType(table_id_t tableID, TableType expectedTableType) {
auto tableEntry =
clientContext->getCatalog()->getTableCatalogEntry(clientContext->getTx(), tableID);
Expand Down
4 changes: 0 additions & 4 deletions src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,6 @@ class Binder {
static void validateOrderByFollowedBySkipOrLimitInWithClause(
const BoundProjectionBody& boundProjectionBody);

// We don't support read after write for simplicity. User should instead querying through
// multiple statement.
static void validateReadNotFollowUpdate(const NormalizedSingleQuery& singleQuery);

void validateTableType(common::table_id_t tableID, common::TableType expectedTableType);
void validateTableExist(const std::string& tableName);

Expand Down
30 changes: 14 additions & 16 deletions src/include/planner/operator/logical_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,45 +58,43 @@ enum class LogicalOperatorType : uint8_t {
IMPORT_DATABASE,
};

class LogicalOperatorUtils {
public:
struct LogicalOperatorUtils {
static std::string logicalOperatorTypeToString(LogicalOperatorType type);
static bool isUpdate(LogicalOperatorType type);
};

class LogicalOperator;
using logical_op_vector_t = std::vector<std::shared_ptr<LogicalOperator>>;

class LogicalOperator {
public:
// Leaf operator.
explicit LogicalOperator(LogicalOperatorType operatorType) : operatorType{operatorType} {}
// Unary operator.
explicit LogicalOperator(
LogicalOperatorType operatorType, std::shared_ptr<LogicalOperator> child);
// Binary operator.
explicit LogicalOperator(LogicalOperatorType operatorType,
std::shared_ptr<LogicalOperator> left, std::shared_ptr<LogicalOperator> right);
explicit LogicalOperator(LogicalOperatorType operatorType, const logical_op_vector_t& children);

virtual ~LogicalOperator() = default;

inline uint32_t getNumChildren() const { return children.size(); }

inline std::shared_ptr<LogicalOperator> getChild(uint64_t idx) const { return children[idx]; }
inline std::vector<std::shared_ptr<LogicalOperator>> getChildren() const { return children; }
inline void setChild(uint64_t idx, std::shared_ptr<LogicalOperator> child) {
uint32_t getNumChildren() const { return children.size(); }
std::shared_ptr<LogicalOperator> getChild(uint64_t idx) const { return children[idx]; }
std::vector<std::shared_ptr<LogicalOperator>> getChildren() const { return children; }
void setChild(uint64_t idx, std::shared_ptr<LogicalOperator> child) {
children[idx] = std::move(child);
}
inline void setChildren(logical_op_vector_t children_) { children = std::move(children_); }

inline LogicalOperatorType getOperatorType() const { return operatorType; }
// Operator type.
LogicalOperatorType getOperatorType() const { return operatorType; }
bool hasUpdateRecursive();

inline Schema* getSchema() const { return schema.get(); }
// Schema
Schema* getSchema() const { return schema.get(); }
virtual void computeFactorizedSchema() = 0;
virtual void computeFlatSchema() = 0;

// Printing.
virtual std::string getExpressionsForPrinting() const = 0;

// Print the sub-plan rooted at this operator.
virtual std::string toString(uint64_t depth = 0) const;

Expand All @@ -105,8 +103,8 @@ class LogicalOperator {
static logical_op_vector_t copy(const logical_op_vector_t& ops);

protected:
inline void createEmptySchema() { schema = std::make_unique<Schema>(); }
inline void copyChildSchema(uint32_t idx) { schema = children[idx]->getSchema()->copy(); }
void createEmptySchema() { schema = std::make_unique<Schema>(); }
void copyChildSchema(uint32_t idx) { schema = children[idx]->getSchema()->copy(); }

protected:
LogicalOperatorType operatorType;
Expand Down
28 changes: 11 additions & 17 deletions src/include/planner/operator/logical_plan.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#pragma once

#include "logical_explain.h"
#include "logical_operator.h"

namespace kuzu {
Expand All @@ -13,28 +12,23 @@ class LogicalPlan {
public:
LogicalPlan() : estCardinality{1}, cost{0} {}

inline void setLastOperator(std::shared_ptr<LogicalOperator> op) {
lastOperator = std::move(op);
}
void setLastOperator(std::shared_ptr<LogicalOperator> op) { lastOperator = std::move(op); }

inline bool isEmpty() const { return lastOperator == nullptr; }
bool isEmpty() const { return lastOperator == nullptr; }

inline std::shared_ptr<LogicalOperator> getLastOperator() const { return lastOperator; }
inline Schema* getSchema() const { return lastOperator->getSchema(); }
std::shared_ptr<LogicalOperator> getLastOperator() const { return lastOperator; }
Schema* getSchema() const { return lastOperator->getSchema(); }

inline void setCardinality(uint64_t cardinality) { estCardinality = cardinality; }
inline uint64_t getCardinality() const { return estCardinality; }
void setCardinality(uint64_t cardinality) { estCardinality = cardinality; }
uint64_t getCardinality() const { return estCardinality; }

inline void setCost(uint64_t cost_) { cost = cost_; }
inline uint64_t getCost() const { return cost; }
void setCost(uint64_t cost_) { cost = cost_; }
uint64_t getCost() const { return cost; }

inline std::string toString() const { return lastOperator->toString(); }
std::string toString() const { return lastOperator->toString(); }

inline bool isProfile() const {
return lastOperator->getOperatorType() == LogicalOperatorType::EXPLAIN &&
reinterpret_cast<LogicalExplain*>(lastOperator.get())->getExplainType() ==
common::ExplainType::PROFILE;
}
bool isProfile() const;
bool hasUpdate() const;

std::unique_ptr<LogicalPlan> shallowCopy() const;

Expand Down
6 changes: 3 additions & 3 deletions src/include/planner/planner.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,16 +225,16 @@ class Planner {

// Append Join operators
void appendHashJoin(const binder::expression_vector& joinNodeIDs, common::JoinType joinType,
LogicalPlan& probePlan, LogicalPlan& buildPlan);
LogicalPlan& probePlan, LogicalPlan& buildPlan, LogicalPlan& resultPlan);
void appendMarkJoin(const binder::expression_vector& joinNodeIDs,
const std::shared_ptr<binder::Expression>& mark, LogicalPlan& probePlan,
LogicalPlan& buildPlan);
void appendIntersect(const std::shared_ptr<binder::Expression>& intersectNodeID,
binder::expression_vector& boundNodeIDs, LogicalPlan& probePlan,
std::vector<std::unique_ptr<LogicalPlan>>& buildPlans);

void appendCrossProduct(
common::AccumulateType accumulateType, LogicalPlan& probePlan, LogicalPlan& buildPlan);
void appendCrossProduct(common::AccumulateType accumulateType, const LogicalPlan& probePlan,
const LogicalPlan& buildPlan, LogicalPlan& resultPlan);

// Append accumulate
void appendAccumulate(common::AccumulateType accumulateType, LogicalPlan& plan);
Expand Down
26 changes: 26 additions & 0 deletions src/planner/operator/logical_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,20 @@ std::string LogicalOperatorUtils::logicalOperatorTypeToString(LogicalOperatorTyp
}
}

bool LogicalOperatorUtils::isUpdate(LogicalOperatorType type) {
switch (type) {
case LogicalOperatorType::INSERT:
case LogicalOperatorType::DELETE_NODE:
case LogicalOperatorType::DELETE_REL:
case LogicalOperatorType::SET_NODE_PROPERTY:
case LogicalOperatorType::SET_REL_PROPERTY:
case LogicalOperatorType::MERGE:
return true;
default:
return false;
}
}

LogicalOperator::LogicalOperator(
LogicalOperatorType operatorType, std::shared_ptr<LogicalOperator> child)
: operatorType{operatorType} {
Expand All @@ -129,6 +143,18 @@ LogicalOperator::LogicalOperator(
}
}

bool LogicalOperator::hasUpdateRecursive() {
if (LogicalOperatorUtils::isUpdate(operatorType)) {
return true;
}
for (auto& child : children) {
if (child->hasUpdateRecursive()) {
return true;
}
}
return false;
}

std::string LogicalOperator::toString(uint64_t depth) const {
auto padding = std::string(depth * 4, ' ');
std::string result = padding;
Expand Down
12 changes: 12 additions & 0 deletions src/planner/operator/logical_plan.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
#include "planner/operator/logical_plan.h"

#include "planner/operator/logical_explain.h"

namespace kuzu {
namespace planner {

bool LogicalPlan::isProfile() const {
return lastOperator->getOperatorType() == LogicalOperatorType::EXPLAIN &&
reinterpret_cast<LogicalExplain*>(lastOperator.get())->getExplainType() ==
common::ExplainType::PROFILE;
}

bool LogicalPlan::hasUpdate() const {
return lastOperator->hasUpdateRecursive();
}

std::unique_ptr<LogicalPlan> LogicalPlan::shallowCopy() const {
auto plan = std::make_unique<LogicalPlan>();
plan->lastOperator = lastOperator; // shallow copy sub-plan
Expand Down
10 changes: 5 additions & 5 deletions src/planner/plan/append_cross_product.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@ using namespace kuzu::common;
namespace kuzu {
namespace planner {

void Planner::appendCrossProduct(
AccumulateType accumulateType, LogicalPlan& probePlan, LogicalPlan& buildPlan) {
void Planner::appendCrossProduct(AccumulateType accumulateType, const LogicalPlan& probePlan,
const LogicalPlan& buildPlan, LogicalPlan& resultPlan) {
auto crossProduct = make_shared<LogicalCrossProduct>(
accumulateType, probePlan.getLastOperator(), buildPlan.getLastOperator());
crossProduct->computeFactorizedSchema();
// update cost
probePlan.setCost(probePlan.getCardinality() + buildPlan.getCardinality());
resultPlan.setCost(probePlan.getCardinality() + buildPlan.getCardinality());
// update cardinality
probePlan.setCardinality(cardinalityEstimator.estimateCrossProduct(probePlan, buildPlan));
probePlan.setLastOperator(std::move(crossProduct));
resultPlan.setCardinality(cardinalityEstimator.estimateCrossProduct(probePlan, buildPlan));
resultPlan.setLastOperator(std::move(crossProduct));
}

} // namespace planner
Expand Down
3 changes: 2 additions & 1 deletion src/planner/plan/append_extend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ void Planner::appendNonRecursiveExtend(const std::shared_ptr<NodeExpression>& bo
appendScanInternalID(rdfInfo->predicateID, rdfInfo->resourceTableIDs, *tmpPlan);
appendScanNodeProperties(
rdfInfo->predicateID, rdfInfo->resourceTableIDs, expression_vector{iri}, *tmpPlan);
appendHashJoin(expression_vector{rdfInfo->predicateID}, JoinType::INNER, plan, *tmpPlan);
appendHashJoin(
expression_vector{rdfInfo->predicateID}, JoinType::INNER, plan, *tmpPlan, plan);
}
}

Expand Down
8 changes: 4 additions & 4 deletions src/planner/plan/append_join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace kuzu {
namespace planner {

void Planner::appendHashJoin(const binder::expression_vector& joinNodeIDs, JoinType joinType,
LogicalPlan& probePlan, LogicalPlan& buildPlan) {
LogicalPlan& probePlan, LogicalPlan& buildPlan, LogicalPlan& resultPlan) {
std::vector<join_condition_t> joinConditions;
for (auto& joinNodeID : joinNodeIDs) {
joinConditions.emplace_back(joinNodeID, joinNodeID);
Expand All @@ -32,11 +32,11 @@ void Planner::appendHashJoin(const binder::expression_vector& joinNodeIDs, JoinT
hashJoin->setSIP(SidewaysInfoPassing::PROHIBIT_BUILD_TO_PROBE);
}
// Update cost
probePlan.setCost(CostModel::computeHashJoinCost(joinNodeIDs, probePlan, buildPlan));
resultPlan.setCost(CostModel::computeHashJoinCost(joinNodeIDs, probePlan, buildPlan));
// Update cardinality
probePlan.setCardinality(
resultPlan.setCardinality(
cardinalityEstimator.estimateHashJoin(joinNodeIDs, probePlan, buildPlan));
probePlan.setLastOperator(std::move(hashJoin));
resultPlan.setLastOperator(std::move(hashJoin));
}

void Planner::appendMarkJoin(const binder::expression_vector& joinNodeIDs,
Expand Down
11 changes: 6 additions & 5 deletions src/planner/plan/plan_join_order.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -561,8 +561,8 @@ void Planner::planInnerHashJoin(const SubqueryGraph& subgraph, const SubqueryGra
if (CostModel::computeHashJoinCost(joinNodeIDs, *leftPlan, *rightPlan) < maxCost) {
auto leftPlanProbeCopy = leftPlan->shallowCopy();
auto rightPlanBuildCopy = rightPlan->shallowCopy();
appendHashJoin(
joinNodeIDs, JoinType::INNER, *leftPlanProbeCopy, *rightPlanBuildCopy);
appendHashJoin(joinNodeIDs, JoinType::INNER, *leftPlanProbeCopy,
*rightPlanBuildCopy, *leftPlanProbeCopy);
appendFilters(predicates, *leftPlanProbeCopy);
context.addPlan(newSubgraph, std::move(leftPlanProbeCopy));
}
Expand All @@ -571,8 +571,8 @@ void Planner::planInnerHashJoin(const SubqueryGraph& subgraph, const SubqueryGra
CostModel::computeHashJoinCost(joinNodeIDs, *rightPlan, *leftPlan) < maxCost) {
auto leftPlanBuildCopy = leftPlan->shallowCopy();
auto rightPlanProbeCopy = rightPlan->shallowCopy();
appendHashJoin(
joinNodeIDs, JoinType::INNER, *rightPlanProbeCopy, *leftPlanBuildCopy);
appendHashJoin(joinNodeIDs, JoinType::INNER, *rightPlanProbeCopy,
*leftPlanBuildCopy, *rightPlanProbeCopy);
appendFilters(predicates, *rightPlanProbeCopy);
context.addPlan(newSubgraph, std::move(rightPlanProbeCopy));
}
Expand All @@ -588,7 +588,8 @@ std::vector<std::unique_ptr<LogicalPlan>> Planner::planCrossProduct(
for (auto& rightPlan : rightPlans) {
auto leftPlanCopy = leftPlan->shallowCopy();
auto rightPlanCopy = rightPlan->shallowCopy();
appendCrossProduct(AccumulateType::REGULAR, *leftPlanCopy, *rightPlanCopy);
appendCrossProduct(
AccumulateType::REGULAR, *leftPlanCopy, *rightPlanCopy, *leftPlanCopy);
result.push_back(std::move(leftPlanCopy));
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/planner/plan/plan_read.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ void Planner::planInQueryCall(
if (!predicatesToPushDown.empty()) {
appendFilters(predicatesToPushDown, *tmpPlan);
}
appendCrossProduct(AccumulateType::REGULAR, *plan, *tmpPlan);
appendCrossProduct(AccumulateType::REGULAR, *plan, *tmpPlan, *plan);
} else {
appendInQueryCall(*readingClause, *plan);
if (!predicatesToPushDown.empty()) {
Expand Down Expand Up @@ -139,7 +139,7 @@ void Planner::planLoadFrom(
if (!predicatesToPushDown.empty()) {
appendFilters(predicatesToPushDown, *tmpPlan);
}
appendCrossProduct(AccumulateType::REGULAR, *plan, *tmpPlan);
appendCrossProduct(AccumulateType::REGULAR, *plan, *tmpPlan, *plan);
} else {
appendScanFile(loadFrom->getInfo(), *plan);
if (!predicatesToPushDown.empty()) {
Expand Down
Loading

0 comments on commit 4d21128

Please sign in to comment.