Skip to content

Commit

Permalink
refactor: data handler super class
Browse files Browse the repository at this point in the history
  • Loading branch information
david20571015 committed Jun 7, 2024
1 parent 23f23f4 commit 303848b
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 47 deletions.
15 changes: 15 additions & 0 deletions sync_crawler/handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import Optional


class DataReader(ABC):
@abstractmethod
def read(self, num: Optional[int] = None) -> Iterable:
pass


class DataWriter(ABC):
@abstractmethod
def write(self, data: Iterable):
pass
3 changes: 2 additions & 1 deletion sync_crawler/store/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Write and read intermediate data to/from a local database."""

from sync_crawler.store.lmdb_store import LmdbStore
from .base_store import BaseStore as BaseStore
from .lmdb_store import LmdbStore

__all__ = ["LmdbStore"]
37 changes: 19 additions & 18 deletions sync_crawler/store/base_store.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,31 @@
import abc
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import Optional, override

from sync_crawler.handlers import DataReader, DataWriter
from sync_crawler.models import News


class BaseStore(abc.ABC):
"""Base class for store implementation."""

@abc.abstractmethod
def put(self, news: Iterable[News]):
"""Store data to storage.
class BaseStore(DataReader, DataWriter, ABC):
@override
@abstractmethod
def read(self, num: Optional[int] = None) -> Iterable[News]:
"""Read data from store.
Args:
news: News to be stored.
num: Number of data to read. Defaults to None, read all data.
Returns:
Iterable of News.
"""
raise NotImplementedError
pass

@abc.abstractmethod
def pop(self, nums=1) -> Iterable[News]:
"""Fetch data from store then delete them.
@override
@abstractmethod
def write(self, news: Iterable[News]):
"""Write data to store.
Args:
nums: Number of data to be popped. If the remaining data is less than `nums`, all
remaining data will be popped.
Returns:
List of popped data.
news: Iterable of News to be written.
"""
raise NotImplementedError
pass
15 changes: 5 additions & 10 deletions sync_crawler/store/lmdb_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import uuid
from collections.abc import Callable, Iterable
from contextlib import closing
from typing import override
from typing import Optional, override

import lmdb

Expand Down Expand Up @@ -36,7 +36,7 @@ def __init__(
self._key_factory = key_factory

@override
def put(self, news: Iterable[News]):
def write(self, news: Iterable[News]):
key_value_pairs = ((self._key_factory(ns), pickle.dumps(ns)) for ns in news)

with (
Expand All @@ -46,18 +46,13 @@ def put(self, news: Iterable[News]):
cur.putmulti(key_value_pairs)

@override
def pop(self, nums=1) -> Iterable[News]:
values: Iterable[News] = []

def read(self, num: Optional[int] = None) -> Iterable[News]:
with (
self._env.begin(write=True) as txn,
closing(txn.cursor()) as cur,
):
for key, value in itertools.islice(cur, nums):
values.append(pickle.loads(value))
txn.delete(key)

return values
for key, value in itertools.islice(cur, num):
yield pickle.loads(cur.pop(key))

def __del__(self):
self._env.close()
5 changes: 3 additions & 2 deletions sync_crawler/writer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Store data to a host database."""

from sync_crawler.writer.chromadb_writer import ChromaDBWriter
from sync_crawler.writer.mongodb_writer import MongoDBWriter
from .base_writer import BaseWriter as BaseWriter
from .chromadb_writer import ChromaDBWriter
from .mongodb_writer import MongoDBWriter

__all__ = ["ChromaDBWriter", "MongoDBWriter"]
22 changes: 10 additions & 12 deletions sync_crawler/writer/base_writer.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,25 @@
import abc
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import Protocol
from typing import Protocol, override

from sync_crawler.handlers import DataWriter
from sync_crawler.models import News


class SupportsStr(Protocol):
def __str__(self) -> str: ...


class BaseWriter(abc.ABC):
"""Base class for writer implementation."""

@abc.abstractmethod
def put(
class BaseWriter(DataWriter, ABC):
@override
@abstractmethod
def write(
self,
ids: Iterable[SupportsStr],
news: Iterable[News],
news_with_id: Iterable[tuple[SupportsStr, News]],
):
"""Store data to storage.
Args:
ids: Object IDs of each news.
news: News to be stored.
news_with_id: Iterable of tuple of id and News.
"""
raise NotImplementedError
pass
4 changes: 2 additions & 2 deletions sync_crawler/writer/chromadb_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
)

@override
def put(self, ids, news):
def write(self, news_with_id):
docs = [
Document(
doc_id=str(id_),
Expand All @@ -51,6 +51,6 @@ def put(self, ids, news):
excluded_embed_metadata_keys=News.excluded_metadata_keys,
excluded_llm_metadata_keys=News.excluded_metadata_keys,
)
for id_, ns in zip(ids, news)
for id_, ns in news_with_id
]
self._index.insert_nodes(docs)
4 changes: 2 additions & 2 deletions sync_crawler/writer/mongodb_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def __init__(
self._collection = self._client[database][collection]

@override
def put(self, ids, news):
def write(self, news_with_id):
news_dicts = (
{"_id": ObjectId(str(_id)), **ns.model_dump()} for _id, ns in zip(ids, news)
{"_id": ObjectId(str(_id)), **ns.model_dump()} for _id, ns in news_with_id
)
self._collection.insert_many(news_dicts)

0 comments on commit 303848b

Please sign in to comment.