Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return node and rel data type #1168

Merged
merged 1 commit into from
Jan 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions src/binder/bind/bind_projection_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,31 @@ expression_vector Binder::bindProjectionExpressions(
}

expression_vector Binder::rewriteNodeOrRelExpression(const Expression& expression) {
if (expression.dataType.typeID == common::NODE) {
return rewriteNodeExpression(expression);
} else {
assert(expression.dataType.typeID == common::REL);
return rewriteRelExpression(expression);
}
}

expression_vector Binder::rewriteNodeExpression(const kuzu::binder::Expression& expression) {
expression_vector result;
auto& node = (NodeExpression&)expression;
result.push_back(node.getInternalIDProperty());
result.push_back(expressionBinder.bindNodeLabelFunction(node));
for (auto& property : node.getPropertyExpressions()) {
result.push_back(property->copy());
}
return result;
}

expression_vector Binder::rewriteRelExpression(const Expression& expression) {
expression_vector result;
auto& nodeOrRel = (NodeOrRelExpression&)expression;
for (auto& property : nodeOrRel.getPropertyExpressions()) {
auto& rel = (RelExpression&)expression;
result.push_back(rel.getSrcNode()->getInternalIDProperty());
result.push_back(rel.getDstNode()->getInternalIDProperty());
for (auto& property : rel.getPropertyExpressions()) {
result.push_back(property->copy());
}
return result;
Expand Down
53 changes: 49 additions & 4 deletions src/binder/expression_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "binder/expression/rel_expression.h"
#include "common/type_utils.h"
#include "function/boolean/vector_boolean_operations.h"
#include "function/node/vector_node_operations.h"
#include "function/null/vector_null_operations.h"
#include "parser/expression/parsed_case_expression.h"
#include "parser/expression/parsed_function_expression.h"
Expand Down Expand Up @@ -208,6 +209,12 @@ shared_ptr<Expression> ExpressionBinder::bindFunctionExpression(
auto& parsedFunctionExpression = (ParsedFunctionExpression&)parsedExpression;
auto functionName = parsedFunctionExpression.getFunctionName();
StringUtils::toUpper(functionName);
// check for special function binding
if (functionName == ID_FUNC_NAME) {
return bindInternalIDExpression(parsedExpression);
} else if (functionName == LABEL_FUNC_NAME) {
return bindLabelFunction(parsedExpression);
}
auto functionType = binder->catalog.getFunctionType(functionName);
if (functionType == FUNCTION) {
return bindScalarFunctionExpression(parsedExpression, functionName);
Expand Down Expand Up @@ -285,13 +292,11 @@ shared_ptr<Expression> ExpressionBinder::staticEvaluate(const string& functionNa
auto strVal = ((LiteralExpression*)children[0].get())->literal->strVal;
return make_shared<LiteralExpression>(DataType(TIMESTAMP),
make_unique<Literal>(Timestamp::FromCString(strVal.c_str(), strVal.length())));
} else if (functionName == CAST_TO_INTERVAL_FUNC_NAME) {
} else {
assert(functionName == CAST_TO_INTERVAL_FUNC_NAME);
auto strVal = ((LiteralExpression*)children[0].get())->literal->strVal;
return make_shared<LiteralExpression>(DataType(INTERVAL),
make_unique<Literal>(Interval::FromCString(strVal.c_str(), strVal.length())));
} else {
assert(functionName == ID_FUNC_NAME);
return bindInternalIDExpression(parsedExpression);
}
}

Expand Down Expand Up @@ -324,6 +329,46 @@ unique_ptr<Expression> ExpressionBinder::createInternalNodeIDExpression(
return result;
}

shared_ptr<Expression> ExpressionBinder::bindLabelFunction(
const ParsedExpression& parsedExpression) {
// bind child node
auto child = bindExpression(*parsedExpression.getChild(0));
assert(child->dataType.typeID == common::NODE);
return bindNodeLabelFunction(*child);
}

shared_ptr<Expression> ExpressionBinder::bindNodeLabelFunction(const Expression& expression) {
auto catalogContent = binder->catalog.getReadOnlyVersion();
auto& node = (NodeExpression&)expression;
if (!node.isMultiLabeled()) {
auto labelName = catalogContent->getNodeTableSchema(node.getSingleTableID())->tableName;
return make_shared<LiteralExpression>(STRING, make_unique<Literal>(labelName));
}
// bind string node labels as list literal
auto nodeTableIDs = catalogContent->getNodeTableIDs();
table_id_t maxNodeTableID = *std::max_element(nodeTableIDs.begin(), nodeTableIDs.end());
vector<Literal> nodeLabels;
nodeLabels.resize(maxNodeTableID + 1);
for (auto i = 0; i < nodeLabels.size(); ++i) {
if (catalogContent->containNodeTable(i)) {
auto tableSchema = catalogContent->getNodeTableSchema(i);
nodeLabels[i] = Literal(tableSchema->tableName);
} else {
// TODO(Xiyang/Guodong): change to null literal once we support null in LIST type.
nodeLabels[i] = Literal(string(""));
}
}
auto literalDataType = DataType(LIST, make_unique<DataType>(STRING));
expression_vector children;
children.push_back(node.getInternalIDProperty());
children.push_back(make_shared<LiteralExpression>(
literalDataType, make_unique<Literal>(nodeLabels, literalDataType)));
auto execFunc = NodeLabelVectorOperation::execFunction;
auto uniqueExpressionName = ScalarFunctionExpression::getUniqueName(LABEL_FUNC_NAME, children);
return make_shared<ScalarFunctionExpression>(
FUNCTION, DataType(STRING), std::move(children), execFunc, nullptr, uniqueExpressionName);
}

shared_ptr<Expression> ExpressionBinder::bindParameterExpression(
const ParsedExpression& parsedExpression) {
auto& parsedParameterExpression = (ParsedParameterExpression&)parsedExpression;
Expand Down
18 changes: 0 additions & 18 deletions src/catalog/catalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,24 +239,6 @@ table_id_t CatalogContent::addRelTableSchema(string tableName, RelMultiplicity r
return tableID;
}

bool CatalogContent::containNodeProperty(table_id_t tableID, const string& propertyName) const {
for (auto& property : nodeTableSchemas.at(tableID)->properties) {
if (propertyName == property.name) {
return true;
}
}
return false;
}

bool CatalogContent::containRelProperty(table_id_t tableID, const string& propertyName) const {
for (auto& property : relTableSchemas.at(tableID)->properties) {
if (propertyName == property.name) {
return true;
}
}
return false;
}

const Property& CatalogContent::getNodeProperty(
table_id_t tableID, const string& propertyName) const {
for (auto& property : nodeTableSchemas.at(tableID)->properties) {
Expand Down
52 changes: 52 additions & 0 deletions src/common/vector/value_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,57 @@ template void ValueVector::setValue<interval_t>(uint32_t pos, interval_t val);
template void ValueVector::setValue<ku_string_t>(uint32_t pos, ku_string_t val);
template void ValueVector::setValue<ku_list_t>(uint32_t pos, ku_list_t val);

void ValueVector::setLiteral(uint32_t pos, const common::Literal& literal) {
assert(dataType == literal.dataType);
if (literal.isNull()) {
setNull(pos, true);
return;
}
auto size = Types::getDataTypeSize(dataType);
copyLiteral(getData() + size * pos, literal);
}

void ValueVector::copyLiteral(uint8_t* dest, const common::Literal& literal) {
auto size = Types::getDataTypeSize(literal.dataType);
switch (literal.dataType.typeID) {
case INT64: {
memcpy(dest, &literal.val.int64Val, size);
} break;
case DOUBLE: {
memcpy(dest, &literal.val.doubleVal, size);
} break;
case BOOL: {
memcpy(dest, &literal.val.booleanVal, size);
} break;
case DATE: {
memcpy(dest, &literal.val.dateVal, size);
} break;
case TIMESTAMP: {
memcpy(dest, &literal.val.timestampVal, size);
} break;
case INTERVAL: {
memcpy(dest, &literal.val.intervalVal, size);
} break;
case STRING: {
InMemOverflowBufferUtils::copyString(literal.strVal.data(), literal.strVal.length(),
*(ku_string_t*)dest, getOverflowBuffer());
} break;
case LIST: {
auto& entry = *(ku_list_t*)dest;
auto numElements = literal.listVal.size();
auto elementSize = Types::getDataTypeSize(*dataType.childType);
InMemOverflowBufferUtils::allocateSpaceForList(
entry, numElements * elementSize, getOverflowBuffer());
entry.size = numElements;
for (auto i = 0u; i < numElements; ++i) {
copyLiteral((uint8_t*)entry.overflowPtr + i * elementSize, literal.listVal[i]);
}
} break;
default:
throw NotImplementedException(
"Unimplemented setLiteral() for type " + Types::dataTypeToString(dataType));
}
}

} // namespace common
} // namespace kuzu
33 changes: 0 additions & 33 deletions src/common/vector/value_vector_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,39 +5,6 @@
using namespace kuzu;
using namespace common;

void ValueVectorUtils::addLiteralToValueVector(
ValueVector& resultVector, uint64_t pos, const Literal& literal) {
if (literal.isNull()) {
resultVector.setNull(pos, true);
return;
}
switch (literal.dataType.typeID) {
case INT64: {
resultVector.setValue(pos, literal.val.int64Val);
} break;
case DOUBLE: {
resultVector.setValue(pos, literal.val.doubleVal);
} break;
case BOOL: {
resultVector.setValue(pos, literal.val.booleanVal);
} break;
case DATE: {
resultVector.setValue(pos, literal.val.dateVal);
} break;
case TIMESTAMP: {
resultVector.setValue(pos, literal.val.timestampVal);
} break;
case INTERVAL: {
resultVector.setValue(pos, literal.val.intervalVal);
} break;
case STRING: {
resultVector.setValue(pos, literal.strVal);
} break;
default:
assert(false);
}
}

void ValueVectorUtils::copyNonNullDataWithSameTypeIntoPos(
ValueVector& resultVector, uint64_t pos, const uint8_t* srcData) {
copyNonNullDataWithSameType(resultVector.dataType, srcData,
Expand Down
2 changes: 1 addition & 1 deletion src/expression_evaluator/literal_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ bool LiteralExpressionEvaluator::select(SelectionVector& selVector) {
void LiteralExpressionEvaluator::resolveResultVector(
const ResultSet& resultSet, MemoryManager* memoryManager) {
resultVector = make_shared<ValueVector>(literal->dataType, memoryManager);
ValueVectorUtils::addLiteralToValueVector(*resultVector, 0, *literal);
resultVector->setLiteral(0, *literal);
resultVector->state = DataChunkState::getSingleValueDataChunkState();
}

Expand Down
3 changes: 0 additions & 3 deletions src/function/built_in_vector_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ void BuiltInVectorOperations::registerVectorOperations() {

bool BuiltInVectorOperations::canApplyStaticEvaluation(
const string& functionName, const expression_vector& children) {
if (functionName == ID_FUNC_NAME) {
return true; // bind as property
}
if ((functionName == CAST_TO_DATE_FUNC_NAME || functionName == CAST_TO_TIMESTAMP_FUNC_NAME ||
functionName == CAST_TO_INTERVAL_FUNC_NAME) &&
children[0]->expressionType == LITERAL && children[0]->dataType.typeID == STRING) {
Expand Down
2 changes: 2 additions & 0 deletions src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ class Binder {
const vector<unique_ptr<ParsedExpression>>& projectionExpressions, bool containsStar);
// Rewrite variable "v" as all properties of "v"
expression_vector rewriteNodeOrRelExpression(const Expression& expression);
expression_vector rewriteNodeExpression(const Expression& expression);
expression_vector rewriteRelExpression(const Expression& expression);

void bindOrderBySkipLimitIfNecessary(
BoundProjectionBody& boundProjectionBody, const ProjectionBody& projectionBody);
Expand Down
4 changes: 2 additions & 2 deletions src/include/binder/expression_binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ class ExpressionBinder {

shared_ptr<Expression> bindNullOperatorExpression(const ParsedExpression& parsedExpression);

// bind to an existing property expression.
shared_ptr<Expression> bindPropertyExpression(const ParsedExpression& parsedExpression);
// bind to an existing property expression of given node table.
shared_ptr<Expression> bindNodePropertyExpression(
const Expression& expression, const string& propertyName);
shared_ptr<Expression> bindRelPropertyExpression(
Expand All @@ -56,6 +54,8 @@ class ExpressionBinder {
shared_ptr<Expression> bindInternalIDExpression(const ParsedExpression& parsedExpression);
shared_ptr<Expression> bindInternalIDExpression(const Expression& expression);
unique_ptr<Expression> createInternalNodeIDExpression(const Expression& node);
shared_ptr<Expression> bindLabelFunction(const ParsedExpression& parsedExpression);
shared_ptr<Expression> bindNodeLabelFunction(const Expression& expression);

shared_ptr<Expression> bindParameterExpression(const ParsedExpression& parsedExpression);

Expand Down
39 changes: 23 additions & 16 deletions src/include/catalog/catalog.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,49 +44,56 @@ class CatalogContent {
const vector<PropertyNameDataType>& propertyDefinitions,
vector<pair<table_id_t, table_id_t>> srcDstTableIDs);

virtual inline string getNodeTableName(table_id_t tableID) const {
inline bool containNodeTable(table_id_t tableID) const {
return nodeTableSchemas.contains(tableID);
}
inline bool containRelTable(table_id_t tableID) const {
return relTableSchemas.contains(tableID);
}

inline string getNodeTableName(table_id_t tableID) const {
assert(containNodeTable(tableID));
return nodeTableSchemas.at(tableID)->tableName;
}
virtual inline string getRelTableName(table_id_t tableID) const {
inline string getRelTableName(table_id_t tableID) const {
assert(containRelTable(tableID));
return relTableSchemas.at(tableID)->tableName;
}

inline NodeTableSchema* getNodeTableSchema(table_id_t tableID) const {
assert(containNodeTable(tableID));
return nodeTableSchemas.at(tableID).get();
}
virtual inline RelTableSchema* getRelTableSchema(table_id_t tableID) const {
inline RelTableSchema* getRelTableSchema(table_id_t tableID) const {
assert(containRelTable(tableID));
return relTableSchemas.at(tableID).get();
}

virtual inline bool containNodeTable(const string& tableName) const {
return end(nodeTableNameToIDMap) != nodeTableNameToIDMap.find(tableName);
inline bool containNodeTable(const string& tableName) const {
return nodeTableNameToIDMap.contains(tableName);
}
virtual inline bool containRelTable(const string& tableName) const {
return end(relTableNameToIDMap) != relTableNameToIDMap.find(tableName);
inline bool containRelTable(const string& tableName) const {
return relTableNameToIDMap.contains(tableName);
}

virtual inline table_id_t getNodeTableIDFromName(const string& tableName) const {
inline table_id_t getNodeTableIDFromName(const string& tableName) const {
return nodeTableNameToIDMap.at(tableName);
}
virtual inline table_id_t getRelTableIDFromName(const string& tableName) const {
inline table_id_t getRelTableIDFromName(const string& tableName) const {
return relTableNameToIDMap.at(tableName);
}

virtual inline bool isSingleMultiplicityInDirection(
table_id_t tableID, RelDirection direction) const {
inline bool isSingleMultiplicityInDirection(table_id_t tableID, RelDirection direction) const {
return relTableSchemas.at(tableID)->isSingleMultiplicityInDirection(direction);
}

/**
* Node and Rel property functions.
*/
virtual bool containNodeProperty(table_id_t tableID, const string& propertyName) const;
virtual bool containRelProperty(table_id_t tableID, const string& propertyName) const;

// getNodeProperty and getRelProperty should be called after checking if property exists
// (containNodeProperty and containRelProperty).
virtual const Property& getNodeProperty(table_id_t tableID, const string& propertyName) const;
virtual const Property& getRelProperty(table_id_t tableID, const string& propertyName) const;
const Property& getNodeProperty(table_id_t tableID, const string& propertyName) const;
const Property& getRelProperty(table_id_t tableID, const string& propertyName) const;

vector<Property> getAllNodeProperties(table_id_t tableID) const;
inline const vector<Property>& getRelProperties(table_id_t tableID) const {
Expand Down
2 changes: 2 additions & 0 deletions src/include/common/expression_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ const string TO_SECONDS_FUNC_NAME = "TO_SECONDS";
const string TO_MILLISECONDS_FUNC_NAME = "TO_MILLISECONDS";
const string TO_MICROSECONDS_FUNC_NAME = "TO_MICROSECONDS";

// Node/Rel functions.
const string ID_FUNC_NAME = "ID";
const string LABEL_FUNC_NAME = "LABEL";

enum ExpressionType : uint8_t {

Expand Down
1 change: 0 additions & 1 deletion src/include/common/in_mem_overflow_buffer_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class InMemOverflowBufferUtils {
reinterpret_cast<T*>(result.overflowPtr)[elementPos] = element;
}

private:
static inline void allocateSpaceForList(
ku_list_t& list, uint64_t numBytes, InMemOverflowBuffer& buffer) {
list.overflowPtr = reinterpret_cast<uint64_t>(buffer.allocateSpace(numBytes));
Expand Down
Loading