Skip to content

Commit

Permalink
Implement postgres-scanner (#3139)
Browse files Browse the repository at this point in the history
Implement postgres-scanner
  • Loading branch information
acquamarin committed Mar 26, 2024
1 parent 80b3e94 commit 3237e6f
Show file tree
Hide file tree
Showing 21 changed files with 746 additions and 108 deletions.
38 changes: 38 additions & 0 deletions .github/workflows/ci-workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ jobs:
steps:
- uses: actions/checkout@v3

- name: Update PostgreSQL host
working-directory: extension/postgres_scanner/test/test_files
env:
FNAME: postgres_scanner.test
FIND: "localhost"
run: |
node -e 'fs=require("fs");fs.readFile(process.env.FNAME,"utf8",(err,data)=>{if(err!=null)throw err;fs.writeFile(process.env.FNAME,data.replaceAll(process.env.FIND,process.env.PG_HOST),"utf8",e=>{if(e!=null)throw e;});});'
cat postgres_scanner.test
- name: Ensure Python dependencies
run: |
pip install torch~=2.0.0 --extra-index-url https://download.pytorch.org/whl/cpu
Expand Down Expand Up @@ -179,6 +188,15 @@ jobs:
steps:
- uses: actions/checkout@v3

- name: Update PostgreSQL host
working-directory: extension/postgres_scanner/test/test_files
env:
FNAME: postgres_scanner.test
FIND: "localhost"
run: |
node -e 'fs=require("fs");fs.readFile(process.env.FNAME,"utf8",(err,data)=>{if(err!=null)throw err;fs.writeFile(process.env.FNAME,data.replaceAll(process.env.FIND,process.env.PG_HOST),"utf8",e=>{if(e!=null)throw e;});});'
cat postgres_scanner.test
- name: Ensure Python dependencies
run: |
pip install torch~=2.0.0 --extra-index-url https://download.pytorch.org/whl/cpu
Expand Down Expand Up @@ -225,10 +243,20 @@ jobs:
AWS_S3_SECRET_ACCESS_KEY: ${{ secrets.AWS_S3_SECRET_ACCESS_KEY }}
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_S3_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_S3_SECRET_ACCESS_KEY }}
PG_HOST: ${{ secrets.PG_HOST }}
RUN_ID: "$(hostname)-$([Math]::Floor((Get-Date).TimeOfDay.TotalSeconds))"
steps:
- uses: actions/checkout@v3

- name: Update PostgreSQL host
working-directory: extension/postgres_scanner/test/test_files
env:
FNAME: postgres_scanner.test
FIND: "localhost"
run: |
node -e 'fs=require("fs");fs.readFile(process.env.FNAME,"utf8",(err,data)=>{if(err!=null)throw err;fs.writeFile(process.env.FNAME,data.replaceAll(process.env.FIND,process.env.PG_HOST),"utf8",e=>{if(e!=null)throw e;});});'
cat postgres_scanner.test
- name: Ensure Python dependencies
run: |
pip install torch~=2.0.0 --extra-index-url https://download.pytorch.org/whl/cpu
Expand Down Expand Up @@ -419,10 +447,20 @@ jobs:
AWS_S3_SECRET_ACCESS_KEY: ${{ secrets.AWS_S3_SECRET_ACCESS_KEY }}
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_S3_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_S3_SECRET_ACCESS_KEY }}
PG_HOST: ${{ secrets.PG_HOST }}
RUN_ID: "$(hostname)-$(date +%s)"
steps:
- uses: actions/checkout@v3

- name: Update PostgreSQL host
working-directory: extension/postgres_scanner/test/test_files
env:
FNAME: postgres_scanner.test
FIND: "localhost"
run: |
node -e 'fs=require("fs");fs.readFile(process.env.FNAME,"utf8",(err,data)=>{if(err!=null)throw err;fs.writeFile(process.env.FNAME,data.replaceAll(process.env.FIND,process.env.PG_HOST),"utf8",e=>{if(e!=null)throw e;});});'
cat postgres_scanner.test
- name: Ensure Python dependencies
run: |
pip3 install torch~=2.0.0 --extra-index-url https://download.pytorch.org/whl/cpu
Expand Down
10 changes: 5 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ allconfig:
$(call config-cmake-release, \
-DBUILD_BENCHMARK=TRUE \
-DBUILD_EXAMPLES=TRUE \
-DBUILD_EXTENSIONS="httpfs;duckdb_scanner" \
-DBUILD_EXTENSIONS="httpfs;duckdb_scanner;postgres_scanner" \
-DBUILD_JAVA=TRUE \
-DBUILD_NODEJS=TRUE \
-DBUILD_PYTHON=TRUE \
Expand All @@ -79,7 +79,7 @@ alldebug:
$(call run-cmake-debug, \
-DBUILD_BENCHMARK=TRUE \
-DBUILD_EXAMPLES=TRUE \
-DBUILD_EXTENSIONS="httpfs;duckdb_scanner" \
-DBUILD_EXTENSIONS="httpfs;duckdb_scanner;postgres_scanner" \
-DBUILD_JAVA=TRUE \
-DBUILD_NODEJS=TRUE \
-DBUILD_PYTHON=TRUE \
Expand Down Expand Up @@ -156,21 +156,21 @@ example:

extension-test:
$(call run-cmake-release, \
-DBUILD_EXTENSIONS="httpfs;duckdb_scanner" \
-DBUILD_EXTENSIONS="httpfs;duckdb_scanner;postgres_scanner" \
-DBUILD_EXTENSION_TESTS=TRUE \
)
ctest --test-dir build/release/extension --output-on-failure -j ${TEST_JOBS}
aws s3 rm s3://kuzu-dataset-us/${RUN_ID}/ --recursive

extension-debug:
$(call run-cmake-debug, \
-DBUILD_EXTENSIONS="httpfs;duckdb_scanner" \
-DBUILD_EXTENSIONS="httpfs;duckdb_scanner;postgres_scanner" \
-DBUILD_KUZU=FALSE \
)

extension-release:
$(call run-cmake-release, \
-DBUILD_EXTENSIONS="httpfs;duckdb_scanner" \
-DBUILD_EXTENSIONS="httpfs;duckdb_scanner;postgres_scanner" \
-DBUILD_KUZU=FALSE \
)

Expand Down
4 changes: 4 additions & 0 deletions extension/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ if ("duckdb_scanner" IN_LIST BUILD_EXTENSIONS)
endif()
endif()

if ("postgres_scanner" IN_LIST BUILD_EXTENSIONS)
add_subdirectory(postgres_scanner)
endif()

if (${BUILD_EXTENSION_TESTS})
add_definitions(-DTEST_FILES_DIR="extension")
add_subdirectory(${CMAKE_SOURCE_DIR}/test/gtest ${CMAKE_CURRENT_BINARY_DIR}/test/gtest EXCLUDE_FROM_ALL)
Expand Down
6 changes: 0 additions & 6 deletions extension/duckdb_scanner/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,6 @@ add_library(duckdb_scanner
src/duckdb_catalog.cpp
src/duckdb_table_catalog_entry.cpp)

set_target_properties(duckdb_scanner PROPERTIES
OUTPUT_NAME duckdb_scanner
PREFIX "lib"
SUFFIX ".kuzu_extension"
)

set_target_properties(duckdb_scanner
PROPERTIES
ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/build"
Expand Down
113 changes: 106 additions & 7 deletions extension/duckdb_scanner/src/duckdb_catalog.cpp
Original file line number Diff line number Diff line change
@@ -1,28 +1,127 @@
#include "duckdb_catalog.h"

#include "common/exception/binder.h"
#include "duckdb_type_converter.h"

namespace kuzu {
namespace duckdb_scanner {

common::table_id_t DuckDBCatalogContent::createForeignTable(
const binder::BoundCreateTableInfo& info) {
auto tableID = assignNextTableID();
void DuckDBCatalogContent::init(
const std::string& dbPath, const std::string& catalogName, main::ClientContext* context) {
auto con = getConnection(dbPath);
auto query = common::stringFormat(
"select table_name from information_schema.tables where table_catalog = '{}' and "
"table_schema = '{}';",
catalogName, getDefaultSchemaName());
auto result = con.Query(query);
std::unique_ptr<duckdb::DataChunk> resultChunk;
try {
resultChunk = result->Fetch();
} catch (std::exception& e) { throw common::BinderException(e.what()); }
if (resultChunk->size() == 0) {
return;
}
common::ValueVector tableNamesVector{
*common::LogicalType::STRING(), context->getMemoryManager()};
duckdb_scanner::duckdb_conversion_func_t conversionFunc;
duckdb_scanner::getDuckDBVectorConversionFunc(common::PhysicalTypeID::STRING, conversionFunc);
conversionFunc(resultChunk->data[0], tableNamesVector, resultChunk->size());
for (auto i = 0u; i < resultChunk->size(); i++) {
auto tableName = tableNamesVector.getValue<common::ku_string_t>(i).getAsString();
createForeignTable(con, tableName, dbPath, catalogName);
}
}

static std::string getQuery(const binder::BoundCreateTableInfo& info) {
auto extraInfo = common::ku_dynamic_cast<binder::BoundExtraCreateCatalogEntryInfo*,
BoundExtraCreateDuckDBTableInfo*>(info.extraInfo.get());
return common::stringFormat(
"SELECT * FROM {}.{}.{}", extraInfo->catalogName, extraInfo->schemaName, info.tableName);
}

void DuckDBCatalogContent::createForeignTable(duckdb::Connection& con, const std::string& tableName,
const std::string& dbPath, const std::string& catalogName) {
auto tableID = assignNextTableID();
auto info = bindCreateTableInfo(con, tableName, dbPath, catalogName);
if (info == nullptr) {
return;
}
auto extraInfo = common::ku_dynamic_cast<binder::BoundExtraCreateCatalogEntryInfo*,
BoundExtraCreateDuckDBTableInfo*>(info->extraInfo.get());
std::vector<common::LogicalType> columnTypes;
std::vector<std::string> columnNames;
for (auto& propertyInfo : extraInfo->propertyInfos) {
columnNames.push_back(propertyInfo.name);
columnTypes.push_back(propertyInfo.type);
}
DuckDBScanBindData bindData(common::stringFormat("SELECT * FROM {}", info.tableName),
extraInfo->dbPath, std::move(columnTypes), std::move(columnNames));
DuckDBScanBindData bindData(getQuery(*info), std::move(columnTypes), std::move(columnNames),
std::bind(&DuckDBCatalogContent::getConnection, this, dbPath));
auto tableEntry = std::make_unique<catalog::DuckDBTableCatalogEntry>(
info.tableName, tableID, getScanFunction(std::move(bindData)));
info->tableName, tableID, getScanFunction(std::move(bindData)));
for (auto& propertyInfo : extraInfo->propertyInfos) {
tableEntry->addProperty(propertyInfo.name, propertyInfo.type.copy());
}
tables->createEntry(std::move(tableEntry));
return tableID;
}

static bool getTableInfo(duckdb::Connection& con, const std::string& tableName,
const std::string& schemaName, const std::string& catalogName,
std::vector<common::LogicalType>& columnTypes, std::vector<std::string>& columnNames) {
auto query =
common::stringFormat("select data_type,column_name from information_schema.columns where "
"table_name = '{}' and table_schema = '{}' and table_catalog = '{}';",
tableName, schemaName, catalogName);
auto result = con.Query(query);
if (result->RowCount() == 0) {
return false;
}
columnTypes.reserve(result->RowCount());
columnNames.reserve(result->RowCount());
for (auto i = 0u; i < result->RowCount(); i++) {
try {
columnTypes.push_back(DuckDBTypeConverter::convertDuckDBType(
result->GetValue(0, i).GetValue<std::string>()));
} catch (common::BinderException& e) { return false; }
columnNames.push_back(result->GetValue(1, i).GetValue<std::string>());
}
return true;
}

bool DuckDBCatalogContent::bindPropertyInfos(duckdb::Connection& con, const std::string& tableName,
const std::string& catalogName, std::vector<binder::PropertyInfo>& propertyInfos) {
std::vector<common::LogicalType> columnTypes;
std::vector<std::string> columnNames;
if (!getTableInfo(
con, tableName, getDefaultSchemaName(), catalogName, columnTypes, columnNames)) {
return false;
}
for (auto i = 0u; i < columnNames.size(); i++) {
auto propertyInfo = binder::PropertyInfo(columnNames[i], columnTypes[i]);
propertyInfos.push_back(std::move(propertyInfo));
}
return true;
}

std::unique_ptr<binder::BoundCreateTableInfo> DuckDBCatalogContent::bindCreateTableInfo(
duckdb::Connection& con, const std::string& tableName, const std::string& dbPath,
const std::string& catalogName) {
std::vector<binder::PropertyInfo> propertyInfos;
if (!bindPropertyInfos(con, tableName, catalogName, propertyInfos)) {
return nullptr;
}
return std::make_unique<binder::BoundCreateTableInfo>(common::TableType::FOREIGN, tableName,
std::make_unique<duckdb_scanner::BoundExtraCreateDuckDBTableInfo>(
dbPath, catalogName, getDefaultSchemaName(), std::move(propertyInfos)));
}

std::string DuckDBCatalogContent::getDefaultSchemaName() const {
return "main";
}

duckdb::Connection DuckDBCatalogContent::getConnection(const std::string& dbPath) const {
duckdb::DuckDB db(dbPath);
duckdb::Connection con(db);
return con;
}

} // namespace duckdb_scanner
Expand Down
17 changes: 10 additions & 7 deletions extension/duckdb_scanner/src/duckdb_scan.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "duckdb_scan.h"

#include "common/exception/binder.h"
#include "common/types/types.h"
#include "function/table/bind_input.h"

Expand All @@ -13,10 +12,11 @@ namespace duckdb_scanner {
void getDuckDBVectorConversionFunc(
PhysicalTypeID physicalTypeID, duckdb_conversion_func_t& conversion_func);

DuckDBScanBindData::DuckDBScanBindData(std::string query, std::string dbPath,
std::vector<common::LogicalType> columnTypes, std::vector<std::string> columnNames)
DuckDBScanBindData::DuckDBScanBindData(std::string query,
std::vector<common::LogicalType> columnTypes, std::vector<std::string> columnNames,
init_duckdb_conn_t initDuckDBConn)
: TableFuncBindData{std::move(columnTypes), std::move(columnNames)}, query{std::move(query)},
dbPath{std::move(dbPath)} {
initDuckDBConn{std::move(initDuckDBConn)} {
conversionFunctions.resize(this->columnTypes.size());
for (auto i = 0u; i < this->columnTypes.size(); i++) {
getDuckDBVectorConversionFunc(
Expand All @@ -25,7 +25,7 @@ DuckDBScanBindData::DuckDBScanBindData(std::string query, std::string dbPath,
}

std::unique_ptr<TableFuncBindData> DuckDBScanBindData::copy() const {
return std::make_unique<DuckDBScanBindData>(query, dbPath, columnTypes, columnNames);
return std::make_unique<DuckDBScanBindData>(query, columnTypes, columnNames, initDuckDBConn);
}

DuckDBScanSharedState::DuckDBScanSharedState(std::unique_ptr<duckdb::QueryResult> queryResult)
Expand All @@ -52,9 +52,12 @@ struct DuckDBScanFunction {
std::unique_ptr<function::TableFuncSharedState> DuckDBScanFunction::initSharedState(
function::TableFunctionInitInput& input) {
auto scanBindData = reinterpret_cast<DuckDBScanBindData*>(input.bindData);
auto db = duckdb::DuckDB(scanBindData->dbPath);
auto conn = duckdb::Connection(db);
auto conn = scanBindData->initDuckDBConn();
auto result = conn.SendQuery(scanBindData->query);
if (result->HasError()) {
throw common::RuntimeException(
common::stringFormat("Failed to execute query: {} in duckdb.", result->GetError()));
}
return std::make_unique<DuckDBScanSharedState>(std::move(result));
}

Expand Down
Loading

0 comments on commit 3237e6f

Please sign in to comment.