From 0d1e36358eff3538b1359fd2948f2eade4de4f95 Mon Sep 17 00:00:00 2001 From: Guodong Jin Date: Tue, 11 Apr 2023 19:11:16 -0400 Subject: [PATCH] rework init reader --- src/include/catalog/catalog_structs.h | 3 +- src/include/storage/copier/node_copier.h | 68 ++- src/include/storage/copier/rel_copier.h | 6 +- src/include/storage/copier/table_copier.h | 31 +- src/storage/copier/node_copier.cpp | 90 ++-- src/storage/copier/rel_copier.cpp | 49 +- src/storage/copier/table_copier.cpp | 162 ++++--- test/copy/copy_fault_test.cpp | 21 +- test/copy/copy_test.cpp | 3 +- test/runner/e2e_copy_transaction_test.cpp | 3 +- tools/python_api/test/conftest.py | 23 +- tools/python_api/test/ground_truth.py | 285 ++++++++++++ tools/python_api/test/test_arrow.py | 226 ++------- tools/python_api/test/test_networkx.py | 5 + .../test/test_query_result_close.py | 2 +- tools/python_api/test/test_torch_geometric.py | 437 ++++-------------- 16 files changed, 652 insertions(+), 762 deletions(-) create mode 100644 tools/python_api/test/ground_truth.py diff --git a/src/include/catalog/catalog_structs.h b/src/include/catalog/catalog_structs.h index 51a84ae9c4..f0de97aac5 100644 --- a/src/include/catalog/catalog_structs.h +++ b/src/include/catalog/catalog_structs.h @@ -139,7 +139,8 @@ struct RelTableSchema : TableSchema { RelTableSchema() : TableSchema{"", common::INVALID_TABLE_ID, false /* isNodeTable */, {} /* properties */}, - relMultiplicity{MANY_MANY} {} + relMultiplicity{MANY_MANY}, srcTableID{common::INVALID_TABLE_ID}, + dstTableID{common::INVALID_TABLE_ID} {} RelTableSchema(std::string tableName, common::table_id_t tableID, RelMultiplicity relMultiplicity, std::vector properties, common::table_id_t srcTableID, common::table_id_t dstTableID) diff --git a/src/include/storage/copier/node_copier.h b/src/include/storage/copier/node_copier.h index d8a6df67a2..7856794997 100644 --- a/src/include/storage/copier/node_copier.h +++ b/src/include/storage/copier/node_copier.h @@ -8,8 +8,6 @@ namespace kuzu { namespace storage { -using lock_t = std::unique_lock; - using set_element_func_t = std::function; @@ -39,7 +37,7 @@ class CSVNodeCopyMorsel : public NodeCopyMorsel { public: CSVNodeCopyMorsel(std::shared_ptr recordBatch, common::offset_t startOffset, common::block_idx_t blockIdx) - : NodeCopyMorsel{startOffset, blockIdx}, recordBatch{recordBatch} {}; + : NodeCopyMorsel{startOffset, blockIdx}, recordBatch{std::move(recordBatch)} {}; const std::vector>& getArrowColumns() override { return recordBatch->columns(); @@ -54,7 +52,7 @@ class ParquetNodeCopyMorsel : public NodeCopyMorsel { public: ParquetNodeCopyMorsel(std::shared_ptr currTable, common::offset_t startOffset, common::block_idx_t blockIdx) - : NodeCopyMorsel{startOffset, blockIdx}, currTable{currTable} {}; + : NodeCopyMorsel{startOffset, blockIdx}, currTable{std::move(currTable)} {}; const std::vector>& getArrowColumns() override { return currTable->columns(); @@ -64,21 +62,21 @@ class ParquetNodeCopyMorsel : public NodeCopyMorsel { std::shared_ptr currTable; }; -template +template class NodeCopySharedState { public: NodeCopySharedState( - std::string filePath, HashIndexBuilder* pkIndex, common::offset_t startOffset) - : filePath{filePath}, pkIndex{pkIndex}, startOffset{startOffset}, blockIdx{0} {}; + std::string filePath, HashIndexBuilder* pkIndex, common::offset_t startOffset) + : filePath{std::move(filePath)}, pkIndex{pkIndex}, startOffset{startOffset}, blockIdx{0} {}; virtual ~NodeCopySharedState() = default; - virtual std::unique_ptr> getMorsel() = 0; + virtual std::unique_ptr> getMorsel() = 0; public: std::string filePath; - HashIndexBuilder* pkIndex; + HashIndexBuilder* pkIndex; common::offset_t startOffset; protected: @@ -86,29 +84,29 @@ class NodeCopySharedState { std::mutex mtx; }; -template -class CSVNodeCopySharedState : public NodeCopySharedState { +template +class CSVNodeCopySharedState : public NodeCopySharedState { public: - CSVNodeCopySharedState(std::string filePath, HashIndexBuilder* pkIndex, + CSVNodeCopySharedState(std::string filePath, HashIndexBuilder* pkIndex, common::offset_t startOffset, std::shared_ptr csvStreamingReader) - : NodeCopySharedState{filePath, pkIndex, startOffset}, - csvStreamingReader{move(csvStreamingReader)} {}; + : NodeCopySharedState{filePath, pkIndex, startOffset}, + csvStreamingReader{std::move(csvStreamingReader)} {}; std::unique_ptr> getMorsel() override; private: std::shared_ptr csvStreamingReader; }; -template -class ParquetNodeCopySharedState : public NodeCopySharedState { +template +class ParquetNodeCopySharedState : public NodeCopySharedState { public: - ParquetNodeCopySharedState(std::string filePath, HashIndexBuilder* pkIndex, + ParquetNodeCopySharedState(std::string filePath, HashIndexBuilder* pkIndex, common::offset_t startOffset, uint64_t numBlocks, std::unique_ptr parquetReader) - : NodeCopySharedState{filePath, pkIndex, startOffset}, + : NodeCopySharedState{filePath, pkIndex, startOffset}, numBlocks{numBlocks}, parquetReader{std::move(parquetReader)} {}; std::unique_ptr> getMorsel() override; @@ -135,39 +133,39 @@ class NodeCopier : public TableCopier { void saveToFile() override; - template + template static void populatePKIndex(InMemColumnChunk* chunk, InMemOverflowFile* overflowFile, - common::NullMask* nullMask, HashIndexBuilder* pkIndex, common::offset_t startOffset, - uint64_t numValues); + common::NullMask* nullMask, HashIndexBuilder* pkIndex, + common::offset_t startOffset, uint64_t numValues); std::unordered_map> columns; private: - template - arrow::Status populateColumns(processor::ExecutionContext* executionContext); + template + void populateColumns(processor::ExecutionContext* executionContext); - template - arrow::Status populateColumnsFromCSV(processor::ExecutionContext* executionContext, - std::unique_ptr>& pkIndex); + template + void populateColumnsFromCSV(processor::ExecutionContext* executionContext, + std::unique_ptr>& pkIndex); - template - arrow::Status populateColumnsFromParquet(processor::ExecutionContext* executionContext, - std::unique_ptr>& pkIndex); + template + void populateColumnsFromParquet(processor::ExecutionContext* executionContext, + std::unique_ptr>& pkIndex); - template + template static void putPropsOfLinesIntoColumns(InMemColumnChunk* columnChunk, NodeInMemColumn* column, - std::shared_ptr arrowArray, common::offset_t startNodeOffset, + std::shared_ptr arrowArray, common::offset_t startNodeOffset, uint64_t numLinesInCurBlock, common::CopyDescription& copyDescription, PageByteCursor& overflowCursor); // Concurrent tasks. - template - static void batchPopulateColumnsTask(NodeCopySharedState* sharedState, + template + static void batchPopulateColumnsTask(NodeCopySharedState* sharedState, NodeCopier* copier, processor::ExecutionContext* executionContext); - template + template static void appendPKIndex(InMemColumnChunk* chunk, InMemOverflowFile* overflowFile, - common::offset_t offset, HashIndexBuilder* pkIndex) { + common::offset_t offset, HashIndexBuilder* pkIndex) { assert(false); } diff --git a/src/include/storage/copier/rel_copier.h b/src/include/storage/copier/rel_copier.h index f8de2a75af..78041a044f 100644 --- a/src/include/storage/copier/rel_copier.h +++ b/src/include/storage/copier/rel_copier.h @@ -39,11 +39,11 @@ class RelCopier : public TableCopier { void initializePkIndexes(common::table_id_t nodeTableID, BufferManager& bufferManager); - arrow::Status executePopulateTask(PopulateTaskType populateTaskType); + void executePopulateTask(PopulateTaskType populateTaskType); - arrow::Status populateFromCSV(PopulateTaskType populateTaskType); + void populateFromCSV(PopulateTaskType populateTaskType); - arrow::Status populateFromParquet(PopulateTaskType populateTaskType); + void populateFromParquet(PopulateTaskType populateTaskType); void populateAdjColumnsAndCountRelsInAdjLists(); diff --git a/src/include/storage/copier/table_copier.h b/src/include/storage/copier/table_copier.h index b7514c310c..ff199b07aa 100644 --- a/src/include/storage/copier/table_copier.h +++ b/src/include/storage/copier/table_copier.h @@ -14,7 +14,6 @@ #include #include #include -#include #include #include #include @@ -52,28 +51,15 @@ class TableCopier { virtual void populateInMemoryStructures(processor::ExecutionContext* executionContext); - inline void updateTableStatistics() { - tablesStatistics->setNumTuplesForTable(tableSchema->tableID, numRows); - } - - void countNumLines(const std::vector& filePath); - - arrow::Status countNumLinesCSV(const std::vector& filePaths); + void countNumLines(const std::vector& filePaths); - arrow::Status countNumLinesParquet(const std::vector& filePaths); + void countNumLinesCSV(const std::vector& filePaths); - arrow::Status initCSVReaderAndCheckStatus( - std::shared_ptr& csv_streaming_reader, - const std::string& filePath); + void countNumLinesParquet(const std::vector& filePaths); - arrow::Status initCSVReader(std::shared_ptr& csv_streaming_reader, - const std::string& filePath); + std::shared_ptr initCSVReader(const std::string& filePath) const; - arrow::Status initParquetReaderAndCheckStatus( - std::unique_ptr& reader, const std::string& filePath); - - arrow::Status initParquetReader( - std::unique_ptr& reader, const std::string& filePath); + std::unique_ptr initParquetReader(const std::string& filePath); static std::vector> getListElementPos( const std::string& l, int64_t from, int64_t to, common::CopyDescription& copyDescription); @@ -86,7 +72,10 @@ class TableCopier { static void throwCopyExceptionIfNotOK(const arrow::Status& status); - uint64_t getNumBlocks() const { + inline void updateTableStatistics() { + tablesStatistics->setNumTuplesForTable(tableSchema->tableID, numRows); + } + inline uint64_t getNumBlocks() const { uint64_t numBlocks = 0; for (auto& [_, info] : fileBlockInfos) { numBlocks += info.numBlocks; @@ -94,6 +83,8 @@ class TableCopier { return numBlocks; } + static std::shared_ptr toArrowDataType(const common::DataType& dataType); + protected: std::shared_ptr logger; common::CopyDescription& copyDescription; diff --git a/src/storage/copier/node_copier.cpp b/src/storage/copier/node_copier.cpp index 917f495ba9..d31799bb56 100644 --- a/src/storage/copier/node_copier.cpp +++ b/src/storage/copier/node_copier.cpp @@ -10,8 +10,8 @@ using namespace kuzu::common; namespace kuzu { namespace storage { -template -std::unique_ptr> CSVNodeCopySharedState::getMorsel() { +template +std::unique_ptr> CSVNodeCopySharedState::getMorsel() { lock_t lck{this->mtx}; std::shared_ptr recordBatch; auto result = csvStreamingReader->ReadNext(&recordBatch); @@ -20,22 +20,23 @@ std::unique_ptr> CSVNodeCopySharedState::getMors "Error reading a batch of rows from CSV using Arrow CSVStreamingReader."); } if (recordBatch == NULL) { - return make_unique(move(recordBatch), INVALID_NODE_OFFSET, + return make_unique(std::move(recordBatch), INVALID_NODE_OFFSET, NodeCopyMorsel::INVALID_BLOCK_IDX); } auto numRows = recordBatch->num_rows(); this->startOffset += numRows; this->blockIdx++; return make_unique( - move(recordBatch), this->startOffset - numRows, this->blockIdx - 1); + std::move(recordBatch), this->startOffset - numRows, this->blockIdx - 1); } -template -std::unique_ptr> ParquetNodeCopySharedState::getMorsel() { +template +std::unique_ptr> +ParquetNodeCopySharedState::getMorsel() { lock_t lck{this->mtx}; std::shared_ptr currTable; if (this->blockIdx == numBlocks) { - return make_unique(move(currTable), INVALID_NODE_OFFSET, + return make_unique(std::move(currTable), INVALID_NODE_OFFSET, NodeCopyMorsel::INVALID_BLOCK_IDX); } auto result = parquetReader->RowGroup(this->blockIdx)->ReadTable(&currTable); @@ -43,17 +44,15 @@ std::unique_ptr> ParquetNodeCopySharedState< throw common::CopyException( "Error reading a batch of rows from CSV using Arrow CSVStreamingReader."); } - // TODO(Semih): I have not verified that, similar to CSV reading, that if ReadTable runs out of - // blocks to read, then it sets the currTable to NULL. if (currTable == NULL) { - return make_unique(move(currTable), INVALID_NODE_OFFSET, + return make_unique(std::move(currTable), INVALID_NODE_OFFSET, NodeCopyMorsel::INVALID_BLOCK_IDX); } auto numRows = currTable->num_rows(); this->startOffset += numRows; this->blockIdx++; return make_unique( - move(currTable), this->startOffset - numRows, this->blockIdx - 1); + std::move(currTable), this->startOffset - numRows, this->blockIdx - 1); } void NodeCopier::initializeColumnsAndLists() { @@ -68,21 +67,19 @@ void NodeCopier::initializeColumnsAndLists() { } void NodeCopier::populateColumnsAndLists(processor::ExecutionContext* executionContext) { - arrow::Status status; auto primaryKey = reinterpret_cast(tableSchema)->getPrimaryKey(); switch (primaryKey.dataType.typeID) { case INT64: { - status = populateColumns(executionContext); + populateColumns(executionContext); } break; case STRING: { - status = populateColumns(executionContext); + populateColumns(executionContext); } break; default: { throw CopyException(StringUtils::string_format("Unsupported data type {} for the ID index.", Types::dataTypeToString(primaryKey.dataType))); } } - throwCopyExceptionIfNotOK(status); } void NodeCopier::saveToFile() { @@ -96,21 +93,20 @@ void NodeCopier::saveToFile() { logger->debug("Done writing node columns to disk."); } -template -arrow::Status NodeCopier::populateColumns(processor::ExecutionContext* executionContext) { +template +void NodeCopier::populateColumns(processor::ExecutionContext* executionContext) { logger->info("Populating properties"); - auto pkIndex = - std::make_unique>(StorageUtils::getNodeIndexFName(this->outputDirectory, - tableSchema->tableID, DBFileType::WAL_VERSION), - reinterpret_cast(tableSchema)->getPrimaryKey().dataType); + auto pkIndex = std::make_unique>( + StorageUtils::getNodeIndexFName( + this->outputDirectory, tableSchema->tableID, DBFileType::WAL_VERSION), + reinterpret_cast(tableSchema)->getPrimaryKey().dataType); pkIndex->bulkReserve(numRows); - arrow::Status status; switch (copyDescription.fileType) { case CopyDescription::FileType::CSV: - status = populateColumnsFromCSV(executionContext, pkIndex); + populateColumnsFromCSV(executionContext, pkIndex); break; case CopyDescription::FileType::PARQUET: - status = populateColumnsFromParquet(executionContext, pkIndex); + populateColumnsFromParquet(executionContext, pkIndex); break; default: { throw CopyException(StringUtils::string_format("Unsupported file type {}.", @@ -120,48 +116,42 @@ arrow::Status NodeCopier::populateColumns(processor::ExecutionContext* execution logger->info("Flush the pk index to disk."); pkIndex->flush(); logger->info("Done populating properties, constructing the pk index."); - return status; } -template -arrow::Status NodeCopier::populateColumnsFromCSV( - processor::ExecutionContext* executionContext, std::unique_ptr>& pkIndex) { +template +void NodeCopier::populateColumnsFromCSV(processor::ExecutionContext* executionContext, + std::unique_ptr>& pkIndex) { for (auto& filePath : copyDescription.filePaths) { - std::shared_ptr csvStreamingReader; - auto status = initCSVReaderAndCheckStatus(csvStreamingReader, filePath); - throwCopyExceptionIfNotOK(status); + std::shared_ptr csvStreamingReader = initCSVReader(filePath); CSVNodeCopySharedState sharedState{ filePath, pkIndex.get(), fileBlockInfos.at(filePath).startOffset, csvStreamingReader}; taskScheduler.scheduleTaskAndWaitOrError( CopyTaskFactory::createParallelCopyTask(executionContext->numThreads, - batchPopulateColumnsTask, &sharedState, this, executionContext), + batchPopulateColumnsTask, &sharedState, this, + executionContext), executionContext); } - return arrow::Status::OK(); } -template -arrow::Status NodeCopier::populateColumnsFromParquet( - processor::ExecutionContext* executionContext, std::unique_ptr>& pkIndex) { +template +void NodeCopier::populateColumnsFromParquet(processor::ExecutionContext* executionContext, + std::unique_ptr>& pkIndex) { for (auto& filePath : copyDescription.filePaths) { - std::unique_ptr parquetReader; - auto status = initParquetReaderAndCheckStatus(parquetReader, filePath); - throwCopyExceptionIfNotOK(status); + std::unique_ptr parquetReader = initParquetReader(filePath); ParquetNodeCopySharedState sharedState{filePath, pkIndex.get(), fileBlockInfos.at(filePath).startOffset, fileBlockInfos.at(filePath).numBlocks, std::move(parquetReader)}; taskScheduler.scheduleTaskAndWaitOrError( CopyTaskFactory::createParallelCopyTask(executionContext->numThreads, - batchPopulateColumnsTask, &sharedState, this, + batchPopulateColumnsTask, &sharedState, this, executionContext), executionContext); } - return arrow::Status::OK(); } -template +template void NodeCopier::populatePKIndex(InMemColumnChunk* chunk, InMemOverflowFile* overflowFile, - common::NullMask* nullMask, HashIndexBuilder* pkIndex, offset_t startOffset, + common::NullMask* nullMask, HashIndexBuilder* pkIndex, offset_t startOffset, uint64_t numValues) { for (auto i = 0u; i < numValues; i++) { auto offset = i + startOffset; @@ -172,8 +162,8 @@ void NodeCopier::populatePKIndex(InMemColumnChunk* chunk, InMemOverflowFile* ove } } -template -void NodeCopier::batchPopulateColumnsTask(NodeCopySharedState* sharedState, +template +void NodeCopier::batchPopulateColumnsTask(NodeCopySharedState* sharedState, NodeCopier* copier, processor::ExecutionContext* executionContext) { while (true) { if (executionContext->clientContext->isInterrupted()) { @@ -183,8 +173,6 @@ void NodeCopier::batchPopulateColumnsTask(NodeCopySharedState* sharedSta if (!result->success()) { break; } - copier->logger->trace( - "Start: path={0} blkIdx={1}", sharedState->filePath, result->blockIdx); auto numLinesInCurBlock = copier->fileBlockInfos.at(sharedState->filePath).numLinesPerBlock[result->blockIdx]; // Create a column chunk for tuples within the [StartOffset, endOffset] range. @@ -204,20 +192,18 @@ void NodeCopier::batchPopulateColumnsTask(NodeCopySharedState* sharedSta for (auto& [propertyIdx, column] : copier->columns) { column->flushChunk(chunks[propertyIdx].get(), result->startOffset, endOffset); } - auto primaryKeyPropertyIdx = reinterpret_cast(copier->tableSchema)->primaryKeyPropertyID; auto pkColumn = copier->columns.at(primaryKeyPropertyIdx).get(); populatePKIndex(chunks[primaryKeyPropertyIdx].get(), pkColumn->getInMemOverflowFile(), pkColumn->getNullMask(), sharedState->pkIndex, result->startOffset, numLinesInCurBlock); - copier->logger->info("End: path={0} blkIdx={1}", sharedState->filePath, result->blockIdx); } } -template +template void NodeCopier::putPropsOfLinesIntoColumns(InMemColumnChunk* columnChunk, NodeInMemColumn* column, - std::shared_ptr arrowArray, common::offset_t startNodeOffset, uint64_t numLinesInCurBlock, - CopyDescription& copyDescription, PageByteCursor& overflowCursor) { + std::shared_ptr arrowArray, common::offset_t startNodeOffset, + uint64_t numLinesInCurBlock, CopyDescription& copyDescription, PageByteCursor& overflowCursor) { auto setElementFunc = getSetElementFunc(column->getDataType().typeID, copyDescription, overflowCursor); for (auto i = 0u; i < numLinesInCurBlock; i++) { diff --git a/src/storage/copier/rel_copier.cpp b/src/storage/copier/rel_copier.cpp index 97b5cb3cf8..c4d60186a4 100644 --- a/src/storage/copier/rel_copier.cpp +++ b/src/storage/copier/rel_copier.cpp @@ -178,20 +178,22 @@ void RelCopier::initializePkIndexes(table_id_t nodeTableID, BufferManager& buffe pkIndexes.emplace(nodeTableID, nodesStore.getPKIndex(nodeTableID)); } -arrow::Status RelCopier::executePopulateTask(PopulateTaskType populateTaskType) { - arrow::Status status; +void RelCopier::executePopulateTask(PopulateTaskType populateTaskType) { switch (copyDescription.fileType) { case CopyDescription::FileType::CSV: { - status = populateFromCSV(populateTaskType); + populateFromCSV(populateTaskType); } break; case CopyDescription::FileType::PARQUET: { - status = populateFromParquet(populateTaskType); + populateFromParquet(populateTaskType); } break; + default: { + throw CopyException(StringUtils::string_format("Unsupported file type {}.", + CopyDescription::getFileTypeName(copyDescription.fileType))); + } } - return status; } -arrow::Status RelCopier::populateFromCSV(PopulateTaskType populateTaskType) { +void RelCopier::populateFromCSV(PopulateTaskType populateTaskType) { auto populateTask = populateAdjColumnsAndCountRelsInAdjListsTask; if (populateTaskType == PopulateTaskType::populateListsTask) { populateTask = populateListsTask; @@ -200,33 +202,33 @@ arrow::Status RelCopier::populateFromCSV(PopulateTaskType populateTaskType) { for (auto& filePath : copyDescription.filePaths) { offset_t startOffset = fileBlockInfos.at(filePath).startOffset; - std::shared_ptr csv_streaming_reader; - auto status = initCSVReaderAndCheckStatus(csv_streaming_reader, filePath); + auto reader = initCSVReader(filePath); std::shared_ptr currBatch; int blockIdx = 0; - auto it = csv_streaming_reader->begin(); - auto endIt = csv_streaming_reader->end(); - while (it != endIt) { - for (int i = 0; i < CopyConstants::NUM_COPIER_TASKS_TO_SCHEDULE_PER_BATCH; ++i) { - if (it == endIt) { + while (true) { + for (auto i = 0u; i < CopyConstants::NUM_COPIER_TASKS_TO_SCHEDULE_PER_BATCH; i++) { + throwCopyExceptionIfNotOK(reader->ReadNext(&currBatch)); + if (currBatch == nullptr) { + // No more batches left, thus, no more tasks to be scheduled. break; } - ARROW_ASSIGN_OR_RAISE(currBatch, *it); taskScheduler.scheduleTask(CopyTaskFactory::createCopyTask( populateTask, blockIdx, startOffset, this, currBatch->columns(), filePath)); startOffset += currBatch->num_rows(); ++blockIdx; - ++it; + } + if (currBatch == nullptr) { + // No more batches left, thus, no more tasks to be scheduled. + break; } taskScheduler.waitUntilEnoughTasksFinish( CopyConstants::MINIMUM_NUM_COPIER_TASKS_TO_SCHEDULE_MORE); } taskScheduler.waitAllTasksToCompleteOrError(); } - return arrow::Status::OK(); } -arrow::Status RelCopier::populateFromParquet(PopulateTaskType populateTaskType) { +void RelCopier::populateFromParquet(PopulateTaskType populateTaskType) { auto populateTask = populateAdjColumnsAndCountRelsInAdjListsTask; if (populateTaskType == PopulateTaskType::populateListsTask) { populateTask = populateListsTask; @@ -234,8 +236,7 @@ arrow::Status RelCopier::populateFromParquet(PopulateTaskType populateTaskType) logger->debug("Assigning task {0}", getTaskTypeName(populateTaskType)); for (auto& filePath : copyDescription.filePaths) { - std::unique_ptr reader; - auto status = initParquetReaderAndCheckStatus(reader, filePath); + auto reader = initParquetReader(filePath); std::shared_ptr currTable; int blockIdx = 0; offset_t startOffset = 0; @@ -245,7 +246,7 @@ arrow::Status RelCopier::populateFromParquet(PopulateTaskType populateTaskType) if (blockIdx == numBlocks) { break; } - ARROW_RETURN_NOT_OK(reader->RowGroup(blockIdx)->ReadTable(&currTable)); + throwCopyExceptionIfNotOK(reader->RowGroup(blockIdx)->ReadTable(&currTable)); taskScheduler.scheduleTask(CopyTaskFactory::createCopyTask( populateTask, blockIdx, startOffset, this, currTable->columns(), filePath)); startOffset += currTable->num_rows(); @@ -257,23 +258,19 @@ arrow::Status RelCopier::populateFromParquet(PopulateTaskType populateTaskType) taskScheduler.waitAllTasksToCompleteOrError(); } - return arrow::Status::OK(); } void RelCopier::populateAdjColumnsAndCountRelsInAdjLists() { logger->info( "Populating adj columns and rel property columns for rel {}.", tableSchema->tableName); - auto status = - executePopulateTask(PopulateTaskType::populateAdjColumnsAndCountRelsInAdjListsTask); - throwCopyExceptionIfNotOK(status); + executePopulateTask(PopulateTaskType::populateAdjColumnsAndCountRelsInAdjListsTask); logger->info( "Done populating adj columns and rel property columns for rel {}.", tableSchema->tableName); } void RelCopier::populateLists() { logger->debug("Populating adjLists and rel property lists for rel {}.", tableSchema->tableName); - auto status = executePopulateTask(PopulateTaskType::populateListsTask); - throwCopyExceptionIfNotOK(status); + executePopulateTask(PopulateTaskType::populateListsTask); logger->debug( "Done populating adjLists and rel property lists for rel {}.", tableSchema->tableName); } diff --git a/src/storage/copier/table_copier.cpp b/src/storage/copier/table_copier.cpp index d88f93607a..faf6f7a957 100644 --- a/src/storage/copier/table_copier.cpp +++ b/src/storage/copier/table_copier.cpp @@ -36,36 +36,33 @@ void TableCopier::populateInMemoryStructures(processor::ExecutionContext* execut } void TableCopier::countNumLines(const std::vector& filePaths) { - arrow::Status status; switch (copyDescription.fileType) { case CopyDescription::FileType::CSV: { - status = countNumLinesCSV(filePaths); + countNumLinesCSV(filePaths); } break; case CopyDescription::FileType::PARQUET: { - status = countNumLinesParquet(filePaths); + countNumLinesParquet(filePaths); } break; default: { throw CopyException{StringUtils::string_format("Unrecognized file type: {}.", CopyDescription::getFileTypeName(copyDescription.fileType))}; } } - throwCopyExceptionIfNotOK(status); } -arrow::Status TableCopier::countNumLinesCSV(const std::vector& filePaths) { +void TableCopier::countNumLinesCSV(const std::vector& filePaths) { numRows = 0; - arrow::Status status; for (auto& filePath : filePaths) { - std::shared_ptr csv_streaming_reader; - status = initCSVReaderAndCheckStatus(csv_streaming_reader, filePath); - throwCopyExceptionIfNotOK(status); + auto csvStreamingReader = initCSVReader(filePath); std::shared_ptr currBatch; uint64_t numBlocks = 0; std::vector numLinesPerBlock; - auto endIt = csv_streaming_reader->end(); auto startNodeOffset = numRows; - for (auto it = csv_streaming_reader->begin(); it != endIt; ++it) { - ARROW_ASSIGN_OR_RAISE(currBatch, *it); + while (true) { + throwCopyExceptionIfNotOK(csvStreamingReader->ReadNext(&currBatch)); + if (currBatch == NULL) { + break; + } ++numBlocks; auto currNumRows = currBatch->num_rows(); numLinesPerBlock.push_back(currNumRows); @@ -74,84 +71,82 @@ arrow::Status TableCopier::countNumLinesCSV(const std::vector& file fileBlockInfos.emplace( filePath, FileBlockInfo{startNodeOffset, numBlocks, numLinesPerBlock}); } - return status; } -arrow::Status TableCopier::countNumLinesParquet(const std::vector& filePaths) { +void TableCopier::countNumLinesParquet(const std::vector& filePaths) { numRows = 0; - arrow::Status status; for (auto& filePath : filePaths) { - std::unique_ptr reader; - status = initParquetReaderAndCheckStatus(reader, filePath); - throwCopyExceptionIfNotOK(status); + std::unique_ptr reader = initParquetReader(filePath); uint64_t numBlocks = reader->num_row_groups(); std::vector numLinesPerBlock; std::shared_ptr table; auto startNodeOffset = numRows; for (auto blockIdx = 0; blockIdx < numBlocks; ++blockIdx) { - ARROW_RETURN_NOT_OK(reader->RowGroup(blockIdx)->ReadTable(&table)); + throwCopyExceptionIfNotOK(reader->RowGroup(blockIdx)->ReadTable(&table)); numLinesPerBlock.push_back(table->num_rows()); numRows += table->num_rows(); } fileBlockInfos.emplace( filePath, FileBlockInfo{startNodeOffset, numBlocks, numLinesPerBlock}); } - return status; } -arrow::Status TableCopier::initCSVReaderAndCheckStatus( - std::shared_ptr& csv_streaming_reader, - const std::string& filePath) { - auto status = initCSVReader(csv_streaming_reader, filePath); - throwCopyExceptionIfNotOK(status); - return status; -} - -arrow::Status TableCopier::initCSVReader( - std::shared_ptr& csv_streaming_reader, - const std::string& filePath) { - std::shared_ptr arrow_input_stream; - ARROW_ASSIGN_OR_RAISE(arrow_input_stream, arrow::io::ReadableFile::Open(filePath)); - auto arrowRead = arrow::csv::ReadOptions::Defaults(); - arrowRead.block_size = CopyConstants::CSV_READING_BLOCK_SIZE; - if (!copyDescription.csvReaderConfig->hasHeader) { - arrowRead.autogenerate_column_names = true; +std::shared_ptr TableCopier::initCSVReader( + const std::string& filePath) const { + std::shared_ptr inputStream; + throwCopyExceptionIfNotOK(arrow::io::ReadableFile::Open(filePath).Value(&inputStream)); + auto csvReadOptions = arrow::csv::ReadOptions::Defaults(); + csvReadOptions.block_size = CopyConstants::CSV_READING_BLOCK_SIZE; + if (!tableSchema->isNodeTable) { + csvReadOptions.column_names.emplace_back("_FROM"); + csvReadOptions.column_names.emplace_back("_TO"); + } + for (auto& property : tableSchema->properties) { + if (!TableSchema::isReservedPropertyName(property.name)) { + csvReadOptions.column_names.push_back(property.name); + } + } + if (copyDescription.csvReaderConfig->hasHeader) { + csvReadOptions.skip_rows = 1; } - auto arrowConvert = arrow::csv::ConvertOptions::Defaults(); - arrowConvert.strings_can_be_null = true; - // Only the empty string is treated as NULL. - arrowConvert.null_values = {""}; - arrowConvert.quoted_strings_can_be_null = false; - - auto arrowParse = arrow::csv::ParseOptions::Defaults(); - arrowParse.delimiter = copyDescription.csvReaderConfig->delimiter; - arrowParse.escape_char = copyDescription.csvReaderConfig->escapeChar; - arrowParse.quote_char = copyDescription.csvReaderConfig->quoteChar; - arrowParse.ignore_empty_lines = false; - arrowParse.escaping = true; + auto csvParseOptions = arrow::csv::ParseOptions::Defaults(); + csvParseOptions.delimiter = copyDescription.csvReaderConfig->delimiter; + csvParseOptions.escape_char = copyDescription.csvReaderConfig->escapeChar; + csvParseOptions.quote_char = copyDescription.csvReaderConfig->quoteChar; + csvParseOptions.ignore_empty_lines = false; + csvParseOptions.escaping = true; - ARROW_ASSIGN_OR_RAISE( - csv_streaming_reader, arrow::csv::StreamingReader::Make(arrow::io::default_io_context(), - arrow_input_stream, arrowRead, arrowParse, arrowConvert)); - return arrow::Status::OK(); -} + auto csvConvertOptions = arrow::csv::ConvertOptions::Defaults(); + csvConvertOptions.strings_can_be_null = true; + // Only the empty string is treated as NULL. + csvConvertOptions.null_values = {""}; + csvConvertOptions.quoted_strings_can_be_null = false; + for (auto& property : tableSchema->properties) { + if (property.name == "_FROM" || property.name == "_TO") { + csvConvertOptions.column_types[property.name] = arrow::int64(); + continue; + } + if (!TableSchema::isReservedPropertyName(property.name)) { + csvConvertOptions.column_types[property.name] = toArrowDataType(property.dataType); + } + } -arrow::Status TableCopier::initParquetReaderAndCheckStatus( - std::unique_ptr& reader, const std::string& filePath) { - auto status = initParquetReader(reader, filePath); - throwCopyExceptionIfNotOK(status); - return status; + std::shared_ptr csvStreamingReader; + throwCopyExceptionIfNotOK(arrow::csv::StreamingReader::Make(arrow::io::default_io_context(), + inputStream, csvReadOptions, csvParseOptions, csvConvertOptions) + .Value(&csvStreamingReader)); + return csvStreamingReader; } -arrow::Status TableCopier::initParquetReader( - std::unique_ptr& reader, const std::string& filePath) { +std::unique_ptr TableCopier::initParquetReader( + const std::string& filePath) { std::shared_ptr infile; - ARROW_ASSIGN_OR_RAISE( - infile, arrow::io::ReadableFile::Open(filePath, arrow::default_memory_pool())); - - ARROW_RETURN_NOT_OK(parquet::arrow::OpenFile(infile, arrow::default_memory_pool(), &reader)); - return arrow::Status::OK(); + throwCopyExceptionIfNotOK(arrow::io::ReadableFile::Open(filePath).Value(&infile)); + std::unique_ptr reader; + throwCopyExceptionIfNotOK( + parquet::arrow::OpenFile(infile, arrow::default_memory_pool(), &reader)); + return reader; } std::vector> TableCopier::getListElementPos( @@ -294,5 +289,40 @@ void TableCopier::throwCopyExceptionIfNotOK(const arrow::Status& status) { } } +std::shared_ptr TableCopier::toArrowDataType(const common::DataType& dataType) { + switch (dataType.typeID) { + case common::BOOL: { + return arrow::boolean(); + } + case common::INT64: { + return arrow::int64(); + } + case common::INT32: { + return arrow::int32(); + } + case common::INT16: { + return arrow::int16(); + } + case common::DOUBLE: { + return arrow::float64(); + } + case common::FLOAT: { + return arrow::float32(); + } + case common::TIMESTAMP: + case common::DATE: + case common::FIXED_LIST: + case common::VAR_LIST: + case common::STRING: + case common::INTERVAL: { + return arrow::utf8(); + } + default: { + throw CopyException( + "Unsupported data type for CSV " + Types::dataTypeToString(dataType.typeID)); + } + } +} + } // namespace storage } // namespace kuzu diff --git a/test/copy/copy_fault_test.cpp b/test/copy/copy_fault_test.cpp index 9f20dc4fe6..05bb9992cb 100644 --- a/test/copy/copy_fault_test.cpp +++ b/test/copy/copy_fault_test.cpp @@ -56,17 +56,16 @@ TEST_F(CopyDuplicateIDTest, DuplicateIDsError) { } TEST_F(CopyNodeUnmatchedColumnTypeTest, UnMatchedColumnTypeError) { - conn->query( - "create node table person (ID INT64, fName INT64, gender INT64, isStudent BOOLEAN, " - "isWorker BOOLEAN, " - "age INT64, eyeSight DOUBLE, birthdate DATE, registerTime TIMESTAMP, lastJobDuration " - "INTERVAL, workedHours INT64[], usedNames STRING[], courseScoresPerTerm INT64[][], " - "PRIMARY " - "KEY (fName))"); + conn->query("create node table person (ID INT64, fName INT64, gender INT64, isStudent BOOLEAN, " + "isWorker BOOLEAN, age INT64, eyeSight DOUBLE, birthdate DATE, registerTime " + "TIMESTAMP, lastJobDuration " + "INTERVAL, workedHours INT64[], usedNames STRING[], courseScoresPerTerm INT64[][], " + "grades INT64[4], height float, PRIMARY KEY (fName))"); auto result = conn->query("COPY person FROM \"" + TestHelper::appendKuzuRootPath("dataset/tinysnb/vPerson.csv\" (HEADER=true)")); - ASSERT_EQ(result->getErrorMessage(), "Invalid number: Alice."); + ASSERT_EQ(result->getErrorMessage(), "Copy exception: Invalid: In CSV column #1: CSV " + "conversion error to int64: invalid value 'Alice'"); } TEST_F(CopyWrongHeaderTest, HeaderError) { @@ -151,14 +150,16 @@ TEST_F(CopyInvalidNumberTest, INT32OverflowError) { validateCopyException( "COPY person FROM \"" + TestHelper::appendKuzuRootPath("dataset/copy-fault-tests/invalid-number/vPerson.csv\""), - "Invalid number: 2147483650."); + "Copy exception: Invalid: In CSV column #1: CSV conversion error to int32: invalid value " + "'2147483650'"); } TEST_F(CopyInvalidNumberTest, InvalidNumberError) { validateCopyException( "COPY person FROM \"" + TestHelper::appendKuzuRootPath("dataset/copy-fault-tests/invalid-number/vMovie.csv\""), - "Invalid number: 312abc."); + "Copy exception: Invalid: In CSV column #1: CSV conversion error to int32: invalid value " + "'312abc'"); } TEST_F(CopyNullPKTest, NullPKErrpr) { diff --git a/test/copy/copy_test.cpp b/test/copy/copy_test.cpp index da1c670f27..9f3a3a20dd 100644 --- a/test/copy/copy_test.cpp +++ b/test/copy/copy_test.cpp @@ -300,7 +300,8 @@ TEST_F(CopyNodeInitRelTablesTest, CopyNodeAndQueryEmptyRelTable) { "create node table person (ID INt64, fName StRING, gender INT64, isStudent " "BoOLEAN, isWorker BOOLEAN, age INT64, eyeSight DOUBLE, birthdate DATE, " "registerTime TIMESTAMP, lastJobDuration interval, workedHours INT64[], usedNames " - "STRING[], courseScoresPerTerm INT64[][], PRIMARY KEY (ID));") + "STRING[], courseScoresPerTerm INT64[][], grades INT64[4], height float, PRIMARY " + "KEY (ID));") ->isSuccess()); ASSERT_TRUE(conn->query("create rel table knows (FROM person TO person, date DATE, meetTime " "TIMESTAMP, validInterval INTERVAL, comments STRING[], MANY_MANY);") diff --git a/test/runner/e2e_copy_transaction_test.cpp b/test/runner/e2e_copy_transaction_test.cpp index 4cfdd9833e..02efffc52d 100644 --- a/test/runner/e2e_copy_transaction_test.cpp +++ b/test/runner/e2e_copy_transaction_test.cpp @@ -194,7 +194,8 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest { "CREATE NODE TABLE person (ID INT64, fName STRING, gender INT64, isStudent BOOLEAN, " "isWorker BOOLEAN, " "age INT64, eyeSight DOUBLE, birthdate DATE, registerTime TIMESTAMP, lastJobDuration " - "INTERVAL, workedHours INT64[], usedNames STRING[], courseScoresPerTerm INT64[][], " + "INTERVAL, workedHours INT64[], usedNames STRING[], courseScoresPerTerm INT64[][], grades " + "INT64[4], height float, " "PRIMARY KEY (ID))"; std::string copyPersonTableCMD = "COPY person FROM \"" + diff --git a/tools/python_api/test/conftest.py b/tools/python_api/test/conftest.py index 9b249ddf29..750320bd0e 100644 --- a/tools/python_api/test/conftest.py +++ b/tools/python_api/test/conftest.py @@ -16,8 +16,8 @@ def init_tiny_snb(tmp_path): conn = kuzu.Connection(db) conn.execute("CREATE NODE TABLE person (ID INT64, fName STRING, gender INT64, isStudent BOOLEAN, isWorker BOOLEAN, " "age INT64, eyeSight DOUBLE, birthdate DATE, registerTime TIMESTAMP, lastJobDuration " - "INTERVAL, workedHours INT64[], usedNames STRING[], courseScoresPerTerm INT64[][], PRIMARY " - "KEY (ID))") + "INTERVAL, workedHours INT64[], usedNames STRING[], courseScoresPerTerm INT64[][], grades INT64[4], " + "height float, PRIMARY KEY (ID))") conn.execute("COPY person FROM \"../../../dataset/tinysnb/vPerson.csv\" (HEADER=true)") conn.execute( "create rel table knows (FROM person TO person, date DATE, meetTime TIMESTAMP, validInterval INTERVAL, " @@ -26,11 +26,13 @@ def init_tiny_snb(tmp_path): conn.execute("create node table organisation (ID INT64, name STRING, orgCode INT64, mark DOUBLE, score INT64, " "history STRING, licenseValidInterval INTERVAL, rating DOUBLE, PRIMARY KEY (ID))") conn.execute('COPY organisation FROM "../../../dataset/tinysnb/vOrganisation.csv"') - conn.execute('CREATE NODE TABLE movies (name STRING, PRIMARY KEY (name))') + conn.execute('CREATE NODE TABLE movies (name STRING, length INT32, note STRING, PRIMARY KEY (name))') conn.execute('COPY movies FROM "../../../dataset/tinysnb/vMovies.csv"') - conn.execute('create rel table workAt (FROM person TO organisation, year INT64, MANY_ONE)') + conn.execute('create rel table workAt (FROM person TO organisation, year INT64, grading DOUBLE[2], rating float,' + ' MANY_ONE)') conn.execute('COPY workAt FROM "../../../dataset/tinysnb/eWorkAt.csv"') - conn.execute('create node table tensor (ID INT64, boolTensor BOOLEAN[], doubleTensor DOUBLE[][], intTensor INT64[][][], oneDimInt INT64, PRIMARY KEY (ID));') + conn.execute('create node table tensor (ID INT64, boolTensor BOOLEAN[], doubleTensor DOUBLE[][], ' + 'intTensor INT64[][][], oneDimInt INT64, PRIMARY KEY (ID));') conn.execute( 'COPY tensor FROM "../../../dataset/tensor-list/vTensor.csv" (HEADER=true)') conn.execute( @@ -45,13 +47,18 @@ def init_tiny_snb(tmp_path): 'create node table npyoned (i64 INT64,i32 INT32,i16 INT16,f64 DOUBLE,f32 FLOAT, PRIMARY KEY(i64));' ) conn.execute( - 'copy npyoned from ("../../../dataset/npy-1d/one_dim_int64.npy", "../../../dataset/npy-1d/one_dim_int32.npy", "../../../dataset/npy-1d/one_dim_int16.npy", "../../../dataset/npy-1d/one_dim_double.npy", "../../../dataset/npy-1d/one_dim_float.npy") by column;' + 'copy npyoned from ("../../../dataset/npy-1d/one_dim_int64.npy", "../../../dataset/npy-1d/one_dim_int32.npy", ' + ' "../../../dataset/npy-1d/one_dim_int16.npy", "../../../dataset/npy-1d/one_dim_double.npy", ' + '"../../../dataset/npy-1d/one_dim_float.npy") by column;' ) conn.execute( - 'create node table npytwod (id INT64, i64 INT64[3],i32 INT32[3],i16 INT16[3],f64 DOUBLE[3],f32 FLOAT[3],PRIMARY KEY(id));' + 'create node table npytwod (id INT64, i64 INT64[3], i32 INT32[3], i16 INT16[3], f64 DOUBLE[3], f32 FLOAT[3],' + 'PRIMARY KEY(id));' ) conn.execute( - 'copy npytwod from ("../../../dataset/npy-2d/id_int64.npy", "../../../dataset/npy-2d/two_dim_int64.npy", "../../../dataset/npy-2d/two_dim_int32.npy", "../../../dataset/npy-2d/two_dim_int16.npy", "../../../dataset/npy-2d/two_dim_double.npy", "../../../dataset/npy-2d/two_dim_float.npy") by column;' + 'copy npytwod from ("../../../dataset/npy-2d/id_int64.npy", "../../../dataset/npy-2d/two_dim_int64.npy", ' + '"../../../dataset/npy-2d/two_dim_int32.npy", "../../../dataset/npy-2d/two_dim_int16.npy", ' + ' "../../../dataset/npy-2d/two_dim_double.npy", "../../../dataset/npy-2d/two_dim_float.npy") by column;' ) return output_path diff --git a/tools/python_api/test/ground_truth.py b/tools/python_api/test/ground_truth.py new file mode 100644 index 0000000000..4d41cd5eba --- /dev/null +++ b/tools/python_api/test/ground_truth.py @@ -0,0 +1,285 @@ +import datetime + +TINY_SNB_PERSONS_GROUND_TRUTH = {0: {'ID': 0, + 'fName': 'Alice', + 'gender': 1, + 'isStudent': True, + 'isWorker': False, + 'age': 35, + 'eyeSight': 5.0, + 'birthdate': datetime.date(1900, 1, 1), + 'registerTime': datetime.datetime(2011, 8, 20, 11, 25, 30), + 'lastJobDuration': datetime.timedelta(days=1082, seconds=46920), + 'workedHours': [10, 5], + 'usedNames': ['Aida'], + 'courseScoresPerTerm': [[10, 8], [6, 7, 8]], + '_label': 'person', + '_id': {'offset': 0, 'table': 0}}, + 2: {'ID': 2, + 'fName': 'Bob', + 'gender': 2, + 'isStudent': True, + 'isWorker': False, + 'age': 30, + 'eyeSight': 5.1, + 'birthdate': datetime.date(1900, 1, 1), + 'registerTime': datetime.datetime(2008, 11, 3, 15, 25, 30, 526), + 'lastJobDuration': datetime.timedelta(days=3750, seconds=46800, microseconds=24), + 'workedHours': [12, 8], + 'usedNames': ['Bobby'], + 'courseScoresPerTerm': [[8, 9], [9, 10]], + '_label': 'person', + '_id': {'offset': 1, 'table': 0}}, + 3: {'ID': 3, + 'fName': 'Carol', + 'gender': 1, + 'isStudent': False, + 'isWorker': True, + 'age': 45, + 'eyeSight': 5.0, + 'birthdate': datetime.date(1940, 6, 22), + 'registerTime': datetime.datetime(1911, 8, 20, 2, 32, 21), + 'lastJobDuration': datetime.timedelta(days=2, seconds=1451), + 'workedHours': [4, 5], + 'usedNames': ['Carmen', 'Fred'], + 'courseScoresPerTerm': [[8, 10]], + '_label': 'person', + '_id': {'offset': 2, 'table': 0}}, + 5: {'ID': 5, + 'fName': 'Dan', + 'gender': 2, + 'isStudent': False, + 'isWorker': True, + 'age': 20, + 'eyeSight': 4.8, + 'birthdate': datetime.date(1950, 7, 23), + 'registerTime': datetime.datetime(2031, 11, 30, 12, 25, 30), + 'lastJobDuration': datetime.timedelta(days=3750, seconds=46800, microseconds=24), + 'workedHours': [1, 9], + 'usedNames': ['Wolfeschlegelstein', 'Daniel'], + 'courseScoresPerTerm': [[7, 4], [8, 8], [9]], + '_label': 'person', + '_id': {'offset': 3, 'table': 0}}, + 7: {'ID': 7, + 'fName': 'Elizabeth', + 'gender': 1, + 'isStudent': False, + 'isWorker': True, + 'age': 20, + 'eyeSight': 4.7, + 'birthdate': datetime.date(1980, 10, 26), + 'registerTime': datetime.datetime(1976, 12, 23, 11, 21, 42), + 'lastJobDuration': datetime.timedelta(days=2, seconds=1451), + 'workedHours': [2], + 'usedNames': ['Ein'], + 'courseScoresPerTerm': [[6], [7], [8]], + '_label': 'person', + '_id': {'offset': 4, 'table': 0}}, + 8: {'ID': 8, + 'fName': 'Farooq', + 'gender': 2, + 'isStudent': True, + 'isWorker': False, + 'age': 25, + 'eyeSight': 4.5, + 'birthdate': datetime.date(1980, 10, 26), + 'registerTime': datetime.datetime(1972, 7, 31, 13, 22, 30, 678559), + 'lastJobDuration': datetime.timedelta(seconds=1080, microseconds=24000), + 'workedHours': [3, 4, 5, 6, 7], + 'usedNames': ['Fesdwe'], + 'courseScoresPerTerm': [[8]], + '_label': 'person', + '_id': {'offset': 5, 'table': 0}}, + 9: {'ID': 9, + 'fName': 'Greg', + 'gender': 2, + 'isStudent': False, + 'isWorker': False, + 'age': 40, + 'eyeSight': 4.9, + 'birthdate': datetime.date(1980, 10, 26), + 'registerTime': datetime.datetime(1976, 12, 23, 4, 41, 42), + 'lastJobDuration': datetime.timedelta(days=3750, seconds=46800, microseconds=24), + 'workedHours': [1], + 'usedNames': ['Grad'], + 'courseScoresPerTerm': [[10]], + '_label': 'person', + '_id': {'offset': 6, 'table': 0}}, + 10: {'ID': 10, + 'fName': 'Hubert Blaine Wolfeschlegelsteinhausenbergerdorff', + 'gender': 2, + 'isStudent': False, + 'isWorker': True, + 'age': 83, + 'eyeSight': 4.9, + 'birthdate': datetime.date(1990, 11, 27), + 'registerTime': datetime.datetime(2023, 2, 21, 13, 25, 30), + 'lastJobDuration': datetime.timedelta(days=1082, seconds=46920), + 'workedHours': [10, 11, 12, 3, 4, 5, 6, 7], + 'usedNames': ['Ad', 'De', 'Hi', 'Kye', 'Orlan'], + 'courseScoresPerTerm': [[7], [10], [6, 7]], + '_label': 'person', + '_id': {'offset': 7, 'table': 0}}} + + +TINY_SNB_ORGANISATIONS_GROUND_TRUTH = {1: {'ID': 1, + 'name': 'ABFsUni', + 'orgCode': 325, + 'mark': 3.7, + 'score': -2, + 'history': '10 years 5 months 13 hours 24 us', + 'licenseValidInterval': datetime.timedelta(days=1085), + 'rating': 1.0, + '_label': 'organisation', + '_id': {'offset': 0, 'table': 2}}, + 4: {'ID': 4, + 'name': 'CsWork', + 'orgCode': 934, + 'mark': 4.1, + 'score': -100, + 'history': '2 years 4 days 10 hours', + 'licenseValidInterval': datetime.timedelta(days=9414), + 'rating': 0.78, + '_label': 'organisation', + '_id': {'offset': 1, 'table': 2}}, + 6: {'ID': 6, + 'name': 'DEsWork', + 'orgCode': 824, + 'mark': 4.1, + 'score': 7, + 'history': '2 years 4 hours 22 us 34 minutes', + 'licenseValidInterval': datetime.timedelta(days=3, seconds=36000, microseconds=100000), + 'rating': 0.52, + '_label': 'organisation', + '_id': {'offset': 2, 'table': 2}}} + +TINY_SNB_KNOWS_GROUND_TRUTH = { + 0: [2, 3, 5], + 2: [0, 3, 5], + 3: [0, 2, 5], + 5: [0, 2, 3], + 7: [8, 9], +} + +TINY_SNB_KNOWS_PROPERTIES_GROUND_TRUTH = { + (0, 2): {'date': datetime.date(2021, 6, 30), + 'meetTime': datetime.datetime(1986, 10, 21, 21, 8, 31, 521000), + 'validInterval': datetime.timedelta(days=3750, seconds=46800, microseconds=24), + 'comments': ['rnme', 'm8sihsdnf2990nfiwf']}, + (0, 3): {'date': datetime.date(2021, 6, 30), + 'meetTime': datetime.datetime(1946, 8, 25, 19, 7, 22), + 'validInterval': datetime.timedelta(days=7232), + 'comments': ['njnojppo9u0jkmf', 'fjiojioh9h9h89hph']}, + (0, 5): {'date': datetime.date(2021, 6, 30), + 'meetTime': datetime.datetime(2012, 12, 11, 20, 7, 22), + 'validInterval': datetime.timedelta(days=10), + 'comments': ['ioji232', 'jifhe8w99u43434']}, + (2, 0): {'date': datetime.date(2021, 6, 30), + 'meetTime': datetime.datetime(1946, 8, 25, 19, 7, 22), + 'validInterval': datetime.timedelta(days=3750, seconds=46800, microseconds=24), + 'comments': ['2huh9y89fsfw23', '23nsihufhw723']}, + (2, 3): {'date': datetime.date(1950, 5, 14), + 'meetTime': datetime.datetime(1946, 8, 25, 19, 7, 22), + 'validInterval': datetime.timedelta(seconds=1380), + 'comments': ['fwehu9h9832wewew', '23u9h989sdfsss']}, + (2, 5): {'date': datetime.date(1950, 5, 14), + 'meetTime': datetime.datetime(2012, 12, 11, 20, 7, 22), + 'validInterval': datetime.timedelta(days=7232), + 'comments': ['fwh9y81232uisuiehuf', 'ewnuihxy8dyf232']}, + (3, 0): {'date': datetime.date(2021, 6, 30), + 'meetTime': datetime.datetime(2002, 7, 31, 11, 42, 53, 123420), + 'validInterval': datetime.timedelta(days=41, seconds=21600), + 'comments': ['fnioh8323aeweae34d', 'osd89e2ejshuih12']}, + (3, 2): {'date': datetime.date(1950, 5, 14), + 'meetTime': datetime.datetime(2007, 2, 12, 12, 11, 42, 123000), + 'validInterval': datetime.timedelta(seconds=1680, microseconds=30000), + 'comments': ['fwh983-sdjisdfji', 'ioh89y32r2huir']}, + (3, 5): {'date': datetime.date(2000, 1, 1), + 'meetTime': datetime.datetime(1998, 10, 2, 13, 9, 22, 423000), + 'validInterval': datetime.timedelta(microseconds=300000), + 'comments': ['psh989823oaaioe', 'nuiuah1nosndfisf']}, + (5, 0): {'date': datetime.date(2021, 6, 30), + 'meetTime': datetime.datetime(1936, 11, 2, 11, 2, 1), + 'validInterval': datetime.timedelta(microseconds=480), + 'comments': ['fwewe']}, + (5, 2): {'date': datetime.date(1950, 5, 14), + 'meetTime': datetime.datetime(1982, 11, 11, 13, 12, 5, 123000), + 'validInterval': datetime.timedelta(seconds=1380), + 'comments': ['fewh9182912e3', + 'h9y8y89soidfsf', + 'nuhudf78w78efw', + 'hioshe0f9023sdsd']}, + (5, 3): {'date': datetime.date(2000, 1, 1), + 'meetTime': datetime.datetime(1999, 4, 21, 15, 12, 11, 420000), + 'validInterval': datetime.timedelta(days=2, microseconds=52000), + 'comments': ['23h9sdslnfowhu2932', 'shuhf98922323sf']}, + (7, 8): {'date': datetime.date(1905, 12, 12), + 'meetTime': datetime.datetime(2025, 1, 1, 11, 22, 33, 520000), + 'validInterval': datetime.timedelta(seconds=2878), + 'comments': ['ahu2333333333333', '12weeeeeeeeeeeeeeeeee']}, + (7, 9): {'date': datetime.date(1905, 12, 12), + 'meetTime': datetime.datetime(2020, 3, 1, 12, 11, 41, 655200), + 'validInterval': datetime.timedelta(seconds=2878), + 'comments': ['peweeeeeeeeeeeeeeeee', 'kowje9w0eweeeeeeeee']} +} + + +TINY_SNB_WORKS_AT_GROUND_TRUTH = { + 3: [4], + 5: [6], + 7: [6], +} + +TINY_SNB_WORKS_AT_PROPERTIES_GROUND_TRUTH = { + (3, 4): {'year': 2015}, + (5, 6): {'year': 2010}, + (7, 6): {'year': 2015} +} + +TENSOR_LIST_GROUND_TRUTH = { + 0: { + 'boolTensor': [True, False], + 'doubleTensor': [[0.1, 0.2], [0.3, 0.4]], + 'intTensor': [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] + }, + 3: { + 'boolTensor': [True, False], + 'doubleTensor': [[0.1, 0.2], [0.3, 0.4]], + 'intTensor': [[[3, 4], [5, 6]], [[7, 8], [9, 10]]] + }, + 4: { + 'boolTensor': [False, True], + 'doubleTensor': [[0.4, 0.8], [0.7, 0.6]], + 'intTensor': [[[5, 6], [7, 8]], [[9, 10], [11, 12]]] + }, + 5: { + 'boolTensor': [True, True], + 'doubleTensor': [[0.4, 0.9], [0.5, 0.2]], + 'intTensor': [[[7, 8], [9, 10]], [[11, 12], [13, 14]]] + }, + 6: { + 'boolTensor': [False, True], + 'doubleTensor': [[0.2, 0.4], [0.5, 0.1]], + 'intTensor': [[[9, 10], [11, 12]], [[13, 14], [15, 16]]] + }, + 8: { + 'boolTensor': [False, True], + 'doubleTensor': [[0.6, 0.4], [0.6, 0.1]], + 'intTensor': [[[11, 12], [13, 14]], [[15, 16], [17, 18]]] + } +} + +PERSONLONGSTRING_GROUND_TRUTH = { + 'AAAAAAAAAAAAAAAAAAAA': { + 'name': 'AAAAAAAAAAAAAAAAAAAA', + 'spouse': 'Bob', + }, + 'Bob': { + 'name': 'Bob', + 'spouse': 'AAAAAAAAAAAAAAAAAAAA', + }, +} + +PERSONLONGSTRING_KNOWS_GROUND_TRUTH = { + 'AAAAAAAAAAAAAAAAAAAA': ['Bob'], +} diff --git a/tools/python_api/test/test_arrow.py b/tools/python_api/test/test_arrow.py index 355f4b48a7..8226133a4c 100644 --- a/tools/python_api/test/test_arrow.py +++ b/tools/python_api/test/test_arrow.py @@ -4,6 +4,7 @@ import kuzu import pyarrow as pa import datetime +import ground_truth def test_to_arrow(establish_connection): @@ -121,182 +122,49 @@ def _test_with_nulls(_conn): _test_with_nulls(conn) -def test_to_arrow_complex(establish_connection): - conn, db = establish_connection - - def _test_node(_conn): - query = "MATCH (p:person) RETURN p ORDER BY p.ID" - query_result = _conn.execute(query) - arrow_tbl = query_result.get_as_arrow(12) - p_col = arrow_tbl.column(0) - - assert p_col.to_pylist() == [{'ID': 0, - '_id': {'offset': 0, 'tableID': 0}, - '_label': 'person', - 'age': 35, - 'birthdate': datetime.date(1900, 1, 1), - 'courseScoresPerTerm': [[10, 8], [6, 7, 8]], - 'eyeSight': 5.0, - 'fName': 'Alice', - 'gender': 1, - 'isStudent': True, - 'isWorker': False, - 'lastJobDuration': datetime.timedelta(days=99, seconds=36334, - microseconds=628000), - 'registerTime': datetime.datetime(2011, 8, 20, 11, 25, 30), - 'usedNames': ['Aida'], - 'workedHours': [10, 5]}, - {'ID': 2, - '_id': {'offset': 1, 'tableID': 0}, - '_label': 'person', - 'age': 30, - 'birthdate': datetime.date(1900, 1, 1), - 'courseScoresPerTerm': [[8, 9], [9, 10]], - 'eyeSight': 5.1, - 'fName': 'Bob', - 'gender': 2, - 'isStudent': True, - 'isWorker': False, - 'lastJobDuration': datetime.timedelta(days=543, seconds=4800), - 'registerTime': datetime.datetime(2008, 11, 3, 15, 25, 30, 526), - 'usedNames': ['Bobby'], - 'workedHours': [12, 8]}, - {'ID': 3, - '_id': {'offset': 2, 'tableID': 0}, - '_label': 'person', - 'age': 45, - 'birthdate': datetime.date(1940, 6, 22), - 'courseScoresPerTerm': [[8, 10]], - 'eyeSight': 5.0, - 'fName': 'Carol', - 'gender': 1, - 'isStudent': False, - 'isWorker': True, - 'lastJobDuration': datetime.timedelta(microseconds=125000), - 'registerTime': datetime.datetime(1911, 8, 20, 2, 32, 21), - 'usedNames': ['Carmen', 'Fred'], - 'workedHours': [4, 5]}, - {'ID': 5, - '_id': {'offset': 3, 'tableID': 0}, - '_label': 'person', - 'age': 20, - 'birthdate': datetime.date(1950, 7, 23), - 'courseScoresPerTerm': [[7, 4], [8, 8], [9]], - 'eyeSight': 4.8, - 'fName': 'Dan', - 'gender': 2, - 'isStudent': False, - 'isWorker': True, - 'lastJobDuration': datetime.timedelta(days=541, seconds=57600, - microseconds=24000), - 'registerTime': datetime.datetime(2031, 11, 30, 12, 25, 30), - 'usedNames': ['Wolfeschlegelstein', 'Daniel'], - 'workedHours': [1, 9]}, - {'ID': 7, - '_id': {'offset': 4, 'tableID': 0}, - '_label': 'person', - 'age': 20, - 'birthdate': datetime.date(1980, 10, 26), - 'courseScoresPerTerm': [[6], [7], [8]], - 'eyeSight': 4.7, - 'fName': 'Elizabeth', - 'gender': 1, - 'isStudent': False, - 'isWorker': True, - 'lastJobDuration': datetime.timedelta(0), - 'registerTime': datetime.datetime(1976, 12, 23, 11, 21, 42), - 'usedNames': ['Ein'], - 'workedHours': [2]}, - {'ID': 8, - '_id': {'offset': 5, 'tableID': 0}, - '_label': 'person', - 'age': 25, - 'birthdate': datetime.date(1980, 10, 26), - 'courseScoresPerTerm': [[8]], - 'eyeSight': 4.5, - 'fName': 'Farooq', - 'gender': 2, - 'isStudent': True, - 'isWorker': False, - 'lastJobDuration': datetime.timedelta(days=2016, seconds=68600), - 'registerTime': datetime.datetime(1972, 7, 31, 13, 22, 30, 678559), - 'usedNames': ['Fesdwe'], - 'workedHours': [3, 4, 5, 6, 7]}, - {'ID': 9, - '_id': {'offset': 6, 'tableID': 0}, - '_label': 'person', - 'age': 40, - 'birthdate': datetime.date(1980, 10, 26), - 'courseScoresPerTerm': [[10]], - 'eyeSight': 4.9, - 'fName': 'Greg', - 'gender': 2, - 'isStudent': False, - 'isWorker': False, - 'lastJobDuration': datetime.timedelta(microseconds=125000), - 'registerTime': datetime.datetime(1976, 12, 23, 4, 41, 42), - 'usedNames': ['Grad'], - 'workedHours': [1]}, - {'ID': 10, - '_id': {'offset': 7, 'tableID': 0}, - '_label': 'person', - 'age': 83, - 'birthdate': datetime.date(1990, 11, 27), - 'courseScoresPerTerm': [[7], [10], [6, 7]], - 'eyeSight': 4.9, - 'fName': 'Hubert Blaine Wolfeschlegelsteinhausenbergerdorff', - 'gender': 2, - 'isStudent': False, - 'isWorker': True, - 'lastJobDuration': datetime.timedelta(days=541, seconds=57600, - microseconds=24000), - 'registerTime': datetime.datetime(2023, 2, 21, 13, 25, 30), - 'usedNames': ['Ad', 'De', 'Hi', 'Kye', 'Orlan'], - 'workedHours': [10, 11, 12, 3, 4, 5, 6, 7]}] - - def _test_node_rel(_conn): - query = "MATCH (a:person)-[e:workAt]->(b:organisation) RETURN a, e, b;" - query_result = _conn.execute(query) - arrow_tbl = query_result.get_as_arrow(12) - assert arrow_tbl.num_columns == 3 - a_col = arrow_tbl.column(0) - assert a_col.length() == 3 - e_col = arrow_tbl.column(1) - assert a_col.length() == 3 - b_col = arrow_tbl.column(2) - assert a_col.length() == 3 - assert a_col.to_pylist() == [ - {'_id': {'offset': 2, 'tableID': 0}, '_label': 'person', 'ID': 3, 'fName': 'Carol', 'gender': 1, - 'isStudent': False, 'isWorker': True, 'age': 45, 'eyeSight': 5.0, 'birthdate': datetime.date(1940, 6, 22), - 'registerTime': datetime.datetime(1911, 8, 20, 2, 32, 21), 'lastJobDuration': datetime.timedelta(0), - 'workedHours': [4, 5], 'usedNames': ['Carmen', 'Fred'], 'courseScoresPerTerm': [[8, 10]]}, - {'_id': {'offset': 3, 'tableID': 0}, '_label': 'person', 'ID': 5, 'fName': 'Dan', 'gender': 2, - 'isStudent': False, 'isWorker': True, 'age': 20, 'eyeSight': 4.8, 'birthdate': datetime.date(1950, 7, 23), - 'registerTime': datetime.datetime(2031, 11, 30, 12, 25, 30), - 'lastJobDuration': datetime.timedelta(days=2016, seconds=68600), 'workedHours': [1, 9], - 'usedNames': ['Wolfeschlegelstein', 'Daniel'], 'courseScoresPerTerm': [[7, 4], [8, 8], [9]]}, - {'_id': {'offset': 4, 'tableID': 0}, '_label': 'person', 'ID': 7, 'fName': 'Elizabeth', 'gender': 1, - 'isStudent': False, 'isWorker': True, 'age': 20, 'eyeSight': 4.7, 'birthdate': datetime.date(1980, 10, 26), - 'registerTime': datetime.datetime(1976, 12, 23, 11, 21, 42), - 'lastJobDuration': datetime.timedelta(microseconds=125000), 'workedHours': [2], 'usedNames': ['Ein'], - 'courseScoresPerTerm': [[6], [7], [8]]}] - assert e_col.to_pylist() == [ - {'_src': {'offset': 2, 'tableID': 0}, '_dst': {'offset': 1, 'tableID': 2}, - '_id': {'offset': 0, 'tableID': 4}, 'year': 2015}, - {'_src': {'offset': 3, 'tableID': 0}, '_dst': {'offset': 2, 'tableID': 2}, - '_id': {'offset': 1, 'tableID': 4}, 'year': 2010}, - {'_src': {'offset': 4, 'tableID': 0}, '_dst': {'offset': 2, 'tableID': 2}, - '_id': {'offset': 2, 'tableID': 4}, 'year': 2015}] - assert b_col.to_pylist() == [ - {'_id': {'offset': 1, 'tableID': 2}, '_label': 'organisation', 'ID': 4, 'name': 'CsWork', 'orgCode': 934, - 'mark': 4.1, 'score': -100, 'history': '2 years 4 days 10 hours', - 'licenseValidInterval': datetime.timedelta(days=2584, seconds=80699, microseconds=704000), 'rating': 0.78}, - {'_id': {'offset': 2, 'tableID': 2}, '_label': 'organisation', 'ID': 6, 'name': 'DEsWork', 'orgCode': 824, - 'mark': 4.1, 'score': 7, 'history': '2 years 4 hours 22 us 34 minutes', - 'licenseValidInterval': datetime.timedelta(days=2000), 'rating': 0.52}, - {'_id': {'offset': 2, 'tableID': 2}, '_label': 'organisation', 'ID': 6, 'name': 'DEsWork', 'orgCode': 824, - 'mark': 4.1, 'score': 7, 'history': '2 years 4 hours 22 us 34 minutes', - 'licenseValidInterval': datetime.timedelta(0), 'rating': 0.52}] - - _test_node(conn) - _test_node_rel(conn) +# TODO: enable this test once we support fixed_size_list +# def test_to_arrow_complex(establish_connection): +# conn, db = establish_connection +# +# def _test_node(_conn): +# query = "MATCH (p:person) RETURN p ORDER BY p.ID" +# query_result = _conn.execute(query) +# arrow_tbl = query_result.get_as_arrow(12) +# p_col = arrow_tbl.column(0) +# +# assert p_col.to_pylist() == [ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[0], +# ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[2], +# ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[3], +# ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[5], +# ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[7], +# ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[8], +# ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[9], +# ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[10]] +# +# def _test_node_rel(_conn): +# query = "MATCH (a:person)-[e:workAt]->(b:organisation) RETURN a, e, b;" +# query_result = _conn.execute(query) +# arrow_tbl = query_result.get_as_arrow(12) +# assert arrow_tbl.num_columns == 3 +# a_col = arrow_tbl.column(0) +# assert a_col.length() == 3 +# e_col = arrow_tbl.column(1) +# assert a_col.length() == 3 +# b_col = arrow_tbl.column(2) +# assert a_col.length() == 3 +# assert a_col.to_pylist() == [ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[3], +# ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[5], +# ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[7]] +# assert e_col.to_pylist() == [ +# {'_src': {'offset': 2, 'tableID': 0}, '_dst': {'offset': 1, 'tableID': 2}, +# '_id': {'offset': 0, 'tableID': 4}, 'year': 2015}, +# {'_src': {'offset': 3, 'tableID': 0}, '_dst': {'offset': 2, 'tableID': 2}, +# '_id': {'offset': 1, 'tableID': 4}, 'year': 2010}, +# {'_src': {'offset': 4, 'tableID': 0}, '_dst': {'offset': 2, 'tableID': 2}, +# '_id': {'offset': 2, 'tableID': 4}, 'year': 2015}] +# assert b_col.to_pylist() == [ground_truth.TINY_SNB_ORGANISATIONS_GROUND_TRUTH[4], +# ground_truth.TINY_SNB_ORGANISATIONS_GROUND_TRUTH[6], +# ground_truth.TINY_SNB_ORGANISATIONS_GROUND_TRUTH[6]] +# +# _test_node(conn) +# _test_node_rel(conn) diff --git a/tools/python_api/test/test_networkx.py b/tools/python_api/test/test_networkx.py index dbb84edf24..d49da0be11 100644 --- a/tools/python_api/test/test_networkx.py +++ b/tools/python_api/test/test_networkx.py @@ -47,6 +47,8 @@ def test_to_networkx_node(establish_connection): [[7, 4], [8, 8], [9]], [ [6], [7], [8]], [[8]], [[10]], [[7], [10], [6, 7]]], + 'grades': [[96, 54, 86, 92], [98, 42, 93, 88], [91, 75, 21, 95], [76, 88, 99, 89], [96, 59, 65, 88], + [80, 78, 34, 83], [43, 83, 67, 43], [77, 64, 100, 54]], '_label': ['person', 'person', 'person', 'person', 'person', 'person', 'person', 'person'], } @@ -106,6 +108,8 @@ def test_networkx_undirected(establish_connection): [[7, 4], [8, 8], [9]], [ [6], [7], [8]], [[8]], [[10]], [[7], [10], [6, 7]]], + 'grades': [[96, 54, 86, 92], [98, 42, 93, 88], [91, 75, 21, 95], [76, 88, 99, 89], [96, 59, 65, 88], + [80, 78, 34, 83], [43, 83, 67, 43], [77, 64, 100, 54]], '_label': ['person', 'person', 'person', 'person', 'person', 'person', 'person', 'person'], } @@ -162,6 +166,7 @@ def test_networkx_directed(establish_connection): 'workedHours': [[4, 5], [1, 9], [2]], 'usedNames': [["Carmen", "Fred"], ['Wolfeschlegelstein', 'Daniel'], ['Ein']], 'courseScoresPerTerm': [[[8, 10]], [[7, 4], [8, 8], [9]], [[6], [7], [8]]], + 'grades': [[91, 75, 21, 95], [76, 88, 99, 89], [96, 59, 65, 88]], '_label': ['person', 'person', 'person'], } diff --git a/tools/python_api/test/test_query_result_close.py b/tools/python_api/test/test_query_result_close.py index 21bbbf76a1..99453fbab8 100644 --- a/tools/python_api/test/test_query_result_close.py +++ b/tools/python_api/test/test_query_result_close.py @@ -12,7 +12,7 @@ def test_query_result_close(get_tmp_path): 'conn.execute(\'CREATE NODE TABLE person (ID INT64, fName STRING, gender INT64,\ isStudent BOOLEAN, isWorker BOOLEAN, age INT64, eyeSight DOUBLE,\ birthdate DATE, registerTime TIMESTAMP, lastJobDuration INTERVAL,\ - workedHours INT64[], usedNames STRING[], courseScoresPerTerm INT64[][],\ + workedHours INT64[], usedNames STRING[], courseScoresPerTerm INT64[][], grades INT64[4], height float, \ PRIMARY KEY (ID))\')', 'conn.execute(\'COPY person FROM \"../../../dataset/tinysnb/vPerson.csv\" (HEADER=true)\')', 'result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN a.isStudent;")', diff --git a/tools/python_api/test/test_torch_geometric.py b/tools/python_api/test/test_torch_geometric.py index fc2c6f69e3..6ddde77b39 100644 --- a/tools/python_api/test/test_torch_geometric.py +++ b/tools/python_api/test/test_torch_geometric.py @@ -1,290 +1,6 @@ -import datetime import torch import warnings - -TINY_SNB_PERSONS_GROUND_TRUTH = {0: {'ID': 0, - 'fName': 'Alice', - 'gender': 1, - 'isStudent': True, - 'isWorker': False, - 'age': 35, - 'eyeSight': 5.0, - 'birthdate': datetime.date(1900, 1, 1), - 'registerTime': datetime.datetime(2011, 8, 20, 11, 25, 30), - 'lastJobDuration': datetime.timedelta(days=1082, seconds=46920), - 'workedHours': [10, 5], - 'usedNames': ['Aida'], - 'courseScoresPerTerm': [[10, 8], [6, 7, 8]], - '_label': 'person', - '_id': {'offset': 0, 'table': 0}}, - 2: {'ID': 2, - 'fName': 'Bob', - 'gender': 2, - 'isStudent': True, - 'isWorker': False, - 'age': 30, - 'eyeSight': 5.1, - 'birthdate': datetime.date(1900, 1, 1), - 'registerTime': datetime.datetime(2008, 11, 3, 15, 25, 30, 526), - 'lastJobDuration': datetime.timedelta(days=3750, seconds=46800, microseconds=24), - 'workedHours': [12, 8], - 'usedNames': ['Bobby'], - 'courseScoresPerTerm': [[8, 9], [9, 10]], - '_label': 'person', - '_id': {'offset': 1, 'table': 0}}, - 3: {'ID': 3, - 'fName': 'Carol', - 'gender': 1, - 'isStudent': False, - 'isWorker': True, - 'age': 45, - 'eyeSight': 5.0, - 'birthdate': datetime.date(1940, 6, 22), - 'registerTime': datetime.datetime(1911, 8, 20, 2, 32, 21), - 'lastJobDuration': datetime.timedelta(days=2, seconds=1451), - 'workedHours': [4, 5], - 'usedNames': ['Carmen', 'Fred'], - 'courseScoresPerTerm': [[8, 10]], - '_label': 'person', - '_id': {'offset': 2, 'table': 0}}, - 5: {'ID': 5, - 'fName': 'Dan', - 'gender': 2, - 'isStudent': False, - 'isWorker': True, - 'age': 20, - 'eyeSight': 4.8, - 'birthdate': datetime.date(1950, 7, 23), - 'registerTime': datetime.datetime(2031, 11, 30, 12, 25, 30), - 'lastJobDuration': datetime.timedelta(days=3750, seconds=46800, microseconds=24), - 'workedHours': [1, 9], - 'usedNames': ['Wolfeschlegelstein', 'Daniel'], - 'courseScoresPerTerm': [[7, 4], [8, 8], [9]], - '_label': 'person', - '_id': {'offset': 3, 'table': 0}}, - 7: {'ID': 7, - 'fName': 'Elizabeth', - 'gender': 1, - 'isStudent': False, - 'isWorker': True, - 'age': 20, - 'eyeSight': 4.7, - 'birthdate': datetime.date(1980, 10, 26), - 'registerTime': datetime.datetime(1976, 12, 23, 11, 21, 42), - 'lastJobDuration': datetime.timedelta(days=2, seconds=1451), - 'workedHours': [2], - 'usedNames': ['Ein'], - 'courseScoresPerTerm': [[6], [7], [8]], - '_label': 'person', - '_id': {'offset': 4, 'table': 0}}, - 8: {'ID': 8, - 'fName': 'Farooq', - 'gender': 2, - 'isStudent': True, - 'isWorker': False, - 'age': 25, - 'eyeSight': 4.5, - 'birthdate': datetime.date(1980, 10, 26), - 'registerTime': datetime.datetime(1972, 7, 31, 13, 22, 30, 678559), - 'lastJobDuration': datetime.timedelta(seconds=1080, microseconds=24000), - 'workedHours': [3, 4, 5, 6, 7], - 'usedNames': ['Fesdwe'], - 'courseScoresPerTerm': [[8]], - '_label': 'person', - '_id': {'offset': 5, 'table': 0}}, - 9: {'ID': 9, - 'fName': 'Greg', - 'gender': 2, - 'isStudent': False, - 'isWorker': False, - 'age': 40, - 'eyeSight': 4.9, - 'birthdate': datetime.date(1980, 10, 26), - 'registerTime': datetime.datetime(1976, 12, 23, 4, 41, 42), - 'lastJobDuration': datetime.timedelta(days=3750, seconds=46800, microseconds=24), - 'workedHours': [1], - 'usedNames': ['Grad'], - 'courseScoresPerTerm': [[10]], - '_label': 'person', - '_id': {'offset': 6, 'table': 0}}, - 10: {'ID': 10, - 'fName': 'Hubert Blaine Wolfeschlegelsteinhausenbergerdorff', - 'gender': 2, - 'isStudent': False, - 'isWorker': True, - 'age': 83, - 'eyeSight': 4.9, - 'birthdate': datetime.date(1990, 11, 27), - 'registerTime': datetime.datetime(2023, 2, 21, 13, 25, 30), - 'lastJobDuration': datetime.timedelta(days=1082, seconds=46920), - 'workedHours': [10, 11, 12, 3, 4, 5, 6, 7], - 'usedNames': ['Ad', 'De', 'Hi', 'Kye', 'Orlan'], - 'courseScoresPerTerm': [[7], [10], [6, 7]], - '_label': 'person', - '_id': {'offset': 7, 'table': 0}}} - -TINY_SNB_ORGANISATIONS_GROUND_TRUTH = {1: {'ID': 1, - 'name': 'ABFsUni', - 'orgCode': 325, - 'mark': 3.7, - 'score': -2, - 'history': '10 years 5 months 13 hours 24 us', - 'licenseValidInterval': datetime.timedelta(days=1085), - 'rating': 1.0, - '_label': 'organisation', - '_id': {'offset': 0, 'table': 2}}, - 4: {'ID': 4, - 'name': 'CsWork', - 'orgCode': 934, - 'mark': 4.1, - 'score': -100, - 'history': '2 years 4 days 10 hours', - 'licenseValidInterval': datetime.timedelta(days=9414), - 'rating': 0.78, - '_label': 'organisation', - '_id': {'offset': 1, 'table': 2}}, - 6: {'ID': 6, - 'name': 'DEsWork', - 'orgCode': 824, - 'mark': 4.1, - 'score': 7, - 'history': '2 years 4 hours 22 us 34 minutes', - 'licenseValidInterval': datetime.timedelta(days=3, seconds=36000, microseconds=100000), - 'rating': 0.52, - '_label': 'organisation', - '_id': {'offset': 2, 'table': 2}}} - -TINY_SNB_KNOWS_GROUND_TRUTH = { - 0: [2, 3, 5], - 2: [0, 3, 5], - 3: [0, 2, 5], - 5: [0, 2, 3], - 7: [8, 9], -} - -TINY_SNB_KNOWS_PROPERTIES_GROUND_TRUTH = { - (0, 2): {'date': datetime.date(2021, 6, 30), - 'meetTime': datetime.datetime(1986, 10, 21, 21, 8, 31, 521000), - 'validInterval': datetime.timedelta(days=3750, seconds=46800, microseconds=24), - 'comments': ['rnme', 'm8sihsdnf2990nfiwf']}, - (0, 3): {'date': datetime.date(2021, 6, 30), - 'meetTime': datetime.datetime(1946, 8, 25, 19, 7, 22), - 'validInterval': datetime.timedelta(days=7232), - 'comments': ['njnojppo9u0jkmf', 'fjiojioh9h9h89hph']}, - (0, 5): {'date': datetime.date(2021, 6, 30), - 'meetTime': datetime.datetime(2012, 12, 11, 20, 7, 22), - 'validInterval': datetime.timedelta(days=10), - 'comments': ['ioji232', 'jifhe8w99u43434']}, - (2, 0): {'date': datetime.date(2021, 6, 30), - 'meetTime': datetime.datetime(1946, 8, 25, 19, 7, 22), - 'validInterval': datetime.timedelta(days=3750, seconds=46800, microseconds=24), - 'comments': ['2huh9y89fsfw23', '23nsihufhw723']}, - (2, 3): {'date': datetime.date(1950, 5, 14), - 'meetTime': datetime.datetime(1946, 8, 25, 19, 7, 22), - 'validInterval': datetime.timedelta(seconds=1380), - 'comments': ['fwehu9h9832wewew', '23u9h989sdfsss']}, - (2, 5): {'date': datetime.date(1950, 5, 14), - 'meetTime': datetime.datetime(2012, 12, 11, 20, 7, 22), - 'validInterval': datetime.timedelta(days=7232), - 'comments': ['fwh9y81232uisuiehuf', 'ewnuihxy8dyf232']}, - (3, 0): {'date': datetime.date(2021, 6, 30), - 'meetTime': datetime.datetime(2002, 7, 31, 11, 42, 53, 123420), - 'validInterval': datetime.timedelta(days=41, seconds=21600), - 'comments': ['fnioh8323aeweae34d', 'osd89e2ejshuih12']}, - (3, 2): {'date': datetime.date(1950, 5, 14), - 'meetTime': datetime.datetime(2007, 2, 12, 12, 11, 42, 123000), - 'validInterval': datetime.timedelta(seconds=1680, microseconds=30000), - 'comments': ['fwh983-sdjisdfji', 'ioh89y32r2huir']}, - (3, 5): {'date': datetime.date(2000, 1, 1), - 'meetTime': datetime.datetime(1998, 10, 2, 13, 9, 22, 423000), - 'validInterval': datetime.timedelta(microseconds=300000), - 'comments': ['psh989823oaaioe', 'nuiuah1nosndfisf']}, - (5, 0): {'date': datetime.date(2021, 6, 30), - 'meetTime': datetime.datetime(1936, 11, 2, 11, 2, 1), - 'validInterval': datetime.timedelta(microseconds=480), - 'comments': ['fwewe']}, - (5, 2): {'date': datetime.date(1950, 5, 14), - 'meetTime': datetime.datetime(1982, 11, 11, 13, 12, 5, 123000), - 'validInterval': datetime.timedelta(seconds=1380), - 'comments': ['fewh9182912e3', - 'h9y8y89soidfsf', - 'nuhudf78w78efw', - 'hioshe0f9023sdsd']}, - (5, 3): {'date': datetime.date(2000, 1, 1), - 'meetTime': datetime.datetime(1999, 4, 21, 15, 12, 11, 420000), - 'validInterval': datetime.timedelta(days=2, microseconds=52000), - 'comments': ['23h9sdslnfowhu2932', 'shuhf98922323sf']}, - (7, 8): {'date': datetime.date(1905, 12, 12), - 'meetTime': datetime.datetime(2025, 1, 1, 11, 22, 33, 520000), - 'validInterval': datetime.timedelta(seconds=2878), - 'comments': ['ahu2333333333333', '12weeeeeeeeeeeeeeeeee']}, - (7, 9): {'date': datetime.date(1905, 12, 12), - 'meetTime': datetime.datetime(2020, 3, 1, 12, 11, 41, 655200), - 'validInterval': datetime.timedelta(seconds=2878), - 'comments': ['peweeeeeeeeeeeeeeeee', 'kowje9w0eweeeeeeeee']} -} - - -TINY_SNB_WORKS_AT_GROUND_TRUTH = { - 3: [4], - 5: [6], - 7: [6], -} - -TINY_SNB_WORKS_AT_PROPERTIES_GROUND_TRUTH = { - (3, 4): {'year': 2015}, - (5, 6): {'year': 2010}, - (7, 6): {'year': 2015} -} - -TENSOR_LIST_GROUND_TRUTH = { - 0: { - 'boolTensor': [True, False], - 'doubleTensor': [[0.1, 0.2], [0.3, 0.4]], - 'intTensor': [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] - }, - 3: { - 'boolTensor': [True, False], - 'doubleTensor': [[0.1, 0.2], [0.3, 0.4]], - 'intTensor': [[[3, 4], [5, 6]], [[7, 8], [9, 10]]] - }, - 4: { - 'boolTensor': [False, True], - 'doubleTensor': [[0.4, 0.8], [0.7, 0.6]], - 'intTensor': [[[5, 6], [7, 8]], [[9, 10], [11, 12]]] - }, - 5: { - 'boolTensor': [True, True], - 'doubleTensor': [[0.4, 0.9], [0.5, 0.2]], - 'intTensor': [[[7, 8], [9, 10]], [[11, 12], [13, 14]]] - }, - 6: { - 'boolTensor': [False, True], - 'doubleTensor': [[0.2, 0.4], [0.5, 0.1]], - 'intTensor': [[[9, 10], [11, 12]], [[13, 14], [15, 16]]] - }, - 8: { - 'boolTensor': [False, True], - 'doubleTensor': [[0.6, 0.4], [0.6, 0.1]], - 'intTensor': [[[11, 12], [13, 14]], [[15, 16], [17, 18]]] - } -} - -PERSONLONGSTRING_GROUND_TRUTH = { - 'AAAAAAAAAAAAAAAAAAAA': { - 'name': 'AAAAAAAAAAAAAAAAAAAA', - 'spouse': 'Bob', - }, - 'Bob': { - 'name': 'Bob', - 'spouse': 'AAAAAAAAAAAAAAAAAAAA', - }, -} - -PERSONLONGSTRING_KNOWS_GROUND_TRUTH = { - 'AAAAAAAAAAAAAAAAAAAA': ['Bob'], -} - +import ground_truth def test_to_torch_geometric_nodes_only(establish_connection): conn, _ = establish_connection @@ -295,6 +11,7 @@ def test_to_torch_geometric_nodes_only(establish_connection): torch_geometric_data, pos_to_idx, unconverted_properties, _ = res.get_as_torch_geometric() warnings_ground_truth = set([ "Property person.courseScoresPerTerm cannot be converted to Tensor (likely due to nested list of variable length). The property is marked as unconverted.", + "Property person.height of type FLOAT is not supported by torch_geometric. The property is marked as unconverted.", "Property person.lastJobDuration of type INTERVAL is not supported by torch_geometric. The property is marked as unconverted.", "Property person.registerTime of type TIMESTAMP is not supported by torch_geometric. The property is marked as unconverted.", "Property person.birthdate of type DATE is not supported by torch_geometric. The property is marked as unconverted.", @@ -302,75 +19,75 @@ def test_to_torch_geometric_nodes_only(establish_connection): "Property person.workedHours has an inconsistent shape. The property is marked as unconverted.", "Property person.usedNames of type STRING is not supported by torch_geometric. The property is marked as unconverted.", ]) - assert len(ws) == 7 + assert len(ws) == 8 for w in ws: assert str(w.message) in warnings_ground_truth assert torch_geometric_data.ID.shape == torch.Size([8]) assert torch_geometric_data.ID.dtype == torch.int64 for i in range(8): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] ]['ID'] == torch_geometric_data.ID[i].item() assert torch_geometric_data.gender.shape == torch.Size([8]) assert torch_geometric_data.gender.dtype == torch.int64 for i in range(8): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] ]['gender'] == torch_geometric_data.gender[i].item() assert torch_geometric_data.isStudent.shape == torch.Size([8]) assert torch_geometric_data.isStudent.dtype == torch.bool for i in range(8): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] ]['isStudent'] == torch_geometric_data.isStudent[i].item() assert torch_geometric_data.isWorker.shape == torch.Size([8]) assert torch_geometric_data.isWorker.dtype == torch.bool for i in range(8): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] ]['isWorker'] == torch_geometric_data.isWorker[i].item() assert torch_geometric_data.age.shape == torch.Size([8]) assert torch_geometric_data.age.dtype == torch.int64 for i in range(8): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] ]['age'] == torch_geometric_data.age[i].item() assert torch_geometric_data.eyeSight.shape == torch.Size([8]) assert torch_geometric_data.eyeSight.dtype == torch.float32 for i in range(8): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i]]['eyeSight'] - \ + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i]]['eyeSight'] - \ torch_geometric_data.eyeSight[i].item() < 1e-6 - assert len(unconverted_properties) == 7 + assert len(unconverted_properties) == 8 assert 'courseScoresPerTerm' in unconverted_properties for i in range(8): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] ]['courseScoresPerTerm'] == unconverted_properties['courseScoresPerTerm'][i] assert 'lastJobDuration' in unconverted_properties for i in range(8): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] ]['lastJobDuration'] == unconverted_properties['lastJobDuration'][i] assert 'registerTime' in unconverted_properties for i in range(8): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] ]['registerTime'] == unconverted_properties['registerTime'][i] assert 'birthdate' in unconverted_properties for i in range(8): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] ]['birthdate'] == unconverted_properties['birthdate'][i] assert 'fName' in unconverted_properties for i in range(8): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] ]['fName'] == unconverted_properties['fName'][i] assert 'usedNames' in unconverted_properties for i in range(8): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] ]['usedNames'] == unconverted_properties['usedNames'][i] assert 'workedHours' in unconverted_properties for i in range(8): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] ]['workedHours'] == unconverted_properties['workedHours'][i] @@ -383,6 +100,7 @@ def test_to_torch_geometric_homogeneous_graph(establish_connection): torch_geometric_data, pos_to_idx, unconverted_properties, edge_properties = res.get_as_torch_geometric() warnings_ground_truth = set([ "Property person.courseScoresPerTerm cannot be converted to Tensor (likely due to nested list of variable length). The property is marked as unconverted.", + "Property person.height of type FLOAT is not supported by torch_geometric. The property is marked as unconverted.", "Property person.lastJobDuration of type INTERVAL is not supported by torch_geometric. The property is marked as unconverted.", "Property person.registerTime of type TIMESTAMP is not supported by torch_geometric. The property is marked as unconverted.", "Property person.birthdate of type DATE is not supported by torch_geometric. The property is marked as unconverted.", @@ -390,75 +108,75 @@ def test_to_torch_geometric_homogeneous_graph(establish_connection): "Property person.workedHours has an inconsistent shape. The property is marked as unconverted.", "Property person.usedNames of type STRING is not supported by torch_geometric. The property is marked as unconverted.", ]) - assert len(ws) == 7 + assert len(ws) == 8 for w in ws: assert str(w.message) in warnings_ground_truth assert torch_geometric_data.ID.shape == torch.Size([7]) assert torch_geometric_data.ID.dtype == torch.int64 for i in range(7): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] ]['ID'] == torch_geometric_data.ID[i].item() assert torch_geometric_data.gender.shape == torch.Size([7]) assert torch_geometric_data.gender.dtype == torch.int64 for i in range(7): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] ]['gender'] == torch_geometric_data.gender[i].item() assert torch_geometric_data.isStudent.shape == torch.Size([7]) assert torch_geometric_data.isStudent.dtype == torch.bool for i in range(7): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] ]['isStudent'] == torch_geometric_data.isStudent[i].item() assert torch_geometric_data.isWorker.shape == torch.Size([7]) assert torch_geometric_data.isWorker.dtype == torch.bool for i in range(7): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] ]['isWorker'] == torch_geometric_data.isWorker[i].item() assert torch_geometric_data.age.shape == torch.Size([7]) assert torch_geometric_data.age.dtype == torch.int64 for i in range(7): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] ]['age'] == torch_geometric_data.age[i].item() assert torch_geometric_data.eyeSight.shape == torch.Size([7]) assert torch_geometric_data.eyeSight.dtype == torch.float32 for i in range(7): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i]]['eyeSight'] - \ + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i]]['eyeSight'] - \ torch_geometric_data.eyeSight[i].item() < 1e-6 - assert len(unconverted_properties) == 7 + assert len(unconverted_properties) == 8 assert 'courseScoresPerTerm' in unconverted_properties for i in range(7): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] ]['courseScoresPerTerm'] == unconverted_properties['courseScoresPerTerm'][i] assert 'lastJobDuration' in unconverted_properties for i in range(7): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] ]['lastJobDuration'] == unconverted_properties['lastJobDuration'][i] assert 'registerTime' in unconverted_properties for i in range(7): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] ]['registerTime'] == unconverted_properties['registerTime'][i] assert 'birthdate' in unconverted_properties for i in range(7): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] ]['birthdate'] == unconverted_properties['birthdate'][i] assert 'fName' in unconverted_properties for i in range(7): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] ]['fName'] == unconverted_properties['fName'][i] assert 'usedNames' in unconverted_properties for i in range(7): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] ]['usedNames'] == unconverted_properties['usedNames'][i] assert 'workedHours' in unconverted_properties for i in range(7): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i] ]['workedHours'] == unconverted_properties['workedHours'][i] assert torch_geometric_data.edge_index.shape == torch.Size([2, 14]) @@ -468,7 +186,7 @@ def test_to_torch_geometric_homogeneous_graph(establish_connection): assert src in pos_to_idx assert dst in pos_to_idx assert src != dst - assert pos_to_idx[dst] in TINY_SNB_KNOWS_GROUND_TRUTH[pos_to_idx[src]] + assert pos_to_idx[dst] in ground_truth.TINY_SNB_KNOWS_GROUND_TRUTH[pos_to_idx[src]] assert len(edge_properties) == 4 assert 'date' in edge_properties @@ -481,14 +199,14 @@ def test_to_torch_geometric_homogeneous_graph(establish_connection): ), torch_geometric_data.edge_index[1][i].item() orginal_src = pos_to_idx[src] orginal_dst = pos_to_idx[dst] - assert (orginal_src, orginal_dst) in TINY_SNB_KNOWS_PROPERTIES_GROUND_TRUTH - assert TINY_SNB_KNOWS_PROPERTIES_GROUND_TRUTH[( + assert (orginal_src, orginal_dst) in ground_truth.TINY_SNB_KNOWS_PROPERTIES_GROUND_TRUTH + assert ground_truth.TINY_SNB_KNOWS_PROPERTIES_GROUND_TRUTH[( orginal_src, orginal_dst)]['date'] == edge_properties['date'][i] - assert TINY_SNB_KNOWS_PROPERTIES_GROUND_TRUTH[( + assert ground_truth.TINY_SNB_KNOWS_PROPERTIES_GROUND_TRUTH[( orginal_src, orginal_dst)]['meetTime'] == edge_properties['meetTime'][i] - assert TINY_SNB_KNOWS_PROPERTIES_GROUND_TRUTH[( + assert ground_truth.TINY_SNB_KNOWS_PROPERTIES_GROUND_TRUTH[( orginal_src, orginal_dst)]['validInterval'] == edge_properties['validInterval'][i] - assert TINY_SNB_KNOWS_PROPERTIES_GROUND_TRUTH[( + assert ground_truth.TINY_SNB_KNOWS_PROPERTIES_GROUND_TRUTH[( orginal_src, orginal_dst)]['comments'] == edge_properties['comments'][i] @@ -500,9 +218,10 @@ def test_to_torch_geometric_heterogeneous_graph(establish_connection): with warnings.catch_warnings(record=True) as ws: torch_geometric_data, pos_to_idx, unconverted_properties, edge_properties = res.get_as_torch_geometric() - assert len(ws) == 9 + assert len(ws) == 10 warnings_ground_truth = set([ "Property organisation.name of type STRING is not supported by torch_geometric. The property is marked as unconverted.", + "Property person.height of type FLOAT is not supported by torch_geometric. The property is marked as unconverted.", "Property person.courseScoresPerTerm cannot be converted to Tensor (likely due to nested list of variable length). The property is marked as unconverted.", "Property person.lastJobDuration of type INTERVAL is not supported by torch_geometric. The property is marked as unconverted.", "Property person.registerTime of type TIMESTAMP is not supported by torch_geometric. The property is marked as unconverted.", @@ -519,64 +238,64 @@ def test_to_torch_geometric_heterogeneous_graph(establish_connection): assert torch_geometric_data['person'].ID.shape == torch.Size([4]) assert torch_geometric_data['person'].ID.dtype == torch.int64 for i in range(4): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx['person'][i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx['person'][i] ]['ID'] == torch_geometric_data['person'].ID[i].item() assert torch_geometric_data['person'].gender.shape == torch.Size([4]) assert torch_geometric_data['person'].gender.dtype == torch.int64 for i in range(4): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx['person'][i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx['person'][i] ]['gender'] == torch_geometric_data['person'].gender[i].item() assert torch_geometric_data['person'].isStudent.shape == torch.Size([4]) assert torch_geometric_data['person'].isStudent.dtype == torch.bool for i in range(4): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx['person'][i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx['person'][i] ]['isStudent'] == torch_geometric_data['person'].isStudent[i].item() assert torch_geometric_data['person'].isWorker.shape == torch.Size([4]) assert torch_geometric_data['person'].isWorker.dtype == torch.bool for i in range(4): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx['person'][i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx['person'][i] ]['isWorker'] == torch_geometric_data['person'].isWorker[i].item() assert torch_geometric_data['person'].age.shape == torch.Size([4]) assert torch_geometric_data['person'].age.dtype == torch.int64 for i in range(4): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx['person'][i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx['person'][i] ]['age'] == torch_geometric_data['person'].age[i].item() assert torch_geometric_data['person'].eyeSight.shape == torch.Size([4]) assert torch_geometric_data['person'].eyeSight.dtype == torch.float32 for i in range(4): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx['person'][i]]['eyeSight'] - \ + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx['person'][i]]['eyeSight'] - \ torch_geometric_data['person'].eyeSight[i].item() < 1e-6 assert 'person' in unconverted_properties - assert len(unconverted_properties['person']) == 6 + assert len(unconverted_properties['person']) == 7 assert 'courseScoresPerTerm' in unconverted_properties['person'] for i in range(4): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx['person'][i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx['person'][i] ]['courseScoresPerTerm'] == unconverted_properties['person']['courseScoresPerTerm'][i] assert 'lastJobDuration' in unconverted_properties['person'] for i in range(4): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx['person'][i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx['person'][i] ]['lastJobDuration'] == unconverted_properties['person']['lastJobDuration'][i] assert 'registerTime' in unconverted_properties['person'] for i in range(4): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx['person'][i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx['person'][i] ]['registerTime'] == unconverted_properties['person']['registerTime'][i] assert 'birthdate' in unconverted_properties['person'] for i in range(4): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx['person'][i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx['person'][i] ]['birthdate'] == unconverted_properties['person']['birthdate'][i] assert 'fName' in unconverted_properties['person'] for i in range(4): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx['person'][i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx['person'][i] ]['fName'] == unconverted_properties['person']['fName'][i] assert 'usedNames' in unconverted_properties['person'] for i in range(4): - assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx['person'][i] + assert ground_truth.TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx['person'][i] ]['usedNames'] == unconverted_properties['person']['usedNames'][i] assert torch_geometric_data['person', 'person'].edge_index.shape == torch.Size([ @@ -587,7 +306,7 @@ def test_to_torch_geometric_heterogeneous_graph(establish_connection): assert src in pos_to_idx['person'] assert dst in pos_to_idx['person'] assert src != dst - assert pos_to_idx['person'][dst] in TINY_SNB_KNOWS_GROUND_TRUTH[pos_to_idx['person'][src]] + assert pos_to_idx['person'][dst] in ground_truth.TINY_SNB_KNOWS_GROUND_TRUTH[pos_to_idx['person'][src]] assert len(edge_properties['person', 'person']) == 4 assert 'date' in edge_properties['person', 'person'] @@ -598,62 +317,62 @@ def test_to_torch_geometric_heterogeneous_graph(establish_connection): src, dst = torch_geometric_data['person', 'person'].edge_index[0][i].item( ), torch_geometric_data['person', 'person'].edge_index[1][i].item() original_src, original_dst = pos_to_idx['person'][src], pos_to_idx['person'][dst] - assert (original_src, original_dst) in TINY_SNB_KNOWS_PROPERTIES_GROUND_TRUTH - assert TINY_SNB_KNOWS_PROPERTIES_GROUND_TRUTH[( + assert (original_src, original_dst) in ground_truth.TINY_SNB_KNOWS_PROPERTIES_GROUND_TRUTH + assert ground_truth.TINY_SNB_KNOWS_PROPERTIES_GROUND_TRUTH[( original_src, original_dst)]['date'] == edge_properties['person', 'person']['date'][i] - assert TINY_SNB_KNOWS_PROPERTIES_GROUND_TRUTH[( + assert ground_truth.TINY_SNB_KNOWS_PROPERTIES_GROUND_TRUTH[( original_src, original_dst)]['meetTime'] == edge_properties['person', 'person']['meetTime'][i] - assert TINY_SNB_KNOWS_PROPERTIES_GROUND_TRUTH[( + assert ground_truth.TINY_SNB_KNOWS_PROPERTIES_GROUND_TRUTH[( original_src, original_dst)]['validInterval'] == edge_properties['person', 'person']['validInterval'][i] - assert TINY_SNB_KNOWS_PROPERTIES_GROUND_TRUTH[( + assert ground_truth.TINY_SNB_KNOWS_PROPERTIES_GROUND_TRUTH[( original_src, original_dst)]['comments'] == edge_properties['person', 'person']['comments'][i] assert torch_geometric_data['organisation'].ID.shape == torch.Size([2]) assert torch_geometric_data['organisation'].ID.dtype == torch.int64 for i in range(2): - assert TINY_SNB_ORGANISATIONS_GROUND_TRUTH[pos_to_idx['organisation'][i] + assert ground_truth.TINY_SNB_ORGANISATIONS_GROUND_TRUTH[pos_to_idx['organisation'][i] ]['ID'] == torch_geometric_data['organisation'].ID[i].item() assert torch_geometric_data['organisation'].orgCode.shape == torch.Size([ 2]) assert torch_geometric_data['organisation'].orgCode.dtype == torch.int64 for i in range(2): - assert TINY_SNB_ORGANISATIONS_GROUND_TRUTH[pos_to_idx['organisation'][i] + assert ground_truth.TINY_SNB_ORGANISATIONS_GROUND_TRUTH[pos_to_idx['organisation'][i] ]['orgCode'] == torch_geometric_data['organisation'].orgCode[i].item() assert torch_geometric_data['organisation'].mark.shape == torch.Size([2]) assert torch_geometric_data['organisation'].mark.dtype == torch.float32 for i in range(2): - assert TINY_SNB_ORGANISATIONS_GROUND_TRUTH[pos_to_idx['organisation'][i] + assert ground_truth.TINY_SNB_ORGANISATIONS_GROUND_TRUTH[pos_to_idx['organisation'][i] ]['mark'] - torch_geometric_data['organisation'].mark[i].item() < 1e-6 assert torch_geometric_data['organisation'].score.shape == torch.Size([2]) assert torch_geometric_data['organisation'].score.dtype == torch.int64 for i in range(2): - assert TINY_SNB_ORGANISATIONS_GROUND_TRUTH[pos_to_idx['organisation'][i] + assert ground_truth.TINY_SNB_ORGANISATIONS_GROUND_TRUTH[pos_to_idx['organisation'][i] ]['score'] - torch_geometric_data['organisation'].score[i].item() < 1e-6 assert torch_geometric_data['organisation'].rating.shape == torch.Size([2]) assert torch_geometric_data['organisation'].rating.dtype == torch.float32 for i in range(2): - assert TINY_SNB_ORGANISATIONS_GROUND_TRUTH[pos_to_idx['organisation'][i] + assert ground_truth.TINY_SNB_ORGANISATIONS_GROUND_TRUTH[pos_to_idx['organisation'][i] ]['rating'] - torch_geometric_data['organisation'].rating[i].item() < 1e-6 assert 'organisation' in unconverted_properties assert len(unconverted_properties['organisation']) == 3 assert 'name' in unconverted_properties['organisation'] for i in range(2): - assert TINY_SNB_ORGANISATIONS_GROUND_TRUTH[pos_to_idx['organisation'][i] + assert ground_truth.TINY_SNB_ORGANISATIONS_GROUND_TRUTH[pos_to_idx['organisation'][i] ]['name'] == unconverted_properties['organisation']['name'][i] assert 'history' in unconverted_properties['organisation'] for i in range(2): - assert TINY_SNB_ORGANISATIONS_GROUND_TRUTH[pos_to_idx['organisation'][i] + assert ground_truth.TINY_SNB_ORGANISATIONS_GROUND_TRUTH[pos_to_idx['organisation'][i] ]['history'] == unconverted_properties['organisation']['history'][i] assert 'licenseValidInterval' in unconverted_properties['organisation'] for i in range(2): - assert TINY_SNB_ORGANISATIONS_GROUND_TRUTH[pos_to_idx['organisation'][i] + assert ground_truth.TINY_SNB_ORGANISATIONS_GROUND_TRUTH[pos_to_idx['organisation'][i] ]['licenseValidInterval'] == unconverted_properties['organisation']['licenseValidInterval'][i] assert torch_geometric_data['person', 'organisation'].edge_index.shape == torch.Size([ @@ -664,18 +383,18 @@ def test_to_torch_geometric_heterogeneous_graph(establish_connection): assert src in pos_to_idx['person'] assert dst in pos_to_idx['organisation'] assert src != dst - assert pos_to_idx['organisation'][dst] in TINY_SNB_WORKS_AT_GROUND_TRUTH[pos_to_idx['person'][src]] - assert len(edge_properties['person', 'organisation']) == 1 + assert pos_to_idx['organisation'][dst] in ground_truth.TINY_SNB_WORKS_AT_GROUND_TRUTH[pos_to_idx['person'][src]] + assert len(edge_properties['person', 'organisation']) == 3 assert 'year' in edge_properties['person', 'organisation'] for i in range(2): src, dst = torch_geometric_data['person', 'organisation'].edge_index[0][i].item( ), torch_geometric_data['person', 'organisation'].edge_index[1][i].item() original_src, original_dst = pos_to_idx['person'][src], pos_to_idx['organisation'][dst] - assert TINY_SNB_WORKS_AT_PROPERTIES_GROUND_TRUTH[( + assert ground_truth.TINY_SNB_WORKS_AT_PROPERTIES_GROUND_TRUTH[( original_src, original_dst)]['year'] == edge_properties['person', 'organisation']['year'][i] -def test_to_torch_geometric_multi_dimensonal_lists(establish_connection): +def test_to_torch_geometric_multi_dimensional_lists(establish_connection): conn, _ = establish_connection query = "MATCH (t:tensor) RETURN t" @@ -691,9 +410,9 @@ def test_to_torch_geometric_multi_dimensonal_lists(establish_connection): for i in range(len(pos_to_idx)): idx = pos_to_idx[i] - bool_list.append(TENSOR_LIST_GROUND_TRUTH[idx]['boolTensor']) - float_list.append(TENSOR_LIST_GROUND_TRUTH[idx]['doubleTensor']) - int_list.append(TENSOR_LIST_GROUND_TRUTH[idx]['intTensor']) + bool_list.append(ground_truth.TENSOR_LIST_GROUND_TRUTH[idx]['boolTensor']) + float_list.append(ground_truth.TENSOR_LIST_GROUND_TRUTH[idx]['doubleTensor']) + int_list.append(ground_truth.TENSOR_LIST_GROUND_TRUTH[idx]['intTensor']) bool_tensor = torch.tensor(bool_list, dtype=torch.bool) float_tensor = torch.tensor(float_list, dtype=torch.float32) @@ -746,7 +465,7 @@ def test_to_torch_geometric_no_properties_converted(establish_connection): assert src in pos_to_idx['personLongString'] assert dst in pos_to_idx['personLongString'] assert src != dst - assert pos_to_idx['personLongString'][dst] in PERSONLONGSTRING_KNOWS_GROUND_TRUTH[pos_to_idx['personLongString'][src]] + assert pos_to_idx['personLongString'][dst] in ground_truth.PERSONLONGSTRING_KNOWS_GROUND_TRUTH[pos_to_idx['personLongString'][src]] assert len(unconverted_properties) == 1 assert len(unconverted_properties['personLongString']) == 2 @@ -754,10 +473,10 @@ def test_to_torch_geometric_no_properties_converted(establish_connection): assert 'spouse' in unconverted_properties['personLongString'] assert len(unconverted_properties['personLongString']['spouse']) == 2 for i in range(2): - assert PERSONLONGSTRING_GROUND_TRUTH[pos_to_idx['personLongString'][i] + assert ground_truth.PERSONLONGSTRING_GROUND_TRUTH[pos_to_idx['personLongString'][i] ]['spouse'] == unconverted_properties['personLongString']['spouse'][i] assert 'name' in unconverted_properties['personLongString'] for i in range(2): - assert PERSONLONGSTRING_GROUND_TRUTH[pos_to_idx['personLongString'][i] + assert ground_truth.PERSONLONGSTRING_GROUND_TRUTH[pos_to_idx['personLongString'][i] ]['name'] == unconverted_properties['personLongString']['name'][i]