diff --git a/tools/python_api/CMakeLists.txt b/tools/python_api/CMakeLists.txt index e1c3e1f4d9..1319a8c43c 100644 --- a/tools/python_api/CMakeLists.txt +++ b/tools/python_api/CMakeLists.txt @@ -9,6 +9,8 @@ file(GLOB SOURCE_PY pybind11_add_module(_kuzu SHARED src_cpp/kuzu_binding.cpp + src_cpp/cached_import/py_cached_item.cpp + src_cpp/cached_import/py_cached_import.cpp src_cpp/py_connection.cpp src_cpp/py_database.cpp src_cpp/py_prepared_statement.cpp diff --git a/tools/python_api/src_cpp/cached_import/py_cached_import.cpp b/tools/python_api/src_cpp/cached_import/py_cached_import.cpp new file mode 100644 index 0000000000..e0317ab9fd --- /dev/null +++ b/tools/python_api/src_cpp/cached_import/py_cached_import.cpp @@ -0,0 +1,18 @@ +#include "cached_import/py_cached_import.h" + +namespace kuzu { + +PythonCachedImport::~PythonCachedImport() { + py::gil_scoped_acquire acquire; + allObjects.clear(); +} + +py::handle PythonCachedImport::addToCache(py::object obj) { + auto ptr = obj.ptr(); + allObjects.push_back(obj); + return ptr; +} + +std::shared_ptr importCache; + +} // namespace kuzu diff --git a/tools/python_api/src_cpp/cached_import/py_cached_item.cpp b/tools/python_api/src_cpp/cached_import/py_cached_item.cpp new file mode 100644 index 0000000000..ebbc1ecb39 --- /dev/null +++ b/tools/python_api/src_cpp/cached_import/py_cached_item.cpp @@ -0,0 +1,24 @@ +#include "cached_import/py_cached_item.h" + + +#include "cached_import/py_cached_import.h" +#include "common/exception/runtime.h" + +namespace kuzu { + +py::handle PythonCachedItem::operator()() { + assert((bool)PyGILState_Check()); + // load if unloaded, return cached object if already loaded + if (loaded) { + return object; + } + if (parent == nullptr) { + object = importCache->addToCache(std::move(py::module::import(name.c_str()))); + } else { + object = importCache->addToCache(std::move((*parent)().attr(name.c_str()))); + } + loaded = true; + return object; +} + +} // namespace kuzu diff --git a/tools/python_api/src_cpp/include/cached_import/py_cached_import.h b/tools/python_api/src_cpp/include/cached_import/py_cached_import.h new file mode 100644 index 0000000000..cd093224e6 --- /dev/null +++ b/tools/python_api/src_cpp/include/cached_import/py_cached_import.h @@ -0,0 +1,33 @@ +#pragma once + +#include + +#include "py_cached_modules.h" + +namespace kuzu { + +class PythonCachedImport { +public: + // Note: Callers generally acquire the GIL prior to entering functions + // that require the import cache. + + PythonCachedImport() = default; + ~PythonCachedImport(); + + py::handle addToCache(py::object obj); + + DateTimeCachedItem datetime; + DecimalCachedItem decimal; + InspectCachedItem inspect; + NumpyMaCachedItem numpyma; + PandasCachedItem pandas; + PyarrowCachedItem pyarrow; + UUIDCachedItem uuid; + +private: + std::vector allObjects; +}; + +extern std::shared_ptr importCache; + +} // namespace kuzu diff --git a/tools/python_api/src_cpp/include/cached_import/py_cached_item.h b/tools/python_api/src_cpp/include/cached_import/py_cached_item.h new file mode 100644 index 0000000000..159220e79b --- /dev/null +++ b/tools/python_api/src_cpp/include/cached_import/py_cached_item.h @@ -0,0 +1,26 @@ +#pragma once + +#include +#include + +#include "pybind_include.h" + +namespace kuzu { + +class PythonCachedItem { +public: + explicit PythonCachedItem(const std::string& name, PythonCachedItem* parent = nullptr) + : name(name), parent(parent), loaded(false) {} + virtual ~PythonCachedItem() = default; + + bool isLoaded() const {return loaded;} + py::handle operator()(); + +private: + std::string name; + PythonCachedItem* parent; + bool loaded; + py::handle object; +}; + +} // namespace kuzu diff --git a/tools/python_api/src_cpp/include/cached_import/py_cached_modules.h b/tools/python_api/src_cpp/include/cached_import/py_cached_modules.h new file mode 100644 index 0000000000..6bc374df4c --- /dev/null +++ b/tools/python_api/src_cpp/include/cached_import/py_cached_modules.h @@ -0,0 +1,128 @@ +#pragma once + +#include "py_cached_item.h" + +namespace kuzu { + +class DateTimeCachedItem : public PythonCachedItem { + +public: + DateTimeCachedItem() : PythonCachedItem("datetime"), date("date", this), + datetime("datetime", this), timedelta("timedelta", this) {} + + PythonCachedItem date; + PythonCachedItem datetime; + PythonCachedItem timedelta; +}; + +class DecimalCachedItem : public PythonCachedItem { + +public: + DecimalCachedItem() : PythonCachedItem("decimal"), Decimal("Decimal", this) {} + + PythonCachedItem Decimal; +}; + +class InspectCachedItem : public PythonCachedItem { + +public: + InspectCachedItem() : PythonCachedItem("inspect"), currentframe("currentframe", this) {} + + PythonCachedItem currentframe; +}; + +class NumpyMaCachedItem : public PythonCachedItem { + +public: + NumpyMaCachedItem() : PythonCachedItem("numpy.ma"), masked_array("masked_array", this) {} + + PythonCachedItem masked_array; +}; + +class PandasCachedItem : public PythonCachedItem { + + class SeriesCachedItem : public PythonCachedItem { + public: + explicit SeriesCachedItem(PythonCachedItem* parent): PythonCachedItem("series", parent), + Series("Series", this) {} + + PythonCachedItem Series; + }; + + class CoreCachedItem : public PythonCachedItem { + public: + explicit CoreCachedItem(PythonCachedItem* parent): PythonCachedItem("core", parent), + series(this) {} + + SeriesCachedItem series; + }; + + class DataFrameCachedItem : public PythonCachedItem { + public: + explicit DataFrameCachedItem(PythonCachedItem* parent): PythonCachedItem("DataFrame", parent), + from_dict("from_dict", this) {} + + PythonCachedItem from_dict; + }; + +public: + PandasCachedItem() : PythonCachedItem("pandas"), core(this), DataFrame(this), NA("NA", this), + NaT("NaT", this) {} + + CoreCachedItem core; + DataFrameCachedItem DataFrame; + PythonCachedItem NA; + PythonCachedItem NaT; +}; + +class PyarrowCachedItem : public PythonCachedItem { + + class RecordBatchCachedItem : public PythonCachedItem { + public: + explicit RecordBatchCachedItem(PythonCachedItem* parent): PythonCachedItem("RecordBatch", parent), + _import_from_c("_import_from_c", this) {} + + PythonCachedItem _import_from_c; + }; + + class SchemaCachedItem : public PythonCachedItem { + public: + explicit SchemaCachedItem(PythonCachedItem* parent): PythonCachedItem("Schema", parent), + _import_from_c("_import_from_c", this) {} + + PythonCachedItem _import_from_c; + }; + + class TableCachedItem : public PythonCachedItem { + public: + explicit TableCachedItem(PythonCachedItem* parent): PythonCachedItem("Table", parent), + from_batches("from_batches", this) {} + + PythonCachedItem from_batches; + }; + + class LibCachedItem : public PythonCachedItem { + public: + explicit LibCachedItem(PythonCachedItem* parent): PythonCachedItem("lib", parent), + RecordBatch(this), Schema(this), Table(this) {} + + RecordBatchCachedItem RecordBatch; + SchemaCachedItem Schema; + TableCachedItem Table; + }; + +public: + PyarrowCachedItem(): PythonCachedItem("pyarrow"), lib(this) {} + + LibCachedItem lib; +}; + +class UUIDCachedItem : public PythonCachedItem { + +public: + UUIDCachedItem() : PythonCachedItem("uuid"), UUID("UUID", this) {} + + PythonCachedItem UUID; +}; + +} // namespace kuzu diff --git a/tools/python_api/src_cpp/include/py_database.h b/tools/python_api/src_cpp/include/py_database.h index 2e65ec9e56..0cc9265e9c 100644 --- a/tools/python_api/src_cpp/include/py_database.h +++ b/tools/python_api/src_cpp/include/py_database.h @@ -2,6 +2,7 @@ #include "main/kuzu.h" #include "main/storage_driver.h" +#include "cached_import/py_cached_import.h" #include "pybind_include.h" // IWYU pragma: keep (used for py:: namespace) #define PYBIND11_DETAILED_ERROR_MESSAGES using namespace kuzu::main; @@ -23,7 +24,7 @@ class PyDatabase { explicit PyDatabase(const std::string& databasePath, uint64_t bufferPoolSize, uint64_t maxNumThreads, bool compression, bool readOnly, uint64_t maxDBSize); - ~PyDatabase() = default; + ~PyDatabase(); template void scanNodeTable(const std::string& tableName, const std::string& propName, diff --git a/tools/python_api/src_cpp/kuzu_binding.cpp b/tools/python_api/src_cpp/kuzu_binding.cpp index b3f8581e17..69a51a4dfc 100644 --- a/tools/python_api/src_cpp/kuzu_binding.cpp +++ b/tools/python_api/src_cpp/kuzu_binding.cpp @@ -1,3 +1,4 @@ +#include "include/cached_import/py_cached_import.h" #include "include/py_connection.h" #include "include/py_database.h" #include "include/py_prepared_statement.h" @@ -8,6 +9,10 @@ void bind(py::module& m) { PyConnection::initialize(m); PyPreparedStatement::initialize(m); PyQueryResult::initialize(m); + auto cleanImportCache = []() { + kuzu::importCache.reset(); + }; + m.add_object("_clean_import_cache", py::capsule(cleanImportCache)); } PYBIND11_MODULE(_kuzu, m) { diff --git a/tools/python_api/src_cpp/pandas/pandas_analyzer.cpp b/tools/python_api/src_cpp/pandas/pandas_analyzer.cpp index 5fd2af3d43..354aa9a55b 100644 --- a/tools/python_api/src_cpp/pandas/pandas_analyzer.cpp +++ b/tools/python_api/src_cpp/pandas/pandas_analyzer.cpp @@ -1,6 +1,7 @@ #include "pandas/pandas_analyzer.h" #include "function/built_in_function_utils.h" +#include "cached_import/py_cached_import.h" #include "py_conversion.h" namespace kuzu { @@ -37,7 +38,7 @@ common::LogicalType PandasAnalyzer::getListType(py::object& ele, bool& canConver for (auto pyVal : ele) { auto object = py::reinterpret_borrow(pyVal); auto itemType = getItemType(object, canConvert); - if (i != 0) { + if (i == 0) { listType = itemType; } else { if (!upgradeType(listType, itemType)) { @@ -88,8 +89,8 @@ static py::object findFirstNonNull(const py::handle& row, uint64_t numRows) { common::LogicalType PandasAnalyzer::innerAnalyze(py::object column, bool& canConvert) { auto numRows = py::len(column); - auto pandasModule = py::module::import("pandas"); - auto pandasSeries = pandasModule.attr("core").attr("series").attr("Series"); + auto pandasModule = importCache->pandas; + auto pandasSeries = pandasModule.core.series.Series(); if (py::isinstance(column, pandasSeries)) { column = column.attr("__array__")(); diff --git a/tools/python_api/src_cpp/pandas/pandas_scan.cpp b/tools/python_api/src_cpp/pandas/pandas_scan.cpp index a8027245bc..f9473a2a9f 100644 --- a/tools/python_api/src_cpp/pandas/pandas_scan.cpp +++ b/tools/python_api/src_cpp/pandas/pandas_scan.cpp @@ -1,6 +1,7 @@ #include "pandas/pandas_scan.h" #include "function/table/bind_input.h" +#include "cached_import/py_cached_import.h" #include "numpy/numpy_scan.h" #include "py_connection.h" #include "pybind11/pytypes.h" @@ -127,10 +128,9 @@ std::unique_ptr tryReplacePD(py::dict& dict, py::str& tableName) { } std::unique_ptr replacePD(common::Value* value) { - py::gil_scoped_acquire acquire; auto pyTableName = py::str(value->getValue()); // Here we do an exhaustive search on the frame lineage. - auto currentFrame = py::module::import("inspect").attr("currentframe")(); + auto currentFrame = importCache->inspect.currentframe()(); while (hasattr(currentFrame, "f_locals")) { auto localDict = py::reinterpret_borrow(currentFrame.attr("f_locals")); if (localDict) { diff --git a/tools/python_api/src_cpp/py_connection.cpp b/tools/python_api/src_cpp/py_connection.cpp index 8ddef71b33..6517cd0953 100644 --- a/tools/python_api/src_cpp/py_connection.cpp +++ b/tools/python_api/src_cpp/py_connection.cpp @@ -4,12 +4,14 @@ #include "common/string_format.h" #include "datetime.h" // from Python +#include "cached_import/py_cached_import.h" #include "main/connection.h" #include "pandas/pandas_scan.h" #include "processor/result/factorized_table.h" #include "common/types/uuid.h" using namespace kuzu::common; +using namespace kuzu; void PyConnection::initialize(py::handle& m) { py::class_(m, "Connection") @@ -151,9 +153,7 @@ void PyConnection::getAllEdgesForTorchGeometric(py::array_t& npArray, } bool PyConnection::isPandasDataframe(const py::object& object) { - // TODO(Ziyi): introduce PythonCachedImport to avoid unnecessary import. - py::module pandas = py::module::import("pandas"); - return py::isinstance(object, pandas.attr("DataFrame")); + return py::isinstance(object, importCache->pandas.DataFrame()); } static Value transformPythonValue(py::handle val); @@ -176,11 +176,10 @@ std::unordered_map> transformPythonParameter } Value transformPythonValue(py::handle val) { - auto datetime_mod = py::module::import("datetime"); - auto datetime_datetime = datetime_mod.attr("datetime"); - auto time_delta = datetime_mod.attr("timedelta"); - auto datetime_date = datetime_mod.attr("date"); - auto uuid = py::module::import("uuid").attr("UUID"); + auto datetime_datetime = importCache->datetime.datetime(); + auto time_delta = importCache->datetime.timedelta(); + auto datetime_date = importCache->datetime.date(); + auto uuid = importCache->uuid.UUID(); if (py::isinstance(val)) { return Value::createValue(val.cast()); } else if (py::isinstance(val)) { diff --git a/tools/python_api/src_cpp/py_conversion.cpp b/tools/python_api/src_cpp/py_conversion.cpp index d194a4b86c..bad439c29d 100644 --- a/tools/python_api/src_cpp/py_conversion.cpp +++ b/tools/python_api/src_cpp/py_conversion.cpp @@ -1,18 +1,18 @@ #include "py_conversion.h" #include "common/type_utils.h" +#include "cached_import/py_cached_import.h" namespace kuzu { using namespace kuzu::common; +using kuzu::importCache; PythonObjectType getPythonObjectType(py::handle& ele) { - py::object pandas = py::module::import("pandas"); - auto pandasNa = pandas.attr("NA"); - auto pandasNat = pandas.attr("NaT"); - py::object datetime = py::module::import("datetime"); - auto pyDateTime = datetime.attr("datetime"); - auto pyDate = datetime.attr("date"); + auto pandasNa = importCache->pandas.NA(); + auto pyDateTime = importCache->datetime.datetime(); + auto pandasNat = importCache->pandas.NaT(); + auto pyDate = importCache->datetime.date(); if (ele.is_none() || ele.is(pandasNa) || ele.is(pandasNat)) { return PythonObjectType::None; } else if (py::isinstance(ele)) { diff --git a/tools/python_api/src_cpp/py_database.cpp b/tools/python_api/src_cpp/py_database.cpp index 0ed4c234fd..b2ed26c506 100644 --- a/tools/python_api/src_cpp/py_database.cpp +++ b/tools/python_api/src_cpp/py_database.cpp @@ -1,5 +1,8 @@ #include "include/py_database.h" +#include "include/cached_import/py_cached_import.h" +#include "pandas/pandas_scan.h" + #include #include "main/version.h" @@ -48,8 +51,14 @@ PyDatabase::PyDatabase(const std::string& databasePath, uint64_t bufferPoolSize, database = std::make_unique(databasePath, systemConfig); database->addBuiltInFunction(READ_PANDAS_FUNC_NAME, kuzu::PandasScanFunction::getFunctionSet()); storageDriver = std::make_unique(database.get()); + py::gil_scoped_acquire acquire; + if (kuzu::importCache.get() == nullptr) { + kuzu::importCache = std::make_shared(); + } } +PyDatabase::~PyDatabase() {} + template void PyDatabase::scanNodeTable(const std::string& tableName, const std::string& propName, const py::array_t& indices, py::array_t& result, int numThreads) { diff --git a/tools/python_api/src_cpp/py_query_result.cpp b/tools/python_api/src_cpp/py_query_result.cpp index a0c1edf388..cfa88c6f54 100644 --- a/tools/python_api/src_cpp/py_query_result.cpp +++ b/tools/python_api/src_cpp/py_query_result.cpp @@ -9,9 +9,11 @@ #include "common/types/value/node.h" #include "common/types/value/rel.h" #include "datetime.h" // python lib +#include "cached_import/py_cached_import.h" #include "include/py_query_result_converter.h" using namespace kuzu::common; +using kuzu::importCache; #define PyDateTimeTZ_FromDateAndTime(year, month, day, hour, min, sec, usec, timezone) \ PyDateTimeAPI->DateTime_FromDateAndTime( \ @@ -125,8 +127,7 @@ py::object convertRdfVariantToPyObject(const Value& value) { case LogicalTypeID::INTERVAL: { auto intervalVal = RdfVariant::getValue(&value); auto days = Interval::DAYS_PER_MONTH * intervalVal.months + intervalVal.days; - return py::cast(py::module::import("datetime") - .attr("timedelta")(py::arg("days") = days, + return py::cast(importCache->datetime.timedelta()(py::arg("days") = days, py::arg("microseconds") = intervalVal.micros)); } default: { @@ -172,7 +173,8 @@ py::object PyQueryResult::convertValueToPyObject(const Value& value) { case LogicalTypeID::INT128: { kuzu::common::int128_t result = value.getValue(); std::string int128_string = kuzu::common::Int128_t::ToString(result); - py::object Decimal = py::module_::import("decimal").attr("Decimal"); + + auto Decimal = importCache->decimal.Decimal(); py::object largeInt = Decimal(int128_string); return largeInt; } @@ -193,7 +195,7 @@ py::object PyQueryResult::convertValueToPyObject(const Value& value) { case LogicalTypeID::UUID: { kuzu::common::int128_t result = value.getValue(); std::string uuidString = kuzu::common::UUID::toString(result); - py::object UUID = py::module_::import("uuid").attr("UUID"); + auto UUID = importCache->uuid.UUID(); return UUID(uuidString); } case LogicalTypeID::DATE: { @@ -236,8 +238,8 @@ py::object PyQueryResult::convertValueToPyObject(const Value& value) { case LogicalTypeID::INTERVAL: { auto intervalVal = value.getValue(); auto days = Interval::DAYS_PER_MONTH * intervalVal.months + intervalVal.days; - return py::cast(py::module::import("datetime") - .attr("timedelta")(py::arg("days") = days, + + return py::cast(importCache->datetime.timedelta()(py::arg("days") = days, py::arg("microseconds") = intervalVal.micros)); } case LogicalTypeID::VAR_LIST: @@ -331,9 +333,9 @@ bool PyQueryResult::getNextArrowChunk(const std::vectorpyarrow.lib.RecordBatch._import_from_c(); + auto schema = ArrowConverter::toArrowSchema(typesInfo); batches.append(batchImportFunc((std::uint64_t)&data, (std::uint64_t)schema.get())); return true; @@ -341,22 +343,19 @@ bool PyQueryResult::getNextArrowChunk(const std::vector>& typesInfo, std::int64_t chunkSize) { - auto pyarrowLibModule = py::module::import("pyarrow").attr("lib"); py::list batches; while (getNextArrowChunk(typesInfo, batches, chunkSize)) {} return batches; } kuzu::pyarrow::Table PyQueryResult::getAsArrow(std::int64_t chunkSize) { - py::gil_scoped_acquire acquire; - - auto pyarrowLibModule = py::module::import("pyarrow").attr("lib"); - auto fromBatchesFunc = pyarrowLibModule.attr("Table").attr("from_batches"); - auto schemaImportFunc = pyarrowLibModule.attr("Schema").attr("_import_from_c"); auto typesInfo = queryResult->getColumnTypesInfo(); py::list batches = getArrowChunks(typesInfo, chunkSize); auto schema = ArrowConverter::toArrowSchema(typesInfo); + + auto fromBatchesFunc = importCache->pyarrow.lib.Table.from_batches(); + auto schemaImportFunc = importCache->pyarrow.lib.Schema._import_from_c(); auto schemaObj = schemaImportFunc((std::uint64_t)schema.get()); return py::cast(fromBatchesFunc(batches, schemaObj)); } diff --git a/tools/python_api/src_cpp/py_query_result_converter.cpp b/tools/python_api/src_cpp/py_query_result_converter.cpp index 7b2a4e5bca..d7e996b50d 100644 --- a/tools/python_api/src_cpp/py_query_result_converter.cpp +++ b/tools/python_api/src_cpp/py_query_result_converter.cpp @@ -1,9 +1,11 @@ #include "include/py_query_result_converter.h" #include "common/types/value/value.h" +#include "cached_import/py_cached_import.h" #include "include/py_query_result.h" using namespace kuzu::common; +using namespace kuzu; NPArrayWrapper::NPArrayWrapper(const LogicalType& type, uint64_t numFlatTuple) : type{type}, numElements{0} { @@ -205,9 +207,12 @@ py::object QueryResultConverter::toDF() { py::dict result; auto colNames = queryResult->getColumnNames(); + auto maskedArray = importCache->numpyma.masked_array(); + auto fromDict = importCache->pandas.DataFrame.from_dict(); + for (auto i = 0u; i < colNames.size(); i++) { result[colNames[i].c_str()] = - py::module::import("numpy.ma").attr("masked_array")(columns[i]->data, columns[i]->mask); + maskedArray(columns[i]->data, columns[i]->mask); } - return py::module::import("pandas").attr("DataFrame").attr("from_dict")(result); + return fromDict(result); }