Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add primary key information to show_connection #3372

Merged
merged 5 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/antlr4/Cypher.g4.copy
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ USE:
( 'U' | 'u') ( 'S' | 's') ( 'E' | 'e');

kU_StandaloneCall
: CALL SP oC_SymbolicName SP? '=' SP? oC_Literal ;
: CALL SP oC_SymbolicName SP? '=' SP? oC_Expression ;

CALL : ( 'C' | 'c' ) ( 'A' | 'a' ) ( 'L' | 'l' ) ( 'L' | 'l' ) ;

Expand Down
2 changes: 1 addition & 1 deletion src/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ USE:
( 'U' | 'u') ( 'S' | 's') ( 'E' | 'e');

kU_StandaloneCall
: CALL SP oC_SymbolicName SP? '=' SP? oC_Literal ;
: CALL SP oC_SymbolicName SP? '=' SP? oC_Expression ;

CALL : ( 'C' | 'c' ) ( 'A' | 'a' ) ( 'L' | 'l' ) ( 'L' | 'l' ) ;

Expand Down
11 changes: 8 additions & 3 deletions src/binder/bind/bind_standalone_call.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "binder/binder.h"
#include "binder/bound_standalone_call.h"
#include "binder/expression/expression_util.h"
#include "binder/expression_visitor.h"
#include "common/exception/binder.h"
#include "extension/extension.h"
#include "main/db_config.h"
Expand All @@ -22,9 +23,13 @@ std::unique_ptr<BoundStatement> Binder::bindStandaloneCall(const parser::Stateme
if (option == nullptr) {
throw BinderException{"Invalid option name: " + callStatement.getOptionName() + "."};
}
auto optionValue = expressionBinder.bindLiteralExpression(*callStatement.getOptionValue());
// TODO(Ziyi): add casting rule for option value.
ExpressionUtil::validateDataType(*optionValue, option->parameterType);
auto optionValue = expressionBinder.bindExpression(*callStatement.getOptionValue());
ExpressionUtil::validateExpressionType(*optionValue, ExpressionType::LITERAL);
optionValue =
expressionBinder.implicitCastIfNecessary(optionValue, LogicalType(option->parameterType));
if (ExpressionVisitor::needFold(*optionValue)) {
manh9203 marked this conversation as resolved.
Show resolved Hide resolved
optionValue = expressionBinder.foldExpression(optionValue);
}
return std::make_unique<BoundStandaloneCall>(option, std::move(optionValue));
}

Expand Down
19 changes: 11 additions & 8 deletions src/binder/bind/read/bind_in_query_call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@ std::unique_ptr<BoundReadingClause> Binder::bindInQueryCall(const ReadingClause&
auto expr = call.getFunctionExpression();
auto functionExpr = expr->constPtrCast<ParsedFunctionExpression>();
auto functionName = functionExpr->getFunctionName();
expression_vector params;
expression_vector children;
for (auto i = 0u; i < functionExpr->getNumChildren(); i++) {
auto child = functionExpr->getChild(i);
params.push_back(expressionBinder.bindExpression(*child));
auto childExpr = functionExpr->getChild(i);
children.push_back(expressionBinder.bindExpression(*childExpr));
}
TableFunction tableFunction;
std::vector<Value> inputValues;
std::vector<LogicalType> inputTypes;
for (auto& param : params) {
ExpressionUtil::validateExpressionType(*param, ExpressionType::LITERAL);
auto literalExpr = param->constPtrCast<LiteralExpression>();
for (auto& child : children) {
ExpressionUtil::validateExpressionType(*child, ExpressionType::LITERAL);
auto literalExpr = child->constPtrCast<LiteralExpression>();
inputTypes.push_back(literalExpr->getDataType());
inputValues.push_back(*literalExpr->getValue());
}
Expand All @@ -40,9 +40,12 @@ std::unique_ptr<BoundReadingClause> Binder::bindInQueryCall(const ReadingClause&
if (functionEntry->getType() != CatalogEntryType::TABLE_FUNCTION_ENTRY) {
throw BinderException(stringFormat("{} is not a table function.", functionName));
}
auto func = BuiltInFunctionsUtils::matchFunction(functionExpr->getFunctionName(), inputTypes,
functionEntry);
auto func = BuiltInFunctionsUtils::matchFunction(functionName, inputTypes, functionEntry);
tableFunction = *func->constPtrCast<TableFunction>();
for (auto i = 0u; i < children.size(); ++i) {
auto parameterTypeID = tableFunction.parameterTypeIDs[i];
ExpressionUtil::validateDataType(*children[i], parameterTypeID);
}
auto bindInput = function::TableFuncBindInput();
bindInput.inputs = std::move(inputValues);
auto bindData = tableFunction.bindFunc(clientContext, &bindInput);
Expand Down
3 changes: 2 additions & 1 deletion src/binder/expression_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ bool ExpressionVisitor::isConstant(const Expression& expression) {
if (expression.expressionType == ExpressionType::AGGREGATE_FUNCTION) {
return false; // We don't have a framework to fold aggregated constant.
}
if (expression.getNumChildren() == 0) {
if (expression.getNumChildren() == 0 &&
manh9203 marked this conversation as resolved.
Show resolved Hide resolved
expression.expressionType != ExpressionType::CASE_ELSE) {
return expression.expressionType == ExpressionType::LITERAL;
}
for (auto& child : ExpressionChildrenCollector::collectChildren(expression)) {
Expand Down
44 changes: 29 additions & 15 deletions src/function/table/call/show_connection.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "catalog/catalog.h"
#include "catalog/catalog_entry/node_table_catalog_entry.h"
#include "catalog/catalog_entry/rel_group_catalog_entry.h"
#include "catalog/catalog_entry/rel_table_catalog_entry.h"
#include "common/exception/binder.h"
Expand Down Expand Up @@ -28,23 +29,37 @@ struct ShowConnectionBindData : public CallTableFuncBindData {
}
};

static void outputRelTableConnection(ValueVector* srcTableNameVector,
ValueVector* dstTableNameVector, uint64_t outputPos, ClientContext* context,
table_id_t tableID) {
static void outputRelTableConnection(DataChunk& outputDataChunk, uint64_t outputPos,
ClientContext* context, table_id_t tableID) {
auto catalog = context->getCatalog();
auto tableEntry = catalog->getTableCatalogEntry(context->getTx(), tableID);
auto relTableEntry = ku_dynamic_cast<TableCatalogEntry*, RelTableCatalogEntry*>(tableEntry);
KU_ASSERT(tableEntry->getTableType() == TableType::REL);
// Get src and dst name
auto srcTableID = relTableEntry->getSrcTableID();
auto dstTableID = relTableEntry->getDstTableID();
srcTableNameVector->setValue(outputPos, catalog->getTableName(context->getTx(), srcTableID));
dstTableNameVector->setValue(outputPos, catalog->getTableName(context->getTx(), dstTableID));
auto srcTableName = catalog->getTableName(context->getTx(), srcTableID);
auto dstTableName = catalog->getTableName(context->getTx(), dstTableID);
// Get src and dst primary key
auto srcTableEntry = catalog->getTableCatalogEntry(context->getTx(), srcTableID);
auto dstTableEntry = catalog->getTableCatalogEntry(context->getTx(), dstTableID);
auto srcTablePrimaryKey =
ku_dynamic_cast<TableCatalogEntry*, NodeTableCatalogEntry*>(srcTableEntry)
->getPrimaryKey()
->getName();
auto dstTablePrimaryKey =
ku_dynamic_cast<TableCatalogEntry*, NodeTableCatalogEntry*>(dstTableEntry)
->getPrimaryKey()
->getName();
// Write result to dataChunk
outputDataChunk.getValueVector(0)->setValue(outputPos, srcTableName);
outputDataChunk.getValueVector(1)->setValue(outputPos, dstTableName);
outputDataChunk.getValueVector(2)->setValue(outputPos, srcTablePrimaryKey);
outputDataChunk.getValueVector(3)->setValue(outputPos, dstTablePrimaryKey);
}

static common::offset_t tableFunc(TableFuncInput& input, TableFuncOutput& output) {
auto& dataChunk = output.dataChunk;
auto srcVector = dataChunk.getValueVector(0).get();
auto dstVector = dataChunk.getValueVector(1).get();
auto morsel = input.sharedState->ptrCast<CallFuncSharedState>()->getMorsel();
if (!morsel.hasMoreToOutput()) {
return 0;
Expand All @@ -55,15 +70,14 @@ static common::offset_t tableFunc(TableFuncInput& input, TableFuncOutput& output
auto vectorPos = 0u;
switch (tableEntry->getTableType()) {
case TableType::REL: {
outputRelTableConnection(srcVector, dstVector, vectorPos, bindData->context,
tableEntry->getTableID());
outputRelTableConnection(dataChunk, vectorPos, bindData->context, tableEntry->getTableID());
vectorPos++;
} break;
case TableType::REL_GROUP: {
auto relGroupEntry = ku_dynamic_cast<TableCatalogEntry*, RelGroupCatalogEntry*>(tableEntry);
auto relTableIDs = relGroupEntry->getRelTableIDs();
for (; vectorPos < numRelationsToOutput; vectorPos++) {
outputRelTableConnection(srcVector, dstVector, vectorPos, bindData->context,
outputRelTableConnection(dataChunk, vectorPos, bindData->context,
relTableIDs[morsel.startOffset + vectorPos]);
}
} break;
Expand All @@ -79,10 +93,6 @@ static std::unique_ptr<TableFuncBindData> bindFunc(ClientContext* context,
TableFuncBindInput* input) {
std::vector<std::string> columnNames;
std::vector<LogicalType> columnTypes;
// Special case here Due to any -> string, but lack implicit cast
if (input->inputs[0].getDataType()->getLogicalTypeID() != LogicalTypeID::STRING) {
throw BinderException{"Show connection can only bind to String!"};
}
auto tableName = input->inputs[0].getValue<std::string>();
auto catalog = context->getCatalog();
auto tableID = catalog->getTableID(context->getTx(), tableName);
Expand All @@ -95,6 +105,10 @@ static std::unique_ptr<TableFuncBindData> bindFunc(ClientContext* context,
columnTypes.emplace_back(*LogicalType::STRING());
columnNames.emplace_back("destination table name");
columnTypes.emplace_back(*LogicalType::STRING());
columnNames.emplace_back("source table primary key");
columnTypes.emplace_back(*LogicalType::STRING());
columnNames.emplace_back("destination table primary key");
columnTypes.emplace_back(*LogicalType::STRING());
common::offset_t maxOffset = 1;
if (tableEntry->getTableType() == common::TableType::REL_GROUP) {
auto relGroupEntry = ku_dynamic_cast<TableCatalogEntry*, RelGroupCatalogEntry*>(tableEntry);
Expand All @@ -106,7 +120,7 @@ static std::unique_ptr<TableFuncBindData> bindFunc(ClientContext* context,

function_set ShowConnectionFunction::getFunctionSet() {
function_set functionSet;
functionSet.push_back(std::make_unique<TableFunction>("db_version", tableFunc, bindFunc,
functionSet.push_back(std::make_unique<TableFunction>(name, tableFunc, bindFunc,
initSharedState, initEmptyLocalState, std::vector<LogicalTypeID>{LogicalTypeID::STRING}));
return functionSet;
}
Expand Down
2 changes: 1 addition & 1 deletion src/parser/transform/transform_standalone_call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace parser {
std::unique_ptr<Statement> Transformer::transformStandaloneCall(
CypherParser::KU_StandaloneCallContext& ctx) {
auto optionName = transformSymbolicName(*ctx.oC_SymbolicName());
auto parameter = transformLiteral(*ctx.oC_Literal());
auto parameter = transformExpression(*ctx.oC_Expression());
return std::make_unique<StandaloneCall>(std::move(optionName), std::move(parameter));
}

Expand Down
4 changes: 2 additions & 2 deletions src/processor/map/map_standalone_call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ namespace processor {

std::unique_ptr<PhysicalOperator> PlanMapper::mapStandaloneCall(
planner::LogicalOperator* logicalOperator) {
auto logicalStandaloneCall = reinterpret_cast<LogicalStandaloneCall*>(logicalOperator);
auto logicalStandaloneCall = logicalOperator->constPtrCast<LogicalStandaloneCall>();
auto optionValue =
reinterpret_cast<binder::LiteralExpression*>(logicalStandaloneCall->getOptionValue().get());
logicalStandaloneCall->getOptionValue()->constPtrCast<binder::LiteralExpression>();
auto standaloneCallInfo = std::make_unique<StandaloneCallInfo>(
logicalStandaloneCall->getOption(), *optionValue->getValue());
return std::make_unique<StandaloneCall>(std::move(standaloneCallInfo),
Expand Down
2 changes: 1 addition & 1 deletion test/test_files/exceptions/binder/binder_error.test
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ Binder exception: Invalid option name: thread.
-LOG InvalidCallOptionValue
-STATEMENT CALL threads='abc'
---- error
Binder exception: abc has data type STRING but INT64 was expected.
Binder exception: Expression abc has data type STRING but expected INT64. Implicit cast is not supported.

-LOG AllShortestPathInvalidLowerBound
-STATEMENT MATCH p = (a)-[* ALL SHORTEST 2..3]-(b) RETURN p
Expand Down
20 changes: 15 additions & 5 deletions test/test_files/tinysnb/call/call.test
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
-STATEMENT CALL current_setting('timeout') RETURN *
---- 1
20000
-STATEMENT CALL timeout=(1+2+3)*10000
---- ok
-STATEMENT CALL current_setting('timeout') RETURN *
---- 1
60000

-LOG SetGetVarLengthMaxDepth
-STATEMENT CALL var_length_extend_max_depth=10
Expand All @@ -53,6 +58,11 @@ True
-STATEMENT CALL current_setting('progress_bar') RETURN *
---- 1
False
-STATEMENT CALL progress_bar=CASE WHEN 1<2 THEN True ELSE False END
---- ok
-STATEMENT CALL current_setting('progress_bar') RETURN *
---- 1
True

-LOG SetGetProgressBarTime
-STATEMENT CALL progress_bar_time=4000
Expand Down Expand Up @@ -184,16 +194,16 @@ ${KUZU_VERSION}
-LOG ReturnTableConnection
-STATEMENT CALL show_connection('knows') RETURN *
---- 1
person|person
person|person|ID|ID
-STATEMENT CALL show_connection('workAt') RETURN *
---- 1
person|organisation
person|organisation|ID|ID
-STATEMENT CREATE REL TABLE GROUP Knows1 (FROM person To person, FROM person to organisation, year INT64);
---- ok
-STATEMENT CALL show_connection('Knows1') RETURN *
---- 2
person|person
person|organisation
person|person|ID|ID
person|organisation|ID|ID
-STATEMENT CALL show_connection('person') RETURN *
---- error
Binder exception: Show connection can only be called on a rel table!
Expand All @@ -207,7 +217,7 @@ Binder exception: Cannot match a built-in function for given function table_info
-LOG WrongParameterType
-STATEMENT CALL show_connection(123) RETURN *
---- error
Binder exception: Show connection can only bind to String!
Binder exception: 123 has data type INT64 but STRING was expected.

-LOG WrongParameterExprType
-STATEMENT CALL show_connection(upper("person")) RETURN *
Expand Down
8 changes: 4 additions & 4 deletions third_party/antlr4_cypher/cypher_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ void cypherParserInitialize() {
0,519,520,5,49,0,0,520,521,5,150,0,0,521,523,3,286,143,0,522,524,5,150,
0,0,523,522,1,0,0,0,523,524,1,0,0,0,524,525,1,0,0,0,525,527,5,6,0,0,526,
528,5,150,0,0,527,526,1,0,0,0,527,528,1,0,0,0,528,529,1,0,0,0,529,530,
3,238,119,0,530,27,1,0,0,0,531,532,5,50,0,0,532,533,5,150,0,0,533,534,
3,190,95,0,530,27,1,0,0,0,531,532,5,50,0,0,532,533,5,150,0,0,533,534,
5,94,0,0,534,535,5,150,0,0,535,536,5,60,0,0,536,537,5,150,0,0,537,538,
3,284,142,0,538,539,5,150,0,0,539,540,5,125,0,0,540,541,5,150,0,0,541,
542,5,136,0,0,542,29,1,0,0,0,543,544,5,92,0,0,544,545,5,150,0,0,545,546,
Expand Down Expand Up @@ -2612,8 +2612,8 @@ CypherParser::OC_SymbolicNameContext* CypherParser::KU_StandaloneCallContext::oC
return getRuleContext<CypherParser::OC_SymbolicNameContext>(0);
}

CypherParser::OC_LiteralContext* CypherParser::KU_StandaloneCallContext::oC_Literal() {
return getRuleContext<CypherParser::OC_LiteralContext>(0);
CypherParser::OC_ExpressionContext* CypherParser::KU_StandaloneCallContext::oC_Expression() {
return getRuleContext<CypherParser::OC_ExpressionContext>(0);
}


Expand Down Expand Up @@ -2661,7 +2661,7 @@ CypherParser::KU_StandaloneCallContext* CypherParser::kU_StandaloneCall() {
match(CypherParser::SP);
}
setState(529);
oC_Literal();
oC_Expression();

}
catch (RecognitionException &e) {
Expand Down
2 changes: 1 addition & 1 deletion third_party/antlr4_cypher/include/cypher_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ class CypherParser : public antlr4::Parser {
std::vector<antlr4::tree::TerminalNode *> SP();
antlr4::tree::TerminalNode* SP(size_t i);
OC_SymbolicNameContext *oC_SymbolicName();
OC_LiteralContext *oC_Literal();
OC_ExpressionContext *oC_Expression();


};
Expand Down
Loading