Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor implicit cast and logical scan of unstructured properties #1052

Merged
merged 1 commit into from
Nov 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 4 additions & 60 deletions src/binder/expression_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include "src/binder/include/binder.h"
#include "src/common/include/type_utils.h"
#include "src/function/boolean/include/vector_boolean_operations.h"
#include "src/function/cast/include/vector_cast_operations.h"
#include "src/function/null/include/vector_null_operations.h"
#include "src/parser/expression/include/parsed_function_expression.h"
#include "src/parser/expression/include/parsed_literal_expression.h"
Expand Down Expand Up @@ -391,65 +390,10 @@ void ExpressionBinder::resolveAnyDataType(Expression& expression, DataType targe

shared_ptr<Expression> ExpressionBinder::implicitCast(
const shared_ptr<Expression>& expression, DataType targetType) {
switch (targetType.typeID) {
case BOOL: {
return implicitCastToBool(expression);
}
case INT64: {
return implicitCastToInt64(expression);
}
case STRING: {
return implicitCastToString(expression);
}
case TIMESTAMP: {
return implicitCastToTimestamp(expression);
}
default:
throw BinderException("Expression " + expression->getRawName() + " has data type " +
Types::dataTypeToString(expression->dataType) + " but expect " +
Types::dataTypeToString(targetType) +
". Implicit cast is not supported.");
}
}

shared_ptr<Expression> ExpressionBinder::implicitCastToBool(
const shared_ptr<Expression>& expression) {
auto children = expression_vector{expression};
auto execFunc = VectorCastOperations::bindImplicitCastToBool(children);
auto uniqueExpressionName =
ScalarFunctionExpression::getUniqueName(IMPLICIT_CAST_TO_BOOL_FUNC_NAME, children);
return make_shared<ScalarFunctionExpression>(FUNCTION, DataType(BOOL), move(children),
move(execFunc), nullptr /* selectFunc */, uniqueExpressionName);
}

shared_ptr<Expression> ExpressionBinder::implicitCastToInt64(
const shared_ptr<Expression>& expression) {
auto children = expression_vector{expression};
auto execFunc = VectorCastOperations::bindImplicitCastToInt64(children);
auto uniqueExpressionName =
ScalarFunctionExpression::getUniqueName(IMPLICIT_CAST_TO_INT_FUNC_NAME, children);
return make_shared<ScalarFunctionExpression>(FUNCTION, DataType(INT64), move(children),
move(execFunc), nullptr /* selectFunc */, uniqueExpressionName);
}

shared_ptr<Expression> ExpressionBinder::implicitCastToString(
const shared_ptr<Expression>& expression) {
auto children = expression_vector{expression};
auto execFunc = VectorCastOperations::bindImplicitCastToString(children);
auto uniqueExpressionName =
ScalarFunctionExpression::getUniqueName(IMPLICIT_CAST_TO_STRING_FUNC_NAME, children);
return make_shared<ScalarFunctionExpression>(FUNCTION, DataType(STRING), move(children),
move(execFunc), nullptr /* selectFunc */, uniqueExpressionName);
}

shared_ptr<Expression> ExpressionBinder::implicitCastToTimestamp(
const shared_ptr<Expression>& expression) {
auto children = expression_vector{expression};
auto execFunc = VectorCastOperations::bindImplicitCastToTimestamp(children);
auto uniqueExpressionName =
ScalarFunctionExpression::getUniqueName(IMPLICIT_CAST_TO_TIMESTAMP_FUNC_NAME, children);
return make_shared<ScalarFunctionExpression>(FUNCTION, DataType(TIMESTAMP), move(children),
move(execFunc), nullptr /* selectFunc */, uniqueExpressionName);
throw BinderException("Expression " + expression->getRawName() + " has data type " +
Types::dataTypeToString(expression->dataType) + " but expect " +
Types::dataTypeToString(targetType) +
". Implicit cast is not supported.");
}

void ExpressionBinder::validateExpectedDataType(
Expand Down
4 changes: 0 additions & 4 deletions src/binder/include/expression_binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,6 @@ class ExpressionBinder {
static void resolveAnyDataType(Expression& expression, DataType targetType);
static shared_ptr<Expression> implicitCast(
const shared_ptr<Expression>& expression, DataType targetType);
static shared_ptr<Expression> implicitCastToBool(const shared_ptr<Expression>& expression);
static shared_ptr<Expression> implicitCastToInt64(const shared_ptr<Expression>& expression);
static shared_ptr<Expression> implicitCastToString(const shared_ptr<Expression>& expression);
static shared_ptr<Expression> implicitCastToTimestamp(const shared_ptr<Expression>& expression);

/****** validation *****/
static void validateExpectedDataType(const Expression& expression, DataTypeID target) {
Expand Down
9 changes: 0 additions & 9 deletions src/function/cast/include/vector_cast_operations.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,7 @@ namespace function {
* Implicit casts are added internally.
*/
class VectorCastOperations : public VectorOperations {

public:
static scalar_exec_func bindImplicitCastToBool(const expression_vector& children);

static scalar_exec_func bindImplicitCastToInt64(const expression_vector& children);

static scalar_exec_func bindImplicitCastToString(const expression_vector& children);

static scalar_exec_func bindImplicitCastToTimestamp(const expression_vector& children);

template<typename OPERAND_TYPE, typename RESULT_TYPE, typename FUNC>
static void UnaryCastExecFunction(
const vector<shared_ptr<ValueVector>>& params, ValueVector& result) {
Expand Down
45 changes: 0 additions & 45 deletions src/function/cast/vector_cast_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,51 +7,6 @@
namespace kuzu {
namespace function {

scalar_exec_func VectorCastOperations::bindImplicitCastToBool(const expression_vector& children) {
assert(children.size() == 1 && children[0]->dataType.typeID != BOOL);
auto child = children[0];
switch (children[0]->dataType.typeID) {
default:
throw NotImplementedException("Expression " + child->getRawName() + " has data type " +
Types::dataTypeToString(child->dataType) +
" but expect BOOL. Implicit cast is not supported.");
}
}

scalar_exec_func VectorCastOperations::bindImplicitCastToInt64(const expression_vector& children) {
assert(children.size() == 1 && children[0]->dataType.typeID != INT64);
auto child = children[0];
switch (children[0]->dataType.typeID) {
default:
throw NotImplementedException("Expression " + child->getRawName() + " has data type " +
Types::dataTypeToString(child->dataType) +
" but expect INT64. Implicit cast is not supported.");
}
}

scalar_exec_func VectorCastOperations::bindImplicitCastToString(const expression_vector& children) {
assert(children.size() == 1 && children[0]->dataType.typeID != STRING);
auto child = children[0];
switch (child->dataType.typeID) {
default:
throw NotImplementedException("Expression " + child->getRawName() + " has data type " +
Types::dataTypeToString(child->dataType) +
" but expect STRING. Implicit cast is not supported.");
}
}

scalar_exec_func VectorCastOperations::bindImplicitCastToTimestamp(
const expression_vector& children) {
assert(children.size() == 1 && children[0]->dataType.typeID != TIMESTAMP);
auto child = children[0];
switch (child->dataType.typeID) {
default:
throw NotImplementedException("Expression " + child->getRawName() + " has data type " +
Types::dataTypeToString(child->dataType) +
" but expect TIMESTAMP. Implicit cast is not supported.");
}
}

vector<unique_ptr<VectorOperationDefinition>> CastToDateVectorOperation::getDefinitions() {
vector<unique_ptr<VectorOperationDefinition>> result;
result.push_back(
Expand Down
5 changes: 3 additions & 2 deletions src/planner/include/join_order_enumerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,10 @@ class JoinOrderEnumerator {

void planNodeScan(uint32_t nodePos);
// Filter push down for node table.
void planFiltersForNode(expression_vector& predicates, NodeExpression& node, LogicalPlan& plan);
void planFiltersForNode(
expression_vector& predicates, shared_ptr<NodeExpression> node, LogicalPlan& plan);
// Property push down for node table.
void planPropertyScansForNode(NodeExpression& node, LogicalPlan& plan);
void planPropertyScansForNode(shared_ptr<NodeExpression> node, LogicalPlan& plan);

void planRelScan(uint32_t relPos);
inline void planRelExtendFiltersAndProperties(shared_ptr<RelExpression>& rel,
Expand Down
9 changes: 2 additions & 7 deletions src/planner/include/query_planner.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,8 @@ class QueryPlanner {

void appendFilter(const shared_ptr<Expression>& expression, LogicalPlan& plan);

inline void appendScanNodePropIfNecessarySwitch(
shared_ptr<Expression> property, NodeExpression& node, LogicalPlan& plan) {
expression_vector properties{move(property)};
appendScanNodePropIfNecessarySwitch(properties, node, plan);
}
void appendScanNodePropIfNecessarySwitch(
expression_vector& properties, NodeExpression& node, LogicalPlan& plan);
void appendScanNodePropIfNecessary(const expression_vector& propertyExpressions,
shared_ptr<NodeExpression> node, LogicalPlan& plan);

inline void appendScanRelPropsIfNecessary(expression_vector& properties, RelExpression& rel,
RelDirection direction, LogicalPlan& plan) {
Expand Down
17 changes: 9 additions & 8 deletions src/planner/join_order_enumerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,26 +199,27 @@ void JoinOrderEnumerator::planNodeScan(uint32_t nodePos) {
} else {
appendScanNode(node, *plan);
}
planFiltersForNode(predicatesToApply, *node, *plan);
planPropertyScansForNode(*node, *plan);
planFiltersForNode(predicatesToApply, node, *plan);
planPropertyScansForNode(node, *plan);
} else {
appendScanNode(node, *plan);
}
context->addPlan(newSubgraph, std::move(plan));
}

void JoinOrderEnumerator::planFiltersForNode(
expression_vector& predicates, NodeExpression& node, LogicalPlan& plan) {
expression_vector& predicates, shared_ptr<NodeExpression> node, LogicalPlan& plan) {
for (auto& predicate : predicates) {
auto propertiesToScan = getPropertiesForVariable(*predicate, node);
queryPlanner->appendScanNodePropIfNecessarySwitch(propertiesToScan, node, plan);
auto propertiesToScan = getPropertiesForVariable(*predicate, *node);
queryPlanner->appendScanNodePropIfNecessary(propertiesToScan, node, plan);
queryPlanner->appendFilter(predicate, plan);
}
}

void JoinOrderEnumerator::planPropertyScansForNode(NodeExpression& node, LogicalPlan& plan) {
auto properties = queryPlanner->getPropertiesForNode(node);
queryPlanner->appendScanNodePropIfNecessarySwitch(properties, node, plan);
void JoinOrderEnumerator::planPropertyScansForNode(
shared_ptr<NodeExpression> node, LogicalPlan& plan) {
auto properties = queryPlanner->getPropertiesForNode(*node);
queryPlanner->appendScanNodePropIfNecessary(properties, node, plan);
}

void JoinOrderEnumerator::planRelScan(uint32_t relPos) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,37 @@

#include "base_logical_operator.h"

#include "src/binder/expression/include/expression.h"

namespace kuzu {
namespace planner {

class LogicalScanNodeProperty : public LogicalOperator {
using namespace kuzu::binder;

class LogicalScanNodeProperty : public LogicalOperator {
public:
LogicalScanNodeProperty(string nodeID, table_id_t tableID, vector<string> propertyNames,
vector<uint32_t> propertyIDs, shared_ptr<LogicalOperator> child)
: LogicalOperator{move(child)}, nodeID{move(nodeID)}, tableID{tableID},
propertyNames{move(propertyNames)}, propertyIDs{move(propertyIDs)} {}
LogicalScanNodeProperty(shared_ptr<NodeExpression> node, expression_vector properties,
shared_ptr<LogicalOperator> child)
: LogicalOperator{move(child)}, node{std::move(node)}, properties{std::move(properties)} {}

LogicalOperatorType getLogicalOperatorType() const override {
inline LogicalOperatorType getLogicalOperatorType() const override {
return LogicalOperatorType::LOGICAL_SCAN_NODE_PROPERTY;
}

string getExpressionsForPrinting() const override {
auto result = string();
for (auto& propertyName : propertyNames) {
result += ", " + propertyName;
}
return result;
inline string getExpressionsForPrinting() const override {
return ExpressionUtil::toString(properties);
}

inline string getNodeID() const { return nodeID; }

inline table_id_t getTableID() const { return tableID; }

inline vector<string> getPropertyNames() const { return propertyNames; }

inline vector<uint32_t> getPropertyIDs() const { return propertyIDs; }
inline shared_ptr<NodeExpression> getNode() const { return node; }
inline expression_vector getProperties() const { return properties; }

unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalScanNodeProperty>(
nodeID, tableID, propertyNames, propertyIDs, children[0]->copy());
inline unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalScanNodeProperty>(node, properties, children[0]->copy());
}

private:
string nodeID;
table_id_t tableID;
vector<string> propertyNames;
vector<uint32_t> propertyIDs;
shared_ptr<NodeExpression> node;
expression_vector properties;
};

} // namespace planner
Expand Down
32 changes: 12 additions & 20 deletions src/planner/query_planner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,32 +376,24 @@ void QueryPlanner::appendFilter(const shared_ptr<Expression>& expression, Logica
plan.setLastOperator(std::move(filter));
}

void QueryPlanner::appendScanNodePropIfNecessarySwitch(
expression_vector& properties, NodeExpression& node, LogicalPlan& plan) {
expression_vector structuredProperties;
for (auto& property : properties) {
structuredProperties.push_back(property);
}
void QueryPlanner::appendScanNodePropIfNecessary(const expression_vector& propertyExpressions,
shared_ptr<NodeExpression> node, LogicalPlan& plan) {
auto schema = plan.getSchema();
vector<string> propertyNames;
vector<uint32_t> propertyIDs;
auto groupPos = schema->getGroupPos(node.getIDProperty());
for (auto& expression : properties) {
if (schema->isExpressionInScope(*expression)) {
expression_vector propertyExpressionToScan;
auto groupPos = schema->getGroupPos(node->getIDProperty());
for (auto& propertyExpression : propertyExpressions) {
if (schema->isExpressionInScope(*propertyExpression)) {
continue;
}
assert(expression->expressionType == PROPERTY);
auto property = static_pointer_cast<PropertyExpression>(expression);
propertyNames.push_back(property->getUniqueName());
propertyIDs.push_back(property->getPropertyID());
schema->insertToGroupAndScope(property, groupPos);
propertyExpressionToScan.push_back(propertyExpression);
schema->insertToGroupAndScope(propertyExpression, groupPos);
}
if (propertyNames.empty()) { // all properties have been scanned before
if (propertyExpressionToScan.empty()) { // all properties have been scanned before
return;
}
auto scanNodeProperty = make_shared<LogicalScanNodeProperty>(node.getIDProperty(),
node.getTableID(), move(propertyNames), move(propertyIDs), plan.getLastOperator());
plan.setLastOperator(move(scanNodeProperty));
auto scanNodeProperty = make_shared<LogicalScanNodeProperty>(
std::move(node), std::move(propertyExpressionToScan), plan.getLastOperator());
plan.setLastOperator(std::move(scanNodeProperty));
}

void QueryPlanner::appendScanRelPropIfNecessary(shared_ptr<Expression>& expression,
Expand Down
20 changes: 10 additions & 10 deletions src/processor/mapper/map_scan_node_property.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,21 @@ unique_ptr<PhysicalOperator> PlanMapper::mapLogicalScanNodePropertyToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext) {
auto& scanProperty = (const LogicalScanNodeProperty&)*logicalOperator;
auto prevOperator = mapLogicalOperatorToPhysical(logicalOperator->getChild(0), mapperContext);
auto inputNodeIDVectorPos = mapperContext.getDataPos(scanProperty.getNodeID());
auto paramsString = scanProperty.getExpressionsForPrinting();
vector<DataPos> outputPropertyVectorsPos;
for (auto& propertyName : scanProperty.getPropertyNames()) {
outputPropertyVectorsPos.push_back(mapperContext.getDataPos(propertyName));
mapperContext.addComputedExpressions(propertyName);
}
auto node = scanProperty.getNode();
auto inputNodeIDVectorPos = mapperContext.getDataPos(node->getIDProperty());
auto& nodeStore = storageManager.getNodesStore();
vector<DataPos> outputPropertyVectorsPos;
vector<Column*> propertyColumns;
for (auto& propertyID : scanProperty.getPropertyIDs()) {
for (auto& expression : scanProperty.getProperties()) {
auto property = static_pointer_cast<PropertyExpression>(expression);
outputPropertyVectorsPos.push_back(mapperContext.getDataPos(property->getUniqueName()));
mapperContext.addComputedExpressions(property->getUniqueName());
propertyColumns.push_back(
nodeStore.getNodePropertyColumn(scanProperty.getTableID(), propertyID));
nodeStore.getNodePropertyColumn(node->getTableID(), property->getPropertyID()));
}
return make_unique<ScanStructuredProperty>(inputNodeIDVectorPos, move(outputPropertyVectorsPos),
move(propertyColumns), move(prevOperator), getOperatorID(), paramsString);
move(propertyColumns), move(prevOperator), getOperatorID(),
scanProperty.getExpressionsForPrinting());
}

} // namespace processor
Expand Down
4 changes: 0 additions & 4 deletions src/processor/operator/scan_column/include/scan_column.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ namespace kuzu {
namespace processor {

class BaseScanColumn : public PhysicalOperator {

public:
BaseScanColumn(const DataPos& inputNodeIDVectorPos, unique_ptr<PhysicalOperator> child,
uint32_t id, const string& paramsString)
Expand All @@ -27,7 +26,6 @@ class BaseScanColumn : public PhysicalOperator {
};

class ScanSingleColumn : public BaseScanColumn {

protected:
ScanSingleColumn(const DataPos& inputNodeIDVectorPos, const DataPos& outputVectorPos,
unique_ptr<PhysicalOperator> child, uint32_t id, const string& paramsString)
Expand All @@ -36,12 +34,10 @@ class ScanSingleColumn : public BaseScanColumn {

protected:
DataPos outputVectorPos;

shared_ptr<ValueVector> outputVector;
};

class ScanMultipleColumns : public BaseScanColumn {

protected:
ScanMultipleColumns(const DataPos& inputNodeIDVectorPos, vector<DataPos> outputVectorsPos,
unique_ptr<PhysicalOperator> child, uint32_t id, const string& paramsString)
Expand Down
Loading