Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove mapper context #1121

Merged
merged 1 commit into from
Dec 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/expression_evaluator/literal_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ bool LiteralExpressionEvaluator::select(SelectionVector& selVector) {
assert(resultVector->dataType.typeID == BOOL);
auto pos = resultVector->state->selVector->selectedPositions[0];
assert(pos == 0u);
return resultVector->getValue<bool>(pos) == true && (!resultVector->isNull(pos));
return resultVector->getValue<bool>(pos) && (!resultVector->isNull(pos));
}

} // namespace evaluator
Expand Down
17 changes: 15 additions & 2 deletions src/include/planner/logical_plan/logical_operator/schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ class FactorizationGroup {
FactorizationGroup() : flat{false}, singleState{false}, cardinalityMultiplier{1} {}
FactorizationGroup(const FactorizationGroup& other)
: flat{other.flat}, singleState{other.singleState},
cardinalityMultiplier{other.cardinalityMultiplier}, expressions{other.expressions} {}
cardinalityMultiplier{other.cardinalityMultiplier}, expressions{other.expressions},
expressionNameToPos{other.expressionNameToPos} {}

inline void setFlat() {
assert(!flat);
Expand All @@ -35,15 +36,22 @@ class FactorizationGroup {
inline uint64_t getMultiplier() const { return cardinalityMultiplier; }

inline void insertExpression(const shared_ptr<Expression>& expression) {
assert(!expressionNameToPos.contains(expression->getUniqueName()));
expressionNameToPos.insert({expression->getUniqueName(), expressions.size()});
expressions.push_back(expression);
}
inline expression_vector getExpressions() const { return expressions; }
inline uint32_t getExpressionPos(const Expression& expression) {
assert(expressionNameToPos.contains(expression.getUniqueName()));
return expressionNameToPos.at(expression.getUniqueName());
}

private:
bool flat;
bool singleState;
uint64_t cardinalityMultiplier;
expression_vector expressions;
unordered_map<string, uint32_t> expressionNameToPos;
};

class Schema {
Expand All @@ -68,7 +76,7 @@ class Schema {

void insertToGroupAndScope(const expression_vector& expressions, uint32_t groupPos);

inline uint32_t getGroupPos(const Expression& expression) {
inline uint32_t getGroupPos(const Expression& expression) const {
return getGroupPos(expression.getUniqueName());
}

Expand All @@ -77,6 +85,11 @@ class Schema {
return expressionNameToGroupPos.at(expressionName);
}

inline pair<uint32_t, uint32_t> getExpressionPos(const Expression& expression) const {
auto groupPos = getGroupPos(expression);
return make_pair(groupPos, groups[groupPos]->getExpressionPos(expression));
}

inline void flattenGroup(uint32_t pos) { groups[pos]->setFlat(); }
inline void setGroupAsSingleState(uint32_t pos) { groups[pos]->setSingleState(); }

Expand Down
5 changes: 3 additions & 2 deletions src/include/processor/data_pos.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ namespace kuzu {
namespace processor {

struct DataPos {

public:
DataPos(uint32_t dataChunkPos, uint32_t valueVectorPos)
explicit DataPos(uint32_t dataChunkPos, uint32_t valueVectorPos)
: dataChunkPos{dataChunkPos}, valueVectorPos{valueVectorPos} {}
explicit DataPos(std::pair<uint32_t, uint32_t> pos)
: dataChunkPos{pos.first}, valueVectorPos{pos.second} {}

DataPos(const DataPos& other) : DataPos(other.dataChunkPos, other.valueVectorPos) {}

Expand Down
7 changes: 3 additions & 4 deletions src/include/processor/mapper/expression_mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include "binder/expression/expression.h"
#include "expression_evaluator/base_evaluator.h"
#include "processor/execution_context.h"
#include "processor/mapper/mapper_context.h"
#include "processor/result/result_set.h"
#include "processor/result/result_set_descriptor.h"

Expand All @@ -19,7 +18,7 @@ class ExpressionMapper {

public:
unique_ptr<BaseExpressionEvaluator> mapExpression(
const shared_ptr<Expression>& expression, const MapperContext& mapperContext);
const shared_ptr<Expression>& expression, const Schema& schema);

private:
unique_ptr<BaseExpressionEvaluator> mapLiteralExpression(
Expand All @@ -29,10 +28,10 @@ class ExpressionMapper {
const shared_ptr<Expression>& expression);

unique_ptr<BaseExpressionEvaluator> mapReferenceExpression(
const shared_ptr<Expression>& expression, const MapperContext& mapperContext);
const shared_ptr<Expression>& expression, const Schema& schema);

unique_ptr<BaseExpressionEvaluator> mapFunctionExpression(
const shared_ptr<Expression>& expression, const MapperContext& mapperContext);
const shared_ptr<Expression>& expression, const Schema& schema);
};

} // namespace processor
Expand Down
34 changes: 0 additions & 34 deletions src/include/processor/mapper/mapper_context.h

This file was deleted.

106 changes: 38 additions & 68 deletions src/include/processor/mapper/plan_mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include "binder/expression/node_expression.h"
#include "planner/logical_plan/logical_plan.h"
#include "processor/mapper/expression_mapper.h"
#include "processor/mapper/mapper_context.h"
#include "processor/operator/result_collector.h"
#include "processor/physical_plan.h"
#include "storage/storage_manager.h"
Expand All @@ -28,94 +27,65 @@ class PlanMapper {

private:
unique_ptr<PhysicalOperator> mapLogicalOperatorToPhysical(
const shared_ptr<LogicalOperator>& logicalOperator, MapperContext& mapperContext);
const shared_ptr<LogicalOperator>& logicalOperator);

unique_ptr<PhysicalOperator> mapLogicalScanNodeToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext);
unique_ptr<PhysicalOperator> mapLogicalScanNodeToPhysical(LogicalOperator* logicalOperator);
unique_ptr<PhysicalOperator> mapLogicalIndexScanNodeToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext);
unique_ptr<PhysicalOperator> mapLogicalUnwindToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext);
unique_ptr<PhysicalOperator> mapLogicalExtendToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext);
unique_ptr<PhysicalOperator> mapLogicalFlattenToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext);
unique_ptr<PhysicalOperator> mapLogicalFilterToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext);
unique_ptr<PhysicalOperator> mapLogicalProjectionToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext);
LogicalOperator* logicalOperator);
unique_ptr<PhysicalOperator> mapLogicalUnwindToPhysical(LogicalOperator* logicalOperator);
unique_ptr<PhysicalOperator> mapLogicalExtendToPhysical(LogicalOperator* logicalOperator);
unique_ptr<PhysicalOperator> mapLogicalFlattenToPhysical(LogicalOperator* logicalOperator);
unique_ptr<PhysicalOperator> mapLogicalFilterToPhysical(LogicalOperator* logicalOperator);
unique_ptr<PhysicalOperator> mapLogicalProjectionToPhysical(LogicalOperator* logicalOperator);
unique_ptr<PhysicalOperator> mapLogicalScanNodePropertyToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext);
unique_ptr<PhysicalOperator> mapLogicalSemiMaskerToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext);
unique_ptr<PhysicalOperator> mapLogicalHashJoinToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext);
unique_ptr<PhysicalOperator> mapLogicalIntersectToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext);
unique_ptr<PhysicalOperator> mapLogicalCrossProductToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext);
LogicalOperator* logicalOperator);
unique_ptr<PhysicalOperator> mapLogicalSemiMaskerToPhysical(LogicalOperator* logicalOperator);
unique_ptr<PhysicalOperator> mapLogicalHashJoinToPhysical(LogicalOperator* logicalOperator);
unique_ptr<PhysicalOperator> mapLogicalIntersectToPhysical(LogicalOperator* logicalOperator);
unique_ptr<PhysicalOperator> mapLogicalCrossProductToPhysical(LogicalOperator* logicalOperator);
unique_ptr<PhysicalOperator> mapLogicalMultiplicityReducerToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext);
unique_ptr<PhysicalOperator> mapLogicalSkipToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext);
unique_ptr<PhysicalOperator> mapLogicalLimitToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext);
unique_ptr<PhysicalOperator> mapLogicalAggregateToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext);
unique_ptr<PhysicalOperator> mapLogicalDistinctToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext);
unique_ptr<PhysicalOperator> mapLogicalOrderByToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext);
unique_ptr<PhysicalOperator> mapLogicalUnionAllToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext);
unique_ptr<PhysicalOperator> mapLogicalAccumulateToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext);
LogicalOperator* logicalOperator);
unique_ptr<PhysicalOperator> mapLogicalSkipToPhysical(LogicalOperator* logicalOperator);
unique_ptr<PhysicalOperator> mapLogicalLimitToPhysical(LogicalOperator* logicalOperator);
unique_ptr<PhysicalOperator> mapLogicalAggregateToPhysical(LogicalOperator* logicalOperator);
unique_ptr<PhysicalOperator> mapLogicalDistinctToPhysical(LogicalOperator* logicalOperator);
unique_ptr<PhysicalOperator> mapLogicalOrderByToPhysical(LogicalOperator* logicalOperator);
unique_ptr<PhysicalOperator> mapLogicalUnionAllToPhysical(LogicalOperator* logicalOperator);
unique_ptr<PhysicalOperator> mapLogicalAccumulateToPhysical(LogicalOperator* logicalOperator);
unique_ptr<PhysicalOperator> mapLogicalExpressionsScanToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext);
unique_ptr<PhysicalOperator> mapLogicalFTableScanToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext);
unique_ptr<PhysicalOperator> mapLogicalCreateNodeToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext);
unique_ptr<PhysicalOperator> mapLogicalCreateRelToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext);
unique_ptr<PhysicalOperator> mapLogicalSetToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext);
unique_ptr<PhysicalOperator> mapLogicalDeleteNodeToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext);
unique_ptr<PhysicalOperator> mapLogicalDeleteRelToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext);
LogicalOperator* logicalOperator);
unique_ptr<PhysicalOperator> mapLogicalFTableScanToPhysical(LogicalOperator* logicalOperator);
unique_ptr<PhysicalOperator> mapLogicalCreateNodeToPhysical(LogicalOperator* logicalOperator);
unique_ptr<PhysicalOperator> mapLogicalCreateRelToPhysical(LogicalOperator* logicalOperator);
unique_ptr<PhysicalOperator> mapLogicalSetToPhysical(LogicalOperator* logicalOperator);
unique_ptr<PhysicalOperator> mapLogicalDeleteNodeToPhysical(LogicalOperator* logicalOperator);
unique_ptr<PhysicalOperator> mapLogicalDeleteRelToPhysical(LogicalOperator* logicalOperator);
unique_ptr<PhysicalOperator> mapLogicalCreateNodeTableToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext);
LogicalOperator* logicalOperator);
unique_ptr<PhysicalOperator> mapLogicalCreateRelTableToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext);
unique_ptr<PhysicalOperator> mapLogicalCopyCSVToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext);
unique_ptr<PhysicalOperator> mapLogicalDropTableToPhysical(
LogicalOperator* logicalOperator, MapperContext& mapperContext);
LogicalOperator* logicalOperator);
unique_ptr<PhysicalOperator> mapLogicalCopyCSVToPhysical(LogicalOperator* logicalOperator);
unique_ptr<PhysicalOperator> mapLogicalDropTableToPhysical(LogicalOperator* logicalOperator);

unique_ptr<ResultCollector> appendResultCollector(const expression_vector& expressionsToCollect,
const Schema& schema, unique_ptr<PhysicalOperator> prevOperator,
MapperContext& mapperContext);
const Schema& schema, unique_ptr<PhysicalOperator> prevOperator);

inline uint32_t getOperatorID() { return physicalOperatorID++; }

unique_ptr<PhysicalOperator> createHashAggregate(
vector<unique_ptr<AggregateFunction>> aggregateFunctions,
vector<DataPos> inputAggVectorsPos, vector<DataPos> outputAggVectorsPos,
vector<DataType> outputAggVectorsDataType, const expression_vector& groupByExpressions,
Schema* schema, unique_ptr<PhysicalOperator> prevOperator,
MapperContext& mapperContextBeforeAggregate, MapperContext& mapperContext,
unique_ptr<PhysicalOperator> prevOperator, const Schema& inSchema, const Schema& outSchema,
const string& paramsString);

void appendGroupByExpressions(const expression_vector& groupByExpressions,
vector<DataPos>& inputGroupByHashKeyVectorsPos, vector<DataPos>& outputGroupByKeyVectorsPos,
vector<DataType>& outputGroupByKeyVectorsDataTypes,
MapperContext& mapperContextBeforeAggregate, MapperContext& mapperContext, Schema* schema,
vector<bool>& isInputGroupByHashKeyVectorFlat);
vector<DataType>& outputGroupByKeyVectorsDataTypes, const Schema& inSchema,
const Schema& outSchema, vector<bool>& isInputGroupByHashKeyVectorFlat);

static BuildDataInfo generateBuildDataInfo(MapperContext& mapperContext,
Schema* buildSideSchema, const vector<shared_ptr<NodeExpression>>& keys,
const expression_vector& payloads);
static BuildDataInfo generateBuildDataInfo(const Schema& buildSideSchema,
const vector<shared_ptr<NodeExpression>>& keys, const expression_vector& payloads);

public:
StorageManager& storageManager;
Expand Down
7 changes: 0 additions & 7 deletions src/include/processor/result/result_set_descriptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,6 @@ class DataChunkDescriptor {
inline void setSingleState() { singleState = true; }
inline bool isSingleState() const { return singleState; }

inline uint32_t getValueVectorPos(const string& name) const {
assert(expressionNameToValueVectorPosMap.contains(name));
return expressionNameToValueVectorPosMap.at(name);
}

inline uint32_t getNumValueVectors() const { return expressions.size(); }

inline void addExpression(shared_ptr<Expression> expression) {
Expand All @@ -50,8 +45,6 @@ class ResultSetDescriptor {
ResultSetDescriptor(const ResultSetDescriptor& other);
~ResultSetDescriptor() = default;

DataPos getDataPos(const string& name) const;

inline uint32_t getNumDataChunks() const { return dataChunkDescriptors.size(); }

inline DataChunkDescriptor* getDataChunkDescriptor(uint32_t pos) const {
Expand Down
16 changes: 8 additions & 8 deletions src/processor/mapper/expression_mapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@ namespace kuzu {
namespace processor {

unique_ptr<BaseExpressionEvaluator> ExpressionMapper::mapExpression(
const shared_ptr<Expression>& expression, const MapperContext& mapperContext) {
const shared_ptr<Expression>& expression, const Schema& schema) {
auto expressionType = expression->expressionType;
if (mapperContext.expressionHasComputed(expression->getUniqueName())) {
return mapReferenceExpression(expression, mapperContext);
if (schema.isExpressionInScope(*expression)) {
return mapReferenceExpression(expression, schema);
} else if (isExpressionLiteral(expressionType)) {
return mapLiteralExpression(expression);
} else if (PARAMETER == expressionType) {
return mapParameterExpression((expression));
} else {
return mapFunctionExpression(expression, mapperContext);
return mapFunctionExpression(expression, schema);
}
}

Expand All @@ -38,16 +38,16 @@ unique_ptr<BaseExpressionEvaluator> ExpressionMapper::mapParameterExpression(
}

unique_ptr<BaseExpressionEvaluator> ExpressionMapper::mapReferenceExpression(
const shared_ptr<Expression>& expression, const MapperContext& mapperContext) {
auto vectorPos = mapperContext.getDataPos(expression->getUniqueName());
const shared_ptr<Expression>& expression, const Schema& schema) {
auto vectorPos = DataPos(schema.getExpressionPos(*expression));
return make_unique<ReferenceExpressionEvaluator>(vectorPos);
}

unique_ptr<BaseExpressionEvaluator> ExpressionMapper::mapFunctionExpression(
const shared_ptr<Expression>& expression, const MapperContext& mapperContext) {
const shared_ptr<Expression>& expression, const Schema& schema) {
vector<unique_ptr<BaseExpressionEvaluator>> children;
for (auto i = 0u; i < expression->getNumChildren(); ++i) {
children.push_back(mapExpression(expression->getChild(i), mapperContext));
children.push_back(mapExpression(expression->getChild(i), schema));
}
return make_unique<FunctionExpressionEvaluator>(expression, std::move(children));
}
Expand Down
Loading