Skip to content

Commit

Permalink
add node label function
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Jan 11, 2023
1 parent bd09081 commit bc59bca
Show file tree
Hide file tree
Showing 29 changed files with 421 additions and 435 deletions.
26 changes: 24 additions & 2 deletions src/binder/bind/bind_projection_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,31 @@ expression_vector Binder::bindProjectionExpressions(
}

expression_vector Binder::rewriteNodeOrRelExpression(const Expression& expression) {
if (expression.dataType.typeID == common::NODE) {
return rewriteNodeExpression(expression);
} else {
assert(expression.dataType.typeID == common::REL);
return rewriteRelExpression(expression);
}
}

expression_vector Binder::rewriteNodeExpression(const kuzu::binder::Expression& expression) {
expression_vector result;
auto& node = (NodeExpression&)expression;
result.push_back(node.getInternalIDProperty());
result.push_back(expressionBinder.bindNodeLabelFunction(node));
for (auto& property : node.getPropertyExpressions()) {
result.push_back(property->copy());
}
return result;
}

expression_vector Binder::rewriteRelExpression(const Expression& expression) {
expression_vector result;
auto& nodeOrRel = (NodeOrRelExpression&)expression;
for (auto& property : nodeOrRel.getPropertyExpressions()) {
auto& rel = (RelExpression&)expression;
result.push_back(rel.getSrcNode()->getInternalIDProperty());
result.push_back(rel.getDstNode()->getInternalIDProperty());
for (auto& property : rel.getPropertyExpressions()) {
result.push_back(property->copy());
}
return result;
Expand Down
53 changes: 49 additions & 4 deletions src/binder/expression_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "binder/expression/rel_expression.h"
#include "common/type_utils.h"
#include "function/boolean/vector_boolean_operations.h"
#include "function/node/vector_node_operations.h"
#include "function/null/vector_null_operations.h"
#include "parser/expression/parsed_case_expression.h"
#include "parser/expression/parsed_function_expression.h"
Expand Down Expand Up @@ -208,6 +209,12 @@ shared_ptr<Expression> ExpressionBinder::bindFunctionExpression(
auto& parsedFunctionExpression = (ParsedFunctionExpression&)parsedExpression;
auto functionName = parsedFunctionExpression.getFunctionName();
StringUtils::toUpper(functionName);
// check for special function binding
if (functionName == ID_FUNC_NAME) {
return bindInternalIDExpression(parsedExpression);
} else if (functionName == LABEL_FUNC_NAME) {
return bindLabelFunction(parsedExpression);
}
auto functionType = binder->catalog.getFunctionType(functionName);
if (functionType == FUNCTION) {
return bindScalarFunctionExpression(parsedExpression, functionName);
Expand Down Expand Up @@ -285,13 +292,11 @@ shared_ptr<Expression> ExpressionBinder::staticEvaluate(const string& functionNa
auto strVal = ((LiteralExpression*)children[0].get())->literal->strVal;
return make_shared<LiteralExpression>(DataType(TIMESTAMP),
make_unique<Literal>(Timestamp::FromCString(strVal.c_str(), strVal.length())));
} else if (functionName == CAST_TO_INTERVAL_FUNC_NAME) {
} else {
assert(functionName == CAST_TO_INTERVAL_FUNC_NAME);
auto strVal = ((LiteralExpression*)children[0].get())->literal->strVal;
return make_shared<LiteralExpression>(DataType(INTERVAL),
make_unique<Literal>(Interval::FromCString(strVal.c_str(), strVal.length())));
} else {
assert(functionName == ID_FUNC_NAME);
return bindInternalIDExpression(parsedExpression);
}
}

Expand Down Expand Up @@ -324,6 +329,46 @@ unique_ptr<Expression> ExpressionBinder::createInternalNodeIDExpression(
return result;
}

shared_ptr<Expression> ExpressionBinder::bindLabelFunction(
const ParsedExpression& parsedExpression) {
// bind child node
auto child = bindExpression(*parsedExpression.getChild(0));
assert(child->dataType.typeID == common::NODE);
return bindNodeLabelFunction(*child);
}

shared_ptr<Expression> ExpressionBinder::bindNodeLabelFunction(const Expression& expression) {
auto catalogContent = binder->catalog.getReadOnlyVersion();
auto& node = (NodeExpression&)expression;
if (!node.isMultiLabeled()) {
auto labelName = catalogContent->getNodeTableSchema(node.getSingleTableID())->tableName;
return make_shared<LiteralExpression>(STRING, make_unique<Literal>(labelName));
}
// bind string node labels as list literal
auto nodeTableIDs = catalogContent->getNodeTableIDs();
table_id_t maxNodeTableID = *std::max_element(nodeTableIDs.begin(), nodeTableIDs.end());
vector<Literal> nodeLabels;
nodeLabels.resize(maxNodeTableID + 1);
for (auto i = 0; i < nodeLabels.size(); ++i) {
if (catalogContent->containNodeTable(i)) {
auto tableSchema = catalogContent->getNodeTableSchema(i);
nodeLabels[i] = Literal(tableSchema->tableName);
} else {
// TODO(Xiyang/Guodong): change to null literal once we support null in LIST type.
nodeLabels[i] = Literal(string(""));
}
}
auto literalDataType = DataType(LIST, make_unique<DataType>(STRING));
expression_vector children;
children.push_back(node.getInternalIDProperty());
children.push_back(make_shared<LiteralExpression>(
literalDataType, make_unique<Literal>(nodeLabels, literalDataType)));
auto execFunc = NodeLabelVectorOperation::execFunction;
auto uniqueExpressionName = ScalarFunctionExpression::getUniqueName(LABEL_FUNC_NAME, children);
return make_shared<ScalarFunctionExpression>(
FUNCTION, DataType(STRING), std::move(children), execFunc, nullptr, uniqueExpressionName);
}

shared_ptr<Expression> ExpressionBinder::bindParameterExpression(
const ParsedExpression& parsedExpression) {
auto& parsedParameterExpression = (ParsedParameterExpression&)parsedExpression;
Expand Down
18 changes: 0 additions & 18 deletions src/catalog/catalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,24 +239,6 @@ table_id_t CatalogContent::addRelTableSchema(string tableName, RelMultiplicity r
return tableID;
}

bool CatalogContent::containNodeProperty(table_id_t tableID, const string& propertyName) const {
for (auto& property : nodeTableSchemas.at(tableID)->properties) {
if (propertyName == property.name) {
return true;
}
}
return false;
}

bool CatalogContent::containRelProperty(table_id_t tableID, const string& propertyName) const {
for (auto& property : relTableSchemas.at(tableID)->properties) {
if (propertyName == property.name) {
return true;
}
}
return false;
}

const Property& CatalogContent::getNodeProperty(
table_id_t tableID, const string& propertyName) const {
for (auto& property : nodeTableSchemas.at(tableID)->properties) {
Expand Down
52 changes: 52 additions & 0 deletions src/common/vector/value_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,57 @@ template void ValueVector::setValue<interval_t>(uint32_t pos, interval_t val);
template void ValueVector::setValue<ku_string_t>(uint32_t pos, ku_string_t val);
template void ValueVector::setValue<ku_list_t>(uint32_t pos, ku_list_t val);

void ValueVector::setLiteral(uint32_t pos, const common::Literal& literal) {
assert(dataType == literal.dataType);
if (literal.isNull()) {
setNull(pos, true);
return;
}
auto size = Types::getDataTypeSize(dataType);
copyLiteral(getData() + size * pos, literal);
}

void ValueVector::copyLiteral(uint8_t* dest, const common::Literal& literal) {
auto size = Types::getDataTypeSize(literal.dataType);
switch (literal.dataType.typeID) {
case INT64: {
memcpy(dest, &literal.val.int64Val, size);
} break;
case DOUBLE: {
memcpy(dest, &literal.val.doubleVal, size);
} break;
case BOOL: {
memcpy(dest, &literal.val.booleanVal, size);
} break;
case DATE: {
memcpy(dest, &literal.val.dateVal, size);
} break;
case TIMESTAMP: {
memcpy(dest, &literal.val.timestampVal, size);
} break;
case INTERVAL: {
memcpy(dest, &literal.val.intervalVal, size);
} break;
case STRING: {
InMemOverflowBufferUtils::copyString(literal.strVal.data(), literal.strVal.length(),
*(ku_string_t*)dest, getOverflowBuffer());
} break;
case LIST: {
auto& entry = *(ku_list_t*)dest;
auto numElements = literal.listVal.size();
auto elementSize = Types::getDataTypeSize(*dataType.childType);
InMemOverflowBufferUtils::allocateSpaceForList(
entry, numElements * elementSize, getOverflowBuffer());
entry.size = numElements;
for (auto i = 0u; i < numElements; ++i) {
copyLiteral((uint8_t*)entry.overflowPtr + i * elementSize, literal.listVal[i]);
}
} break;
default:
throw NotImplementedException(
"Unimplemented setLiteral() for type " + Types::dataTypeToString(dataType));
}
}

} // namespace common
} // namespace kuzu
33 changes: 0 additions & 33 deletions src/common/vector/value_vector_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,39 +5,6 @@
using namespace kuzu;
using namespace common;

void ValueVectorUtils::addLiteralToValueVector(
ValueVector& resultVector, uint64_t pos, const Literal& literal) {
if (literal.isNull()) {
resultVector.setNull(pos, true);
return;
}
switch (literal.dataType.typeID) {
case INT64: {
resultVector.setValue(pos, literal.val.int64Val);
} break;
case DOUBLE: {
resultVector.setValue(pos, literal.val.doubleVal);
} break;
case BOOL: {
resultVector.setValue(pos, literal.val.booleanVal);
} break;
case DATE: {
resultVector.setValue(pos, literal.val.dateVal);
} break;
case TIMESTAMP: {
resultVector.setValue(pos, literal.val.timestampVal);
} break;
case INTERVAL: {
resultVector.setValue(pos, literal.val.intervalVal);
} break;
case STRING: {
resultVector.setValue(pos, literal.strVal);
} break;
default:
assert(false);
}
}

void ValueVectorUtils::copyNonNullDataWithSameTypeIntoPos(
ValueVector& resultVector, uint64_t pos, const uint8_t* srcData) {
copyNonNullDataWithSameType(resultVector.dataType, srcData,
Expand Down
2 changes: 1 addition & 1 deletion src/expression_evaluator/literal_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ bool LiteralExpressionEvaluator::select(SelectionVector& selVector) {
void LiteralExpressionEvaluator::resolveResultVector(
const ResultSet& resultSet, MemoryManager* memoryManager) {
resultVector = make_shared<ValueVector>(literal->dataType, memoryManager);
ValueVectorUtils::addLiteralToValueVector(*resultVector, 0, *literal);
resultVector->setLiteral(0, *literal);
resultVector->state = DataChunkState::getSingleValueDataChunkState();
}

Expand Down
3 changes: 0 additions & 3 deletions src/function/built_in_vector_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ void BuiltInVectorOperations::registerVectorOperations() {

bool BuiltInVectorOperations::canApplyStaticEvaluation(
const string& functionName, const expression_vector& children) {
if (functionName == ID_FUNC_NAME) {
return true; // bind as property
}
if ((functionName == CAST_TO_DATE_FUNC_NAME || functionName == CAST_TO_TIMESTAMP_FUNC_NAME ||
functionName == CAST_TO_INTERVAL_FUNC_NAME) &&
children[0]->expressionType == LITERAL && children[0]->dataType.typeID == STRING) {
Expand Down
2 changes: 2 additions & 0 deletions src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ class Binder {
const vector<unique_ptr<ParsedExpression>>& projectionExpressions, bool containsStar);
// Rewrite variable "v" as all properties of "v"
expression_vector rewriteNodeOrRelExpression(const Expression& expression);
expression_vector rewriteNodeExpression(const Expression& expression);
expression_vector rewriteRelExpression(const Expression& expression);

void bindOrderBySkipLimitIfNecessary(
BoundProjectionBody& boundProjectionBody, const ProjectionBody& projectionBody);
Expand Down
4 changes: 2 additions & 2 deletions src/include/binder/expression_binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ class ExpressionBinder {

shared_ptr<Expression> bindNullOperatorExpression(const ParsedExpression& parsedExpression);

// bind to an existing property expression.
shared_ptr<Expression> bindPropertyExpression(const ParsedExpression& parsedExpression);
// bind to an existing property expression of given node table.
shared_ptr<Expression> bindNodePropertyExpression(
const Expression& expression, const string& propertyName);
shared_ptr<Expression> bindRelPropertyExpression(
Expand All @@ -56,6 +54,8 @@ class ExpressionBinder {
shared_ptr<Expression> bindInternalIDExpression(const ParsedExpression& parsedExpression);
shared_ptr<Expression> bindInternalIDExpression(const Expression& expression);
unique_ptr<Expression> createInternalNodeIDExpression(const Expression& node);
shared_ptr<Expression> bindLabelFunction(const ParsedExpression& parsedExpression);
shared_ptr<Expression> bindNodeLabelFunction(const Expression& expression);

shared_ptr<Expression> bindParameterExpression(const ParsedExpression& parsedExpression);

Expand Down
39 changes: 23 additions & 16 deletions src/include/catalog/catalog.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,49 +44,56 @@ class CatalogContent {
const vector<PropertyNameDataType>& propertyDefinitions,
vector<pair<table_id_t, table_id_t>> srcDstTableIDs);

virtual inline string getNodeTableName(table_id_t tableID) const {
inline bool containNodeTable(table_id_t tableID) const {
return nodeTableSchemas.contains(tableID);
}
inline bool containRelTable(table_id_t tableID) const {
return relTableSchemas.contains(tableID);
}

inline string getNodeTableName(table_id_t tableID) const {
assert(containNodeTable(tableID));
return nodeTableSchemas.at(tableID)->tableName;
}
virtual inline string getRelTableName(table_id_t tableID) const {
inline string getRelTableName(table_id_t tableID) const {
assert(containRelTable(tableID));
return relTableSchemas.at(tableID)->tableName;
}

inline NodeTableSchema* getNodeTableSchema(table_id_t tableID) const {
assert(containNodeTable(tableID));
return nodeTableSchemas.at(tableID).get();
}
virtual inline RelTableSchema* getRelTableSchema(table_id_t tableID) const {
inline RelTableSchema* getRelTableSchema(table_id_t tableID) const {
assert(containRelTable(tableID));
return relTableSchemas.at(tableID).get();
}

virtual inline bool containNodeTable(const string& tableName) const {
return end(nodeTableNameToIDMap) != nodeTableNameToIDMap.find(tableName);
inline bool containNodeTable(const string& tableName) const {
return nodeTableNameToIDMap.contains(tableName);
}
virtual inline bool containRelTable(const string& tableName) const {
return end(relTableNameToIDMap) != relTableNameToIDMap.find(tableName);
inline bool containRelTable(const string& tableName) const {
return relTableNameToIDMap.contains(tableName);
}

virtual inline table_id_t getNodeTableIDFromName(const string& tableName) const {
inline table_id_t getNodeTableIDFromName(const string& tableName) const {
return nodeTableNameToIDMap.at(tableName);
}
virtual inline table_id_t getRelTableIDFromName(const string& tableName) const {
inline table_id_t getRelTableIDFromName(const string& tableName) const {
return relTableNameToIDMap.at(tableName);
}

virtual inline bool isSingleMultiplicityInDirection(
table_id_t tableID, RelDirection direction) const {
inline bool isSingleMultiplicityInDirection(table_id_t tableID, RelDirection direction) const {
return relTableSchemas.at(tableID)->isSingleMultiplicityInDirection(direction);
}

/**
* Node and Rel property functions.
*/
virtual bool containNodeProperty(table_id_t tableID, const string& propertyName) const;
virtual bool containRelProperty(table_id_t tableID, const string& propertyName) const;

// getNodeProperty and getRelProperty should be called after checking if property exists
// (containNodeProperty and containRelProperty).
virtual const Property& getNodeProperty(table_id_t tableID, const string& propertyName) const;
virtual const Property& getRelProperty(table_id_t tableID, const string& propertyName) const;
const Property& getNodeProperty(table_id_t tableID, const string& propertyName) const;
const Property& getRelProperty(table_id_t tableID, const string& propertyName) const;

vector<Property> getAllNodeProperties(table_id_t tableID) const;
inline const vector<Property>& getRelProperties(table_id_t tableID) const {
Expand Down
2 changes: 2 additions & 0 deletions src/include/common/expression_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ const string TO_SECONDS_FUNC_NAME = "TO_SECONDS";
const string TO_MILLISECONDS_FUNC_NAME = "TO_MILLISECONDS";
const string TO_MICROSECONDS_FUNC_NAME = "TO_MICROSECONDS";

// Node/Rel functions.
const string ID_FUNC_NAME = "ID";
const string LABEL_FUNC_NAME = "LABEL";

enum ExpressionType : uint8_t {

Expand Down
1 change: 0 additions & 1 deletion src/include/common/in_mem_overflow_buffer_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class InMemOverflowBufferUtils {
reinterpret_cast<T*>(result.overflowPtr)[elementPos] = element;
}

private:
static inline void allocateSpaceForList(
ku_list_t& list, uint64_t numBytes, InMemOverflowBuffer& buffer) {
list.overflowPtr = reinterpret_cast<uint64_t>(buffer.allocateSpace(numBytes));
Expand Down
Loading

0 comments on commit bc59bca

Please sign in to comment.