Skip to content

Commit

Permalink
Merge pull request #1717 from kuzudb/call
Browse files Browse the repository at this point in the history
Implement call statement
  • Loading branch information
acquamarin committed Jun 23, 2023
2 parents d91ca92 + bdec9d7 commit 2c21a5e
Show file tree
Hide file tree
Showing 42 changed files with 4,288 additions and 3,762 deletions.
13 changes: 11 additions & 2 deletions src/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,19 @@ grammar Cypher;
}

oC_Cypher
: SP ? oC_AnyCypherOption? SP? ( oC_Statement | kU_DDL | kU_CopyNPY | kU_CopyCSV ) ( SP? ';' )? SP? EOF ;
: SP ? oC_AnyCypherOption? SP? ( oC_Statement ) ( SP? ';' )? SP? EOF ;

kU_CopyCSV
: COPY SP oC_SchemaName SP FROM SP kU_FilePaths ( SP? '(' SP? kU_ParsingOptions SP? ')' )? ;

kU_CopyNPY
: COPY SP oC_SchemaName SP FROM SP '(' SP? StringLiteral ( SP? ',' SP? StringLiteral )* ')' SP BY SP COLUMN ;

kU_Call
: CALL SP oC_SymbolicName SP? '=' SP? oC_Literal ;

CALL : ( 'C' | 'c' ) ( 'A' | 'a' ) ( 'L' | 'l' ) ( 'L' | 'l' ) ;

kU_FilePaths
: '[' SP? StringLiteral ( SP? ',' SP? StringLiteral )* ']'
| StringLiteral
Expand Down Expand Up @@ -131,7 +136,11 @@ oC_Profile
PROFILE : ( 'P' | 'p' ) ( 'R' | 'r' ) ( 'O' | 'o' ) ( 'F' | 'f' ) ( 'I' | 'i' ) ( 'L' | 'l' ) ( 'E' | 'e' ) ;

oC_Statement
: oC_Query ;
: oC_Query
| kU_DDL
| kU_CopyNPY
| kU_CopyCSV
| kU_Call ;

oC_Query
: oC_RegularQuery ;
Expand Down
1 change: 1 addition & 0 deletions src/binder/bind/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_library(
kuzu_binder_bind
OBJECT
bind_call.cpp
bind_copy.cpp
bind_ddl.cpp
bind_graph_pattern.cpp
Expand Down
24 changes: 24 additions & 0 deletions src/binder/bind/bind_call.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include "binder/binder.h"
#include "binder/call/bound_call.h"
#include "common/string_utils.h"
#include "parser/call/call.h"

namespace kuzu {
namespace binder {

std::unique_ptr<BoundStatement> Binder::bindCallClause(const parser::Statement& statement) {
auto& callStatement = reinterpret_cast<const parser::Call&>(statement);
auto option = main::DBConfig::getOptionByName(callStatement.getOptionName());
if (option == nullptr) {
throw common::BinderException{
"Invalid option name: " + callStatement.getOptionName() + "."};
}
auto optionValue = expressionBinder.bindLiteralExpression(*callStatement.getOptionValue());
// TODO(Ziyi): add casting rule for option value.
ExpressionBinder::validateExpectedDataType(*optionValue, option->parameterType);
auto boundCall = std::make_unique<BoundCall>(*option, std::move(optionValue));
return boundCall;
}

} // namespace binder
} // namespace kuzu
3 changes: 3 additions & 0 deletions src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ std::unique_ptr<BoundStatement> Binder::bind(const Statement& statement) {
case StatementType::QUERY: {
return bindQuery((const RegularQuery&)statement);
}
case StatementType::CALL: {
return bindCallClause(statement);
}
default:
assert(false);
}
Expand Down
3 changes: 3 additions & 0 deletions src/binder/bound_statement_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ void BoundStatementVisitor::visit(const kuzu::binder::BoundStatement& statement)
case StatementType::COPY: {
visitCopy(statement);
} break;
case StatementType::CALL: {
visitCall(statement);
} break;
default:
throw NotImplementedException("BoundStatementVisitor::visit");
}
Expand Down
28 changes: 17 additions & 11 deletions src/function/built_in_vector_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ uint32_t BuiltInVectorOperations::getCastCost(
case common::LogicalTypeID::SERIAL:
return castSerial(targetTypeID);
default:
return UINT32_MAX;
return UNDEFINED_CAST_COST;
}
}
}
Expand All @@ -121,7 +121,13 @@ uint32_t BuiltInVectorOperations::getCastCost(
switch (inputType.getLogicalTypeID()) {
case common::LogicalTypeID::FIXED_LIST:
case common::LogicalTypeID::VAR_LIST:
return UINT32_MAX;
case common::LogicalTypeID::MAP:
case common::LogicalTypeID::UNION:
case common::LogicalTypeID::STRUCT:
case common::LogicalTypeID::INTERNAL_ID:
// TODO(Ziyi): add boolean cast operations.
case common::LogicalTypeID::BOOL:
return UNDEFINED_CAST_COST;
default:
return getCastCost(inputType.getLogicalTypeID(), targetType.getLogicalTypeID());
}
Expand Down Expand Up @@ -157,7 +163,7 @@ uint32_t BuiltInVectorOperations::castInt64(common::LogicalTypeID targetTypeID)
case common::LogicalTypeID::DOUBLE:
return getTargetTypeCost(targetTypeID);
default:
return UINT32_MAX;
return UNDEFINED_CAST_COST;
}
}

Expand All @@ -168,7 +174,7 @@ uint32_t BuiltInVectorOperations::castInt32(common::LogicalTypeID targetTypeID)
case common::LogicalTypeID::DOUBLE:
return getTargetTypeCost(targetTypeID);
default:
return UINT32_MAX;
return UNDEFINED_CAST_COST;
}
}

Expand All @@ -180,14 +186,14 @@ uint32_t BuiltInVectorOperations::castInt16(common::LogicalTypeID targetTypeID)
case common::LogicalTypeID::DOUBLE:
return getTargetTypeCost(targetTypeID);
default:
return UINT32_MAX;
return UNDEFINED_CAST_COST;
}
}

uint32_t BuiltInVectorOperations::castDouble(common::LogicalTypeID targetTypeID) {
switch (targetTypeID) {
default:
return UINT32_MAX;
return UNDEFINED_CAST_COST;
}
}

Expand All @@ -196,7 +202,7 @@ uint32_t BuiltInVectorOperations::castFloat(common::LogicalTypeID targetTypeID)
case common::LogicalTypeID::DOUBLE:
return getTargetTypeCost(targetTypeID);
default:
return UINT32_MAX;
return UNDEFINED_CAST_COST;
}
}

Expand All @@ -205,7 +211,7 @@ uint32_t BuiltInVectorOperations::castDate(common::LogicalTypeID targetTypeID) {
case common::LogicalTypeID::TIMESTAMP:
return getTargetTypeCost(targetTypeID);
default:
return UINT32_MAX;
return UNDEFINED_CAST_COST;
}
}

Expand All @@ -224,7 +230,7 @@ VectorOperationDefinition* BuiltInVectorOperations::getBestMatch(
std::vector<VectorOperationDefinition*>& functions) {
assert(functions.size() > 1);
VectorOperationDefinition* result = nullptr;
auto cost = UINT32_MAX;
auto cost = UNDEFINED_CAST_COST;
for (auto& function : functions) {
std::unordered_set<LogicalTypeID> distinctParameterTypes;
for (auto& parameterTypeID : function->parameterTypeIDs) {
Expand Down Expand Up @@ -259,7 +265,7 @@ uint32_t BuiltInVectorOperations::matchParameters(const std::vector<LogicalType>
auto cost = 0u;
for (auto i = 0u; i < inputTypes.size(); ++i) {
auto castCost = getCastCost(inputTypes[i].getLogicalTypeID(), targetTypeIDs[i]);
if (castCost == UINT32_MAX) {
if (castCost == UNDEFINED_CAST_COST) {
return UINT32_MAX;
}
cost += castCost;
Expand All @@ -272,7 +278,7 @@ uint32_t BuiltInVectorOperations::matchVarLengthParameters(
auto cost = 0u;
for (auto& inputType : inputTypes) {
auto castCost = getCastCost(inputType.getLogicalTypeID(), targetTypeID);
if (castCost == UINT32_MAX) {
if (castCost == UNDEFINED_CAST_COST) {
return UINT32_MAX;
}
cost += castCost;
Expand Down
3 changes: 3 additions & 0 deletions src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ class Binder {
std::unique_ptr<BoundSingleQuery> bindSingleQuery(const parser::SingleQuery& singleQuery);
std::unique_ptr<BoundQueryPart> bindQueryPart(const parser::QueryPart& queryPart);

/*** bind call ***/
std::unique_ptr<BoundStatement> bindCallClause(const parser::Statement& statement);

/*** bind reading clause ***/
std::unique_ptr<BoundReadingClause> bindReadingClause(
const parser::ReadingClause& readingClause);
Expand Down
1 change: 1 addition & 0 deletions src/include/binder/bound_statement_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class BoundStatementVisitor {
virtual void visitDropProperty(const BoundStatement& statement) {}
virtual void visitRenameProperty(const BoundStatement& statement) {}
virtual void visitCopy(const BoundStatement& statement) {}
virtual void visitCall(const BoundStatement& statement) {}

void visitReadingClause(const BoundReadingClause& readingClause);
virtual void visitMatch(const BoundReadingClause& readingClause) {}
Expand Down
25 changes: 25 additions & 0 deletions src/include/binder/call/bound_call.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#pragma once

#include "binder/expression/expression.h"
#include "main/db_config.h"

namespace kuzu {
namespace binder {

class BoundCall : public BoundStatement {
public:
BoundCall(main::ConfigurationOption option, std::shared_ptr<Expression> optionValue)
: BoundStatement{common::StatementType::CALL, BoundStatementResult::createEmptyResult()},
option{option}, optionValue{optionValue} {}

inline main::ConfigurationOption getOption() const { return option; }

inline std::shared_ptr<Expression> getOptionValue() const { return optionValue; }

private:
main::ConfigurationOption option;
std::shared_ptr<Expression> optionValue;
};

} // namespace binder
} // namespace kuzu
1 change: 1 addition & 0 deletions src/include/common/statement_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ enum class StatementType : uint8_t {
DROP_PROPERTY = 7,
RENAME_PROPERTY = 8,
COPY = 20,
CALL = 21,
};

class StatementTypeUtils {
Expand Down
1 change: 1 addition & 0 deletions src/include/common/types/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ using struct_field_idx_t = uint64_t;
using union_field_idx_t = uint64_t;
constexpr struct_field_idx_t INVALID_STRUCT_FIELD_IDX = UINT64_MAX;
using tuple_idx_t = uint64_t;
constexpr uint32_t UNDEFINED_CAST_COST = UINT32_MAX;

// System representation for a variable-sized overflow value.
struct overflow_value_t {
Expand Down
7 changes: 5 additions & 2 deletions src/include/main/client_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ namespace main {

struct ActiveQuery {
explicit ActiveQuery();

std::atomic<bool> interrupted;
common::Timer timer;
};
Expand All @@ -26,6 +25,8 @@ class ClientContext {
friend class Connection;
friend class testing::TinySnbDDLTest;
friend class testing::TinySnbCopyCSVTransactionTest;
friend class ThreadsSetting;
friend class TimeoutSetting;

public:
explicit ClientContext();
Expand All @@ -36,7 +37,9 @@ class ClientContext {

bool isInterrupted() const { return activeQuery->interrupted; }

inline bool isTimeOut() { return activeQuery->timer.getElapsedTimeInMS() > timeoutInMS; }
inline bool isTimeOut() {
return isTimeOutEnabled() && activeQuery->timer.getElapsedTimeInMS() > timeoutInMS;
}

inline bool isTimeOutEnabled() const { return timeoutInMS != 0; }

Expand Down
6 changes: 6 additions & 0 deletions src/include/main/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,12 @@ class Connection {
*/
KUZU_API void setQueryTimeOut(uint64_t timeoutInMS);

/**
* @brief gets the query timeout value of the current connection. A value of zero (the default)
* disables the timeout.
*/
KUZU_API uint64_t getQueryTimeOut();

protected:
ConnectionTransactionMode getTransactionMode();
void setTransactionModeNoLock(ConnectionTransactionMode newTransactionMode);
Expand Down
23 changes: 23 additions & 0 deletions src/include/main/db_config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#pragma once

#include "common/types/types.h"
#include "main/client_context.h"

namespace kuzu {
namespace main {

typedef void (*set_context)(ClientContext* context, const common::Value& parameter);

struct ConfigurationOption {
const char* name;
common::LogicalTypeID parameterType;
set_context setContext;
};

class DBConfig {
public:
static ConfigurationOption* getOptionByName(const std::string& optionName);
};

} // namespace main
} // namespace kuzu
29 changes: 29 additions & 0 deletions src/include/main/settings.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#pragma once

#include "common/types/value.h"
#include "main/client_context.h"

namespace kuzu {
namespace main {

struct ThreadsSetting {
static constexpr const char* name = "threads";
static constexpr const common::LogicalTypeID inputType = common::LogicalTypeID::INT64;
static void setContext(ClientContext* context, const common::Value& parameter) {
assert(parameter.getDataType().getLogicalTypeID() == common::LogicalTypeID::INT64);
context->numThreadsForExecution = parameter.getValue<int64_t>();
}
};

struct TimeoutSetting {
static constexpr const char* name = "timeout";
static constexpr const common::LogicalTypeID inputType = common::LogicalTypeID::INT64;
static void setContext(ClientContext* context, const common::Value& parameter) {
assert(parameter.getDataType().getLogicalTypeID() == common::LogicalTypeID::INT64);
context->timeoutInMS = parameter.getValue<int64_t>();
context->startTimingIfEnabled();
}
};

} // namespace main
} // namespace kuzu
25 changes: 25 additions & 0 deletions src/include/parser/call/call.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#pragma once

#include "parser/expression/parsed_expression.h"
#include "parser/statement.h"

namespace kuzu {
namespace parser {

class Call : public Statement {
public:
explicit Call(std::string optionName, std::unique_ptr<ParsedExpression> optionValue)
: Statement{common::StatementType::CALL}, optionName{std::move(optionName)},
optionValue{std::move(optionValue)} {}

inline std::string getOptionName() const { return optionName; }

inline ParsedExpression* getOptionValue() const { return optionValue.get(); }

private:
std::string optionName;
std::unique_ptr<ParsedExpression> optionValue;
};

} // namespace parser
} // namespace kuzu
4 changes: 4 additions & 0 deletions src/include/parser/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class Transformer {
std::unique_ptr<Statement> transform();

private:
std::unique_ptr<Statement> transformOcStatement(CypherParser::OC_StatementContext& ctx);

std::unique_ptr<RegularQuery> transformQuery(CypherParser::OC_QueryContext& ctx);

std::unique_ptr<RegularQuery> transformRegularQuery(CypherParser::OC_RegularQueryContext& ctx);
Expand Down Expand Up @@ -246,6 +248,8 @@ class Transformer {

std::unique_ptr<Statement> transformCopyNPY(CypherParser::KU_CopyNPYContext& ctx);

std::unique_ptr<Statement> transformCall(CypherParser::KU_CallContext& ctx);

std::vector<std::string> transformFilePaths(
std::vector<antlr4::tree::TerminalNode*> stringLiteral);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ enum class LogicalOperatorType : uint8_t {
ACCUMULATE,
ADD_PROPERTY,
AGGREGATE,
CALL,
COPY,
CREATE_NODE,
CREATE_REL,
Expand Down
Loading

0 comments on commit 2c21a5e

Please sign in to comment.