Skip to content

Commit

Permalink
Merge pull request #1230 from kuzudb/rel-label-function
Browse files Browse the repository at this point in the history
Rel label function
  • Loading branch information
andyfengHKU committed Feb 3, 2023
2 parents 7dec459 + 6b02213 commit c3d2b8c
Show file tree
Hide file tree
Showing 16 changed files with 135 additions and 95 deletions.
1 change: 1 addition & 0 deletions src/binder/bind/bind_projection_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ expression_vector Binder::rewriteRelExpression(const Expression& expression) {
auto& rel = (RelExpression&)expression;
result.push_back(rel.getSrcNode()->getInternalIDProperty());
result.push_back(rel.getDstNode()->getInternalIDProperty());
result.push_back(expressionBinder.bindRelLabelFunction(rel));
for (auto& property : rel.getPropertyExpressions()) {
result.push_back(property->copy());
}
Expand Down
69 changes: 48 additions & 21 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#include "binder/expression/function_expression.h"
#include "binder/expression/literal_expression.h"
#include "binder/expression_binder.h"
#include "function/node/vector_node_operations.h"
#include "function/schema/vector_label_operations.h"
#include "parser/expression/parsed_function_expression.h"

namespace kuzu {
Expand Down Expand Up @@ -135,37 +135,64 @@ shared_ptr<Expression> ExpressionBinder::bindLabelFunction(
const ParsedExpression& parsedExpression) {
// bind child node
auto child = bindExpression(*parsedExpression.getChild(0));
assert(child->dataType.typeID == common::NODE);
return bindNodeLabelFunction(*child);
if (child->dataType.typeID == common::NODE) {
return bindNodeLabelFunction(*child);
} else {
assert(child->dataType.typeID == common::REL);
return bindRelLabelFunction(*child);
}
}

static vector<unique_ptr<Value>> populateLabelValues(
vector<table_id_t> tableIDs, const CatalogContent& catalogContent) {
auto tableIDsSet = unordered_set<table_id_t>(tableIDs.begin(), tableIDs.end());
table_id_t maxTableID = *std::max_element(tableIDsSet.begin(), tableIDsSet.end());
vector<unique_ptr<Value>> labels;
labels.resize(maxTableID + 1);
for (auto i = 0; i < labels.size(); ++i) {
if (tableIDsSet.contains(i)) {
labels[i] = make_unique<Value>(catalogContent.getTableName(i));
} else {
// TODO(Xiyang/Guodong): change to null literal once we support null in LIST type.
labels[i] = make_unique<Value>(string(""));
}
}
return labels;
}

shared_ptr<Expression> ExpressionBinder::bindNodeLabelFunction(const Expression& expression) {
auto catalogContent = binder->catalog.getReadOnlyVersion();
auto& node = (NodeExpression&)expression;
if (!node.isMultiLabeled()) {
auto labelName = catalogContent->getTableName(node.getSingleTableID());
auto value = make_unique<Value>(labelName);
return createLiteralExpression(std::move(value));
return createLiteralExpression(make_unique<Value>(labelName));
}
// bind string node labels as list literal
auto nodeTableIDs = catalogContent->getNodeTableIDs();
table_id_t maxNodeTableID = *std::max_element(nodeTableIDs.begin(), nodeTableIDs.end());
vector<unique_ptr<Value>> nodeLabels;
nodeLabels.resize(maxNodeTableID + 1);
for (auto i = 0; i < nodeLabels.size(); ++i) {
if (catalogContent->containNodeTable(i)) {
nodeLabels[i] = make_unique<Value>(catalogContent->getTableName(i));
} else {
// TODO(Xiyang/Guodong): change to null literal once we support null in LIST type.
nodeLabels[i] = make_unique<Value>(string(""));
}
}
auto literalDataType = DataType(LIST, make_unique<DataType>(STRING));
expression_vector children;
children.push_back(node.getInternalIDProperty());
auto value = make_unique<Value>(literalDataType, std::move(nodeLabels));
children.push_back(createLiteralExpression(std::move(value)));
auto execFunc = NodeLabelVectorOperation::execFunction;
auto labelsValue = make_unique<Value>(DataType(LIST, make_unique<DataType>(STRING)),
populateLabelValues(nodeTableIDs, *catalogContent));
children.push_back(createLiteralExpression(std::move(labelsValue)));
auto execFunc = LabelVectorOperation::execFunction;
auto uniqueExpressionName = ScalarFunctionExpression::getUniqueName(LABEL_FUNC_NAME, children);
return make_shared<ScalarFunctionExpression>(
FUNCTION, DataType(STRING), std::move(children), execFunc, nullptr, uniqueExpressionName);
}

shared_ptr<Expression> ExpressionBinder::bindRelLabelFunction(const Expression& expression) {
auto catalogContent = binder->catalog.getReadOnlyVersion();
auto& rel = (RelExpression&)expression;
if (!rel.isMultiLabeled()) {
auto labelName = catalogContent->getTableName(rel.getSingleTableID());
return createLiteralExpression(make_unique<Value>(labelName));
}
auto relTableIDs = catalogContent->getRelTableIDs();
expression_vector children;
children.push_back(rel.getInternalIDProperty());
auto labelsValue = make_unique<Value>(DataType(LIST, make_unique<DataType>(STRING)),
populateLabelValues(relTableIDs, *catalogContent));
children.push_back(createLiteralExpression(std::move(labelsValue)));
auto execFunc = LabelVectorOperation::execFunction;
auto uniqueExpressionName = ScalarFunctionExpression::getUniqueName(LABEL_FUNC_NAME, children);
return make_shared<ScalarFunctionExpression>(
FUNCTION, DataType(STRING), std::move(children), execFunc, nullptr, uniqueExpressionName);
Expand Down
11 changes: 8 additions & 3 deletions src/common/types/value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,8 @@ string NodeVal::getLabelName() const {

string NodeVal::toString() const {
std::string result = "(";
result += idVal->toString();
result += ":" + labelVal->toString() + " ";
result += "label:" + labelVal->toString() + ", ";
result += idVal->toString() + ", ";
result += propertiesToString(properties);
result += ")";
return result;
Expand All @@ -284,6 +284,7 @@ string NodeVal::toString() const {
RelVal::RelVal(const RelVal& other) {
srcNodeIDVal = other.srcNodeIDVal->copy();
dstNodeIDVal = other.dstNodeIDVal->copy();
labelVal = other.labelVal->copy();
for (auto& [key, val] : other.properties) {
addProperty(key, val->copy());
}
Expand All @@ -297,10 +298,14 @@ nodeID_t RelVal::getDstNodeID() const {
return dstNodeIDVal->getValue<nodeID_t>();
}

string RelVal::getLabelName() {
return labelVal->getValue<string>();
}

string RelVal::toString() const {
std::string result;
result += "(" + srcNodeIDVal->toString() + ")";
result += "-[" + propertiesToString(properties) + "]->";
result += "-[label:" + labelVal->toString() + ", " + propertiesToString(properties) + "]->";
result += "(" + dstNodeIDVal->toString() + ")";
return result;
}
Expand Down
1 change: 1 addition & 0 deletions src/include/binder/expression_binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class ExpressionBinder {
unique_ptr<Expression> createInternalNodeIDExpression(const Expression& node);
shared_ptr<Expression> bindLabelFunction(const ParsedExpression& parsedExpression);
shared_ptr<Expression> bindNodeLabelFunction(const Expression& expression);
shared_ptr<Expression> bindRelLabelFunction(const Expression& expression);

shared_ptr<Expression> bindParameterExpression(const ParsedExpression& parsedExpression);

Expand Down
12 changes: 8 additions & 4 deletions src/include/common/types/value.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,10 @@ class NodeVal {

class RelVal {
public:
RelVal(unique_ptr<Value> srcNodeIDVal, unique_ptr<Value> dstNodeIDVal)
: srcNodeIDVal{std::move(srcNodeIDVal)}, dstNodeIDVal{std::move(dstNodeIDVal)} {}
RelVal(
unique_ptr<Value> srcNodeIDVal, unique_ptr<Value> dstNodeIDVal, unique_ptr<Value> labelVal)
: srcNodeIDVal{std::move(srcNodeIDVal)},
dstNodeIDVal{std::move(dstNodeIDVal)}, labelVal{std::move(labelVal)} {}
RelVal(const RelVal& other);

inline void addProperty(const std::string& key, unique_ptr<Value> value) {
Expand All @@ -146,14 +148,16 @@ class RelVal {
inline Value* getSrcNodeIDVal() { return srcNodeIDVal.get(); }
inline Value* getDstNodeIDVal() { return dstNodeIDVal.get(); }

inline unique_ptr<RelVal> copy() const { return make_unique<RelVal>(*this); }

nodeID_t getSrcNodeID() const;
nodeID_t getDstNodeID() const;
string getLabelName();

string toString() const;

inline unique_ptr<RelVal> copy() const { return make_unique<RelVal>(*this); }

private:
unique_ptr<Value> labelVal;
unique_ptr<Value> srcNodeIDVal;
unique_ptr<Value> dstNodeIDVal;
vector<pair<std::string, unique_ptr<Value>>> properties;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace operation {

struct Label {
static inline void operation(
nodeID_t& left, ku_list_t& right, ku_string_t& result, ValueVector& resultVector) {
internalID_t& left, ku_list_t& right, ku_string_t& result, ValueVector& resultVector) {
assert(left.tableID < right.size);
auto& value = ((ku_string_t*)right.overflowPtr)[left.tableID];
if (!ku_string_t::isShortString(value.len)) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
#pragma once

#include "function/vector_operations.h"
#include "node_operations.h"
#include "label_operations.h"

namespace kuzu {
namespace function {

struct VectorNodeOperations : public VectorOperations {};

struct NodeLabelVectorOperation : public VectorNodeOperations {
struct LabelVectorOperation : public VectorOperations {
static void execFunction(const vector<shared_ptr<ValueVector>>& params, ValueVector& result) {
assert(params.size() == 2);
BinaryOperationExecutor::executeStringAndList<nodeID_t, ku_list_t, ku_string_t,
BinaryOperationExecutor::executeStringAndList<internalID_t, ku_list_t, ku_string_t,
operation::Label>(*params[0], *params[1], result);
}
};
Expand Down
8 changes: 6 additions & 2 deletions src/main/query_result.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,12 @@ void QueryResult::initResultTableAndIterator(
auto dstNodeIDVal =
make_unique<Value>(Value::createDefaultValue(DataType(INTERNAL_ID)));
valuesToCollect.push_back(dstNodeIDVal.get());
auto relVal = make_unique<RelVal>(std::move(srcNodeIDVal), std::move(dstNodeIDVal));
for (auto j = 2u; j < expressionsToCollect.size(); ++j) {
// third expression is rel label function.
auto labelNameVal = make_unique<Value>(Value::createDefaultValue(DataType(STRING)));
valuesToCollect.push_back(labelNameVal.get());
auto relVal = make_unique<RelVal>(
std::move(srcNodeIDVal), std::move(dstNodeIDVal), std::move(labelNameVal));
for (auto j = 3u; j < expressionsToCollect.size(); ++j) {
assert(expressionsToCollect[j]->expressionType == common::PROPERTY);
auto property = (PropertyExpression*)expressionsToCollect[j].get();
auto propertyValue =
Expand Down
8 changes: 4 additions & 4 deletions test/copy/arrow_node_copy_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using namespace kuzu::common;
using namespace kuzu::storage;
using namespace kuzu::testing;

class arrowNodeCopyTest : public DBTest {
class ArrowNodeCopyTest : public DBTest {
void SetUp() override {
BaseGraphTest::SetUp();
createDBAndConn();
Expand All @@ -15,17 +15,17 @@ class arrowNodeCopyTest : public DBTest {
}
};

TEST_F(arrowNodeCopyTest, arrowNodeCopyCSVTest) {
TEST_F(ArrowNodeCopyTest, ArrowNodeCopyCSVTest) {
initGraphFromPath(TestHelper::appendKuzuRootPath("dataset/copy-test/node/csv/"));
runTest(TestHelper::appendKuzuRootPath("test/test_files/copy/copy_node.test"));
}

TEST_F(arrowNodeCopyTest, arrowNodeCopyArrowTest) {
TEST_F(ArrowNodeCopyTest, ArrowNodeCopyArrowTest) {
initGraphFromPath(TestHelper::appendKuzuRootPath("dataset/copy-test/node/arrow/"));
runTest(TestHelper::appendKuzuRootPath("test/test_files/copy/copy_node.test"));
}

TEST_F(arrowNodeCopyTest, arrowNodeCopyParquetTest) {
TEST_F(ArrowNodeCopyTest, ArrowNodeCopyParquetTest) {
initGraphFromPath(TestHelper::appendKuzuRootPath("dataset/copy-test/node/parquet/"));
runTest(TestHelper::appendKuzuRootPath("test/test_files/copy/copy_node.test"));
}
10 changes: 5 additions & 5 deletions test/demo_db/demo_db_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ TEST_F(DemoDBTest, CreateAvgNullTest) {
auto result =
conn->query("MATCH (a:User) WITH a, avg(a.age) AS b, SUM(a.age) AS c, COUNT(a.age) AS d, "
"COUNT(*) AS e RETURN a, b, c,d, e ORDER BY c DESC");
auto groundTruth = vector<string>{"(0:4:User {name:Alice, age:})|||0|1",
"(0:2:User {name:Zhang, age:50})|50.000000|50|1|1",
"(0:1:User {name:Karissa, age:40})|40.000000|40|1|1",
"(0:0:User {name:Adam, age:30})|30.000000|30|1|1",
"(0:3:User {name:Noura, age:25})|25.000000|25|1|1"};
auto groundTruth = vector<string>{"(label:User, 0:4, {name:Alice, age:})|||0|1",
"(label:User, 0:2, {name:Zhang, age:50})|50.000000|50|1|1",
"(label:User, 0:1, {name:Karissa, age:40})|40.000000|40|1|1",
"(label:User, 0:0, {name:Adam, age:30})|30.000000|30|1|1",
"(label:User, 0:3, {name:Noura, age:25})|25.000000|25|1|1"};
ASSERT_EQ(TestHelper::convertResultToString(*result, true /* check order */), groundTruth);
}

Expand Down
4 changes: 2 additions & 2 deletions test/runner/e2e_ddl_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ class TinySnbDDLTest : public DBTest {
auto result = conn->query("MATCH (p:person) RETURN * ORDER BY p.ID LIMIT 1");
ASSERT_EQ(TestHelper::convertResultToString(*result),
vector<string>{
"(0:0:person {ID:0, fName:Alice, isStudent:True, isWorker:False, age:35, "
"(label:person, 0:0, {ID:0, fName:Alice, isStudent:True, isWorker:False, age:35, "
"eyeSight:5.000000, birthdate:1900-01-01, registerTime:2011-08-20 11:25:30, "
"lastJobDuration:3 years 2 days 13:02:00, workedHours:[10,5], "
"usedNames:[Aida], courseScoresPerTerm:[[10,8],[6,7,8]]})"});
Expand Down Expand Up @@ -338,7 +338,7 @@ class TinySnbDDLTest : public DBTest {
auto result = conn->query(
"MATCH (:person)-[s:studyAt]->(:organisation) RETURN * ORDER BY s.year DESC LIMIT 1");
ASSERT_EQ(TestHelper::convertResultToString(*result),
vector<string>{"(0:0)-[{_id:4:0, year:2021}]->(1:0)"});
vector<string>{"(0:0)-[label:studyAt, {_id:4:0, year:2021}]->(1:0)"});
}

void ddlStatementsInsideActiveTransactionErrorTest(string query) {
Expand Down
10 changes: 5 additions & 5 deletions test/runner/e2e_update_node_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,9 @@ TEST_F(TinySnbUpdateTest, InsertSingleNToNRelTest) {
"MATCH (a:person), (b:person) WHERE a.ID = 9 AND b.ID = 10 "
"CREATE (a)-[:knows {meetTime:timestamp('1976-12-23 11:21:42'), validInterval:interval('2 "
"years'), comments:['A', 'k'], date:date('1997-03-22')}]->(b);");
auto groundTruth =
vector<string>{"9|10|(0:6)-[{_id:3:14, date:1997-03-22, meetTime:1976-12-23 11:21:42, "
"validInterval:2 years, comments:[A,k]}]->(0:7)|3:14"};
auto groundTruth = vector<string>{
"9|10|(0:6)-[label:knows, {_id:3:14, date:1997-03-22, meetTime:1976-12-23 11:21:42, "
"validInterval:2 years, comments:[A,k]}]->(0:7)|3:14"};
auto result = conn->query(
"MATCH (a:person)-[e:knows]->(b:person) WHERE a.ID > 8 RETURN a.ID, b.ID, e, ID(e)");
ASSERT_EQ(TestHelper::convertResultToString(*result), groundTruth);
Expand All @@ -245,9 +245,9 @@ TEST_F(TinySnbUpdateTest, InsertSingleNTo1RelTest) {
// insert studyAt edge between Greg and CsWork
conn->query("MATCH (a:person), (b:organisation) WHERE a.ID = 9 AND b.orgCode = 934 "
"CREATE (a)-[:studyAt {year:2022}]->(b);");
auto groundTruth = vector<string>{"8|325|(0:5)-[{_id:4:2, year:2020, "
auto groundTruth = vector<string>{"8|325|(0:5)-[label:studyAt, {_id:4:2, year:2020, "
"places:[awndsnjwejwen,isuhuwennjnuhuhuwewe]}]->(1:0)|4:2",
"9|934|(0:6)-[{_id:4:3, year:2022, places:}]->(1:1)|4:3"};
"9|934|(0:6)-[label:studyAt, {_id:4:3, year:2022, places:}]->(1:1)|4:3"};
auto result = conn->query("MATCH (a:person)-[e:studyAt]->(b:organisation) WHERE a.ID > 5 "
"RETURN a.ID, b.orgCode, e, ID(e)");
ASSERT_EQ(TestHelper::convertResultToString(*result), groundTruth);
Expand Down
14 changes: 7 additions & 7 deletions test/test_files/copy/copy_node.test
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
-NAME SubsetTest
-QUERY MATCH (row:tableOfTypes) WHERE row.id >= 20 AND row.id <= 24 RETURN *;
---- 5
(0:20:tableOfTypes {id:20, int64Column:0, doubleColumn:57.579280, booleanColumn:True, dateColumn:1731-09-26, timestampColumn:1731-09-26 03:30:08, stringColumn:OdM})
(0:21:tableOfTypes {id:21, int64Column:7, doubleColumn:64.630960, booleanColumn:False, dateColumn:1307-01-26, timestampColumn:1307-01-26 03:31:08, stringColumn:AjbxHQThEtDDlOjbzMjCQSXlvGQEjcFLykESrnFHwPKX})
(0:22:tableOfTypes {id:22, int64Column:71, doubleColumn:37.963386, booleanColumn:True, dateColumn:1455-07-26, timestampColumn:1455-07-26 03:07:03, stringColumn:dRvHHdyNXYfSUcicaxBoQEKQUfgex})
(0:23:tableOfTypes {id:23, int64Column:58, doubleColumn:42.774957, booleanColumn:False, dateColumn:1181-10-16, timestampColumn:1181-10-16 18:19:43, stringColumn:ISImRVpUjynGMFRQyYmeIUVjM})
(0:24:tableOfTypes {id:24, int64Column:75, doubleColumn:53.813224, booleanColumn:False, dateColumn:1942-10-24, timestampColumn:1942-10-24 09:30:16, stringColumn:naDlQ})
(label:tableOfTypes, 0:20, {id:20, int64Column:0, doubleColumn:57.579280, booleanColumn:True, dateColumn:1731-09-26, timestampColumn:1731-09-26 03:30:08, stringColumn:OdM})
(label:tableOfTypes, 0:21, {id:21, int64Column:7, doubleColumn:64.630960, booleanColumn:False, dateColumn:1307-01-26, timestampColumn:1307-01-26 03:31:08, stringColumn:AjbxHQThEtDDlOjbzMjCQSXlvGQEjcFLykESrnFHwPKX})
(label:tableOfTypes, 0:22, {id:22, int64Column:71, doubleColumn:37.963386, booleanColumn:True, dateColumn:1455-07-26, timestampColumn:1455-07-26 03:07:03, stringColumn:dRvHHdyNXYfSUcicaxBoQEKQUfgex})
(label:tableOfTypes, 0:23, {id:23, int64Column:58, doubleColumn:42.774957, booleanColumn:False, dateColumn:1181-10-16, timestampColumn:1181-10-16 18:19:43, stringColumn:ISImRVpUjynGMFRQyYmeIUVjM})
(label:tableOfTypes, 0:24, {id:24, int64Column:75, doubleColumn:53.813224, booleanColumn:False, dateColumn:1942-10-24, timestampColumn:1942-10-24 09:30:16, stringColumn:naDlQ})

-NAME CheckNumLinesTest
-QUERY MATCH (row:tableOfTypes) RETURN count(*)
Expand All @@ -31,7 +31,7 @@
-NAME EmptyStringTest
-QUERY MATCH (row:tableOfTypes) WHERE row.id = 49992 RETURN *;
---- 1
(0:49992:tableOfTypes {id:49992, int64Column:50, doubleColumn:31.582059, booleanColumn:False, dateColumn:1551-07-19, timestampColumn:1551-07-19 16:28:31, stringColumn:})
(label:tableOfTypes, 0:49992, {id:49992, int64Column:50, doubleColumn:31.582059, booleanColumn:False, dateColumn:1551-07-19, timestampColumn:1551-07-19 16:28:31, stringColumn:})

-NAME FloatTest
-QUERY MATCH (row:tableOfTypes) WHERE row.doubleColumn = 68.73718401556897 RETURN row.dateColumn;
Expand All @@ -41,7 +41,7 @@
-NAME DateTest
-QUERY MATCH (row:tableOfTypes) WHERE row.id = 25531 RETURN *;
---- 1
(0:25531:tableOfTypes {id:25531, int64Column:77, doubleColumn:28.417543, booleanColumn:False, dateColumn:1895-03-13, timestampColumn:1895-03-13 04:31:22, stringColumn:XB})
(label:tableOfTypes, 0:25531, {id:25531, int64Column:77, doubleColumn:28.417543, booleanColumn:False, dateColumn:1895-03-13, timestampColumn:1895-03-13 04:31:22, stringColumn:XB})

-NAME IntervalTest
-QUERY MATCH (row:tableOfTypes) WHERE 0 <= row.doubleColumn AND row.doubleColumn <= 10 AND 0 <= row.int64Column AND row.int64Column <= 10 RETURN count(*);
Expand Down
8 changes: 4 additions & 4 deletions test/test_files/demo_db/demo_db.test
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ Adam
-NAME Match1
-QUERY MATCH (a:User) RETURN a;
---- 4
(0:0:User {name:Adam, age:30})
(0:1:User {name:Karissa, age:40})
(0:2:User {name:Zhang, age:50})
(0:3:User {name:Noura, age:25})
(label:User, 0:0, {name:Adam, age:30})
(label:User, 0:1, {name:Karissa, age:40})
(label:User, 0:2, {name:Zhang, age:50})
(label:User, 0:3, {name:Noura, age:25})

-NAME OptionalMatch1
-QUERY MATCH (u:User) OPTIONAL MATCH (u)-[:Follows]->(u1:User) RETURN u.name, u1.name;
Expand Down
Loading

0 comments on commit c3d2b8c

Please sign in to comment.