Skip to content

Commit

Permalink
Add direct scan npy, add where predicate in LOAD FROM
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Sep 26, 2023
1 parent 9b4fd80 commit 0744f57
Show file tree
Hide file tree
Showing 16 changed files with 2,239 additions and 2,065 deletions.
2 changes: 1 addition & 1 deletion src/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ oC_ReadingClause
;

kU_LoadFrom
: LOAD SP FROM SP kU_FilePaths ( SP? '(' SP? kU_ParsingOptions SP? ')' )? ;
: LOAD SP FROM SP kU_FilePaths ( SP? '(' SP? kU_ParsingOptions SP? ')' )? (SP? oC_Where)? ;

LOAD : ( 'L' | 'l' ) ( 'O' | 'o' ) ( 'A' | 'a' ) ( 'D' | 'd' ) ;

Expand Down
33 changes: 29 additions & 4 deletions src/binder/bind/bind_reading_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "parser/query/reading_clause/load_from.h"
#include "parser/query/reading_clause/unwind_clause.h"
#include "processor/operator/persistent/reader/csv/csv_reader.h"
#include "processor/operator/persistent/reader/npy_reader.h"
#include "processor/operator/persistent/reader/parquet/parquet_reader.h"

using namespace kuzu::common;
Expand Down Expand Up @@ -45,8 +46,8 @@ std::unique_ptr<BoundReadingClause> Binder::bindMatchClause(const ReadingClause&
auto boundMatchClause = make_unique<BoundMatchClause>(
std::move(queryGraphCollection), matchClause.getMatchClauseType());
std::shared_ptr<Expression> whereExpression;
if (matchClause.hasWhereClause()) {
whereExpression = bindWhereExpression(*matchClause.getWhereClause());
if (matchClause.hasWherePredicate()) {
whereExpression = bindWhereExpression(*matchClause.getWherePredicate());
}
// Rewrite self loop edge
// e.g. rewrite (a)-[e]->(a) as [a]-[e]->(b) WHERE id(a) = id(b)
Expand Down Expand Up @@ -79,7 +80,7 @@ std::unique_ptr<BoundReadingClause> Binder::bindMatchClause(const ReadingClause&
expressionBinder.combineConjunctiveExpressions(predicate, whereExpression);
}

boundMatchClause->setWhereExpression(std::move(whereExpression));
boundMatchClause->setWherePredicate(std::move(whereExpression));
return boundMatchClause;
}

Expand Down Expand Up @@ -113,6 +114,17 @@ std::unique_ptr<BoundReadingClause> Binder::bindInQueryCall(const ReadingClause&
tableFunctionDefinition->tableFunc, std::move(bindData), std::move(outputExpressions));
}

static std::unique_ptr<LogicalType> bindFixedListType(

Check warning on line 117 in src/binder/bind/bind_reading_clause.cpp

View check run for this annotation

Codecov / codecov/patch

src/binder/bind/bind_reading_clause.cpp#L117

Added line #L117 was not covered by tests
const std::vector<size_t>& shape, LogicalTypeID typeID) {
if (shape.size() == 1) {
return std::make_unique<LogicalType>(typeID);

Check warning on line 120 in src/binder/bind/bind_reading_clause.cpp

View check run for this annotation

Codecov / codecov/patch

src/binder/bind/bind_reading_clause.cpp#L119-L120

Added lines #L119 - L120 were not covered by tests
}
auto childShape = std::vector<size_t>{shape.begin() + 1, shape.end()};
auto childType = bindFixedListType(childShape, typeID);
auto extraInfo = std::make_unique<FixedListTypeInfo>(std::move(childType), (uint32_t)shape[0]);
return std::make_unique<LogicalType>(LogicalTypeID::FIXED_LIST, std::move(extraInfo));

Check warning on line 125 in src/binder/bind/bind_reading_clause.cpp

View check run for this annotation

Codecov / codecov/patch

src/binder/bind/bind_reading_clause.cpp#L122-L125

Added lines #L122 - L125 were not covered by tests
}

std::unique_ptr<BoundReadingClause> Binder::bindLoadFrom(
const parser::ReadingClause& readingClause) {
auto& loadFrom = reinterpret_cast<const LoadFrom&>(readingClause);
Expand Down Expand Up @@ -151,13 +163,26 @@ std::unique_ptr<BoundReadingClause> Binder::bindLoadFrom(
columns.push_back(createVariable(columnName, *columnType));
}
} break;
case FileType::NPY: {
auto reader = NpyReader(readerConfig->filePaths[0]);
auto columnType = bindFixedListType(reader.getShape(), reader.getType());
auto columnName = std::string("column0");
readerConfig->columnNames.push_back(columnName);
readerConfig->columnTypes.push_back(columnType->copy());
columns.push_back(createVariable(columnName, *columnType));
} break;

Check warning on line 173 in src/binder/bind/bind_reading_clause.cpp

View check run for this annotation

Codecov / codecov/patch

src/binder/bind/bind_reading_clause.cpp#L167-L173

Added lines #L167 - L173 were not covered by tests
default:
throw BinderException(StringUtils::string_format(
"Load from {} file is not supported.", FileTypeUtils::toString(fileType)));
}
auto info = std::make_unique<BoundFileScanInfo>(
std::move(readerConfig), std::move(columns), nullptr, TableType::UNKNOWN);
return std::make_unique<BoundLoadFrom>(std::move(info));
auto boundLoadFrom = std::make_unique<BoundLoadFrom>(std::move(info));
if (loadFrom.hasWherePredicate()) {
auto wherePredicate = expressionBinder.bindExpression(*loadFrom.getWherePredicate());
boundLoadFrom->setWherePredicate(std::move(wherePredicate));
}
return boundLoadFrom;
}

} // namespace binder
Expand Down
12 changes: 10 additions & 2 deletions src/binder/visitor/property_collector.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "binder/visitor/property_collector.h"

#include "binder/expression_visitor.h"
#include "binder/query/reading_clause/bound_load_from.h"
#include "binder/query/reading_clause/bound_match_clause.h"
#include "binder/query/reading_clause/bound_unwind_clause.h"
#include "binder/query/updating_clause/bound_create_clause.h"
Expand Down Expand Up @@ -28,8 +29,8 @@ void PropertyCollector::visitMatch(const BoundReadingClause& readingClause) {
properties.insert(rel->getInternalIDProperty());
}
}
if (matchClause.hasWhereExpression()) {
collectPropertyExpressions(matchClause.getWhereExpression());
if (matchClause.hasWherePredicate()) {
collectPropertyExpressions(matchClause.getWherePredicate());
}
}

Expand All @@ -38,6 +39,13 @@ void PropertyCollector::visitUnwind(const BoundReadingClause& readingClause) {
collectPropertyExpressions(unwindClause.getExpression());
}

void PropertyCollector::visitLoadFrom(const BoundReadingClause& readingClause) {
auto& loadFromClause = reinterpret_cast<const BoundLoadFrom&>(readingClause);
if (loadFromClause.hasWherePredicate()) {
collectPropertyExpressions(loadFromClause.getWherePredicate());
}
}

void PropertyCollector::visitSet(const BoundUpdatingClause& updatingClause) {
auto& boundSetClause = (BoundSetClause&)updatingClause;
for (auto& info : boundSetClause.getInfosRef()) {
Expand Down
17 changes: 15 additions & 2 deletions src/include/binder/query/reading_clause/bound_load_from.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,30 @@ namespace binder {

class BoundLoadFrom : public BoundReadingClause {
public:
BoundLoadFrom(std::unique_ptr<BoundFileScanInfo> info)
explicit BoundLoadFrom(std::unique_ptr<BoundFileScanInfo> info)
: BoundReadingClause{common::ClauseType::LOAD_FROM}, info{std::move(info)} {}
BoundLoadFrom(const BoundLoadFrom& other)
: BoundReadingClause{common::ClauseType::LOAD_FROM}, info{other.info->copy()},
wherePredicate{other.wherePredicate} {}

Check warning on line 15 in src/include/binder/query/reading_clause/bound_load_from.h

View check run for this annotation

Codecov / codecov/patch

src/include/binder/query/reading_clause/bound_load_from.h#L13-L15

Added lines #L13 - L15 were not covered by tests

inline BoundFileScanInfo* getInfo() const { return info.get(); }

inline void setWherePredicate(std::shared_ptr<Expression> expression) {
wherePredicate = std::move(expression);
}
inline bool hasWherePredicate() const { return wherePredicate != nullptr; }
inline std::shared_ptr<Expression> getWherePredicate() const { return wherePredicate; }
inline expression_vector getPredicatesSplitOnAnd() const {
return hasWherePredicate() ? wherePredicate->splitOnAND() : expression_vector{};
}

inline std::unique_ptr<BoundReadingClause> copy() override {
return std::make_unique<BoundLoadFrom>(info->copy());
return std::make_unique<BoundLoadFrom>(*this);

Check warning on line 29 in src/include/binder/query/reading_clause/bound_load_from.h

View check run for this annotation

Codecov / codecov/patch

src/include/binder/query/reading_clause/bound_load_from.h#L29

Added line #L29 was not covered by tests
}

private:
std::unique_ptr<BoundFileScanInfo> info;
std::shared_ptr<Expression> wherePredicate;
};

} // namespace binder
Expand Down
14 changes: 7 additions & 7 deletions src/include/binder/query/reading_clause/bound_match_clause.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@ class BoundMatchClause : public BoundReadingClause {
BoundMatchClause(const BoundMatchClause& other)
: BoundReadingClause(common::ClauseType::MATCH),
queryGraphCollection{other.queryGraphCollection->copy()},
whereExpression(other.whereExpression), matchClauseType{other.matchClauseType} {}
wherePredicate(other.wherePredicate), matchClauseType{other.matchClauseType} {}

Check warning on line 19 in src/include/binder/query/reading_clause/bound_match_clause.h

View check run for this annotation

Codecov / codecov/patch

src/include/binder/query/reading_clause/bound_match_clause.h#L19

Added line #L19 was not covered by tests

inline QueryGraphCollection* getQueryGraphCollection() const {
return queryGraphCollection.get();
}

inline void setWhereExpression(std::shared_ptr<Expression> expression) {
whereExpression = std::move(expression);
inline void setWherePredicate(std::shared_ptr<Expression> expression) {
wherePredicate = std::move(expression);
}
inline bool hasWhereExpression() const { return whereExpression != nullptr; }
inline std::shared_ptr<Expression> getWhereExpression() const { return whereExpression; }
inline bool hasWherePredicate() const { return wherePredicate != nullptr; }
inline std::shared_ptr<Expression> getWherePredicate() const { return wherePredicate; }
inline expression_vector getPredicatesSplitOnAnd() const {
return hasWhereExpression() ? whereExpression->splitOnAND() : expression_vector{};
return hasWherePredicate() ? wherePredicate->splitOnAND() : expression_vector{};
}

inline common::MatchClauseType getMatchClauseType() const { return matchClauseType; }
Expand All @@ -39,7 +39,7 @@ class BoundMatchClause : public BoundReadingClause {

private:
std::unique_ptr<QueryGraphCollection> queryGraphCollection;
std::shared_ptr<Expression> whereExpression;
std::shared_ptr<Expression> wherePredicate;
common::MatchClauseType matchClauseType;
};

Expand Down
1 change: 1 addition & 0 deletions src/include/binder/visitor/property_collector.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class PropertyCollector : public BoundStatementVisitor {
private:
void visitMatch(const BoundReadingClause& readingClause) final;
void visitUnwind(const BoundReadingClause& readingClause) final;
void visitLoadFrom(const BoundReadingClause& readingClause) final;

void visitSet(const BoundUpdatingClause& updatingClause) final;
void visitDelete(const BoundUpdatingClause& updatingClause) final;
Expand Down
7 changes: 7 additions & 0 deletions src/include/parser/query/reading_clause/load_from.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,16 @@ class LoadFrom : public ReadingClause {
inline std::vector<std::string> getFilePaths() const { return filePaths; }
inline const parsing_option_t& getParsingOptionsRef() const { return parsingOptions; }

inline void setWherePredicate(std::unique_ptr<ParsedExpression> expression) {
wherePredicate = std::move(expression);
}
inline bool hasWherePredicate() const { return wherePredicate != nullptr; }
inline ParsedExpression* getWherePredicate() const { return wherePredicate.get(); }

private:
std::vector<std::string> filePaths;
parsing_option_t parsingOptions;
std::unique_ptr<ParsedExpression> wherePredicate;
};

} // namespace parser
Expand Down
11 changes: 5 additions & 6 deletions src/include/parser/query/reading_clause/match_clause.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,17 @@ class MatchClause : public ReadingClause {
return patternElements;
}

inline void setWhereClause(std::unique_ptr<ParsedExpression> expression) {
whereClause = std::move(expression);
inline void setWherePredicate(std::unique_ptr<ParsedExpression> expression) {
wherePredicate = std::move(expression);
}

inline bool hasWhereClause() const { return whereClause != nullptr; }
inline ParsedExpression* getWhereClause() const { return whereClause.get(); }
inline bool hasWherePredicate() const { return wherePredicate != nullptr; }
inline ParsedExpression* getWherePredicate() const { return wherePredicate.get(); }

inline common::MatchClauseType getMatchClauseType() const { return matchClauseType; }

private:
std::vector<std::unique_ptr<PatternElement>> patternElements;
std::unique_ptr<ParsedExpression> whereClause;
std::unique_ptr<ParsedExpression> wherePredicate;
common::MatchClauseType matchClauseType;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ class NpyReader {

// Used in tests only.
inline common::LogicalTypeID getType() const { return type; }
inline std::vector<size_t> const& getShape() const { return shape; }
inline size_t getNumDimensions() const { return shape.size(); }
inline std::vector<size_t> getShape() const { return shape; }

Check warning on line 30 in src/include/processor/operator/persistent/reader/npy_reader.h

View check run for this annotation

Codecov / codecov/patch

src/include/processor/operator/persistent/reader/npy_reader.h#L30

Added line #L30 was not covered by tests

void validate(const common::LogicalType& type_, common::offset_t numRows);

Expand Down
8 changes: 6 additions & 2 deletions src/parser/transform/transform_reading_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ std::unique_ptr<ReadingClause> Transformer::transformMatch(CypherParser::OC_Matc
auto matchClause =
std::make_unique<MatchClause>(transformPattern(*ctx.oC_Pattern()), matchClauseType);
if (ctx.oC_Where()) {
matchClause->setWhereClause(transformWhere(*ctx.oC_Where()));
matchClause->setWherePredicate(transformWhere(*ctx.oC_Where()));
}
return matchClause;
}
Expand Down Expand Up @@ -60,7 +60,11 @@ std::unique_ptr<ReadingClause> Transformer::transformLoadFrom(
if (ctx.kU_ParsingOptions()) {
parsingOptions = transformParsingOptions(*ctx.kU_ParsingOptions());
}
return std::make_unique<LoadFrom>(std::move(filePaths), std::move(parsingOptions));
auto loadFrom = std::make_unique<LoadFrom>(std::move(filePaths), std::move(parsingOptions));
if (ctx.oC_Where()) {
loadFrom->setWherePredicate(transformWhere(*ctx.oC_Where()));
}
return loadFrom;
}

} // namespace parser
Expand Down
34 changes: 34 additions & 0 deletions src/planner/plan/plan_read.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "binder/expression_visitor.h"
#include "binder/query/reading_clause/bound_load_from.h"
#include "binder/query/reading_clause/bound_match_clause.h"
#include "planner/query_planner.h"
Expand Down Expand Up @@ -78,16 +79,49 @@ void QueryPlanner::planInQueryCall(
}
}

static bool hasExternalDependency(const std::shared_ptr<Expression>& expression,
const std::unordered_set<std::string>& variableNameSet) {
auto collector = ExpressionCollector();
for (auto& name : collector.getDependentVariableNames(expression)) {
if (!variableNameSet.contains(name)) {
return true;
}
}
return false;
}

void QueryPlanner::planLoadFrom(
binder::BoundReadingClause* readingClause, std::vector<std::unique_ptr<LogicalPlan>>& plans) {
auto loadFrom = reinterpret_cast<BoundLoadFrom*>(readingClause);
std::unordered_set<std::string> columnNameSet;
for (auto& column : loadFrom->getInfo()->columns) {
columnNameSet.insert(column->getUniqueName());
}
expression_vector predicatesToPushDown;
expression_vector predicatesToPullUp;
for (auto& predicate : loadFrom->getPredicatesSplitOnAnd()) {
if (hasExternalDependency(predicate, columnNameSet)) {
predicatesToPullUp.push_back(predicate);
} else {
predicatesToPushDown.push_back(predicate);
}
}
for (auto& plan : plans) {
if (!plan->isEmpty()) {
auto tmpPlan = std::make_unique<LogicalPlan>();
appendScanFile(loadFrom->getInfo(), *tmpPlan);
if (!predicatesToPushDown.empty()) {
appendFilters(predicatesToPushDown, *tmpPlan);

Check warning on line 114 in src/planner/plan/plan_read.cpp

View check run for this annotation

Codecov / codecov/patch

src/planner/plan/plan_read.cpp#L114

Added line #L114 was not covered by tests
}
appendCrossProduct(AccumulateType::REGULAR, *plan, *tmpPlan);
} else {
appendScanFile(loadFrom->getInfo(), *plan);
if (!predicatesToPushDown.empty()) {
appendFilters(predicatesToPushDown, *plan);
}
}
if (!predicatesToPullUp.empty()) {
appendFilter(loadFrom->getWherePredicate(), *plan);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ NpyMultiFileReader::NpyMultiFileReader(const std::vector<std::string>& filePaths
}

void NpyMultiFileReader::readBlock(block_idx_t blockIdx, common::DataChunk* dataChunkToRead) const {
assert(fileReaders.size() > 1);
for (auto i = 0u; i < fileReaders.size(); i++) {
fileReaders[i]->readBlock(blockIdx, dataChunkToRead->getValueVector(i).get());
}
Expand Down
2 changes: 1 addition & 1 deletion test/test_files/copy/copy_to.test
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Binder exception: Cannot find property non_prop for a.
-LOG Non-Query command
-STATEMENT COPY (EXPLAIN MATCH (p:person) RETURN p.ID) TO "test.csv"
---- error
Parser exception: mismatched input 'EXPLAIN' expecting {CALL, LOAD, OPTIONAL, MATCH, UNWIND, CREATE, MERGE, SET, DELETE, WITH, RETURN} (line: 1, offset: 6)
Parser exception: extraneous input 'EXPLAIN' expecting {CALL, LOAD, OPTIONAL, MATCH, UNWIND, CREATE, MERGE, SET, DELETE, WITH, RETURN, SP} (line: 1, offset: 6)
"COPY (EXPLAIN MATCH (p:person) RETURN p.ID) TO "test.csv""
^^^^^^^

Expand Down
39 changes: 38 additions & 1 deletion test/test_files/tinysnb/load_from/load_from.test
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,37 @@
-DATASET CSV tinysnb

--
-CASE LoadFromParquetTest1

-CASE LoadFromNpyTest
-SKIP
-STATEMENT LOAD FROM "${KUZU_ROOT_DIRECTORY}/dataset/npy-1d/one_dim_double.npy" RETURN * ORDER BY column0 LIMIT 5;
---- 3
1.000000
2.000000
3.000000
-STATEMENT LOAD FROM "${KUZU_ROOT_DIRECTORY}/dataset/npy-1d/one_dim_int64.npy" WHERE column0 > 1 RETURN * ;
---- 2
2
3
-STATEMENT LOAD FROM "${KUZU_ROOT_DIRECTORY}/dataset/npy-2d/two_dim_int64.npy" RETURN * ;
---- 3
[1,2,3]
[4,5,6]
[7,8,9]

-CASE LoadFromParquetTest
-SKIP
-STATEMENT LOAD FROM "${KUZU_ROOT_DIRECTORY}/dataset/copy-test/node/parquet/types_50k_0.parquet" RETURN * ORDER BY id LIMIT 5;
---- 5
0|73|3.258507|True|1994-01-12|FrPZkcHFuepVxcAiMwyAsRqDlRtQx|[65,25]|[4deQc5]|[[163,237],[28,60,77,31,137],[286,186,249,206]]|{id: 764, name: CwFRaCoEp}
1|17|45.692842|False|2077-04-16|La|[64,41]|[Yq029g79TUAiq9VA,5h5ozRjtfsxbtCeb,2WLnSHVZojagYe,3HsiFD7b7DRk6n]|[[189,84,16],[143,135],[284,182,219,45],[250,143,195,210,244],[31,85]]|{id: 461, name: PmAvlzC0MVN2kr5}
2|30|13.397253|True|2015-01-06|uQJCBEePLuGkoAp|[47,27,57,46]|[INx9T8cF,fQds,GVbSmwovuURxXiRQ6vI3]|[[89,232],[186,224],[278,106,154]]|{id: 275, name: LeJHI4vdgjFDl}
3|4|3.174669|True|2104-03-14|fjyKxMjhXXgCkZmwBACpRrjNHlhrDtkQPl|[58,77,66,48]|[SUFT8NmyhMQ,DaTDnzkotQ2pjvdCN]|[[44]]|{id: 545, name: 0jhUkRv7R8}
4|99|17.608944|True|2089-10-27||[78,93,50,3]|[7Jyqki,Y0FQsTGx,7LqWTypucemvMYm,t5spe07tWSCJ]|[[267,172,283],[74,37],[148,62,96,47],[277,95]]|{id: 460, name: 1e6nIx}
-STATEMENT LOAD FROM "${KUZU_ROOT_DIRECTORY}/dataset/copy-test/node/parquet/types_50k_0.parquet"
WHERE id = 2 RETURN column1, column2;
---- 1
30|13.397253
-STATEMENT LOAD FROM "${KUZU_ROOT_DIRECTORY}/dataset/copy-test/node/parquet/types_50k_0.parquet" RETURN id, column1, column2 ORDER BY column1, id LIMIT 3;
---- 3
20|0|57.579280
Expand Down Expand Up @@ -64,3 +87,17 @@ Binder exception: Load from TURTLE file is not supported.
3|4|Carol
5|6|Dan
7|6|Elizabeth
-STATEMENT MATCH (a:person) WHERE a.ID < 2
LOAD FROM "${KUZU_ROOT_DIRECTORY}/dataset/tinysnb/vPerson.csv" (HEADER=True)
WHERE column1 = "Alice" or column1 = "Bob" AND a.fName = column1
RETURN column0, column1;
---- 1
0|Alice
-STATEMENT LOAD FROM "${KUZU_ROOT_DIRECTORY}/dataset/tinysnb/eWorkAt.csv" (HEADER=False)
WHERE column0 = "3"
WITH column0 AS a, column1 AS b
LOAD FROM "${KUZU_ROOT_DIRECTORY}/dataset/tinysnb/vPerson.csv" (HEADER=True)
WHERE column0 = a
RETURN a, b, column1;
---- 1
3|4|Carol
Loading

0 comments on commit 0744f57

Please sign in to comment.