Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement call statement #1717

Merged
merged 1 commit into from
Jun 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,19 @@ grammar Cypher;
}

oC_Cypher
: SP ? oC_AnyCypherOption? SP? ( oC_Statement | kU_DDL | kU_CopyNPY | kU_CopyCSV ) ( SP? ';' )? SP? EOF ;
: SP ? oC_AnyCypherOption? SP? ( oC_Statement ) ( SP? ';' )? SP? EOF ;

kU_CopyCSV
: COPY SP oC_SchemaName SP FROM SP kU_FilePaths ( SP? '(' SP? kU_ParsingOptions SP? ')' )? ;

kU_CopyNPY
: COPY SP oC_SchemaName SP FROM SP '(' SP? StringLiteral ( SP? ',' SP? StringLiteral )* ')' SP BY SP COLUMN ;

kU_Call
: CALL SP oC_SymbolicName SP? '=' SP? oC_Literal ;
acquamarin marked this conversation as resolved.
Show resolved Hide resolved

CALL : ( 'C' | 'c' ) ( 'A' | 'a' ) ( 'L' | 'l' ) ( 'L' | 'l' ) ;

kU_FilePaths
: '[' SP? StringLiteral ( SP? ',' SP? StringLiteral )* ']'
| StringLiteral
Expand Down Expand Up @@ -131,7 +136,11 @@ oC_Profile
PROFILE : ( 'P' | 'p' ) ( 'R' | 'r' ) ( 'O' | 'o' ) ( 'F' | 'f' ) ( 'I' | 'i' ) ( 'L' | 'l' ) ( 'E' | 'e' ) ;

oC_Statement
: oC_Query ;
: oC_Query
| kU_DDL
| kU_CopyNPY
| kU_CopyCSV
| kU_Call ;

oC_Query
: oC_RegularQuery ;
Expand Down
1 change: 1 addition & 0 deletions src/binder/bind/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_library(
kuzu_binder_bind
OBJECT
bind_call.cpp
bind_copy.cpp
bind_ddl.cpp
bind_graph_pattern.cpp
Expand Down
24 changes: 24 additions & 0 deletions src/binder/bind/bind_call.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include "binder/binder.h"
#include "binder/call/bound_call.h"
#include "common/string_utils.h"
#include "parser/call/call.h"

namespace kuzu {
namespace binder {

std::unique_ptr<BoundStatement> Binder::bindCallClause(const parser::Statement& statement) {
auto& callStatement = reinterpret_cast<const parser::Call&>(statement);
auto option = main::DBConfig::getOptionByName(callStatement.getOptionName());
if (option == nullptr) {
throw common::BinderException{
"Invalid option name: " + callStatement.getOptionName() + "."};
}
auto optionValue = expressionBinder.bindLiteralExpression(*callStatement.getOptionValue());
// TODO(Ziyi): add casting rule for option value.
ExpressionBinder::validateExpectedDataType(*optionValue, option->parameterType);
auto boundCall = std::make_unique<BoundCall>(*option, std::move(optionValue));
return boundCall;
}

} // namespace binder
} // namespace kuzu
3 changes: 3 additions & 0 deletions src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ std::unique_ptr<BoundStatement> Binder::bind(const Statement& statement) {
case StatementType::QUERY: {
return bindQuery((const RegularQuery&)statement);
}
case StatementType::CALL: {
return bindCallClause(statement);
}
default:
assert(false);
}
Expand Down
3 changes: 3 additions & 0 deletions src/binder/bound_statement_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ void BoundStatementVisitor::visit(const kuzu::binder::BoundStatement& statement)
case StatementType::COPY: {
visitCopy(statement);
} break;
case StatementType::CALL: {
visitCall(statement);
} break;
default:
throw NotImplementedException("BoundStatementVisitor::visit");
}
Expand Down
28 changes: 17 additions & 11 deletions src/function/built_in_vector_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ uint32_t BuiltInVectorOperations::getCastCost(
case common::LogicalTypeID::SERIAL:
return castSerial(targetTypeID);
default:
return UINT32_MAX;
return UNDEFINED_CAST_COST;
}
}
}
Expand All @@ -121,7 +121,13 @@ uint32_t BuiltInVectorOperations::getCastCost(
switch (inputType.getLogicalTypeID()) {
case common::LogicalTypeID::FIXED_LIST:
case common::LogicalTypeID::VAR_LIST:
return UINT32_MAX;
case common::LogicalTypeID::MAP:
case common::LogicalTypeID::UNION:
case common::LogicalTypeID::STRUCT:
case common::LogicalTypeID::INTERNAL_ID:
// TODO(Ziyi): add boolean cast operations.
case common::LogicalTypeID::BOOL:
acquamarin marked this conversation as resolved.
Show resolved Hide resolved
return UNDEFINED_CAST_COST;
default:
return getCastCost(inputType.getLogicalTypeID(), targetType.getLogicalTypeID());
}
Expand Down Expand Up @@ -157,7 +163,7 @@ uint32_t BuiltInVectorOperations::castInt64(common::LogicalTypeID targetTypeID)
case common::LogicalTypeID::DOUBLE:
return getTargetTypeCost(targetTypeID);
default:
return UINT32_MAX;
return UNDEFINED_CAST_COST;
}
}

Expand All @@ -168,7 +174,7 @@ uint32_t BuiltInVectorOperations::castInt32(common::LogicalTypeID targetTypeID)
case common::LogicalTypeID::DOUBLE:
return getTargetTypeCost(targetTypeID);
default:
return UINT32_MAX;
return UNDEFINED_CAST_COST;
}
}

Expand All @@ -180,14 +186,14 @@ uint32_t BuiltInVectorOperations::castInt16(common::LogicalTypeID targetTypeID)
case common::LogicalTypeID::DOUBLE:
return getTargetTypeCost(targetTypeID);
default:
return UINT32_MAX;
return UNDEFINED_CAST_COST;
}
}

uint32_t BuiltInVectorOperations::castDouble(common::LogicalTypeID targetTypeID) {
switch (targetTypeID) {
default:
return UINT32_MAX;
return UNDEFINED_CAST_COST;
}
}

Expand All @@ -196,7 +202,7 @@ uint32_t BuiltInVectorOperations::castFloat(common::LogicalTypeID targetTypeID)
case common::LogicalTypeID::DOUBLE:
return getTargetTypeCost(targetTypeID);
default:
return UINT32_MAX;
return UNDEFINED_CAST_COST;
}
}

Expand All @@ -205,7 +211,7 @@ uint32_t BuiltInVectorOperations::castDate(common::LogicalTypeID targetTypeID) {
case common::LogicalTypeID::TIMESTAMP:
return getTargetTypeCost(targetTypeID);
default:
return UINT32_MAX;
return UNDEFINED_CAST_COST;
}
}

Expand All @@ -224,7 +230,7 @@ VectorOperationDefinition* BuiltInVectorOperations::getBestMatch(
std::vector<VectorOperationDefinition*>& functions) {
assert(functions.size() > 1);
VectorOperationDefinition* result = nullptr;
auto cost = UINT32_MAX;
auto cost = UNDEFINED_CAST_COST;
for (auto& function : functions) {
std::unordered_set<LogicalTypeID> distinctParameterTypes;
for (auto& parameterTypeID : function->parameterTypeIDs) {
Expand Down Expand Up @@ -259,7 +265,7 @@ uint32_t BuiltInVectorOperations::matchParameters(const std::vector<LogicalType>
auto cost = 0u;
for (auto i = 0u; i < inputTypes.size(); ++i) {
auto castCost = getCastCost(inputTypes[i].getLogicalTypeID(), targetTypeIDs[i]);
if (castCost == UINT32_MAX) {
if (castCost == UNDEFINED_CAST_COST) {
return UINT32_MAX;
}
cost += castCost;
Expand All @@ -272,7 +278,7 @@ uint32_t BuiltInVectorOperations::matchVarLengthParameters(
auto cost = 0u;
for (auto& inputType : inputTypes) {
auto castCost = getCastCost(inputType.getLogicalTypeID(), targetTypeID);
if (castCost == UINT32_MAX) {
if (castCost == UNDEFINED_CAST_COST) {
return UINT32_MAX;
}
cost += castCost;
Expand Down
3 changes: 3 additions & 0 deletions src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ class Binder {
std::unique_ptr<BoundSingleQuery> bindSingleQuery(const parser::SingleQuery& singleQuery);
std::unique_ptr<BoundQueryPart> bindQueryPart(const parser::QueryPart& queryPart);

/*** bind call ***/
std::unique_ptr<BoundStatement> bindCallClause(const parser::Statement& statement);

/*** bind reading clause ***/
std::unique_ptr<BoundReadingClause> bindReadingClause(
const parser::ReadingClause& readingClause);
Expand Down
1 change: 1 addition & 0 deletions src/include/binder/bound_statement_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class BoundStatementVisitor {
virtual void visitDropProperty(const BoundStatement& statement) {}
virtual void visitRenameProperty(const BoundStatement& statement) {}
virtual void visitCopy(const BoundStatement& statement) {}
virtual void visitCall(const BoundStatement& statement) {}

void visitReadingClause(const BoundReadingClause& readingClause);
virtual void visitMatch(const BoundReadingClause& readingClause) {}
Expand Down
25 changes: 25 additions & 0 deletions src/include/binder/call/bound_call.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#pragma once

#include "binder/expression/expression.h"
#include "main/db_config.h"

namespace kuzu {
namespace binder {

class BoundCall : public BoundStatement {
public:
BoundCall(main::ConfigurationOption option, std::shared_ptr<Expression> optionValue)
: BoundStatement{common::StatementType::CALL, BoundStatementResult::createEmptyResult()},
option{option}, optionValue{optionValue} {}

inline main::ConfigurationOption getOption() const { return option; }

inline std::shared_ptr<Expression> getOptionValue() const { return optionValue; }

private:
main::ConfigurationOption option;
std::shared_ptr<Expression> optionValue;
};

} // namespace binder
} // namespace kuzu
1 change: 1 addition & 0 deletions src/include/common/statement_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ enum class StatementType : uint8_t {
DROP_PROPERTY = 7,
RENAME_PROPERTY = 8,
COPY = 20,
CALL = 21,
};

class StatementTypeUtils {
Expand Down
1 change: 1 addition & 0 deletions src/include/common/types/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ using struct_field_idx_t = uint64_t;
using union_field_idx_t = uint64_t;
constexpr struct_field_idx_t INVALID_STRUCT_FIELD_IDX = UINT64_MAX;
using tuple_idx_t = uint64_t;
constexpr uint32_t UNDEFINED_CAST_COST = UINT32_MAX;

// System representation for a variable-sized overflow value.
struct overflow_value_t {
Expand Down
7 changes: 5 additions & 2 deletions src/include/main/client_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ namespace main {

struct ActiveQuery {
explicit ActiveQuery();

std::atomic<bool> interrupted;
common::Timer timer;
};
Expand All @@ -26,6 +25,8 @@ class ClientContext {
friend class Connection;
friend class testing::TinySnbDDLTest;
friend class testing::TinySnbCopyCSVTransactionTest;
friend class ThreadsSetting;
friend class TimeoutSetting;

public:
explicit ClientContext();
Expand All @@ -36,7 +37,9 @@ class ClientContext {

bool isInterrupted() const { return activeQuery->interrupted; }

inline bool isTimeOut() { return activeQuery->timer.getElapsedTimeInMS() > timeoutInMS; }
inline bool isTimeOut() {
return isTimeOutEnabled() && activeQuery->timer.getElapsedTimeInMS() > timeoutInMS;
}

inline bool isTimeOutEnabled() const { return timeoutInMS != 0; }

Expand Down
6 changes: 6 additions & 0 deletions src/include/main/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,12 @@ class Connection {
*/
KUZU_API void setQueryTimeOut(uint64_t timeoutInMS);

/**
* @brief gets the query timeout value of the current connection. A value of zero (the default)
* disables the timeout.
*/
KUZU_API uint64_t getQueryTimeOut();

protected:
ConnectionTransactionMode getTransactionMode();
void setTransactionModeNoLock(ConnectionTransactionMode newTransactionMode);
Expand Down
23 changes: 23 additions & 0 deletions src/include/main/db_config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#pragma once

#include "common/types/types.h"
#include "main/client_context.h"

namespace kuzu {
namespace main {

typedef void (*set_context)(ClientContext* context, const common::Value& parameter);

struct ConfigurationOption {
const char* name;
common::LogicalTypeID parameterType;
set_context setContext;
};

class DBConfig {
public:
static ConfigurationOption* getOptionByName(const std::string& optionName);
};

} // namespace main
} // namespace kuzu
29 changes: 29 additions & 0 deletions src/include/main/settings.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#pragma once

#include "common/types/value.h"
#include "main/client_context.h"

namespace kuzu {
namespace main {

struct ThreadsSetting {
static constexpr const char* name = "threads";
static constexpr const common::LogicalTypeID inputType = common::LogicalTypeID::INT64;
acquamarin marked this conversation as resolved.
Show resolved Hide resolved
static void setContext(ClientContext* context, const common::Value& parameter) {
assert(parameter.getDataType().getLogicalTypeID() == common::LogicalTypeID::INT64);
context->numThreadsForExecution = parameter.getValue<int64_t>();
}
};

struct TimeoutSetting {
static constexpr const char* name = "timeout";
static constexpr const common::LogicalTypeID inputType = common::LogicalTypeID::INT64;
static void setContext(ClientContext* context, const common::Value& parameter) {
assert(parameter.getDataType().getLogicalTypeID() == common::LogicalTypeID::INT64);
context->timeoutInMS = parameter.getValue<int64_t>();
context->startTimingIfEnabled();
}
};

} // namespace main
} // namespace kuzu
25 changes: 25 additions & 0 deletions src/include/parser/call/call.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#pragma once

#include "parser/expression/parsed_expression.h"
#include "parser/statement.h"

namespace kuzu {
namespace parser {

class Call : public Statement {
public:
explicit Call(std::string optionName, std::unique_ptr<ParsedExpression> optionValue)
: Statement{common::StatementType::CALL}, optionName{std::move(optionName)},
optionValue{std::move(optionValue)} {}

inline std::string getOptionName() const { return optionName; }

inline ParsedExpression* getOptionValue() const { return optionValue.get(); }

private:
std::string optionName;
std::unique_ptr<ParsedExpression> optionValue;
};

} // namespace parser
} // namespace kuzu
4 changes: 4 additions & 0 deletions src/include/parser/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class Transformer {
std::unique_ptr<Statement> transform();

private:
std::unique_ptr<Statement> transformOcStatement(CypherParser::OC_StatementContext& ctx);

std::unique_ptr<RegularQuery> transformQuery(CypherParser::OC_QueryContext& ctx);

std::unique_ptr<RegularQuery> transformRegularQuery(CypherParser::OC_RegularQueryContext& ctx);
Expand Down Expand Up @@ -246,6 +248,8 @@ class Transformer {

std::unique_ptr<Statement> transformCopyNPY(CypherParser::KU_CopyNPYContext& ctx);

std::unique_ptr<Statement> transformCall(CypherParser::KU_CallContext& ctx);

std::vector<std::string> transformFilePaths(
std::vector<antlr4::tree::TerminalNode*> stringLiteral);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ enum class LogicalOperatorType : uint8_t {
ACCUMULATE,
ADD_PROPERTY,
AGGREGATE,
CALL,
COPY,
CREATE_NODE,
CREATE_REL,
Expand Down
Loading
Loading