Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add direct scan npy, add where predicate in LOAD FROM #2093

Merged
merged 1 commit into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ oC_ReadingClause
;

kU_LoadFrom
: LOAD SP FROM SP kU_FilePaths ( SP? '(' SP? kU_ParsingOptions SP? ')' )? ;
: LOAD SP FROM SP kU_FilePaths ( SP? '(' SP? kU_ParsingOptions SP? ')' )? (SP? oC_Where)? ;

LOAD : ( 'L' | 'l' ) ( 'O' | 'o' ) ( 'A' | 'a' ) ( 'D' | 'd' ) ;

Expand Down
33 changes: 29 additions & 4 deletions src/binder/bind/bind_reading_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "parser/query/reading_clause/load_from.h"
#include "parser/query/reading_clause/unwind_clause.h"
#include "processor/operator/persistent/reader/csv/csv_reader.h"
#include "processor/operator/persistent/reader/npy_reader.h"
#include "processor/operator/persistent/reader/parquet/parquet_reader.h"

using namespace kuzu::common;
Expand Down Expand Up @@ -45,8 +46,8 @@ std::unique_ptr<BoundReadingClause> Binder::bindMatchClause(const ReadingClause&
auto boundMatchClause = make_unique<BoundMatchClause>(
std::move(queryGraphCollection), matchClause.getMatchClauseType());
std::shared_ptr<Expression> whereExpression;
if (matchClause.hasWhereClause()) {
whereExpression = bindWhereExpression(*matchClause.getWhereClause());
if (matchClause.hasWherePredicate()) {
whereExpression = bindWhereExpression(*matchClause.getWherePredicate());
}
// Rewrite self loop edge
// e.g. rewrite (a)-[e]->(a) as [a]-[e]->(b) WHERE id(a) = id(b)
Expand Down Expand Up @@ -79,7 +80,7 @@ std::unique_ptr<BoundReadingClause> Binder::bindMatchClause(const ReadingClause&
expressionBinder.combineConjunctiveExpressions(predicate, whereExpression);
}

boundMatchClause->setWhereExpression(std::move(whereExpression));
boundMatchClause->setWherePredicate(std::move(whereExpression));
return boundMatchClause;
}

Expand Down Expand Up @@ -113,6 +114,17 @@ std::unique_ptr<BoundReadingClause> Binder::bindInQueryCall(const ReadingClause&
tableFunctionDefinition->tableFunc, std::move(bindData), std::move(outputExpressions));
}

static std::unique_ptr<LogicalType> bindFixedListType(
const std::vector<size_t>& shape, LogicalTypeID typeID) {
if (shape.size() == 1) {
return std::make_unique<LogicalType>(typeID);
}
auto childShape = std::vector<size_t>{shape.begin() + 1, shape.end()};
auto childType = bindFixedListType(childShape, typeID);
auto extraInfo = std::make_unique<FixedListTypeInfo>(std::move(childType), (uint32_t)shape[0]);
return std::make_unique<LogicalType>(LogicalTypeID::FIXED_LIST, std::move(extraInfo));
}

std::unique_ptr<BoundReadingClause> Binder::bindLoadFrom(
const parser::ReadingClause& readingClause) {
auto& loadFrom = reinterpret_cast<const LoadFrom&>(readingClause);
Expand Down Expand Up @@ -151,13 +163,26 @@ std::unique_ptr<BoundReadingClause> Binder::bindLoadFrom(
columns.push_back(createVariable(columnName, *columnType));
}
} break;
case FileType::NPY: {
auto reader = NpyReader(readerConfig->filePaths[0]);
auto columnType = bindFixedListType(reader.getShape(), reader.getType());
auto columnName = std::string("column0");
readerConfig->columnNames.push_back(columnName);
readerConfig->columnTypes.push_back(columnType->copy());
columns.push_back(createVariable(columnName, *columnType));
} break;
default:
throw BinderException(StringUtils::string_format(
"Load from {} file is not supported.", FileTypeUtils::toString(fileType)));
}
auto info = std::make_unique<BoundFileScanInfo>(
std::move(readerConfig), std::move(columns), nullptr, TableType::UNKNOWN);
return std::make_unique<BoundLoadFrom>(std::move(info));
auto boundLoadFrom = std::make_unique<BoundLoadFrom>(std::move(info));
if (loadFrom.hasWherePredicate()) {
auto wherePredicate = expressionBinder.bindExpression(*loadFrom.getWherePredicate());
boundLoadFrom->setWherePredicate(std::move(wherePredicate));
}
return boundLoadFrom;
}

} // namespace binder
Expand Down
12 changes: 10 additions & 2 deletions src/binder/visitor/property_collector.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "binder/visitor/property_collector.h"

#include "binder/expression_visitor.h"
#include "binder/query/reading_clause/bound_load_from.h"
#include "binder/query/reading_clause/bound_match_clause.h"
#include "binder/query/reading_clause/bound_unwind_clause.h"
#include "binder/query/updating_clause/bound_create_clause.h"
Expand Down Expand Up @@ -28,8 +29,8 @@ void PropertyCollector::visitMatch(const BoundReadingClause& readingClause) {
properties.insert(rel->getInternalIDProperty());
}
}
if (matchClause.hasWhereExpression()) {
collectPropertyExpressions(matchClause.getWhereExpression());
if (matchClause.hasWherePredicate()) {
collectPropertyExpressions(matchClause.getWherePredicate());
}
}

Expand All @@ -38,6 +39,13 @@ void PropertyCollector::visitUnwind(const BoundReadingClause& readingClause) {
collectPropertyExpressions(unwindClause.getExpression());
}

void PropertyCollector::visitLoadFrom(const BoundReadingClause& readingClause) {
auto& loadFromClause = reinterpret_cast<const BoundLoadFrom&>(readingClause);
if (loadFromClause.hasWherePredicate()) {
collectPropertyExpressions(loadFromClause.getWherePredicate());
}
}

void PropertyCollector::visitSet(const BoundUpdatingClause& updatingClause) {
auto& boundSetClause = (BoundSetClause&)updatingClause;
for (auto& info : boundSetClause.getInfosRef()) {
Expand Down
17 changes: 15 additions & 2 deletions src/include/binder/query/reading_clause/bound_load_from.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,30 @@

class BoundLoadFrom : public BoundReadingClause {
public:
BoundLoadFrom(std::unique_ptr<BoundFileScanInfo> info)
explicit BoundLoadFrom(std::unique_ptr<BoundFileScanInfo> info)
: BoundReadingClause{common::ClauseType::LOAD_FROM}, info{std::move(info)} {}
BoundLoadFrom(const BoundLoadFrom& other)
: BoundReadingClause{common::ClauseType::LOAD_FROM}, info{other.info->copy()},
wherePredicate{other.wherePredicate} {}

Check warning on line 15 in src/include/binder/query/reading_clause/bound_load_from.h

View check run for this annotation

Codecov / codecov/patch

src/include/binder/query/reading_clause/bound_load_from.h#L13-L15

Added lines #L13 - L15 were not covered by tests

inline BoundFileScanInfo* getInfo() const { return info.get(); }

inline void setWherePredicate(std::shared_ptr<Expression> expression) {
wherePredicate = std::move(expression);
}
inline bool hasWherePredicate() const { return wherePredicate != nullptr; }
inline std::shared_ptr<Expression> getWherePredicate() const { return wherePredicate; }
inline expression_vector getPredicatesSplitOnAnd() const {
return hasWherePredicate() ? wherePredicate->splitOnAND() : expression_vector{};
}

inline std::unique_ptr<BoundReadingClause> copy() override {
return std::make_unique<BoundLoadFrom>(info->copy());
return std::make_unique<BoundLoadFrom>(*this);

Check warning on line 29 in src/include/binder/query/reading_clause/bound_load_from.h

View check run for this annotation

Codecov / codecov/patch

src/include/binder/query/reading_clause/bound_load_from.h#L29

Added line #L29 was not covered by tests
}

private:
std::unique_ptr<BoundFileScanInfo> info;
std::shared_ptr<Expression> wherePredicate;
};

} // namespace binder
Expand Down
14 changes: 7 additions & 7 deletions src/include/binder/query/reading_clause/bound_match_clause.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@
BoundMatchClause(const BoundMatchClause& other)
: BoundReadingClause(common::ClauseType::MATCH),
queryGraphCollection{other.queryGraphCollection->copy()},
whereExpression(other.whereExpression), matchClauseType{other.matchClauseType} {}
wherePredicate(other.wherePredicate), matchClauseType{other.matchClauseType} {}

Check warning on line 19 in src/include/binder/query/reading_clause/bound_match_clause.h

View check run for this annotation

Codecov / codecov/patch

src/include/binder/query/reading_clause/bound_match_clause.h#L19

Added line #L19 was not covered by tests

inline QueryGraphCollection* getQueryGraphCollection() const {
return queryGraphCollection.get();
}

inline void setWhereExpression(std::shared_ptr<Expression> expression) {
whereExpression = std::move(expression);
inline void setWherePredicate(std::shared_ptr<Expression> expression) {
wherePredicate = std::move(expression);
}
inline bool hasWhereExpression() const { return whereExpression != nullptr; }
inline std::shared_ptr<Expression> getWhereExpression() const { return whereExpression; }
inline bool hasWherePredicate() const { return wherePredicate != nullptr; }
inline std::shared_ptr<Expression> getWherePredicate() const { return wherePredicate; }
inline expression_vector getPredicatesSplitOnAnd() const {
return hasWhereExpression() ? whereExpression->splitOnAND() : expression_vector{};
return hasWherePredicate() ? wherePredicate->splitOnAND() : expression_vector{};
}

inline common::MatchClauseType getMatchClauseType() const { return matchClauseType; }
Expand All @@ -39,7 +39,7 @@

private:
std::unique_ptr<QueryGraphCollection> queryGraphCollection;
std::shared_ptr<Expression> whereExpression;
std::shared_ptr<Expression> wherePredicate;
common::MatchClauseType matchClauseType;
};

Expand Down
1 change: 1 addition & 0 deletions src/include/binder/visitor/property_collector.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class PropertyCollector : public BoundStatementVisitor {
private:
void visitMatch(const BoundReadingClause& readingClause) final;
void visitUnwind(const BoundReadingClause& readingClause) final;
void visitLoadFrom(const BoundReadingClause& readingClause) final;

void visitSet(const BoundUpdatingClause& updatingClause) final;
void visitDelete(const BoundUpdatingClause& updatingClause) final;
Expand Down
7 changes: 7 additions & 0 deletions src/include/parser/query/reading_clause/load_from.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,16 @@ class LoadFrom : public ReadingClause {
inline std::vector<std::string> getFilePaths() const { return filePaths; }
inline const parsing_option_t& getParsingOptionsRef() const { return parsingOptions; }

inline void setWherePredicate(std::unique_ptr<ParsedExpression> expression) {
wherePredicate = std::move(expression);
}
inline bool hasWherePredicate() const { return wherePredicate != nullptr; }
inline ParsedExpression* getWherePredicate() const { return wherePredicate.get(); }

private:
std::vector<std::string> filePaths;
parsing_option_t parsingOptions;
std::unique_ptr<ParsedExpression> wherePredicate;
};

} // namespace parser
Expand Down
11 changes: 5 additions & 6 deletions src/include/parser/query/reading_clause/match_clause.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,17 @@ class MatchClause : public ReadingClause {
return patternElements;
}

inline void setWhereClause(std::unique_ptr<ParsedExpression> expression) {
whereClause = std::move(expression);
inline void setWherePredicate(std::unique_ptr<ParsedExpression> expression) {
wherePredicate = std::move(expression);
}

inline bool hasWhereClause() const { return whereClause != nullptr; }
inline ParsedExpression* getWhereClause() const { return whereClause.get(); }
inline bool hasWherePredicate() const { return wherePredicate != nullptr; }
inline ParsedExpression* getWherePredicate() const { return wherePredicate.get(); }

inline common::MatchClauseType getMatchClauseType() const { return matchClauseType; }

private:
std::vector<std::unique_ptr<PatternElement>> patternElements;
std::unique_ptr<ParsedExpression> whereClause;
std::unique_ptr<ParsedExpression> wherePredicate;
common::MatchClauseType matchClauseType;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ class NpyReader {

// Used in tests only.
inline common::LogicalTypeID getType() const { return type; }
inline std::vector<size_t> const& getShape() const { return shape; }
inline size_t getNumDimensions() const { return shape.size(); }
inline std::vector<size_t> getShape() const { return shape; }

void validate(const common::LogicalType& type_, common::offset_t numRows);

Expand Down
8 changes: 6 additions & 2 deletions src/parser/transform/transform_reading_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ std::unique_ptr<ReadingClause> Transformer::transformMatch(CypherParser::OC_Matc
auto matchClause =
std::make_unique<MatchClause>(transformPattern(*ctx.oC_Pattern()), matchClauseType);
if (ctx.oC_Where()) {
matchClause->setWhereClause(transformWhere(*ctx.oC_Where()));
matchClause->setWherePredicate(transformWhere(*ctx.oC_Where()));
}
return matchClause;
}
Expand Down Expand Up @@ -60,7 +60,11 @@ std::unique_ptr<ReadingClause> Transformer::transformLoadFrom(
if (ctx.kU_ParsingOptions()) {
parsingOptions = transformParsingOptions(*ctx.kU_ParsingOptions());
}
return std::make_unique<LoadFrom>(std::move(filePaths), std::move(parsingOptions));
auto loadFrom = std::make_unique<LoadFrom>(std::move(filePaths), std::move(parsingOptions));
if (ctx.oC_Where()) {
loadFrom->setWherePredicate(transformWhere(*ctx.oC_Where()));
}
return loadFrom;
}

} // namespace parser
Expand Down
34 changes: 34 additions & 0 deletions src/planner/plan/plan_read.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "binder/expression_visitor.h"
#include "binder/query/reading_clause/bound_load_from.h"
#include "binder/query/reading_clause/bound_match_clause.h"
#include "planner/query_planner.h"
Expand Down Expand Up @@ -78,16 +79,49 @@
}
}

static bool hasExternalDependency(const std::shared_ptr<Expression>& expression,
const std::unordered_set<std::string>& variableNameSet) {
auto collector = ExpressionCollector();
for (auto& name : collector.getDependentVariableNames(expression)) {
if (!variableNameSet.contains(name)) {
return true;
}
}
return false;
}

void QueryPlanner::planLoadFrom(
binder::BoundReadingClause* readingClause, std::vector<std::unique_ptr<LogicalPlan>>& plans) {
auto loadFrom = reinterpret_cast<BoundLoadFrom*>(readingClause);
std::unordered_set<std::string> columnNameSet;
for (auto& column : loadFrom->getInfo()->columns) {
columnNameSet.insert(column->getUniqueName());
}
expression_vector predicatesToPushDown;
expression_vector predicatesToPullUp;
for (auto& predicate : loadFrom->getPredicatesSplitOnAnd()) {
if (hasExternalDependency(predicate, columnNameSet)) {
predicatesToPullUp.push_back(predicate);
} else {
predicatesToPushDown.push_back(predicate);
}
}
for (auto& plan : plans) {
if (!plan->isEmpty()) {
auto tmpPlan = std::make_unique<LogicalPlan>();
appendScanFile(loadFrom->getInfo(), *tmpPlan);
if (!predicatesToPushDown.empty()) {
appendFilters(predicatesToPushDown, *tmpPlan);

Check warning on line 114 in src/planner/plan/plan_read.cpp

View check run for this annotation

Codecov / codecov/patch

src/planner/plan/plan_read.cpp#L114

Added line #L114 was not covered by tests
}
appendCrossProduct(AccumulateType::REGULAR, *plan, *tmpPlan);
} else {
appendScanFile(loadFrom->getInfo(), *plan);
if (!predicatesToPushDown.empty()) {
appendFilters(predicatesToPushDown, *plan);
}
}
if (!predicatesToPullUp.empty()) {
appendFilter(loadFrom->getWherePredicate(), *plan);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ NpyMultiFileReader::NpyMultiFileReader(const std::vector<std::string>& filePaths
}

void NpyMultiFileReader::readBlock(block_idx_t blockIdx, common::DataChunk* dataChunkToRead) const {
assert(fileReaders.size() > 1);
for (auto i = 0u; i < fileReaders.size(); i++) {
fileReaders[i]->readBlock(blockIdx, dataChunkToRead->getValueVector(i).get());
}
Expand Down
2 changes: 1 addition & 1 deletion test/test_files/copy/copy_to.test
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Binder exception: Cannot find property non_prop for a.
-LOG Non-Query command
-STATEMENT COPY (EXPLAIN MATCH (p:person) RETURN p.ID) TO "test.csv"
---- error
Parser exception: mismatched input 'EXPLAIN' expecting {CALL, LOAD, OPTIONAL, MATCH, UNWIND, CREATE, MERGE, SET, DELETE, WITH, RETURN} (line: 1, offset: 6)
Parser exception: extraneous input 'EXPLAIN' expecting {CALL, LOAD, OPTIONAL, MATCH, UNWIND, CREATE, MERGE, SET, DELETE, WITH, RETURN, SP} (line: 1, offset: 6)
"COPY (EXPLAIN MATCH (p:person) RETURN p.ID) TO "test.csv""
^^^^^^^

Expand Down
37 changes: 36 additions & 1 deletion test/test_files/tinysnb/load_from/load_from.test
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,35 @@
-DATASET CSV tinysnb

--
-CASE LoadFromParquetTest1

-CASE LoadFromNpyTest
-STATEMENT LOAD FROM "${KUZU_ROOT_DIRECTORY}/dataset/npy-1d/one_dim_double.npy" RETURN * ORDER BY column0 LIMIT 5;
---- 3
1.000000
2.000000
3.000000
-STATEMENT LOAD FROM "${KUZU_ROOT_DIRECTORY}/dataset/npy-1d/one_dim_int64.npy" WHERE column0 > 1 RETURN * ;
---- 2
2
3
-STATEMENT LOAD FROM "${KUZU_ROOT_DIRECTORY}/dataset/npy-2d/two_dim_int64.npy" RETURN * ;
---- 3
[1,2,3]
[4,5,6]
[7,8,9]

-CASE LoadFromParquetTest
-STATEMENT LOAD FROM "${KUZU_ROOT_DIRECTORY}/dataset/copy-test/node/parquet/types_50k_0.parquet" RETURN * ORDER BY id LIMIT 5;
---- 5
0|73|3.258507|True|1994-01-12|FrPZkcHFuepVxcAiMwyAsRqDlRtQx|[65,25]|[4deQc5]|[[163,237],[28,60,77,31,137],[286,186,249,206]]|{id: 764, name: CwFRaCoEp}
1|17|45.692842|False|2077-04-16|La|[64,41]|[Yq029g79TUAiq9VA,5h5ozRjtfsxbtCeb,2WLnSHVZojagYe,3HsiFD7b7DRk6n]|[[189,84,16],[143,135],[284,182,219,45],[250,143,195,210,244],[31,85]]|{id: 461, name: PmAvlzC0MVN2kr5}
2|30|13.397253|True|2015-01-06|uQJCBEePLuGkoAp|[47,27,57,46]|[INx9T8cF,fQds,GVbSmwovuURxXiRQ6vI3]|[[89,232],[186,224],[278,106,154]]|{id: 275, name: LeJHI4vdgjFDl}
3|4|3.174669|True|2104-03-14|fjyKxMjhXXgCkZmwBACpRrjNHlhrDtkQPl|[58,77,66,48]|[SUFT8NmyhMQ,DaTDnzkotQ2pjvdCN]|[[44]]|{id: 545, name: 0jhUkRv7R8}
4|99|17.608944|True|2089-10-27||[78,93,50,3]|[7Jyqki,Y0FQsTGx,7LqWTypucemvMYm,t5spe07tWSCJ]|[[267,172,283],[74,37],[148,62,96,47],[277,95]]|{id: 460, name: 1e6nIx}
-STATEMENT LOAD FROM "${KUZU_ROOT_DIRECTORY}/dataset/copy-test/node/parquet/types_50k_0.parquet"
WHERE id = 2 RETURN column1, column2;
---- 1
30|13.397253
-STATEMENT LOAD FROM "${KUZU_ROOT_DIRECTORY}/dataset/copy-test/node/parquet/types_50k_0.parquet" RETURN id, column1, column2 ORDER BY column1, id LIMIT 3;
---- 3
20|0|57.579280
Expand Down Expand Up @@ -64,3 +85,17 @@ Binder exception: Load from TURTLE file is not supported.
3|4|Carol
5|6|Dan
7|6|Elizabeth
-STATEMENT MATCH (a:person) WHERE a.ID < 2
LOAD FROM "${KUZU_ROOT_DIRECTORY}/dataset/tinysnb/vPerson.csv" (HEADER=True)
WHERE column1 = "Alice" or column1 = "Bob" AND a.fName = column1
RETURN column0, column1;
---- 1
0|Alice
-STATEMENT LOAD FROM "${KUZU_ROOT_DIRECTORY}/dataset/tinysnb/eWorkAt.csv" (HEADER=False)
WHERE column0 = "3"
WITH column0 AS a, column1 AS b
LOAD FROM "${KUZU_ROOT_DIRECTORY}/dataset/tinysnb/vPerson.csv" (HEADER=True)
WHERE column0 = a
RETURN a, b, column1;
---- 1
3|4|Carol
Loading