diff --git a/.github/workflows/interactive.yml b/.github/workflows/interactive.yml index fc10b28cf220..1039ce0cd073 100644 --- a/.github/workflows/interactive.yml +++ b/.github/workflows/interactive.yml @@ -175,7 +175,7 @@ jobs: INTERACTIVE_WORKSPACE: /tmp/interactive_workspace run: | cd ${GITHUB_WORKSPACE}/flex/tests/hqps - bash hqps_robust_test.sh ${INTERACTIVE_WORKSPACE} ./interactive_config_test.yaml + bash hqps_robust_test.sh ${INTERACTIVE_WORKSPACE} ./interactive_config_test.yaml ./interactive_config_test_cbo.yaml - name: Sample Query test env: diff --git a/flex/engines/graph_db/database/graph_db.cc b/flex/engines/graph_db/database/graph_db.cc index ed27fac9b825..1b8fff41a8f6 100644 --- a/flex/engines/graph_db/database/graph_db.cc +++ b/flex/engines/graph_db/database/graph_db.cc @@ -259,9 +259,8 @@ void GraphDB::Close() { std::fill(app_factories_.begin(), app_factories_.end(), nullptr); } -ReadTransaction GraphDB::GetReadTransaction() { - uint32_t ts = version_manager_.acquire_read_timestamp(); - return {graph_, version_manager_, ts}; +ReadTransaction GraphDB::GetReadTransaction(int thread_id) { + return contexts_[thread_id].session.GetReadTransaction(); } InsertTransaction GraphDB::GetInsertTransaction(int thread_id) { diff --git a/flex/engines/graph_db/database/graph_db.h b/flex/engines/graph_db/database/graph_db.h index 502710abd012..d345838f7be3 100644 --- a/flex/engines/graph_db/database/graph_db.h +++ b/flex/engines/graph_db/database/graph_db.h @@ -98,7 +98,7 @@ class GraphDB { * * @return graph_dir The directory of graph data. */ - ReadTransaction GetReadTransaction(); + ReadTransaction GetReadTransaction(int thread_id = 0); /** @brief Create a transaction to insert vertices and edges with a default * allocator. diff --git a/flex/engines/graph_db/database/graph_db_session.cc b/flex/engines/graph_db/database/graph_db_session.cc index dc2a411c0ef2..8173fa65f5f8 100644 --- a/flex/engines/graph_db/database/graph_db_session.cc +++ b/flex/engines/graph_db/database/graph_db_session.cc @@ -29,7 +29,7 @@ namespace gs { ReadTransaction GraphDBSession::GetReadTransaction() const { uint32_t ts = db_.version_manager_.acquire_read_timestamp(); - return ReadTransaction(db_.graph_, db_.version_manager_, ts); + return ReadTransaction(*this, db_.graph_, db_.version_manager_, ts); } InsertTransaction GraphDBSession::GetInsertTransaction() { diff --git a/flex/engines/graph_db/database/read_transaction.cc b/flex/engines/graph_db/database/read_transaction.cc index 723345c6d30b..70e094b022b6 100644 --- a/flex/engines/graph_db/database/read_transaction.cc +++ b/flex/engines/graph_db/database/read_transaction.cc @@ -19,9 +19,10 @@ namespace gs { -ReadTransaction::ReadTransaction(const MutablePropertyFragment& graph, +ReadTransaction::ReadTransaction(const GraphDBSession& session, + const MutablePropertyFragment& graph, VersionManager& vm, timestamp_t timestamp) - : graph_(graph), vm_(vm), timestamp_(timestamp) {} + : session_(session), graph_(graph), vm_(vm), timestamp_(timestamp) {} ReadTransaction::~ReadTransaction() { release(); } timestamp_t ReadTransaction::timestamp() const { return timestamp_; } @@ -135,4 +136,6 @@ void ReadTransaction::release() { } } +const GraphDBSession& ReadTransaction::GetSession() const { return session_; } + } // namespace gs diff --git a/flex/engines/graph_db/database/read_transaction.h b/flex/engines/graph_db/database/read_transaction.h index 5a88b1807f7e..23b93acf3fe3 100644 --- a/flex/engines/graph_db/database/read_transaction.h +++ b/flex/engines/graph_db/database/read_transaction.h @@ -26,6 +26,7 @@ namespace gs { class MutablePropertyFragment; +class GraphDBSession; class VersionManager; template class AdjListView { @@ -276,7 +277,8 @@ class SingleImmutableGraphView { class ReadTransaction { public: - ReadTransaction(const MutablePropertyFragment& graph, VersionManager& vm, + ReadTransaction(const GraphDBSession& session, + const MutablePropertyFragment& graph, VersionManager& vm, timestamp_t timestamp); ~ReadTransaction(); @@ -429,9 +431,12 @@ class ReadTransaction { return SingleImmutableGraphView(*csr); } + const GraphDBSession& GetSession() const; + private: void release(); + const GraphDBSession& session_; const MutablePropertyFragment& graph_; VersionManager& vm_; timestamp_t timestamp_; diff --git a/flex/engines/graph_db/runtime/adhoc/operators/operators.h b/flex/engines/graph_db/runtime/adhoc/operators/operators.h index 395d94acae9c..c3aeb496ef65 100644 --- a/flex/engines/graph_db/runtime/adhoc/operators/operators.h +++ b/flex/engines/graph_db/runtime/adhoc/operators/operators.h @@ -78,6 +78,11 @@ bl::result eval_join(const physical::Join& opr, Context&& ctx, bl::result eval_limit(const algebra::Limit& opr, Context&& ctx); +bl::result eval_procedure_call(const std::vector& alias, + const physical::ProcedureCall& opr, + const ReadTransaction& txn, + Context&& ctx); + void eval_sink(const Context& ctx, const ReadTransaction& txn, Encoder& output); } // namespace runtime diff --git a/flex/engines/graph_db/runtime/adhoc/operators/procedure_call.cc b/flex/engines/graph_db/runtime/adhoc/operators/procedure_call.cc new file mode 100644 index 000000000000..09b1fb17b0ee --- /dev/null +++ b/flex/engines/graph_db/runtime/adhoc/operators/procedure_call.cc @@ -0,0 +1,387 @@ +/** Copyright 2020 Alibaba Group Holding Limited. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "flex/engines/graph_db/database/graph_db.h" +#include "flex/engines/graph_db/database/graph_db_session.h" +#include "flex/engines/graph_db/runtime/adhoc/operators/operators.h" +#include "flex/engines/graph_db/runtime/common/leaf_utils.h" +#include "flex/proto_generated_gie/algebra.pb.h" + +namespace gs { +namespace runtime { + +std::shared_ptr any_vec_to_column( + const std::vector& any_vec) { + if (any_vec.empty()) { + return nullptr; + } + auto first = any_vec[0].type(); + if (first == RTAnyType::kBoolValue) { + ValueColumnBuilder builder; + for (auto& any : any_vec) { + builder.push_back_opt(any.as_bool()); + } + return builder.finish(); + } else if (first == RTAnyType::kI32Value) { + ValueColumnBuilder builder; + for (auto& any : any_vec) { + builder.push_back_opt(any.as_int32()); + } + return builder.finish(); + } else if (first == RTAnyType::kI64Value) { + ValueColumnBuilder builder; + for (auto& any : any_vec) { + builder.push_back_opt(any.as_int64()); + } + return builder.finish(); + } else if (first == RTAnyType::kU64Value) { + ValueColumnBuilder builder; + for (auto& any : any_vec) { + builder.push_back_opt(any.as_uint64()); + } + return builder.finish(); + } else if (first == RTAnyType::kF64Value) { + ValueColumnBuilder builder; + for (auto& any : any_vec) { + builder.push_back_opt(any.as_double()); + } + return builder.finish(); + } else if (first == RTAnyType::kStringValue) { + ValueColumnBuilder builder; + for (auto& any : any_vec) { + builder.push_back_elem(any); + } + return builder.finish(); + } else if (first == RTAnyType::kStringSetValue) { + ValueColumnBuilder> builder; + for (auto& any : any_vec) { + builder.push_back_opt(any.as_string_set()); + } + return builder.finish(); + } else if (first == RTAnyType::kDate32) { + ValueColumnBuilder builder; + for (auto& any : any_vec) { + builder.push_back_opt(any.as_date32()); + } + return builder.finish(); + } else { + LOG(FATAL) << "Unsupported RTAny type: " + << static_cast(first.type_enum_); + } +} + +RTAny object_to_rt_any(const common::Value& val) { + if (val.item_case() == common::Value::kBoolean) { + return RTAny::from_bool(val.boolean()); + } else if (val.item_case() == common::Value::kI32) { + return RTAny::from_int32(val.i32()); + } else if (val.item_case() == common::Value::kI64) { + return RTAny::from_int64(val.i64()); + } else if (val.item_case() == common::Value::kF64) { + return RTAny::from_double(val.f64()); + } else if (val.item_case() == common::Value::kStr) { + return RTAny::from_string(val.str()); + } else { + LOG(FATAL) << "Unsupported value type: " << val.item_case(); + } +} + +Any property_to_any(const results::Property& prop) { + // We just need the value; + const auto& val = prop.value(); + Any res; + if (val.item_case() == common::Value::kBoolean) { + res.set_bool(val.boolean()); + } else if (val.item_case() == common::Value::kI32) { + res.set_i32(val.i32()); + } else if (val.item_case() == common::Value::kI64) { + res.set_i64(val.i64()); + } else if (val.item_case() == common::Value::kF64) { + res.set_double(val.f64()); + } else if (val.item_case() == common::Value::kStr) { + res.set_string_view(std::string_view(val.str())); + } else { + LOG(FATAL) << "Unsupported value type: " << val.item_case(); + } + return res; +} + +RTAny vertex_to_rt_any(const results::Vertex& vertex) { + auto label_id = vertex.label().id(); + auto label_id_vid = decode_unique_vertex_id(vertex.id()); + CHECK(label_id == label_id_vid.first) << "Inconsistent label id."; + return RTAny::from_vertex(label_id, label_id_vid.second); +} + +RTAny edge_to_rt_any(const results::Edge& edge) { + LOG(FATAL) << "Not implemented."; + label_t src_label_id = (label_t) edge.src_label().id(); + label_t dst_label_id = (label_t) edge.dst_label().id(); + auto edge_triplet_tuple = decode_edge_label_id(edge.label().id()); + CHECK((src_label_id == std::get<0>(edge_triplet_tuple)) && + (dst_label_id == std::get<1>(edge_triplet_tuple))) + << "Inconsistent src label id."; + auto src_vertex_id = edge.src_id(); + auto dst_vertex_id = edge.dst_id(); + auto [_, src_vid] = decode_unique_vertex_id(src_vertex_id); + auto [__, dst_vid] = decode_unique_vertex_id(dst_vertex_id); + // properties + auto properties = edge.properties(); + LabelTriplet label_triplet{src_label_id, dst_label_id, + std::get<2>(edge_triplet_tuple)}; + if (properties.size() == 0) { + return RTAny::from_edge( + std::tuple{label_triplet, src_vid, dst_vid, Any(), Direction::kOut}); + } else if (properties.size() == 1) { + LOG(FATAL) << "Not implemented."; + return RTAny::from_edge(std::tuple{label_triplet, src_vid, dst_vid, + property_to_any(properties[0]), + Direction::kOut}); + } else { + std::vector props; + for (auto& prop : properties) { + props.push_back(property_to_any(prop)); + } + Any any; + any.set_record(props); + return RTAny::from_edge( + std::tuple{label_triplet, src_vid, dst_vid, any, Direction::kOut}); + } +} // namespace runtime + +RTAny graph_path_to_rt_any(const results::GraphPath& path) { + LOG(FATAL) << "Not implemented."; +} + +RTAny element_to_rt_any(const results::Element& element) { + if (element.inner_case() == results::Element::kVertex) { + return vertex_to_rt_any(element.vertex()); + } else if (element.inner_case() == results::Element::kEdge) { + return edge_to_rt_any(element.edge()); + } else if (element.inner_case() == results::Element::kObject) { + return object_to_rt_any(element.object()); + } else if (element.inner_case() == results::Element::kGraphPath) { + return graph_path_to_rt_any(element.graph_path()); + } else { + LOG(FATAL) << "Unsupported element type: " << element.inner_case(); + } +} + +RTAny collection_to_rt_any(const results::Collection& collection) { + std::vector values; + for (const auto& element : collection.collection()) { + values.push_back(element_to_rt_any(element)); + } + return RTAny::from_tuple(std::move(values)); +} + +RTAny column_to_rt_any(const results::Column& column) { + auto& entry = column.entry(); + if (entry.has_element()) { + return element_to_rt_any(entry.element()); + } else if (entry.has_collection()) { + return collection_to_rt_any(entry.collection()); + } else { + LOG(FATAL) << "Unsupported column entry type: " << entry.inner_case(); + } +} + +std::vector result_to_rt_any(const results::Results& result) { + auto& record = result.record(); + if (record.columns_size() == 0) { + LOG(WARNING) << "Empty result."; + return {}; + } else { + std::vector tuple; + for (int32_t i = 0; i < record.columns_size(); ++i) { + tuple.push_back(column_to_rt_any(record.columns(i))); + } + return tuple; + } +} + +std::pair>, std::vector> +collective_result_vec_to_column( + int32_t expect_col_num, + const std::vector& collective_results_vec) { + std::vector offsets; + offsets.push_back(0); + size_t record_cnt = 0; + for (size_t i = 0; i < collective_results_vec.size(); ++i) { + record_cnt += collective_results_vec[i].results_size(); + offsets.push_back(record_cnt); + } + std::vector> any_vec(expect_col_num); + for (size_t i = 0; i < collective_results_vec.size(); ++i) { + for (int32_t j = 0; j < collective_results_vec[i].results_size(); ++j) { + auto tuple = result_to_rt_any(collective_results_vec[i].results(j)); + CHECK(tuple.size() == (size_t) expect_col_num) + << "Inconsistent column number."; + for (int32_t k = 0; k < expect_col_num; ++k) { + any_vec[k].push_back(tuple[k]); + } + } + } + std::vector> columns; + for (int32_t i = 0; i < expect_col_num; ++i) { + columns.push_back(any_vec_to_column(any_vec[i])); + } + return std::make_pair(columns, offsets); +} + +bl::result fill_in_query(const procedure::Query& query, + const Context& ctx, size_t idx) { + procedure::Query real_query; + real_query.mutable_query_name()->CopyFrom(query.query_name()); + for (auto& param : query.arguments()) { + auto argument = real_query.add_arguments(); + if (param.value_case() == procedure::Argument::kVar) { + auto& var = param.var(); + auto tag = var.tag().id(); + auto col = ctx.get(tag); + if (col == nullptr) { + LOG(ERROR) << "Tag not found: " << tag; + continue; + } + auto val = col->get_elem(idx); + auto const_value = argument->mutable_const_(); + if (val.type() == gs::runtime::RTAnyType::kVertex) { + RETURN_BAD_REQUEST_ERROR("The input param should not be a vertex"); + } else if (val.type() == gs::runtime::RTAnyType::kEdge) { + RETURN_BAD_REQUEST_ERROR("The input param should not be an edge"); + } else if (val.type() == gs::runtime::RTAnyType::kI64Value) { + const_value->set_i64(val.as_int64()); + } else if (val.type() == gs::runtime::RTAnyType::kI32Value) { + const_value->set_i32(val.as_int32()); + } else if (val.type() == gs::runtime::RTAnyType::kStringValue) { + const_value->set_str(std::string(val.as_string())); + } else if (val.type() == gs::runtime::RTAnyType::kF64Value) { + const_value->set_f64(val.as_double()); + } else if (val.type() == gs::runtime::RTAnyType::kBoolValue) { + const_value->set_boolean(val.as_bool()); + } else if (val.type() == gs::runtime::RTAnyType::kDate32) { + const_value->set_i64(val.as_date32()); + } else { + LOG(ERROR) << "Unsupported type: " + << static_cast(val.type().type_enum_); + } + } else { + argument->CopyFrom(param); + } + } + return real_query; +} + +/** + * @brief Evaluate the ProcedureCall operator. + * The ProcedureCall operator is used to call a stored procedure, which is + * already registered in the system. The return value of the stored procedure + * is a result::CollectiveResults object, we need to convert it to a Column, + * and append to the current context. + * + * + * @param opr The ProcedureCall operator. + * @param txn The read transaction. + * @param ctx The input context. + * + * @return bl::result The output context. + * + * + */ +bl::result eval_procedure_call(const std::vector& aliases, + const physical::ProcedureCall& opr, + const ReadTransaction& txn, + Context&& ctx) { + auto& query = opr.query(); + auto& proc_name = query.query_name(); + + if (proc_name.item_case() == common::NameOrId::kName) { + const auto& sess = txn.GetSession(); + // cast off const, to get the app pointer. + // Why do we need to cast off const? Because current GetApp method is not + // const. + // TODO(zhanglei): Refactor the GetApp method to be const(maybe create the + // app once initialize, not on need). + GraphDBSession& sess_cast = const_cast(sess); + AppBase* app = const_cast(sess_cast.GetApp(proc_name.name())); + if (!app) { + RETURN_BAD_REQUEST_ERROR("Stored procedure not found: " + + proc_name.name()); + } + ReadAppBase* read_app = dynamic_cast(app); + if (!app) { + RETURN_BAD_REQUEST_ERROR("Stored procedure is not a read procedure: " + + proc_name.name()); + } + + std::vector results; + // Iterate over current context. + for (size_t i = 0; i < ctx.row_num(); ++i) { + // Call the procedure. + // Use real values from the context to replace the placeholders in the + // query. + BOOST_LEAF_AUTO(real_query, fill_in_query(query, ctx, i)); + // We need to serialize the protobuf-based arguments to the input format + // that a cypher procedure can accept. + auto query_str = real_query.SerializeAsString(); + // append CYPHER_PROTO as the last byte as input_format + query_str.push_back(static_cast( + GraphDBSession::InputFormat::kCypherProtoProcedure)); + std::vector buffer; + Encoder encoder(buffer); + Decoder decoder(query_str.data(), query_str.size()); + if (!read_app->Query(sess, decoder, encoder)) { + RETURN_CALL_PROCEDURE_ERROR("Failed to call procedure: "); + } + // Decode the result from the encoder. + Decoder result_decoder(buffer.data(), buffer.size()); + if (result_decoder.size() < 4) { + LOG(ERROR) << "Unexpected result size: " << result_decoder.size(); + RETURN_CALL_PROCEDURE_ERROR("Unexpected result size"); + } + std::string collective_results_str(result_decoder.get_string()); + results::CollectiveResults collective_results; + if (!collective_results.ParseFromString(collective_results_str)) { + LOG(ERROR) << "Failed to parse CollectiveResults"; + RETURN_CALL_PROCEDURE_ERROR("Failed to parse procedure's result"); + } + results.push_back(collective_results); + } + + auto column_and_offsets = + collective_result_vec_to_column(aliases.size(), results); + auto& columns = column_and_offsets.first; + auto& offsets = column_and_offsets.second; + if (columns.size() != aliases.size()) { + LOG(ERROR) << "Column size mismatch: " << columns.size() << " vs " + << aliases.size(); + RETURN_CALL_PROCEDURE_ERROR("Column size mismatch"); + } + if (columns.size() >= 1) { + ctx.set_with_reshuffle(aliases[0], columns[0], offsets); + } + for (size_t i = 1; i < columns.size(); ++i) { + ctx.set(aliases[i], columns[i]); + } + return std::move(ctx); + } else { + LOG(ERROR) << "Currently only support calling stored procedure by name"; + RETURN_UNSUPPORTED_ERROR( + "Currently only support calling stored procedure by name"); + } +} + +} // namespace runtime +} // namespace gs diff --git a/flex/engines/graph_db/runtime/adhoc/runtime.cc b/flex/engines/graph_db/runtime/adhoc/runtime.cc index 48f35b9b2992..b0d9e27ba78f 100644 --- a/flex/engines/graph_db/runtime/adhoc/runtime.cc +++ b/flex/engines/graph_db/runtime/adhoc/runtime.cc @@ -229,6 +229,15 @@ bl::result runtime_eval_impl( case physical::PhysicalOpr_Operator::OpKindCase::kLimit: { BOOST_LEAF_ASSIGN(ret, eval_limit(opr.opr().limit(), std::move(ret))); } break; + case physical::PhysicalOpr_Operator::OpKindCase::kProcedureCall: { + std::vector aliases; + for (int32_t i = 0; i < opr.meta_data_size(); ++i) { + aliases.push_back(opr.meta_data(i).alias()); + } + BOOST_LEAF_ASSIGN( + ret, eval_procedure_call(aliases, opr.opr().procedure_call(), txn, + std::move(ret))); + } break; default: LOG(ERROR) << "Unknown operator type: " diff --git a/flex/engines/graph_db/runtime/common/leaf_utils.h b/flex/engines/graph_db/runtime/common/leaf_utils.h index e5be15bdd415..88fee5382789 100644 --- a/flex/engines/graph_db/runtime/common/leaf_utils.h +++ b/flex/engines/graph_db/runtime/common/leaf_utils.h @@ -39,4 +39,8 @@ namespace bl = boost::leaf; return ::boost::leaf::new_error( \ ::gs::Status(::gs::StatusCode::UNIMPLEMENTED, PREPEND_LINE_INFO(msg))) +#define RETURN_CALL_PROCEDURE_ERROR(msg) \ + return ::boost::leaf::new_error( \ + ::gs::Status(::gs::StatusCode::QUERY_FAILED, PREPEND_LINE_INFO(msg))) + #endif // RUNTIME_COMMON_LEAF_UTILS_H_ diff --git a/flex/engines/graph_db/runtime/common/types.cc b/flex/engines/graph_db/runtime/common/types.cc index 4f0bf79a5b98..f217595b8ba4 100644 --- a/flex/engines/graph_db/runtime/common/types.cc +++ b/flex/engines/graph_db/runtime/common/types.cc @@ -24,10 +24,17 @@ uint64_t encode_unique_vertex_id(label_t label_id, vid_t vid) { return global_id.global_id; } +std::pair decode_unique_vertex_id(uint64_t unique_id) { + return std::pair{GlobalId::get_label_id(unique_id), + GlobalId::get_vid(unique_id)}; +} + uint32_t generate_edge_label_id(label_t src_label_id, label_t dst_label_id, label_t edge_label_id) { uint32_t unique_edge_label_id = src_label_id; static constexpr int num_bits = sizeof(label_t) * 8; + static_assert(num_bits * 3 <= sizeof(uint32_t) * 8, + "label_t is too large to be encoded in 32 bits"); unique_edge_label_id = unique_edge_label_id << num_bits; unique_edge_label_id = unique_edge_label_id | dst_label_id; unique_edge_label_id = unique_edge_label_id << num_bits; @@ -35,6 +42,20 @@ uint32_t generate_edge_label_id(label_t src_label_id, label_t dst_label_id, return unique_edge_label_id; } +std::tuple decode_edge_label_id( + uint32_t edge_label_id) { + static constexpr int num_bits = sizeof(label_t) * 8; + static_assert(num_bits * 3 <= sizeof(uint32_t) * 8, + "label_t is too large to be encoded in 32 bits"); + auto mask = (1 << num_bits) - 1; + label_t edge_label = edge_label_id & mask; + edge_label_id = edge_label_id >> num_bits; + label_t dst_label = edge_label_id & mask; + edge_label_id = edge_label_id >> num_bits; + label_t src_label = edge_label_id & mask; + return std::make_tuple(src_label, dst_label, edge_label); +} + int64_t encode_unique_edge_id(uint32_t label_id, vid_t src, vid_t dst) { // We assume label_id is only used by 24 bits. int64_t unique_edge_id = label_id; diff --git a/flex/engines/graph_db/runtime/common/types.h b/flex/engines/graph_db/runtime/common/types.h index 28f2a5784bff..af5cd4f80c93 100644 --- a/flex/engines/graph_db/runtime/common/types.h +++ b/flex/engines/graph_db/runtime/common/types.h @@ -26,9 +26,14 @@ namespace gs { namespace runtime { uint64_t encode_unique_vertex_id(label_t label_id, vid_t vid); +std::pair decode_unique_vertex_id(uint64_t unique_id); + uint32_t generate_edge_label_id(label_t src_label_id, label_t dst_label_id, label_t edge_label_id); int64_t encode_unique_edge_id(uint32_t label_id, vid_t src, vid_t dst); + +std::tuple decode_edge_label_id( + uint32_t edge_label_id); enum class Direction { kOut, kIn, diff --git a/flex/engines/hqps_db/app/interactive_app_base.h b/flex/engines/hqps_db/app/interactive_app_base.h index 5186379ad948..987401bc895a 100644 --- a/flex/engines/hqps_db/app/interactive_app_base.h +++ b/flex/engines/hqps_db/app/interactive_app_base.h @@ -35,7 +35,12 @@ inline bool parse_input_argument_from_proto_impl( } else { auto& type = std::get(tuple); auto& argument = args.Get(I); - auto& value = argument.value(); + if (argument.value_case() != procedure::Argument::kConst) { + LOG(ERROR) << "Expect a const value for input param, but got " + << argument.value_case(); + return false; + } + auto& value = argument.const_(); auto item_case = value.item_case(); if (item_case == common::Value::kI32) { if constexpr (std::is_same 0 + + +@pytest.mark.skipif( + os.environ.get("RUN_ON_PROTO", None) != "ON", reason="Only works on proto" +) +def test_call_proc_in_cypher(interactive_session, neo4j_session, create_modern_graph): + print("[Test call procedure in cypher]") + import_data_to_full_modern_graph(interactive_session, create_modern_graph) + result = neo4j_session.run( + 'MATCH(p: person) with p.id as oid CALL k_neighbors("person", oid, 1) return label_name, vertex_oid;' + ) + cnt = 0 + for record in result: + cnt += 1 + assert cnt == 8 diff --git a/flex/tests/hqps/hqps_robust_test.sh b/flex/tests/hqps/hqps_robust_test.sh index b93e3a65f94d..8090d8be8164 100644 --- a/flex/tests/hqps/hqps_robust_test.sh +++ b/flex/tests/hqps/hqps_robust_test.sh @@ -20,14 +20,15 @@ ADMIN_PORT=7777 QUERY_PORT=10000 CYPHER_PORT=7687 -if [ ! $# -eq 2 ]; then - echo "only receives: $# args, need 2" - echo "Usage: $0 " +if [ ! $# -eq 3 ]; then + echo "only receives: $# args, need 3" + echo "Usage: $0 " exit 1 fi INTERACTIVE_WORKSPACE=$1 ENGINE_CONFIG_PATH=$2 +CBO_ENGINE_CONFIG_PATH=$3 if [ ! -d ${INTERACTIVE_WORKSPACE} ]; then echo "INTERACTIVE_WORKSPACE: ${INTERACTIVE_WORKSPACE} not exists" @@ -38,6 +39,11 @@ if [ ! -f ${ENGINE_CONFIG_PATH} ]; then exit 1 fi +if [ ! -f ${CBO_ENGINE_CONFIG_PATH} ]; then + echo "CBO_ENGINE_CONFIG_PATH: ${CBO_ENGINE_CONFIG_PATH} not exists" + exit 1 +fi + RED='\033[0;31m' GREEN='\033[0;32m' NC='\033[0m' # No Color @@ -63,13 +69,19 @@ trap kill_service EXIT # start engine service start_engine_service(){ + # expect one argument + if [ ! $# -eq 1 ]; then + err "start_engine_service need one argument" + exit 1 + fi + local config_path=$1 #check SERVER_BIN exists if [ ! -f ${SERVER_BIN} ]; then err "SERVER_BIN not found" exit 1 fi - cmd="${SERVER_BIN} -c ${ENGINE_CONFIG_PATH} --enable-admin-service true " + cmd="${SERVER_BIN} -c ${config_path} --enable-admin-service true " cmd="${cmd} -w ${INTERACTIVE_WORKSPACE} --start-compiler true &" echo "Start engine service with command: ${cmd}" @@ -91,13 +103,27 @@ run_robust_test(){ popd } +run_additional_robust_test(){ + pushd ${FLEX_HOME}/interactive/sdk/python/gs_interactive + export RUN_ON_PROTO=ON + cmd="python3 -m pytest -s tests/test_robustness.py -k test_call_proc_in_cypher" + echo "Run additional robust test with command: ${cmd}" + eval ${cmd} || (err "Run additional robust test failed"; exit 1) + info "Run additional robust test success" + popd +} + kill_service -start_engine_service +start_engine_service $ENGINE_CONFIG_PATH export INTERACTIVE_ADMIN_ENDPOINT=http://localhost:${ADMIN_PORT} export INTERACTIVE_STORED_PROC_ENDPOINT=http://localhost:${QUERY_PORT} export INTERACTIVE_CYPHER_ENDPOINT=neo4j://localhost:${CYPHER_PORT} export INTERACTIVE_GREMLIN_ENDPOINT=ws://localhost:${GREMLIN_PORT}/gremlin run_robust_test +kill_service +sleep 5 +start_engine_service $CBO_ENGINE_CONFIG_PATH +run_additional_robust_test kill_service \ No newline at end of file diff --git a/interactive_engine/compiler/src/main/antlr4/CypherGS.g4 b/interactive_engine/compiler/src/main/antlr4/CypherGS.g4 index d4edbfacd7b9..1245dd637abb 100644 --- a/interactive_engine/compiler/src/main/antlr4/CypherGS.g4 +++ b/interactive_engine/compiler/src/main/antlr4/CypherGS.g4 @@ -52,7 +52,7 @@ CALL : ( 'C' | 'c' ) ( 'A' | 'a' ) ( 'L' | 'l' ) ( 'L' | 'l' ) ; YIELD : ( 'Y' | 'y' ) ( 'I' | 'i' ) ( 'E' | 'e' ) ( 'L' | 'l' ) ( 'D' | 'd' ) ; oC_RegularQuery - : oC_Match ( SP? ( oC_Match | oC_With | oC_Unwind ) )* ( SP oC_Return ) ; + : oC_Match ( SP? ( oC_Match | oC_With | oC_StandaloneCall | oC_Unwind ) )* ( SP oC_Return ) ; oC_Match : ( OPTIONAL SP )? MATCH SP? oC_Pattern ( SP? oC_Where )? ; diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/meta/schema/GSDataTypeConvertor.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/meta/schema/GSDataTypeConvertor.java index c8e89f7f2e5d..f12a70e8caf8 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/meta/schema/GSDataTypeConvertor.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/meta/schema/GSDataTypeConvertor.java @@ -108,6 +108,8 @@ public RelDataType convert(GSDataTypeDesc from) { Object value; if ((value = typeMap.get("primitive_type")) != null) { switch (value.toString()) { + case "DT_ANY": + return typeFactory.createSqlType(SqlTypeName.ANY); case "DT_SIGNED_INT32": return typeFactory.createSqlType(SqlTypeName.INTEGER); case "DT_SIGNED_INT64": diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/planner/GraphHepPlanner.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/planner/GraphHepPlanner.java index 5ec5e479e702..4963c8e6e3fc 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/planner/GraphHepPlanner.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/planner/GraphHepPlanner.java @@ -145,6 +145,11 @@ public RelNode visit(LogicalJoin join) { return findBestIfRoot(join, visitChildren(join)); } + @Override + public RelNode visit(GraphProcedureCall procedureCall) { + return findBestIfRoot(procedureCall, visitChildren(procedureCall)); + } + @Override public RelNode visit(CommonTableScan tableScan) { RelOptTable optTable = tableScan.getTable(); diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/rel/GraphProcedureCall.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/rel/GraphProcedureCall.java new file mode 100644 index 000000000000..f49bc68eb8f4 --- /dev/null +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/rel/GraphProcedureCall.java @@ -0,0 +1,93 @@ +/* + * + * * Copyright 2020 Alibaba Group Holding Limited. + * * + * * Licensed under the Apache License, Version 2.0 (the "License"); + * * you may not use this file except in compliance with the License. + * * You may obtain a copy of the License at + * * + * * http://www.apache.org/licenses/LICENSE-2.0 + * * + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, + * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * * See the License for the specific language governing permissions and + * * limitations under the License. + * + */ + +package com.alibaba.graphscope.common.ir.rel; + +import com.alibaba.graphscope.common.ir.tools.AliasInference; + +import org.apache.calcite.plan.GraphOptCluster; +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelShuttle; +import org.apache.calcite.rel.RelWriter; +import org.apache.calcite.rel.SingleRel; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rel.type.RelDataTypeFieldImpl; +import org.apache.calcite.rel.type.RelRecordType; +import org.apache.calcite.rex.RexNode; + +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +public class GraphProcedureCall extends SingleRel { + private final RexNode procedure; + + public GraphProcedureCall( + RelOptCluster optCluster, RelTraitSet traitSet, RelNode input, RexNode procedure) { + super(optCluster, traitSet, input); + this.procedure = procedure; + } + + @Override + public RelWriter explainTerms(RelWriter pw) { + return super.explainTerms(pw).item("procedure", procedure); + } + + @Override + public RelNode accept(RelShuttle shuttle) { + if (shuttle instanceof GraphShuttle) { + return ((GraphShuttle) shuttle).visit(this); + } + return shuttle.visit(this); + } + + @Override + public RelDataType deriveRowType() { + Set uniqueNameList = AliasInference.getUniqueAliasList(input, true); + List reOrgFields = + this.procedure.getType().getFieldList().stream() + .map( + k -> { + // ensure the name is unique in the query + String checkName = + AliasInference.inferDefault( + k.getName(), uniqueNameList); + uniqueNameList.add(checkName); + return new RelDataTypeFieldImpl( + checkName, + ((GraphOptCluster) getCluster()) + .getIdGenerator() + .generate(checkName), + k.getType()); + }) + .collect(Collectors.toList()); + return new RelRecordType(reOrgFields); + } + + @Override + public GraphProcedureCall copy(RelTraitSet traitSet, List inputs) { + return new GraphProcedureCall(getCluster(), traitSet, sole(inputs), procedure); + } + + public RexNode getProcedure() { + return procedure; + } +} diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/rel/GraphShuttle.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/rel/GraphShuttle.java index c12e6851ce5a..ed9cd4fbd47a 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/rel/GraphShuttle.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/rel/GraphShuttle.java @@ -105,6 +105,10 @@ public RelNode visit(GraphLogicalUnfold unfold) { return visitChildren(unfold); } + public RelNode visit(GraphProcedureCall procedureCall) { + return visitChildren(procedureCall); + } + @Override public RelNode visit(RelNode other) { if (other instanceof MultiJoin) { diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/ProcedurePhysicalBuilder.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/ProcedurePhysicalBuilder.java index a7b2b3dda339..add9393f0c17 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/ProcedurePhysicalBuilder.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/ProcedurePhysicalBuilder.java @@ -16,46 +16,28 @@ package com.alibaba.graphscope.common.ir.runtime; +import com.alibaba.graphscope.common.config.Configs; +import com.alibaba.graphscope.common.ir.meta.IrMeta; +import com.alibaba.graphscope.common.ir.runtime.proto.RexToProtoConverter; import com.alibaba.graphscope.common.ir.runtime.proto.Utils; +import com.alibaba.graphscope.common.ir.tools.GraphPlanner; import com.alibaba.graphscope.common.ir.tools.LogicalPlan; -import com.alibaba.graphscope.gaia.proto.Common; import com.alibaba.graphscope.gaia.proto.StoredProcedure; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.util.JsonFormat; -import org.apache.calcite.rex.*; -import org.apache.calcite.sql.SqlOperator; - -import java.util.List; - public class ProcedurePhysicalBuilder extends PhysicalBuilder { private final StoredProcedure.Query.Builder builder; - public ProcedurePhysicalBuilder(LogicalPlan logicalPlan) { + public ProcedurePhysicalBuilder(Configs configs, IrMeta irMeta, LogicalPlan logicalPlan) { super(logicalPlan); - this.builder = StoredProcedure.Query.newBuilder(); - RexCall procedureCall = (RexCall) logicalPlan.getProcedureCall(); - setStoredProcedureName(procedureCall, builder); - setStoredProcedureArgs(procedureCall, builder); - } - - private void setStoredProcedureName( - RexCall procedureCall, StoredProcedure.Query.Builder builder) { - SqlOperator operator = procedureCall.getOperator(); - builder.setQueryName(Common.NameOrId.newBuilder().setName(operator.getName()).build()); - } - - private void setStoredProcedureArgs( - RexCall procedureCall, StoredProcedure.Query.Builder builder) { - List operands = procedureCall.getOperands(); - for (int i = 0; i < operands.size(); ++i) { - builder.addArguments( - StoredProcedure.Argument.newBuilder() - // param name is omitted - .setParamInd(i) - .setValue(Utils.protoValue((RexLiteral) operands.get(i))) - .build()); - } + this.builder = + Utils.protoProcedure( + logicalPlan.getProcedureCall(), + new RexToProtoConverter( + true, + irMeta.getSchema().isColumnId(), + GraphPlanner.rexBuilderFactory.apply(configs))); } @Override diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/proto/GraphRelToProtoConverter.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/proto/GraphRelToProtoConverter.java index 53a56eaab4fd..e1e4524c2cc3 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/proto/GraphRelToProtoConverter.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/proto/GraphRelToProtoConverter.java @@ -105,6 +105,30 @@ public GraphRelToProtoConverter( this.depth = depth; } + @Override + public RelNode visit(GraphProcedureCall procedureCall) { + visitChildren(procedureCall); + physicalBuilder.addPlan( + GraphAlgebraPhysical.PhysicalOpr.newBuilder() + .setOpr( + GraphAlgebraPhysical.PhysicalOpr.Operator.newBuilder() + .setProcedureCall( + GraphAlgebraPhysical.ProcedureCall.newBuilder() + .setQuery( + Utils.protoProcedure( + procedureCall + .getProcedure(), + new RexToProtoConverter( + true, + isColumnId, + this.rexBuilder)))) + .build()) + .addAllMetaData( + Utils.physicalProtoRowType(procedureCall.getRowType(), isColumnId)) + .build()); + return procedureCall; + } + @Override public RelNode visit(GraphLogicalSource source) { GraphAlgebraPhysical.PhysicalOpr.Builder oprBuilder = diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/proto/Utils.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/proto/Utils.java index 41955150c9a6..bf6c6ba35056 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/proto/Utils.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/proto/Utils.java @@ -36,6 +36,7 @@ import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlOperator; @@ -776,4 +777,32 @@ public static void removeEdgeProperties( } tagColumns.keySet().removeAll(removeKeys); } + + public static final StoredProcedure.Query.Builder protoProcedure( + RexNode procedure, RexToProtoConverter converter) { + RexCall procedureCall = (RexCall) procedure; + StoredProcedure.Query.Builder builder = StoredProcedure.Query.newBuilder(); + SqlOperator operator = procedureCall.getOperator(); + builder.setQueryName(Common.NameOrId.newBuilder().setName(operator.getName()).build()); + List operands = procedureCall.getOperands(); + for (int i = 0; i < operands.size(); ++i) { + // param name is omitted + StoredProcedure.Argument.Builder paramBuilder = + StoredProcedure.Argument.newBuilder().setParamInd(i); + OuterExpression.ExprOpr protoValue = operands.get(i).accept(converter).getOperators(0); + switch (protoValue.getItemCase()) { + case VAR: + paramBuilder.setVar(protoValue.getVar()); + break; + case CONST: + paramBuilder.setConst(protoValue.getConst()); + break; + default: + throw new IllegalArgumentException( + "cannot set value=" + protoValue + " to any parameter in procedure"); + } + builder.addArguments(paramBuilder); + } + return builder; + } } diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/GraphPlanner.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/GraphPlanner.java index 42c1820c1478..24c821b2f2e9 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/GraphPlanner.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/GraphPlanner.java @@ -169,7 +169,7 @@ public PhysicalPlan planPhysical(LogicalPlan logicalPlan) { } } } else { - return new ProcedurePhysicalBuilder(logicalPlan).build(); + return new ProcedurePhysicalBuilder(graphConfig, irMeta, logicalPlan).build(); } } } diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/antlr4/visitor/GraphBuilderVisitor.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/antlr4/visitor/GraphBuilderVisitor.java index ffa4485ec732..caaee0081332 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/antlr4/visitor/GraphBuilderVisitor.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/antlr4/visitor/GraphBuilderVisitor.java @@ -19,6 +19,7 @@ import com.alibaba.graphscope.common.antlr4.ExprUniqueAliasInfer; import com.alibaba.graphscope.common.antlr4.ExprVisitorResult; import com.alibaba.graphscope.common.ir.rel.GraphLogicalAggregate; +import com.alibaba.graphscope.common.ir.rel.GraphProcedureCall; import com.alibaba.graphscope.common.ir.rel.type.group.GraphAggCall; import com.alibaba.graphscope.common.ir.rex.RexTmpVariableConverter; import com.alibaba.graphscope.common.ir.rex.RexVariableAliasCollector; @@ -26,10 +27,13 @@ import com.alibaba.graphscope.common.ir.tools.config.GraphOpt; import com.alibaba.graphscope.grammar.CypherGSBaseVisitor; import com.alibaba.graphscope.grammar.CypherGSParser; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import com.google.common.base.Supplier; import com.google.common.collect.Lists; import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.plan.RelTraitSet; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexNode; @@ -47,15 +51,41 @@ public class GraphBuilderVisitor extends CypherGSBaseVisitor { private final GraphBuilder builder; private final ExpressionVisitor expressionVisitor; private final ExprUniqueAliasInfer aliasInfer; + private final Supplier procedureCallVisitorSupplier; - public GraphBuilderVisitor(GraphBuilder builder) { - this(builder, new ExprUniqueAliasInfer()); + public GraphBuilderVisitor( + GraphBuilder builder, Supplier procedureCallVisitorSupplier) { + this(builder, new ExprUniqueAliasInfer(), procedureCallVisitorSupplier); } - public GraphBuilderVisitor(GraphBuilder builder, ExprUniqueAliasInfer aliasInfer) { + public GraphBuilderVisitor( + GraphBuilder builder, + ExprUniqueAliasInfer aliasInfer, + Supplier procedureCallVisitorSupplier) { this.builder = Objects.requireNonNull(builder); this.aliasInfer = Objects.requireNonNull(aliasInfer); this.expressionVisitor = new ExpressionVisitor(this); + this.procedureCallVisitorSupplier = Objects.requireNonNull(procedureCallVisitorSupplier); + } + + @VisibleForTesting + public GraphBuilderVisitor(GraphBuilder builder) { + this(builder, new ExprUniqueAliasInfer(), () -> null); + } + + @Override + public GraphBuilder visitOC_StandaloneCall(CypherGSParser.OC_StandaloneCallContext ctx) { + ProcedureCallVisitor procedureCallVisitor = procedureCallVisitorSupplier.get(); + Preconditions.checkArgument( + procedureCallVisitor != null, + "cannot do procedure call without procedure call visitor"); + RexNode procedure = procedureCallVisitor.visitOC_StandaloneCall(ctx); + return builder.push( + new GraphProcedureCall( + builder.getCluster(), + RelTraitSet.createEmpty(), + builder.build(), + procedure)); } @Override diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/antlr4/visitor/LogicalPlanVisitor.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/antlr4/visitor/LogicalPlanVisitor.java index 8472fd682812..8033ad049a4f 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/antlr4/visitor/LogicalPlanVisitor.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/antlr4/visitor/LogicalPlanVisitor.java @@ -59,7 +59,10 @@ public LogicalPlan visitOC_Cypher(CypherGSParser.OC_CypherContext ctx) { @Override public LogicalPlan visitOC_Query(CypherGSParser.OC_QueryContext ctx) { if (ctx.oC_RegularQuery() != null) { - GraphBuilderVisitor builderVisitor = new GraphBuilderVisitor(this.builder); + GraphBuilderVisitor builderVisitor = + new GraphBuilderVisitor( + this.builder, + () -> new ProcedureCallVisitor(this.builder, this.irMeta)); RelNode regularQuery = builderVisitor.visitOC_RegularQuery(ctx.oC_RegularQuery()).build(); Map map = builderVisitor.getExpressionVisitor().getDynamicParams(); diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/antlr4/visitor/ProcedureCallVisitor.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/antlr4/visitor/ProcedureCallVisitor.java index b5d690d0751b..915ed3d04f18 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/antlr4/visitor/ProcedureCallVisitor.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/antlr4/visitor/ProcedureCallVisitor.java @@ -40,7 +40,8 @@ public class ProcedureCallVisitor extends CypherGSBaseVisitor { public ProcedureCallVisitor(GraphBuilder builder, IrMeta irMeta) { this.builder = builder; - this.expressionVisitor = new ExpressionVisitor(new GraphBuilderVisitor(this.builder)); + this.expressionVisitor = + new ExpressionVisitor(new GraphBuilderVisitor(this.builder, () -> this)); this.irMeta = irMeta; } diff --git a/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/common/ir/runtime/FfiLogicalPlanTest.java b/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/common/ir/runtime/FfiLogicalPlanTest.java index fa2cc42c35e1..dcb3decb3899 100644 --- a/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/common/ir/runtime/FfiLogicalPlanTest.java +++ b/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/common/ir/runtime/FfiLogicalPlanTest.java @@ -159,7 +159,8 @@ public void logical_plan_4_test() throws Exception { LogicalPlan logicalPlan = com.alibaba.graphscope.cypher.antlr4.Utils.evalLogicalPlan( "Call ldbc_ic2(10l, 20120112l)", "config/modern/graph.yaml"); - try (PhysicalBuilder ffiBuilder = new ProcedurePhysicalBuilder(logicalPlan)) { + try (PhysicalBuilder ffiBuilder = + new ProcedurePhysicalBuilder(Utils.configs, Utils.schemaMeta, logicalPlan)) { PhysicalPlan plan = ffiBuilder.build(); Assert.assertEquals( FileUtils.readJsonFromResource("call_procedure.json"), plan.explain()); diff --git a/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/cypher/antlr4/MatchTest.java b/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/cypher/antlr4/MatchTest.java index 417d7f4f4995..d23e36aebe51 100644 --- a/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/cypher/antlr4/MatchTest.java +++ b/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/cypher/antlr4/MatchTest.java @@ -25,6 +25,8 @@ import com.alibaba.graphscope.common.ir.rel.graph.GraphLogicalSource; import com.alibaba.graphscope.common.ir.tools.GraphBuilder; import com.alibaba.graphscope.common.ir.tools.LogicalPlan; +import com.alibaba.graphscope.cypher.antlr4.parser.CypherAntlr4Parser; +import com.alibaba.graphscope.cypher.antlr4.visitor.LogicalPlanVisitor; import com.google.common.collect.ImmutableMap; import org.apache.calcite.rel.RelNode; @@ -54,7 +56,7 @@ public static void beforeClass() { optimizer = new GraphRelOptimizer(configs); irMeta = com.alibaba.graphscope.common.ir.Utils.mockIrMeta( - "schema/modern.json", + "config/modern/graph.yaml", "statistics/modern_statistics.json", optimizer.getGlogueHolder()); } @@ -607,4 +609,44 @@ public void udf_function_test() { + " alias=[person], opt=[VERTEX])", after3.explain().trim()); } + + @Test + public void shortest_path_test() { + GraphBuilder builder = + com.alibaba.graphscope.common.ir.Utils.mockGraphBuilder(optimizer, irMeta); + LogicalPlanVisitor logicalPlanVisitor = new LogicalPlanVisitor(builder, irMeta); + LogicalPlan logicalPlan = + logicalPlanVisitor.visit( + new CypherAntlr4Parser() + .parse( + "MATCH\n" + + "(person1:person {id:" + + " $person1Id})-[:knows]->(person2:person {id:" + + " $person2Id})\n" + + "CALL shortestPath.dijkstra.stream(\n" + + " person1, person2, 'KNOWS', 'BOTH', 'weight', 5)\n" + + "WITH person1, person2, totalCost\n" + + "WHERE person1.id <> $person2Id\n" + + "Return person1.id AS person1Id, totalCost AS" + + " totalWeight;")); + RelNode after = + optimizer.optimize( + logicalPlan.getRegularQuery(), new GraphIOProcessor(builder, irMeta)); + Assert.assertEquals( + "GraphLogicalProject(person1Id=[person1.id], totalWeight=[totalCost]," + + " isAppend=[false])\n" + + " LogicalFilter(condition=[<>(person1.id, ?1)])\n" + + " GraphLogicalProject(person1=[person1], person2=[person2]," + + " totalCost=[totalCost], isAppend=[false])\n" + + " GraphProcedureCall(procedure=[shortestPath.dijkstra.stream(person1," + + " person2, _UTF-8'KNOWS', _UTF-8'BOTH', _UTF-8'weight', 5)])\n" + + " GraphPhysicalGetV(tableConfig=[{isAll=false, tables=[person]}]," + + " alias=[person2], fusedFilter=[[=(_.id, ?1)]], opt=[END]," + + " physicalOpt=[ITSELF])\n" + + " GraphPhysicalExpand(tableConfig=[{isAll=false, tables=[knows]}]," + + " alias=[_], startAlias=[person1], opt=[OUT], physicalOpt=[VERTEX])\n" + + " GraphLogicalSource(tableConfig=[{isAll=false, tables=[person]}]," + + " alias=[person1], opt=[VERTEX], uniqueKeyFilters=[=(_.id, ?0)])", + after.explain().trim()); + } } diff --git a/interactive_engine/compiler/src/test/resources/call_procedure.json b/interactive_engine/compiler/src/test/resources/call_procedure.json index 4381fb62d6ad..06ab8c2aced5 100644 --- a/interactive_engine/compiler/src/test/resources/call_procedure.json +++ b/interactive_engine/compiler/src/test/resources/call_procedure.json @@ -3,12 +3,12 @@ "name": "ldbc_ic2" }, "arguments": [{ - "value": { + "const": { "i64": "10" } }, { "paramInd": 1, - "value": { + "const": { "i64": "20120112" } }] diff --git a/interactive_engine/compiler/src/test/resources/config/modern/graph.yaml b/interactive_engine/compiler/src/test/resources/config/modern/graph.yaml index f4e425d1e381..053d11184603 100644 --- a/interactive_engine/compiler/src/test/resources/config/modern/graph.yaml +++ b/interactive_engine/compiler/src/test/resources/config/modern/graph.yaml @@ -19,6 +19,38 @@ stored_procedures: query: "MATCH(n: PERSON ${personId2}) WHERE n.creationDate < ${maxDate} RETURN n.firstName AS name LIMIT 10;" library: libquery_ic2.so encoding: string + - name: shortestPath.dijkstra.stream + description: "" + type: x_cypher + params: + - name: person1 + type: + primitive_type: DT_ANY + - name: person2 + type: + primitive_type: DT_ANY + - name: label + type: + string: + long_text: + - name: direction + type: + string: + long_text: + - name: property + type: + string: + long_text: + - name: iterations + type: + primitive_type: DT_SIGNED_INT32 + returns: + - name: totalCost + type: + primitive_type: DT_FLOAT + query: "" + library: libquery_shortest_path.so + encoding: string schema: vertex_types: - type_name: person diff --git a/interactive_engine/executor/ir/common/build.rs b/interactive_engine/executor/ir/common/build.rs index 6442a98ee65b..96cb3a96dea9 100644 --- a/interactive_engine/executor/ir/common/build.rs +++ b/interactive_engine/executor/ir/common/build.rs @@ -31,6 +31,7 @@ fn codegen_inplace() -> Result<(), Box> { println!("cargo:rerun-if-changed=../proto/results.proto"); println!("cargo:rerun-if-changed=../proto/physical.proto"); println!("cargo:rerun-if-changed=../proto/type.proto"); + println!("cargo:rerun-if-changed=../proto/stored_procedure.proto"); let out_dir = PathBuf::from(GEN_DIR); if out_dir.exists() { let _ = std::fs::remove_dir_all(GEN_DIR); @@ -48,6 +49,7 @@ fn codegen_inplace() -> Result<(), Box> { "../proto/results.proto", "../proto/physical.proto", "../proto/type.proto", + "../proto/stored_procedure.proto", ], &["../proto"], )?; @@ -64,6 +66,7 @@ fn codegen_inplace() -> Result<(), Box> { println!("cargo:rerun-if-changed=../proto/results.proto"); println!("cargo:rerun-if-changed=../proto/physical.proto"); println!("cargo:rerun-if-changed=../proto/type.proto"); + println!("cargo:rerun-if-changed=../proto/stored_procedure.proto"); prost_build::Config::new() .type_attribute(".", "#[derive(Serialize,Deserialize)]") .compile_protos( @@ -75,6 +78,7 @@ fn codegen_inplace() -> Result<(), Box> { "../proto/results.proto", "../proto/physical.proto", "../proto/type.proto", + "../proto/stored_procedure.proto", ], &["../proto"], )?; diff --git a/interactive_engine/executor/ir/common/src/lib.rs b/interactive_engine/executor/ir/common/src/lib.rs index ea5166005db9..7fd8d0710e2c 100644 --- a/interactive_engine/executor/ir/common/src/lib.rs +++ b/interactive_engine/executor/ir/common/src/lib.rs @@ -46,6 +46,8 @@ pub mod generated { pub mod schema; #[path = "physical.rs"] pub mod physical; + #[path = "procedure.rs"] + pub mod procedure; } #[cfg(not(feature = "proto_inplace"))] @@ -65,6 +67,9 @@ pub mod generated { pub mod physical { tonic::include_proto!("physical"); } + pub mod procedure { + tonic::include_proto!("procedure"); + } } pub type KeyId = i32; diff --git a/interactive_engine/executor/ir/proto/physical.proto b/interactive_engine/executor/ir/proto/physical.proto index 8e4f30d40ac4..3cafff2ebe94 100644 --- a/interactive_engine/executor/ir/proto/physical.proto +++ b/interactive_engine/executor/ir/proto/physical.proto @@ -24,6 +24,7 @@ import "expr.proto"; import "schema.proto"; import "type.proto"; import "algebra.proto"; +import "stored_procedure.proto"; import "google/protobuf/wrappers.proto"; // To project a relation on certain attributes or further their properties @@ -277,6 +278,10 @@ message Repartition { // A dummy node to delegate a source opr for multiple scan cases. message Root {} +message ProcedureCall { + procedure.Query query = 1; +} + message PhysicalOpr { message Operator { oneof op_kind { @@ -300,6 +305,7 @@ message PhysicalOpr { GetV vertex = 30; EdgeExpand edge = 31; PathExpand path = 32; + ProcedureCall procedure_call = 33; } } message MetaData { diff --git a/interactive_engine/executor/ir/proto/stored_procedure.proto b/interactive_engine/executor/ir/proto/stored_procedure.proto index 51bd7d64c431..48446b077233 100644 --- a/interactive_engine/executor/ir/proto/stored_procedure.proto +++ b/interactive_engine/executor/ir/proto/stored_procedure.proto @@ -20,11 +20,16 @@ option java_package = "com.alibaba.graphscope.gaia.proto"; option java_outer_classname = "StoredProcedure"; import "common.proto"; +import "expr.proto"; message Argument { string param_name = 1; // param name int32 param_ind = 2; // index of param - common.Value value = 3; // real value + + oneof value { + common.Value const = 3; // real value + common.Variable var = 4; + } } message Query { diff --git a/interactive_engine/executor/ir/runtime/src/assembly.rs b/interactive_engine/executor/ir/runtime/src/assembly.rs index d6ea8040a923..2dd1f2825a15 100644 --- a/interactive_engine/executor/ir/runtime/src/assembly.rs +++ b/interactive_engine/executor/ir/runtime/src/assembly.rs @@ -949,6 +949,10 @@ impl IRJobAssembly { // this would be processed in assemble, and cannot be reached when install. Err(FnGenError::unsupported_error("unreachable sink in install"))? } + OpKind::ProcedureCall(procedure_call) => Err(FnGenError::unsupported_error(&format!( + "ProcedureCall Operator {:?}", + procedure_call + )))?, } prev_op_kind = to_op_kind(op)?;