Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement postgres-scanner #3139

Merged
merged 6 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions .github/workflows/ci-workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ jobs:
steps:
- uses: actions/checkout@v3

- name: Update PostgreSQL host
working-directory: extension/postgres_scanner/test/test_files
env:
FNAME: postgres_scanner.test
FIND: "localhost"
run: |
node -e 'fs=require("fs");fs.readFile(process.env.FNAME,"utf8",(err,data)=>{if(err!=null)throw err;fs.writeFile(process.env.FNAME,data.replaceAll(process.env.FIND,process.env.PG_HOST),"utf8",e=>{if(e!=null)throw e;});});'
cat postgres_scanner.test

- name: Ensure Python dependencies
run: |
pip install torch~=2.0.0 --extra-index-url https://download.pytorch.org/whl/cpu
Expand Down Expand Up @@ -179,6 +188,15 @@ jobs:
steps:
- uses: actions/checkout@v3

- name: Update PostgreSQL host
working-directory: extension/postgres_scanner/test/test_files
env:
FNAME: postgres_scanner.test
FIND: "localhost"
run: |
node -e 'fs=require("fs");fs.readFile(process.env.FNAME,"utf8",(err,data)=>{if(err!=null)throw err;fs.writeFile(process.env.FNAME,data.replaceAll(process.env.FIND,process.env.PG_HOST),"utf8",e=>{if(e!=null)throw e;});});'
cat postgres_scanner.test

- name: Ensure Python dependencies
run: |
pip install torch~=2.0.0 --extra-index-url https://download.pytorch.org/whl/cpu
Expand Down Expand Up @@ -225,10 +243,20 @@ jobs:
AWS_S3_SECRET_ACCESS_KEY: ${{ secrets.AWS_S3_SECRET_ACCESS_KEY }}
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_S3_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_S3_SECRET_ACCESS_KEY }}
PG_HOST: ${{ secrets.PG_HOST }}
RUN_ID: "$(hostname)-$([Math]::Floor((Get-Date).TimeOfDay.TotalSeconds))"
steps:
- uses: actions/checkout@v3

- name: Update PostgreSQL host
working-directory: extension/postgres_scanner/test/test_files
env:
FNAME: postgres_scanner.test
FIND: "localhost"
run: |
node -e 'fs=require("fs");fs.readFile(process.env.FNAME,"utf8",(err,data)=>{if(err!=null)throw err;fs.writeFile(process.env.FNAME,data.replaceAll(process.env.FIND,process.env.PG_HOST),"utf8",e=>{if(e!=null)throw e;});});'
cat postgres_scanner.test

- name: Ensure Python dependencies
run: |
pip install torch~=2.0.0 --extra-index-url https://download.pytorch.org/whl/cpu
Expand Down Expand Up @@ -430,10 +458,20 @@ jobs:
AWS_S3_SECRET_ACCESS_KEY: ${{ secrets.AWS_S3_SECRET_ACCESS_KEY }}
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_S3_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_S3_SECRET_ACCESS_KEY }}
PG_HOST: ${{ secrets.PG_HOST }}
RUN_ID: "$(hostname)-$(date +%s)"
steps:
- uses: actions/checkout@v3

- name: Update PostgreSQL host
working-directory: extension/postgres_scanner/test/test_files
env:
FNAME: postgres_scanner.test
FIND: "localhost"
run: |
node -e 'fs=require("fs");fs.readFile(process.env.FNAME,"utf8",(err,data)=>{if(err!=null)throw err;fs.writeFile(process.env.FNAME,data.replaceAll(process.env.FIND,process.env.PG_HOST),"utf8",e=>{if(e!=null)throw e;});});'
cat postgres_scanner.test

- name: Ensure Python dependencies
run: |
pip3 install torch~=2.0.0 --extra-index-url https://download.pytorch.org/whl/cpu
Expand Down
10 changes: 5 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ allconfig:
$(call config-cmake-release, \
-DBUILD_BENCHMARK=TRUE \
-DBUILD_EXAMPLES=TRUE \
-DBUILD_EXTENSIONS="httpfs;duckdb_scanner" \
-DBUILD_EXTENSIONS="httpfs;duckdb_scanner;postgres_scanner" \
-DBUILD_JAVA=TRUE \
-DBUILD_NODEJS=TRUE \
-DBUILD_PYTHON=TRUE \
Expand All @@ -79,7 +79,7 @@ alldebug:
$(call run-cmake-debug, \
-DBUILD_BENCHMARK=TRUE \
-DBUILD_EXAMPLES=TRUE \
-DBUILD_EXTENSIONS="httpfs;duckdb_scanner" \
-DBUILD_EXTENSIONS="httpfs;duckdb_scanner;postgres_scanner" \
-DBUILD_JAVA=TRUE \
-DBUILD_NODEJS=TRUE \
-DBUILD_PYTHON=TRUE \
Expand Down Expand Up @@ -156,21 +156,21 @@ example:

extension-test:
$(call run-cmake-release, \
-DBUILD_EXTENSIONS="httpfs;duckdb_scanner" \
-DBUILD_EXTENSIONS="httpfs;duckdb_scanner;postgres_scanner" \
-DBUILD_EXTENSION_TESTS=TRUE \
)
ctest --test-dir build/release/extension --output-on-failure -j ${TEST_JOBS}
aws s3 rm s3://kuzu-dataset-us/${RUN_ID}/ --recursive

extension-debug:
$(call run-cmake-debug, \
-DBUILD_EXTENSIONS="httpfs;duckdb_scanner" \
-DBUILD_EXTENSIONS="httpfs;duckdb_scanner;postgres_scanner" \
-DBUILD_KUZU=FALSE \
)

extension-release:
$(call run-cmake-release, \
-DBUILD_EXTENSIONS="httpfs;duckdb_scanner" \
-DBUILD_EXTENSIONS="httpfs;duckdb_scanner;postgres_scanner" \
-DBUILD_KUZU=FALSE \
)

Expand Down
4 changes: 4 additions & 0 deletions extension/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ if ("duckdb_scanner" IN_LIST BUILD_EXTENSIONS)
endif()
endif()

if ("postgres_scanner" IN_LIST BUILD_EXTENSIONS)
add_subdirectory(postgres_scanner)
endif()

if (${BUILD_EXTENSION_TESTS})
add_definitions(-DTEST_FILES_DIR="extension")
add_subdirectory(${CMAKE_SOURCE_DIR}/test/gtest ${CMAKE_CURRENT_BINARY_DIR}/test/gtest EXCLUDE_FROM_ALL)
Expand Down
6 changes: 0 additions & 6 deletions extension/duckdb_scanner/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,6 @@ add_library(duckdb_scanner
src/duckdb_catalog.cpp
src/duckdb_table_catalog_entry.cpp)

set_target_properties(duckdb_scanner PROPERTIES
OUTPUT_NAME duckdb_scanner
PREFIX "lib"
SUFFIX ".kuzu_extension"
)

set_target_properties(duckdb_scanner
PROPERTIES
ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/build"
Expand Down
113 changes: 106 additions & 7 deletions extension/duckdb_scanner/src/duckdb_catalog.cpp
Original file line number Diff line number Diff line change
@@ -1,28 +1,127 @@
#include "duckdb_catalog.h"

#include "common/exception/binder.h"
#include "duckdb_type_converter.h"

namespace kuzu {
namespace duckdb_scanner {

common::table_id_t DuckDBCatalogContent::createForeignTable(
const binder::BoundCreateTableInfo& info) {
auto tableID = assignNextTableID();
void DuckDBCatalogContent::init(
const std::string& dbPath, const std::string& catalogName, main::ClientContext* context) {
auto con = getConnection(dbPath);
auto query = common::stringFormat(
"select table_name from information_schema.tables where table_catalog = '{}' and "
"table_schema = '{}';",
catalogName, getDefaultSchemaName());
auto result = con.Query(query);
std::unique_ptr<duckdb::DataChunk> resultChunk;
try {
resultChunk = result->Fetch();
} catch (std::exception& e) { throw common::BinderException(e.what()); }
if (resultChunk->size() == 0) {
return;
}
common::ValueVector tableNamesVector{
*common::LogicalType::STRING(), context->getMemoryManager()};
duckdb_scanner::duckdb_conversion_func_t conversionFunc;
duckdb_scanner::getDuckDBVectorConversionFunc(common::PhysicalTypeID::STRING, conversionFunc);
conversionFunc(resultChunk->data[0], tableNamesVector, resultChunk->size());
for (auto i = 0u; i < resultChunk->size(); i++) {
auto tableName = tableNamesVector.getValue<common::ku_string_t>(i).getAsString();
createForeignTable(con, tableName, dbPath, catalogName);
}
}

static std::string getQuery(const binder::BoundCreateTableInfo& info) {
auto extraInfo = common::ku_dynamic_cast<binder::BoundExtraCreateCatalogEntryInfo*,
BoundExtraCreateDuckDBTableInfo*>(info.extraInfo.get());
return common::stringFormat(
"SELECT * FROM {}.{}.{}", extraInfo->catalogName, extraInfo->schemaName, info.tableName);
}

void DuckDBCatalogContent::createForeignTable(duckdb::Connection& con, const std::string& tableName,
const std::string& dbPath, const std::string& catalogName) {
auto tableID = assignNextTableID();
auto info = bindCreateTableInfo(con, tableName, dbPath, catalogName);
if (info == nullptr) {
return;
}
auto extraInfo = common::ku_dynamic_cast<binder::BoundExtraCreateCatalogEntryInfo*,
BoundExtraCreateDuckDBTableInfo*>(info->extraInfo.get());
std::vector<common::LogicalType> columnTypes;
std::vector<std::string> columnNames;
for (auto& propertyInfo : extraInfo->propertyInfos) {
columnNames.push_back(propertyInfo.name);
columnTypes.push_back(propertyInfo.type);
}
DuckDBScanBindData bindData(common::stringFormat("SELECT * FROM {}", info.tableName),
extraInfo->dbPath, std::move(columnTypes), std::move(columnNames));
DuckDBScanBindData bindData(getQuery(*info), std::move(columnTypes), std::move(columnNames),
std::bind(&DuckDBCatalogContent::getConnection, this, dbPath));
auto tableEntry = std::make_unique<catalog::DuckDBTableCatalogEntry>(
info.tableName, tableID, getScanFunction(std::move(bindData)));
info->tableName, tableID, getScanFunction(std::move(bindData)));
for (auto& propertyInfo : extraInfo->propertyInfos) {
tableEntry->addProperty(propertyInfo.name, propertyInfo.type.copy());
}
tables->createEntry(std::move(tableEntry));
return tableID;
}

static bool getTableInfo(duckdb::Connection& con, const std::string& tableName,
const std::string& schemaName, const std::string& catalogName,
std::vector<common::LogicalType>& columnTypes, std::vector<std::string>& columnNames) {
auto query =
common::stringFormat("select data_type,column_name from information_schema.columns where "
"table_name = '{}' and table_schema = '{}' and table_catalog = '{}';",
tableName, schemaName, catalogName);
auto result = con.Query(query);
if (result->RowCount() == 0) {
return false;
}
columnTypes.reserve(result->RowCount());
columnNames.reserve(result->RowCount());
for (auto i = 0u; i < result->RowCount(); i++) {
try {
columnTypes.push_back(DuckDBTypeConverter::convertDuckDBType(
result->GetValue(0, i).GetValue<std::string>()));
} catch (common::BinderException& e) { return false; }
columnNames.push_back(result->GetValue(1, i).GetValue<std::string>());
}
return true;
}

bool DuckDBCatalogContent::bindPropertyInfos(duckdb::Connection& con, const std::string& tableName,
const std::string& catalogName, std::vector<binder::PropertyInfo>& propertyInfos) {
std::vector<common::LogicalType> columnTypes;
std::vector<std::string> columnNames;
if (!getTableInfo(
con, tableName, getDefaultSchemaName(), catalogName, columnTypes, columnNames)) {
return false;
}
for (auto i = 0u; i < columnNames.size(); i++) {
auto propertyInfo = binder::PropertyInfo(columnNames[i], columnTypes[i]);
propertyInfos.push_back(std::move(propertyInfo));
}
return true;
}

std::unique_ptr<binder::BoundCreateTableInfo> DuckDBCatalogContent::bindCreateTableInfo(
duckdb::Connection& con, const std::string& tableName, const std::string& dbPath,
const std::string& catalogName) {
std::vector<binder::PropertyInfo> propertyInfos;
if (!bindPropertyInfos(con, tableName, catalogName, propertyInfos)) {
return nullptr;
}
return std::make_unique<binder::BoundCreateTableInfo>(common::TableType::FOREIGN, tableName,
std::make_unique<duckdb_scanner::BoundExtraCreateDuckDBTableInfo>(
dbPath, catalogName, getDefaultSchemaName(), std::move(propertyInfos)));
}

std::string DuckDBCatalogContent::getDefaultSchemaName() const {
return "main";
}

duckdb::Connection DuckDBCatalogContent::getConnection(const std::string& dbPath) const {
duckdb::DuckDB db(dbPath);
duckdb::Connection con(db);
return con;
}

} // namespace duckdb_scanner
Expand Down
17 changes: 10 additions & 7 deletions extension/duckdb_scanner/src/duckdb_scan.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "duckdb_scan.h"

#include "common/exception/binder.h"
#include "common/types/types.h"
#include "function/table/bind_input.h"

Expand All @@ -13,10 +12,11 @@ namespace duckdb_scanner {
void getDuckDBVectorConversionFunc(
PhysicalTypeID physicalTypeID, duckdb_conversion_func_t& conversion_func);

DuckDBScanBindData::DuckDBScanBindData(std::string query, std::string dbPath,
std::vector<common::LogicalType> columnTypes, std::vector<std::string> columnNames)
DuckDBScanBindData::DuckDBScanBindData(std::string query,
std::vector<common::LogicalType> columnTypes, std::vector<std::string> columnNames,
init_duckdb_conn_t initDuckDBConn)
: TableFuncBindData{std::move(columnTypes), std::move(columnNames)}, query{std::move(query)},
dbPath{std::move(dbPath)} {
initDuckDBConn{std::move(initDuckDBConn)} {
conversionFunctions.resize(this->columnTypes.size());
for (auto i = 0u; i < this->columnTypes.size(); i++) {
getDuckDBVectorConversionFunc(
Expand All @@ -25,7 +25,7 @@ DuckDBScanBindData::DuckDBScanBindData(std::string query, std::string dbPath,
}

std::unique_ptr<TableFuncBindData> DuckDBScanBindData::copy() const {
return std::make_unique<DuckDBScanBindData>(query, dbPath, columnTypes, columnNames);
return std::make_unique<DuckDBScanBindData>(query, columnTypes, columnNames, initDuckDBConn);
}

DuckDBScanSharedState::DuckDBScanSharedState(std::unique_ptr<duckdb::QueryResult> queryResult)
Expand All @@ -52,9 +52,12 @@ struct DuckDBScanFunction {
std::unique_ptr<function::TableFuncSharedState> DuckDBScanFunction::initSharedState(
function::TableFunctionInitInput& input) {
auto scanBindData = reinterpret_cast<DuckDBScanBindData*>(input.bindData);
auto db = duckdb::DuckDB(scanBindData->dbPath);
auto conn = duckdb::Connection(db);
auto conn = scanBindData->initDuckDBConn();
auto result = conn.SendQuery(scanBindData->query);
if (result->HasError()) {
throw common::RuntimeException(
common::stringFormat("Failed to execute query: {} in duckdb.", result->GetError()));
}
return std::make_unique<DuckDBScanSharedState>(std::move(result));
}

Expand Down
Loading
Loading