Skip to content

Commit

Permalink
Add primary key information to show_connection (#3372)
Browse files Browse the repository at this point in the history
* Add node primary key to show_connection

* Add type check for in query call functions

* Enable literal expression for stand alone call option value
  • Loading branch information
manh9203 committed Apr 26, 2024
1 parent f61108a commit d9dcc0a
Show file tree
Hide file tree
Showing 12 changed files with 76 additions and 43 deletions.
2 changes: 1 addition & 1 deletion scripts/antlr4/Cypher.g4.copy
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ USE:
( 'U' | 'u') ( 'S' | 's') ( 'E' | 'e');

kU_StandaloneCall
: CALL SP oC_SymbolicName SP? '=' SP? oC_Literal ;
: CALL SP oC_SymbolicName SP? '=' SP? oC_Expression ;

CALL : ( 'C' | 'c' ) ( 'A' | 'a' ) ( 'L' | 'l' ) ( 'L' | 'l' ) ;

Expand Down
2 changes: 1 addition & 1 deletion src/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ USE:
( 'U' | 'u') ( 'S' | 's') ( 'E' | 'e');

kU_StandaloneCall
: CALL SP oC_SymbolicName SP? '=' SP? oC_Literal ;
: CALL SP oC_SymbolicName SP? '=' SP? oC_Expression ;

CALL : ( 'C' | 'c' ) ( 'A' | 'a' ) ( 'L' | 'l' ) ( 'L' | 'l' ) ;

Expand Down
11 changes: 8 additions & 3 deletions src/binder/bind/bind_standalone_call.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "binder/binder.h"
#include "binder/bound_standalone_call.h"
#include "binder/expression/expression_util.h"
#include "binder/expression_visitor.h"
#include "common/exception/binder.h"
#include "extension/extension.h"
#include "main/db_config.h"
Expand All @@ -22,9 +23,13 @@ std::unique_ptr<BoundStatement> Binder::bindStandaloneCall(const parser::Stateme
if (option == nullptr) {
throw BinderException{"Invalid option name: " + callStatement.getOptionName() + "."};
}
auto optionValue = expressionBinder.bindLiteralExpression(*callStatement.getOptionValue());
// TODO(Ziyi): add casting rule for option value.
ExpressionUtil::validateDataType(*optionValue, option->parameterType);
auto optionValue = expressionBinder.bindExpression(*callStatement.getOptionValue());
ExpressionUtil::validateExpressionType(*optionValue, ExpressionType::LITERAL);
optionValue =
expressionBinder.implicitCastIfNecessary(optionValue, LogicalType(option->parameterType));
if (ExpressionVisitor::needFold(*optionValue)) {
optionValue = expressionBinder.foldExpression(optionValue);
}
return std::make_unique<BoundStandaloneCall>(option, std::move(optionValue));
}

Expand Down
19 changes: 11 additions & 8 deletions src/binder/bind/read/bind_in_query_call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@ std::unique_ptr<BoundReadingClause> Binder::bindInQueryCall(const ReadingClause&
auto expr = call.getFunctionExpression();
auto functionExpr = expr->constPtrCast<ParsedFunctionExpression>();
auto functionName = functionExpr->getFunctionName();
expression_vector params;
expression_vector children;
for (auto i = 0u; i < functionExpr->getNumChildren(); i++) {
auto child = functionExpr->getChild(i);
params.push_back(expressionBinder.bindExpression(*child));
auto childExpr = functionExpr->getChild(i);
children.push_back(expressionBinder.bindExpression(*childExpr));
}
TableFunction tableFunction;
std::vector<Value> inputValues;
std::vector<LogicalType> inputTypes;
for (auto& param : params) {
ExpressionUtil::validateExpressionType(*param, ExpressionType::LITERAL);
auto literalExpr = param->constPtrCast<LiteralExpression>();
for (auto& child : children) {
ExpressionUtil::validateExpressionType(*child, ExpressionType::LITERAL);
auto literalExpr = child->constPtrCast<LiteralExpression>();
inputTypes.push_back(literalExpr->getDataType());
inputValues.push_back(*literalExpr->getValue());
}
Expand All @@ -40,9 +40,12 @@ std::unique_ptr<BoundReadingClause> Binder::bindInQueryCall(const ReadingClause&
if (functionEntry->getType() != CatalogEntryType::TABLE_FUNCTION_ENTRY) {
throw BinderException(stringFormat("{} is not a table function.", functionName));
}
auto func = BuiltInFunctionsUtils::matchFunction(functionExpr->getFunctionName(), inputTypes,
functionEntry);
auto func = BuiltInFunctionsUtils::matchFunction(functionName, inputTypes, functionEntry);
tableFunction = *func->constPtrCast<TableFunction>();
for (auto i = 0u; i < children.size(); ++i) {
auto parameterTypeID = tableFunction.parameterTypeIDs[i];
ExpressionUtil::validateDataType(*children[i], parameterTypeID);
}
auto bindInput = function::TableFuncBindInput();
bindInput.inputs = std::move(inputValues);
auto bindData = tableFunction.bindFunc(clientContext, &bindInput);
Expand Down
3 changes: 2 additions & 1 deletion src/binder/expression_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ bool ExpressionVisitor::isConstant(const Expression& expression) {
if (expression.expressionType == ExpressionType::AGGREGATE_FUNCTION) {
return false; // We don't have a framework to fold aggregated constant.
}
if (expression.getNumChildren() == 0) {
if (expression.getNumChildren() == 0 &&
expression.expressionType != ExpressionType::CASE_ELSE) {
return expression.expressionType == ExpressionType::LITERAL;
}
for (auto& child : ExpressionChildrenCollector::collectChildren(expression)) {
Expand Down
44 changes: 29 additions & 15 deletions src/function/table/call/show_connection.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "catalog/catalog.h"
#include "catalog/catalog_entry/node_table_catalog_entry.h"
#include "catalog/catalog_entry/rel_group_catalog_entry.h"
#include "catalog/catalog_entry/rel_table_catalog_entry.h"
#include "common/exception/binder.h"
Expand Down Expand Up @@ -28,23 +29,37 @@ struct ShowConnectionBindData : public CallTableFuncBindData {
}
};

static void outputRelTableConnection(ValueVector* srcTableNameVector,
ValueVector* dstTableNameVector, uint64_t outputPos, ClientContext* context,
table_id_t tableID) {
static void outputRelTableConnection(DataChunk& outputDataChunk, uint64_t outputPos,
ClientContext* context, table_id_t tableID) {
auto catalog = context->getCatalog();
auto tableEntry = catalog->getTableCatalogEntry(context->getTx(), tableID);
auto relTableEntry = ku_dynamic_cast<TableCatalogEntry*, RelTableCatalogEntry*>(tableEntry);
KU_ASSERT(tableEntry->getTableType() == TableType::REL);
// Get src and dst name
auto srcTableID = relTableEntry->getSrcTableID();
auto dstTableID = relTableEntry->getDstTableID();
srcTableNameVector->setValue(outputPos, catalog->getTableName(context->getTx(), srcTableID));
dstTableNameVector->setValue(outputPos, catalog->getTableName(context->getTx(), dstTableID));
auto srcTableName = catalog->getTableName(context->getTx(), srcTableID);
auto dstTableName = catalog->getTableName(context->getTx(), dstTableID);
// Get src and dst primary key
auto srcTableEntry = catalog->getTableCatalogEntry(context->getTx(), srcTableID);
auto dstTableEntry = catalog->getTableCatalogEntry(context->getTx(), dstTableID);
auto srcTablePrimaryKey =
ku_dynamic_cast<TableCatalogEntry*, NodeTableCatalogEntry*>(srcTableEntry)
->getPrimaryKey()
->getName();
auto dstTablePrimaryKey =
ku_dynamic_cast<TableCatalogEntry*, NodeTableCatalogEntry*>(dstTableEntry)
->getPrimaryKey()
->getName();
// Write result to dataChunk
outputDataChunk.getValueVector(0)->setValue(outputPos, srcTableName);
outputDataChunk.getValueVector(1)->setValue(outputPos, dstTableName);
outputDataChunk.getValueVector(2)->setValue(outputPos, srcTablePrimaryKey);
outputDataChunk.getValueVector(3)->setValue(outputPos, dstTablePrimaryKey);
}

static common::offset_t tableFunc(TableFuncInput& input, TableFuncOutput& output) {
auto& dataChunk = output.dataChunk;
auto srcVector = dataChunk.getValueVector(0).get();
auto dstVector = dataChunk.getValueVector(1).get();
auto morsel = input.sharedState->ptrCast<CallFuncSharedState>()->getMorsel();
if (!morsel.hasMoreToOutput()) {
return 0;
Expand All @@ -55,15 +70,14 @@ static common::offset_t tableFunc(TableFuncInput& input, TableFuncOutput& output
auto vectorPos = 0u;
switch (tableEntry->getTableType()) {
case TableType::REL: {
outputRelTableConnection(srcVector, dstVector, vectorPos, bindData->context,
tableEntry->getTableID());
outputRelTableConnection(dataChunk, vectorPos, bindData->context, tableEntry->getTableID());
vectorPos++;
} break;
case TableType::REL_GROUP: {
auto relGroupEntry = ku_dynamic_cast<TableCatalogEntry*, RelGroupCatalogEntry*>(tableEntry);
auto relTableIDs = relGroupEntry->getRelTableIDs();
for (; vectorPos < numRelationsToOutput; vectorPos++) {
outputRelTableConnection(srcVector, dstVector, vectorPos, bindData->context,
outputRelTableConnection(dataChunk, vectorPos, bindData->context,
relTableIDs[morsel.startOffset + vectorPos]);
}
} break;
Expand All @@ -79,10 +93,6 @@ static std::unique_ptr<TableFuncBindData> bindFunc(ClientContext* context,
TableFuncBindInput* input) {
std::vector<std::string> columnNames;
std::vector<LogicalType> columnTypes;
// Special case here Due to any -> string, but lack implicit cast
if (input->inputs[0].getDataType()->getLogicalTypeID() != LogicalTypeID::STRING) {
throw BinderException{"Show connection can only bind to String!"};
}
auto tableName = input->inputs[0].getValue<std::string>();
auto catalog = context->getCatalog();
auto tableID = catalog->getTableID(context->getTx(), tableName);
Expand All @@ -95,6 +105,10 @@ static std::unique_ptr<TableFuncBindData> bindFunc(ClientContext* context,
columnTypes.emplace_back(*LogicalType::STRING());
columnNames.emplace_back("destination table name");
columnTypes.emplace_back(*LogicalType::STRING());
columnNames.emplace_back("source table primary key");
columnTypes.emplace_back(*LogicalType::STRING());
columnNames.emplace_back("destination table primary key");
columnTypes.emplace_back(*LogicalType::STRING());
common::offset_t maxOffset = 1;
if (tableEntry->getTableType() == common::TableType::REL_GROUP) {
auto relGroupEntry = ku_dynamic_cast<TableCatalogEntry*, RelGroupCatalogEntry*>(tableEntry);
Expand All @@ -106,7 +120,7 @@ static std::unique_ptr<TableFuncBindData> bindFunc(ClientContext* context,

function_set ShowConnectionFunction::getFunctionSet() {
function_set functionSet;
functionSet.push_back(std::make_unique<TableFunction>("db_version", tableFunc, bindFunc,
functionSet.push_back(std::make_unique<TableFunction>(name, tableFunc, bindFunc,
initSharedState, initEmptyLocalState, std::vector<LogicalTypeID>{LogicalTypeID::STRING}));
return functionSet;
}
Expand Down
2 changes: 1 addition & 1 deletion src/parser/transform/transform_standalone_call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace parser {
std::unique_ptr<Statement> Transformer::transformStandaloneCall(
CypherParser::KU_StandaloneCallContext& ctx) {
auto optionName = transformSymbolicName(*ctx.oC_SymbolicName());
auto parameter = transformLiteral(*ctx.oC_Literal());
auto parameter = transformExpression(*ctx.oC_Expression());
return std::make_unique<StandaloneCall>(std::move(optionName), std::move(parameter));
}

Expand Down
4 changes: 2 additions & 2 deletions src/processor/map/map_standalone_call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ namespace processor {

std::unique_ptr<PhysicalOperator> PlanMapper::mapStandaloneCall(
planner::LogicalOperator* logicalOperator) {
auto logicalStandaloneCall = reinterpret_cast<LogicalStandaloneCall*>(logicalOperator);
auto logicalStandaloneCall = logicalOperator->constPtrCast<LogicalStandaloneCall>();
auto optionValue =
reinterpret_cast<binder::LiteralExpression*>(logicalStandaloneCall->getOptionValue().get());
logicalStandaloneCall->getOptionValue()->constPtrCast<binder::LiteralExpression>();
auto standaloneCallInfo = std::make_unique<StandaloneCallInfo>(
logicalStandaloneCall->getOption(), *optionValue->getValue());
return std::make_unique<StandaloneCall>(std::move(standaloneCallInfo),
Expand Down
2 changes: 1 addition & 1 deletion test/test_files/exceptions/binder/binder_error.test
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ Binder exception: Invalid option name: thread.
-LOG InvalidCallOptionValue
-STATEMENT CALL threads='abc'
---- error
Binder exception: abc has data type STRING but INT64 was expected.
Binder exception: Expression abc has data type STRING but expected INT64. Implicit cast is not supported.

-LOG AllShortestPathInvalidLowerBound
-STATEMENT MATCH p = (a)-[* ALL SHORTEST 2..3]-(b) RETURN p
Expand Down
20 changes: 15 additions & 5 deletions test/test_files/tinysnb/call/call.test
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
-STATEMENT CALL current_setting('timeout') RETURN *
---- 1
20000
-STATEMENT CALL timeout=(1+2+3)*10000
---- ok
-STATEMENT CALL current_setting('timeout') RETURN *
---- 1
60000

-LOG SetGetVarLengthMaxDepth
-STATEMENT CALL var_length_extend_max_depth=10
Expand All @@ -53,6 +58,11 @@ True
-STATEMENT CALL current_setting('progress_bar') RETURN *
---- 1
False
-STATEMENT CALL progress_bar=CASE WHEN 1<2 THEN True ELSE False END
---- ok
-STATEMENT CALL current_setting('progress_bar') RETURN *
---- 1
True

-LOG SetGetProgressBarTime
-STATEMENT CALL progress_bar_time=4000
Expand Down Expand Up @@ -184,16 +194,16 @@ ${KUZU_VERSION}
-LOG ReturnTableConnection
-STATEMENT CALL show_connection('knows') RETURN *
---- 1
person|person
person|person|ID|ID
-STATEMENT CALL show_connection('workAt') RETURN *
---- 1
person|organisation
person|organisation|ID|ID
-STATEMENT CREATE REL TABLE GROUP Knows1 (FROM person To person, FROM person to organisation, year INT64);
---- ok
-STATEMENT CALL show_connection('Knows1') RETURN *
---- 2
person|person
person|organisation
person|person|ID|ID
person|organisation|ID|ID
-STATEMENT CALL show_connection('person') RETURN *
---- error
Binder exception: Show connection can only be called on a rel table!
Expand All @@ -207,7 +217,7 @@ Binder exception: Cannot match a built-in function for given function table_info
-LOG WrongParameterType
-STATEMENT CALL show_connection(123) RETURN *
---- error
Binder exception: Show connection can only bind to String!
Binder exception: 123 has data type INT64 but STRING was expected.

-LOG WrongParameterExprType
-STATEMENT CALL show_connection(upper("person")) RETURN *
Expand Down
8 changes: 4 additions & 4 deletions third_party/antlr4_cypher/cypher_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ void cypherParserInitialize() {
0,519,520,5,49,0,0,520,521,5,150,0,0,521,523,3,286,143,0,522,524,5,150,
0,0,523,522,1,0,0,0,523,524,1,0,0,0,524,525,1,0,0,0,525,527,5,6,0,0,526,
528,5,150,0,0,527,526,1,0,0,0,527,528,1,0,0,0,528,529,1,0,0,0,529,530,
3,238,119,0,530,27,1,0,0,0,531,532,5,50,0,0,532,533,5,150,0,0,533,534,
3,190,95,0,530,27,1,0,0,0,531,532,5,50,0,0,532,533,5,150,0,0,533,534,
5,94,0,0,534,535,5,150,0,0,535,536,5,60,0,0,536,537,5,150,0,0,537,538,
3,284,142,0,538,539,5,150,0,0,539,540,5,125,0,0,540,541,5,150,0,0,541,
542,5,136,0,0,542,29,1,0,0,0,543,544,5,92,0,0,544,545,5,150,0,0,545,546,
Expand Down Expand Up @@ -2612,8 +2612,8 @@ CypherParser::OC_SymbolicNameContext* CypherParser::KU_StandaloneCallContext::oC
return getRuleContext<CypherParser::OC_SymbolicNameContext>(0);
}

CypherParser::OC_LiteralContext* CypherParser::KU_StandaloneCallContext::oC_Literal() {
return getRuleContext<CypherParser::OC_LiteralContext>(0);
CypherParser::OC_ExpressionContext* CypherParser::KU_StandaloneCallContext::oC_Expression() {
return getRuleContext<CypherParser::OC_ExpressionContext>(0);
}


Expand Down Expand Up @@ -2661,7 +2661,7 @@ CypherParser::KU_StandaloneCallContext* CypherParser::kU_StandaloneCall() {
match(CypherParser::SP);
}
setState(529);
oC_Literal();
oC_Expression();

}
catch (RecognitionException &e) {
Expand Down
2 changes: 1 addition & 1 deletion third_party/antlr4_cypher/include/cypher_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ class CypherParser : public antlr4::Parser {
std::vector<antlr4::tree::TerminalNode *> SP();
antlr4::tree::TerminalNode* SP(size_t i);
OC_SymbolicNameContext *oC_SymbolicName();
OC_LiteralContext *oC_Literal();
OC_ExpressionContext *oC_Expression();


};
Expand Down

0 comments on commit d9dcc0a

Please sign in to comment.