Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Import cache fix and revert revert #3025

Merged
merged 2 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tools/python_api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ file(GLOB SOURCE_PY
pybind11_add_module(_kuzu
SHARED
src_cpp/kuzu_binding.cpp
src_cpp/cached_import/py_cached_item.cpp
src_cpp/cached_import/py_cached_import.cpp
src_cpp/py_connection.cpp
src_cpp/py_database.cpp
src_cpp/py_prepared_statement.cpp
Expand Down
18 changes: 18 additions & 0 deletions tools/python_api/src_cpp/cached_import/py_cached_import.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#include "cached_import/py_cached_import.h"

namespace kuzu {

PythonCachedImport::~PythonCachedImport() {
py::gil_scoped_acquire acquire;
allObjects.clear();
}

py::handle PythonCachedImport::addToCache(py::object obj) {
auto ptr = obj.ptr();
mxwli marked this conversation as resolved.
Show resolved Hide resolved
allObjects.push_back(obj);
return ptr;
}

std::shared_ptr<PythonCachedImport> importCache;

} // namespace kuzu
24 changes: 24 additions & 0 deletions tools/python_api/src_cpp/cached_import/py_cached_item.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include "cached_import/py_cached_item.h"


#include "cached_import/py_cached_import.h"
#include "common/exception/runtime.h"

namespace kuzu {

py::handle PythonCachedItem::operator()() {
assert((bool)PyGILState_Check());
// load if unloaded, return cached object if already loaded
if (loaded) {
return object;
}
if (parent == nullptr) {
object = importCache->addToCache(std::move(py::module::import(name.c_str())));
} else {
object = importCache->addToCache(std::move((*parent)().attr(name.c_str())));
}
loaded = true;
return object;
}

} // namespace kuzu
33 changes: 33 additions & 0 deletions tools/python_api/src_cpp/include/cached_import/py_cached_import.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#pragma once

#include <vector>

#include "py_cached_modules.h"

namespace kuzu {

class PythonCachedImport {
public:
// Note: Callers generally acquire the GIL prior to entering functions
// that require the import cache.

PythonCachedImport() = default;
~PythonCachedImport();

py::handle addToCache(py::object obj);

DateTimeCachedItem datetime;
DecimalCachedItem decimal;
InspectCachedItem inspect;
NumpyMaCachedItem numpyma;
PandasCachedItem pandas;
PyarrowCachedItem pyarrow;
UUIDCachedItem uuid;

private:
std::vector<py::object> allObjects;
};

extern std::shared_ptr<PythonCachedImport> importCache;

} // namespace kuzu
26 changes: 26 additions & 0 deletions tools/python_api/src_cpp/include/cached_import/py_cached_item.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#pragma once

#include <memory>
#include <string>

#include "pybind_include.h"

namespace kuzu {

class PythonCachedItem {
public:
explicit PythonCachedItem(const std::string& name, PythonCachedItem* parent = nullptr)
: name(name), parent(parent), loaded(false) {}
virtual ~PythonCachedItem() = default;

Check warning on line 14 in tools/python_api/src_cpp/include/cached_import/py_cached_item.h

View check run for this annotation

Codecov / codecov/patch

tools/python_api/src_cpp/include/cached_import/py_cached_item.h#L14

Added line #L14 was not covered by tests

bool isLoaded() const {return loaded;}
py::handle operator()();

private:
std::string name;
PythonCachedItem* parent;
bool loaded;
py::handle object;
};

} // namespace kuzu
128 changes: 128 additions & 0 deletions tools/python_api/src_cpp/include/cached_import/py_cached_modules.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
#pragma once

#include "py_cached_item.h"

namespace kuzu {

class DateTimeCachedItem : public PythonCachedItem {

public:
DateTimeCachedItem() : PythonCachedItem("datetime"), date("date", this),

Check warning on line 10 in tools/python_api/src_cpp/include/cached_import/py_cached_modules.h

View check run for this annotation

Codecov / codecov/patch

tools/python_api/src_cpp/include/cached_import/py_cached_modules.h#L10

Added line #L10 was not covered by tests
datetime("datetime", this), timedelta("timedelta", this) {}

PythonCachedItem date;
PythonCachedItem datetime;
PythonCachedItem timedelta;
};

class DecimalCachedItem : public PythonCachedItem {

public:
DecimalCachedItem() : PythonCachedItem("decimal"), Decimal("Decimal", this) {}

PythonCachedItem Decimal;
};

class InspectCachedItem : public PythonCachedItem {

public:
InspectCachedItem() : PythonCachedItem("inspect"), currentframe("currentframe", this) {}

PythonCachedItem currentframe;
};

class NumpyMaCachedItem : public PythonCachedItem {

public:
NumpyMaCachedItem() : PythonCachedItem("numpy.ma"), masked_array("masked_array", this) {}

PythonCachedItem masked_array;
};

class PandasCachedItem : public PythonCachedItem {

class SeriesCachedItem : public PythonCachedItem {
public:
explicit SeriesCachedItem(PythonCachedItem* parent): PythonCachedItem("series", parent),
Series("Series", this) {}

PythonCachedItem Series;
};

class CoreCachedItem : public PythonCachedItem {
public:
explicit CoreCachedItem(PythonCachedItem* parent): PythonCachedItem("core", parent),
series(this) {}

SeriesCachedItem series;
};

class DataFrameCachedItem : public PythonCachedItem {
public:
explicit DataFrameCachedItem(PythonCachedItem* parent): PythonCachedItem("DataFrame", parent),
from_dict("from_dict", this) {}

PythonCachedItem from_dict;
};

public:
PandasCachedItem() : PythonCachedItem("pandas"), core(this), DataFrame(this), NA("NA", this),
NaT("NaT", this) {}

CoreCachedItem core;
DataFrameCachedItem DataFrame;
PythonCachedItem NA;
PythonCachedItem NaT;
};

class PyarrowCachedItem : public PythonCachedItem {

class RecordBatchCachedItem : public PythonCachedItem {
public:
explicit RecordBatchCachedItem(PythonCachedItem* parent): PythonCachedItem("RecordBatch", parent),
_import_from_c("_import_from_c", this) {}

PythonCachedItem _import_from_c;
};

class SchemaCachedItem : public PythonCachedItem {
public:
explicit SchemaCachedItem(PythonCachedItem* parent): PythonCachedItem("Schema", parent),
_import_from_c("_import_from_c", this) {}

PythonCachedItem _import_from_c;
};

class TableCachedItem : public PythonCachedItem {
public:
explicit TableCachedItem(PythonCachedItem* parent): PythonCachedItem("Table", parent),
from_batches("from_batches", this) {}

PythonCachedItem from_batches;
};

class LibCachedItem : public PythonCachedItem {
public:
explicit LibCachedItem(PythonCachedItem* parent): PythonCachedItem("lib", parent),
RecordBatch(this), Schema(this), Table(this) {}

RecordBatchCachedItem RecordBatch;
SchemaCachedItem Schema;
TableCachedItem Table;
};

public:
PyarrowCachedItem(): PythonCachedItem("pyarrow"), lib(this) {}

LibCachedItem lib;
};

class UUIDCachedItem : public PythonCachedItem {

public:
UUIDCachedItem() : PythonCachedItem("uuid"), UUID("UUID", this) {}

PythonCachedItem UUID;
};

} // namespace kuzu
3 changes: 2 additions & 1 deletion tools/python_api/src_cpp/include/py_database.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "main/kuzu.h"
#include "main/storage_driver.h"
#include "cached_import/py_cached_import.h"
#include "pybind_include.h" // IWYU pragma: keep (used for py:: namespace)
#define PYBIND11_DETAILED_ERROR_MESSAGES
using namespace kuzu::main;
Expand All @@ -23,7 +24,7 @@ class PyDatabase {
explicit PyDatabase(const std::string& databasePath, uint64_t bufferPoolSize,
uint64_t maxNumThreads, bool compression, bool readOnly, uint64_t maxDBSize);

~PyDatabase() = default;
~PyDatabase();

template<class T>
void scanNodeTable(const std::string& tableName, const std::string& propName,
Expand Down
5 changes: 5 additions & 0 deletions tools/python_api/src_cpp/kuzu_binding.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "include/cached_import/py_cached_import.h"
#include "include/py_connection.h"
#include "include/py_database.h"
#include "include/py_prepared_statement.h"
Expand All @@ -8,6 +9,10 @@ void bind(py::module& m) {
PyConnection::initialize(m);
PyPreparedStatement::initialize(m);
PyQueryResult::initialize(m);
auto cleanImportCache = []() {
kuzu::importCache.reset();
};
m.add_object("_clean_import_cache", py::capsule(cleanImportCache));
}

PYBIND11_MODULE(_kuzu, m) {
Expand Down
7 changes: 4 additions & 3 deletions tools/python_api/src_cpp/pandas/pandas_analyzer.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "pandas/pandas_analyzer.h"

#include "function/built_in_function_utils.h"
#include "cached_import/py_cached_import.h"
#include "py_conversion.h"

namespace kuzu {
Expand Down Expand Up @@ -37,7 +38,7 @@ common::LogicalType PandasAnalyzer::getListType(py::object& ele, bool& canConver
for (auto pyVal : ele) {
auto object = py::reinterpret_borrow<py::object>(pyVal);
auto itemType = getItemType(object, canConvert);
if (i != 0) {
if (i == 0) {
listType = itemType;
} else {
if (!upgradeType(listType, itemType)) {
Expand Down Expand Up @@ -88,8 +89,8 @@ static py::object findFirstNonNull(const py::handle& row, uint64_t numRows) {

common::LogicalType PandasAnalyzer::innerAnalyze(py::object column, bool& canConvert) {
auto numRows = py::len(column);
auto pandasModule = py::module::import("pandas");
auto pandasSeries = pandasModule.attr("core").attr("series").attr("Series");
auto pandasModule = importCache->pandas;
auto pandasSeries = pandasModule.core.series.Series();

if (py::isinstance(column, pandasSeries)) {
column = column.attr("__array__")();
Expand Down
4 changes: 2 additions & 2 deletions tools/python_api/src_cpp/pandas/pandas_scan.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "pandas/pandas_scan.h"

#include "function/table/bind_input.h"
#include "cached_import/py_cached_import.h"
#include "numpy/numpy_scan.h"
#include "py_connection.h"
#include "pybind11/pytypes.h"
Expand Down Expand Up @@ -127,10 +128,9 @@ std::unique_ptr<Value> tryReplacePD(py::dict& dict, py::str& tableName) {
}

std::unique_ptr<common::Value> replacePD(common::Value* value) {
py::gil_scoped_acquire acquire;
auto pyTableName = py::str(value->getValue<std::string>());
// Here we do an exhaustive search on the frame lineage.
auto currentFrame = py::module::import("inspect").attr("currentframe")();
auto currentFrame = importCache->inspect.currentframe()();
while (hasattr(currentFrame, "f_locals")) {
auto localDict = py::reinterpret_borrow<py::dict>(currentFrame.attr("f_locals"));
if (localDict) {
Expand Down
15 changes: 7 additions & 8 deletions tools/python_api/src_cpp/py_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@

#include "common/string_format.h"
#include "datetime.h" // from Python
#include "cached_import/py_cached_import.h"
#include "main/connection.h"
#include "pandas/pandas_scan.h"
#include "processor/result/factorized_table.h"
#include "common/types/uuid.h"

using namespace kuzu::common;
using namespace kuzu;

void PyConnection::initialize(py::handle& m) {
py::class_<PyConnection>(m, "Connection")
Expand Down Expand Up @@ -151,9 +153,7 @@ void PyConnection::getAllEdgesForTorchGeometric(py::array_t<int64_t>& npArray,
}

bool PyConnection::isPandasDataframe(const py::object& object) {
// TODO(Ziyi): introduce PythonCachedImport to avoid unnecessary import.
py::module pandas = py::module::import("pandas");
return py::isinstance(object, pandas.attr("DataFrame"));
return py::isinstance(object, importCache->pandas.DataFrame());
}

static Value transformPythonValue(py::handle val);
Expand All @@ -176,11 +176,10 @@ std::unordered_map<std::string, std::unique_ptr<Value>> transformPythonParameter
}

Value transformPythonValue(py::handle val) {
auto datetime_mod = py::module::import("datetime");
auto datetime_datetime = datetime_mod.attr("datetime");
auto time_delta = datetime_mod.attr("timedelta");
auto datetime_date = datetime_mod.attr("date");
auto uuid = py::module::import("uuid").attr("UUID");
auto datetime_datetime = importCache->datetime.datetime();
auto time_delta = importCache->datetime.timedelta();
auto datetime_date = importCache->datetime.date();
auto uuid = importCache->uuid.UUID();
if (py::isinstance<py::bool_>(val)) {
return Value::createValue<bool>(val.cast<bool>());
} else if (py::isinstance<py::int_>(val)) {
Expand Down
12 changes: 6 additions & 6 deletions tools/python_api/src_cpp/py_conversion.cpp
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
#include "py_conversion.h"

#include "common/type_utils.h"
#include "cached_import/py_cached_import.h"

namespace kuzu {

using namespace kuzu::common;
using kuzu::importCache;

PythonObjectType getPythonObjectType(py::handle& ele) {
py::object pandas = py::module::import("pandas");
auto pandasNa = pandas.attr("NA");
auto pandasNat = pandas.attr("NaT");
py::object datetime = py::module::import("datetime");
auto pyDateTime = datetime.attr("datetime");
auto pyDate = datetime.attr("date");
auto pandasNa = importCache->pandas.NA();
auto pyDateTime = importCache->datetime.datetime();
auto pandasNat = importCache->pandas.NaT();
auto pyDate = importCache->datetime.date();
if (ele.is_none() || ele.is(pandasNa) || ele.is(pandasNat)) {
return PythonObjectType::None;
} else if (py::isinstance<py::bool_>(ele)) {
Expand Down
Loading
Loading