Skip to content

Commit

Permalink
Merge pull request #1052 from kuzudb/unstr-removal-refactor
Browse files Browse the repository at this point in the history
refactor unstr related code
  • Loading branch information
andyfengHKU committed Nov 21, 2022
2 parents c148fea + 3f66507 commit 8ce2d31
Show file tree
Hide file tree
Showing 12 changed files with 62 additions and 201 deletions.
64 changes: 4 additions & 60 deletions src/binder/expression_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include "src/binder/include/binder.h"
#include "src/common/include/type_utils.h"
#include "src/function/boolean/include/vector_boolean_operations.h"
#include "src/function/cast/include/vector_cast_operations.h"
#include "src/function/null/include/vector_null_operations.h"
#include "src/parser/expression/include/parsed_function_expression.h"
#include "src/parser/expression/include/parsed_literal_expression.h"
Expand Down Expand Up @@ -391,65 +390,10 @@ void ExpressionBinder::resolveAnyDataType(Expression& expression, DataType targe

shared_ptr<Expression> ExpressionBinder::implicitCast(
const shared_ptr<Expression>& expression, DataType targetType) {
switch (targetType.typeID) {
case BOOL: {
return implicitCastToBool(expression);
}
case INT64: {
return implicitCastToInt64(expression);
}
case STRING: {
return implicitCastToString(expression);
}
case TIMESTAMP: {
return implicitCastToTimestamp(expression);
}
default:
throw BinderException("Expression " + expression->getRawName() + " has data type " +
Types::dataTypeToString(expression->dataType) + " but expect " +
Types::dataTypeToString(targetType) +
". Implicit cast is not supported.");
}
}

shared_ptr<Expression> ExpressionBinder::implicitCastToBool(
const shared_ptr<Expression>& expression) {
auto children = expression_vector{expression};
auto execFunc = VectorCastOperations::bindImplicitCastToBool(children);
auto uniqueExpressionName =
ScalarFunctionExpression::getUniqueName(IMPLICIT_CAST_TO_BOOL_FUNC_NAME, children);
return make_shared<ScalarFunctionExpression>(FUNCTION, DataType(BOOL), move(children),
move(execFunc), nullptr /* selectFunc */, uniqueExpressionName);
}

shared_ptr<Expression> ExpressionBinder::implicitCastToInt64(
const shared_ptr<Expression>& expression) {
auto children = expression_vector{expression};
auto execFunc = VectorCastOperations::bindImplicitCastToInt64(children);
auto uniqueExpressionName =
ScalarFunctionExpression::getUniqueName(IMPLICIT_CAST_TO_INT_FUNC_NAME, children);
return make_shared<ScalarFunctionExpression>(FUNCTION, DataType(INT64), move(children),
move(execFunc), nullptr /* selectFunc */, uniqueExpressionName);
}

shared_ptr<Expression> ExpressionBinder::implicitCastToString(
const shared_ptr<Expression>& expression) {
auto children = expression_vector{expression};
auto execFunc = VectorCastOperations::bindImplicitCastToString(children);
auto uniqueExpressionName =
ScalarFunctionExpression::getUniqueName(IMPLICIT_CAST_TO_STRING_FUNC_NAME, children);
return make_shared<ScalarFunctionExpression>(FUNCTION, DataType(STRING), move(children),
move(execFunc), nullptr /* selectFunc */, uniqueExpressionName);
}

shared_ptr<Expression> ExpressionBinder::implicitCastToTimestamp(
const shared_ptr<Expression>& expression) {
auto children = expression_vector{expression};
auto execFunc = VectorCastOperations::bindImplicitCastToTimestamp(children);
auto uniqueExpressionName =
ScalarFunctionExpression::getUniqueName(IMPLICIT_CAST_TO_TIMESTAMP_FUNC_NAME, children);
return make_shared<ScalarFunctionExpression>(FUNCTION, DataType(TIMESTAMP), move(children),
move(execFunc), nullptr /* selectFunc */, uniqueExpressionName);
throw BinderException("Expression " + expression->getRawName() + " has data type " +
Types::dataTypeToString(expression->dataType) + " but expect " +
Types::dataTypeToString(targetType) +
". Implicit cast is not supported.");
}

void ExpressionBinder::validateExpectedDataType(
Expand Down
4 changes: 0 additions & 4 deletions src/binder/include/expression_binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,6 @@ class ExpressionBinder {
static void resolveAnyDataType(Expression& expression, DataType targetType);
static shared_ptr<Expression> implicitCast(
const shared_ptr<Expression>& expression, DataType targetType);
static shared_ptr<Expression> implicitCastToBool(const shared_ptr<Expression>& expression);
static shared_ptr<Expression> implicitCastToInt64(const shared_ptr<Expression>& expression);
static shared_ptr<Expression> implicitCastToString(const shared_ptr<Expression>& expression);
static shared_ptr<Expression> implicitCastToTimestamp(const shared_ptr<Expression>& expression);

/****** validation *****/
static void validateExpectedDataType(const Expression& expression, DataTypeID target) {
Expand Down
9 changes: 0 additions & 9 deletions src/function/cast/include/vector_cast_operations.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,7 @@ namespace function {
* Implicit casts are added internally.
*/
class VectorCastOperations : public VectorOperations {

public:
static scalar_exec_func bindImplicitCastToBool(const expression_vector& children);

static scalar_exec_func bindImplicitCastToInt64(const expression_vector& children);

static scalar_exec_func bindImplicitCastToString(const expression_vector& children);

static scalar_exec_func bindImplicitCastToTimestamp(const expression_vector& children);

template<typename OPERAND_TYPE, typename RESULT_TYPE, typename FUNC>
static void UnaryCastExecFunction(
const vector<shared_ptr<ValueVector>>& params, ValueVector& result) {
Expand Down
45 changes: 0 additions & 45 deletions src/function/cast/vector_cast_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,51 +7,6 @@
namespace kuzu {
namespace function {

scalar_exec_func VectorCastOperations::bindImplicitCastToBool(const expression_vector& children) {
assert(children.size() == 1 && children[0]->dataType.typeID != BOOL);
auto child = children[0];
switch (children[0]->dataType.typeID) {
default:
throw NotImplementedException("Expression " + child->getRawName() + " has data type " +
Types::dataTypeToString(child->dataType) +
" but expect BOOL. Implicit cast is not supported.");
}
}

scalar_exec_func VectorCastOperations::bindImplicitCastToInt64(const expression_vector& children) {
assert(children.size() == 1 && children[0]->dataType.typeID != INT64);
auto child = children[0];
switch (children[0]->dataType.typeID) {
default:
throw NotImplementedException("Expression " + child->getRawName() + " has data type " +
Types::dataTypeToString(child->dataType) +
" but expect INT64. Implicit cast is not supported.");
}
}

scalar_exec_func VectorCastOperations::bindImplicitCastToString(const expression_vector& children) {
assert(children.size() == 1 && children[0]->dataType.typeID != STRING);
auto child = children[0];
switch (child->dataType.typeID) {
default:
throw NotImplementedException("Expression " + child->getRawName() + " has data type " +
Types::dataTypeToString(child->dataType) +
" but expect STRING. Implicit cast is not supported.");
}
}

scalar_exec_func VectorCastOperations::bindImplicitCastToTimestamp(
const expression_vector& children) {
assert(children.size() == 1 && children[0]->dataType.typeID != TIMESTAMP);
auto child = children[0];
switch (child->dataType.typeID) {
default:
throw NotImplementedException("Expression " + child->getRawName() + " has data type " +
Types::dataTypeToString(child->dataType) +
" but expect TIMESTAMP. Implicit cast is not supported.");
}
}

vector<unique_ptr<VectorOperationDefinition>> CastToDateVectorOperation::getDefinitions() {
vector<unique_ptr<VectorOperationDefinition>> result;
result.push_back(
Expand Down
5 changes: 3 additions & 2 deletions src/planner/include/join_order_enumerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,10 @@ class JoinOrderEnumerator {

void planNodeScan(uint32_t nodePos);
// Filter push down for node table.
void planFiltersForNode(expression_vector& predicates, NodeExpression& node, LogicalPlan& plan);
void planFiltersForNode(
expression_vector& predicates, shared_ptr<NodeExpression> node, LogicalPlan& plan);
// Property push down for node table.
void planPropertyScansForNode(NodeExpression& node, LogicalPlan& plan);
void planPropertyScansForNode(shared_ptr<NodeExpression> node, LogicalPlan& plan);

void planRelScan(uint32_t relPos);
inline void planRelExtendFiltersAndProperties(shared_ptr<RelExpression>& rel,
Expand Down
9 changes: 2 additions & 7 deletions src/planner/include/query_planner.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,8 @@ class QueryPlanner {

void appendFilter(const shared_ptr<Expression>& expression, LogicalPlan& plan);

inline void appendScanNodePropIfNecessarySwitch(
shared_ptr<Expression> property, NodeExpression& node, LogicalPlan& plan) {
expression_vector properties{move(property)};
appendScanNodePropIfNecessarySwitch(properties, node, plan);
}
void appendScanNodePropIfNecessarySwitch(
expression_vector& properties, NodeExpression& node, LogicalPlan& plan);
void appendScanNodePropIfNecessary(const expression_vector& propertyExpressions,
shared_ptr<NodeExpression> node, LogicalPlan& plan);

inline void appendScanRelPropsIfNecessary(expression_vector& properties, RelExpression& rel,
RelDirection direction, LogicalPlan& plan) {
Expand Down
17 changes: 9 additions & 8 deletions src/planner/join_order_enumerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,26 +199,27 @@ void JoinOrderEnumerator::planNodeScan(uint32_t nodePos) {
} else {
appendScanNode(node, *plan);
}
planFiltersForNode(predicatesToApply, *node, *plan);
planPropertyScansForNode(*node, *plan);
planFiltersForNode(predicatesToApply, node, *plan);
planPropertyScansForNode(node, *plan);
} else {
appendScanNode(node, *plan);
}
context->addPlan(newSubgraph, std::move(plan));
}

void JoinOrderEnumerator::planFiltersForNode(
expression_vector& predicates, NodeExpression& node, LogicalPlan& plan) {
expression_vector& predicates, shared_ptr<NodeExpression> node, LogicalPlan& plan) {
for (auto& predicate : predicates) {
auto propertiesToScan = getPropertiesForVariable(*predicate, node);
queryPlanner->appendScanNodePropIfNecessarySwitch(propertiesToScan, node, plan);
auto propertiesToScan = getPropertiesForVariable(*predicate, *node);
queryPlanner->appendScanNodePropIfNecessary(propertiesToScan, node, plan);
queryPlanner->appendFilter(predicate, plan);
}
}

void JoinOrderEnumerator::planPropertyScansForNode(NodeExpression& node, LogicalPlan& plan) {
auto properties = queryPlanner->getPropertiesForNode(node);
queryPlanner->appendScanNodePropIfNecessarySwitch(properties, node, plan);
void JoinOrderEnumerator::planPropertyScansForNode(
shared_ptr<NodeExpression> node, LogicalPlan& plan) {
auto properties = queryPlanner->getPropertiesForNode(*node);
queryPlanner->appendScanNodePropIfNecessary(properties, node, plan);
}

void JoinOrderEnumerator::planRelScan(uint32_t relPos) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,37 @@

#include "base_logical_operator.h"

#include "src/binder/expression/include/expression.h"

namespace kuzu {
namespace planner {

class LogicalScanNodeProperty : public LogicalOperator {
using namespace kuzu::binder;

class LogicalScanNodeProperty : public LogicalOperator {
public:
LogicalScanNodeProperty(string nodeID, table_id_t tableID, vector<string> propertyNames,
vector<uint32_t> propertyIDs, shared_ptr<LogicalOperator> child)
: LogicalOperator{move(child)}, nodeID{move(nodeID)}, tableID{tableID},
propertyNames{move(propertyNames)}, propertyIDs{move(propertyIDs)} {}
LogicalScanNodeProperty(shared_ptr<NodeExpression> node, expression_vector properties,
shared_ptr<LogicalOperator> child)
: LogicalOperator{move(child)}, node{std::move(node)}, properties{std::move(properties)} {}

LogicalOperatorType getLogicalOperatorType() const override {
inline LogicalOperatorType getLogicalOperatorType() const override {
return LogicalOperatorType::LOGICAL_SCAN_NODE_PROPERTY;
}

string getExpressionsForPrinting() const override {
auto result = string();
for (auto& propertyName : propertyNames) {
result += ", " + propertyName;
}
return result;
inline string getExpressionsForPrinting() const override {
return ExpressionUtil::toString(properties);
}

inline string getNodeID() const { return nodeID; }

inline table_id_t getTableID() const { return tableID; }

inline vector<string> getPropertyNames() const { return propertyNames; }

inline vector<uint32_t> getPropertyIDs() const { return propertyIDs; }
inline shared_ptr<NodeExpression> getNode() const { return node; }
inline expression_vector getProperties() const { return properties; }

unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalScanNodeProperty>(
nodeID, tableID, propertyNames, propertyIDs, children[0]->copy());
inline unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalScanNodeProperty>(node, properties, children[0]->copy());
}

private:
string nodeID;
table_id_t tableID;
vector<string> propertyNames;
vector<uint32_t> propertyIDs;
shared_ptr<NodeExpression> node;
expression_vector properties;
};

} // namespace planner
Expand Down
32 changes: 12 additions & 20 deletions src/planner/query_planner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,32 +376,24 @@ void QueryPlanner::appendFilter(const shared_ptr<Expression>& expression, Logica
plan.setLastOperator(std::move(filter));
}

void QueryPlanner::appendScanNodePropIfNecessarySwitch(
expression_vector& properties, NodeExpression& node, LogicalPlan& plan) {
expression_vector structuredProperties;
for (auto& property : properties) {
structuredProperties.push_back(property);
}
void QueryPlanner::appendScanNodePropIfNecessary(const expression_vector& propertyExpressions,
shared_ptr<NodeExpression> node, LogicalPlan& plan) {
auto schema = plan.getSchema();
vector<string> propertyNames;
vector<uint32_t> propertyIDs;
auto groupPos = schema->getGroupPos(node.getIDProperty());
for (auto& expression : properties) {
if (schema->isExpressionInScope(*expression)) {
expression_vector propertyExpressionToScan;
auto groupPos = schema->getGroupPos(node->getIDProperty());
for (auto& propertyExpression : propertyExpressions) {
if (schema->isExpressionInScope(*propertyExpression)) {
continue;
}
assert(expression->expressionType == PROPERTY);
auto property = static_pointer_cast<PropertyExpression>(expression);
propertyNames.push_back(property->getUniqueName());
propertyIDs.push_back(property->getPropertyID());
schema->insertToGroupAndScope(property, groupPos);
propertyExpressionToScan.push_back(propertyExpression);
schema->insertToGroupAndScope(propertyExpression, groupPos);
}
if (propertyNames.empty()) { // all properties have been scanned before
if (propertyExpressionToScan.empty()) { // all properties have been scanned before
return;
}
auto scanNodeProperty = make_shared<LogicalScanNodeProperty>(node.getIDProperty(),
node.getTableID(), move(propertyNames), move(propertyIDs), plan.getLastOperator());
plan.setLastOperator(move(scanNodeProperty));
auto scanNodeProperty = make_shared<LogicalScanNodeProperty>(
std::move(node), std::move(propertyExpressionToScan), plan.getLastOperator());
plan.setLastOperator(std::move(scanNodeProperty));
}

void QueryPlanner::appendScanRelPropIfNecessary(shared_ptr<Expression>& expression,
Expand Down
20 changes: 10 additions & 10 deletions src/processor/mapper/map_scan_node_property.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,21 @@ unique_ptr<PhysicalOperator> PlanMapper::mapLogicalScanNodePropertyToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext) {
auto& scanProperty = (const LogicalScanNodeProperty&)*logicalOperator;
auto prevOperator = mapLogicalOperatorToPhysical(logicalOperator->getChild(0), mapperContext);
auto inputNodeIDVectorPos = mapperContext.getDataPos(scanProperty.getNodeID());
auto paramsString = scanProperty.getExpressionsForPrinting();
vector<DataPos> outputPropertyVectorsPos;
for (auto& propertyName : scanProperty.getPropertyNames()) {
outputPropertyVectorsPos.push_back(mapperContext.getDataPos(propertyName));
mapperContext.addComputedExpressions(propertyName);
}
auto node = scanProperty.getNode();
auto inputNodeIDVectorPos = mapperContext.getDataPos(node->getIDProperty());
auto& nodeStore = storageManager.getNodesStore();
vector<DataPos> outputPropertyVectorsPos;
vector<Column*> propertyColumns;
for (auto& propertyID : scanProperty.getPropertyIDs()) {
for (auto& expression : scanProperty.getProperties()) {
auto property = static_pointer_cast<PropertyExpression>(expression);
outputPropertyVectorsPos.push_back(mapperContext.getDataPos(property->getUniqueName()));
mapperContext.addComputedExpressions(property->getUniqueName());
propertyColumns.push_back(
nodeStore.getNodePropertyColumn(scanProperty.getTableID(), propertyID));
nodeStore.getNodePropertyColumn(node->getTableID(), property->getPropertyID()));
}
return make_unique<ScanStructuredProperty>(inputNodeIDVectorPos, move(outputPropertyVectorsPos),
move(propertyColumns), move(prevOperator), getOperatorID(), paramsString);
move(propertyColumns), move(prevOperator), getOperatorID(),
scanProperty.getExpressionsForPrinting());
}

} // namespace processor
Expand Down
4 changes: 0 additions & 4 deletions src/processor/operator/scan_column/include/scan_column.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ namespace kuzu {
namespace processor {

class BaseScanColumn : public PhysicalOperator {

public:
BaseScanColumn(const DataPos& inputNodeIDVectorPos, unique_ptr<PhysicalOperator> child,
uint32_t id, const string& paramsString)
Expand All @@ -27,7 +26,6 @@ class BaseScanColumn : public PhysicalOperator {
};

class ScanSingleColumn : public BaseScanColumn {

protected:
ScanSingleColumn(const DataPos& inputNodeIDVectorPos, const DataPos& outputVectorPos,
unique_ptr<PhysicalOperator> child, uint32_t id, const string& paramsString)
Expand All @@ -36,12 +34,10 @@ class ScanSingleColumn : public BaseScanColumn {

protected:
DataPos outputVectorPos;

shared_ptr<ValueVector> outputVector;
};

class ScanMultipleColumns : public BaseScanColumn {

protected:
ScanMultipleColumns(const DataPos& inputNodeIDVectorPos, vector<DataPos> outputVectorsPos,
unique_ptr<PhysicalOperator> child, uint32_t id, const string& paramsString)
Expand Down
Loading

0 comments on commit 8ce2d31

Please sign in to comment.