From 8d6b5e90a246a1ae62fcd74844a6201a614f001a Mon Sep 17 00:00:00 2001 From: Elron Bandel Date: Mon, 29 Jan 2024 03:33:32 -0500 Subject: [PATCH] Add side affect operators and fix operators base classes Signed-off-by: Elron Bandel --- src/unitxt/artifact.py | 4 +-- src/unitxt/fusion.py | 6 ++-- src/unitxt/load.py | 6 ++-- src/unitxt/operator.py | 60 ++++++---------------------------------- src/unitxt/operators.py | 61 +++++++++++++++++++++++++++++++++++++++-- 5 files changed, 75 insertions(+), 62 deletions(-) diff --git a/src/unitxt/artifact.py b/src/unitxt/artifact.py index 90afe47ad0..5d1ab30eb2 100644 --- a/src/unitxt/artifact.py +++ b/src/unitxt/artifact.py @@ -134,8 +134,8 @@ def register_class(cls, artifact_class): if cls.is_registered_type(snake_case_key): assert ( - cls._class_register[snake_case_key] == artifact_class - ), f"Artifact class name must be unique, '{snake_case_key}' already exists for '{cls._class_register[snake_case_key]}'" + str(cls._class_register[snake_case_key]) == str(artifact_class) + ), f"Artifact class name must be unique, '{snake_case_key}' already exists for {cls._class_register[snake_case_key]}. Cannot be overriden by {artifact_class}." return snake_case_key diff --git a/src/unitxt/fusion.py b/src/unitxt/fusion.py index 9807aac5ea..f2bfe6d891 100644 --- a/src/unitxt/fusion.py +++ b/src/unitxt/fusion.py @@ -3,7 +3,7 @@ from typing import Generator, List, Optional from .dataclass import NonPositionalField -from .operator import SourceOperator, StreamSource +from .operator import SourceOperator from .random_utils import new_random_generator from .stream import MultiStream, Stream @@ -15,7 +15,7 @@ class BaseFusion(SourceOperator): include_splits: List of splits to include. If None, all splits are included. """ - origins: List[StreamSource] + origins: List[SourceOperator] include_splits: Optional[List[str]] = NonPositionalField(default=None) @abstractmethod @@ -73,7 +73,7 @@ class WeightedFusion(BaseFusion): max_total_examples: Total number of examples to return. If None, all examples are returned. """ - origins: List[StreamSource] = None + origins: List[SourceOperator] = None weights: List[float] = None max_total_examples: int = None diff --git a/src/unitxt/load.py b/src/unitxt/load.py index 9c2efd46c2..e985153713 100644 --- a/src/unitxt/load.py +++ b/src/unitxt/load.py @@ -3,12 +3,12 @@ from datasets import DatasetDict from .artifact import fetch_artifact -from .operator import StreamSource +from .operator import SourceOperator -def load_dataset(source: Union[StreamSource, str]) -> DatasetDict: +def load_dataset(source: Union[SourceOperator, str]) -> DatasetDict: assert isinstance( - source, (StreamSource, str) + source, (SourceOperator, str) ), "source must be a StreamSource or a string" if isinstance(source, str): source, _ = fetch_artifact(source) diff --git a/src/unitxt/operator.py b/src/unitxt/operator.py index 9f63a5bb87..ee04102701 100644 --- a/src/unitxt/operator.py +++ b/src/unitxt/operator.py @@ -1,11 +1,8 @@ import re -import zipfile from abc import abstractmethod from dataclasses import field from typing import Any, Dict, Generator, List, Optional -import requests - from .artifact import Artifact from .dataclass import NonPositionalField from .stream import MultiStream, Stream @@ -15,43 +12,6 @@ class Operator(Artifact): pass -class DownloadError(Exception): - def __init__( - self, - message, - ): - self.__super__(message) - - -class UnexpectedHttpCodeError(Exception): - def __init__(self, http_code): - self.__super__(f"unexpected http code {http_code}") - - -class DownloadOperator(Operator): - source: str - target: str - - def __call__(self): - try: - response = requests.get(self.source, allow_redirects=True) - except Exception as e: - raise DownloadError(f"Unabled to download {self.source}") from e - if response.status_code != 200: - raise UnexpectedHttpCodeError(response.status_code) - with open(self.target, "wb") as f: - f.write(response.content) - - -class ZipExtractorOperator(Operator): - zip_file: str - target_dir: str - - def __call__(self): - with zipfile.ZipFile(self.zip_file) as zf: - zf.extractall(self.target_dir) - - class OperatorError(Exception): def __init__(self, exception: Exception, operators: List[Operator]): super().__init__( @@ -97,21 +57,19 @@ def __call__(self, streams: Optional[MultiStream] = None) -> MultiStream: """ -class StreamSource(StreamingOperator): - """A class representing a stream source operator in the streaming system. - - A stream source operator is a special type of `StreamingOperator` that generates a data stream without taking any input streams. It serves as the starting point in a stream processing pipeline, providing the initial data that other operators in the pipeline can process. - - When called, a `StreamSource` should generate a `MultiStream`. This behavior must be implemented by any classes that inherit from `StreamSource`. +class SideEffectOperator(StreamingOperator): + """Base class for operators that does not affect the stream.""" - """ + def __call__(self, streams: Optional[MultiStream] = None) -> MultiStream: + self.process() + return streams @abstractmethod - def __call__(self) -> MultiStream: + def process() -> None: pass -class SourceOperator(StreamSource): +class SourceOperator(StreamingOperator): """A class representing a source operator in the streaming system. A source operator is responsible for generating the data stream from some source, such as a database or a file. @@ -126,7 +84,7 @@ class SourceOperator(StreamSource): caching: bool = NonPositionalField(default=None) - def __call__(self) -> MultiStream: + def __call__(self, multi_stream: Optional[MultiStream] = None) -> MultiStream: multi_stream = self.process() if self.caching is not None: multi_stream.set_caching(self.caching) @@ -137,7 +95,7 @@ def process(self) -> MultiStream: pass -class StreamInitializerOperator(StreamSource): +class StreamInitializerOperator(SourceOperator): """A class representing a stream initializer operator in the streaming system. A stream initializer operator is a special type of `StreamSource` that is capable of taking parameters during the stream generation process. This can be useful in situations where the stream generation process needs to be customized or configured based on certain parameters. diff --git a/src/unitxt/operators.py b/src/unitxt/operators.py index 09d09ecaf7..3339eff19f 100644 --- a/src/unitxt/operators.py +++ b/src/unitxt/operators.py @@ -36,6 +36,7 @@ import operator import os import uuid +import zipfile from abc import abstractmethod from collections import Counter from copy import deepcopy @@ -54,6 +55,8 @@ Union, ) +import requests + from .artifact import Artifact, fetch_artifact from .dataclass import NonPositionalField from .dict_utils import dict_delete, dict_get, dict_set, is_subpath @@ -62,12 +65,13 @@ MultiStreamOperator, PagedStreamOperator, SequentialOperator, + SideEffectOperator, SingleStreamOperator, SingleStreamReducer, + SourceOperator, StreamingOperator, StreamInitializerOperator, StreamInstanceOperator, - StreamSource, ) from .random_utils import new_random_generator from .stream import Stream @@ -89,7 +93,7 @@ def process(self, iterables: Dict[str, Iterable]) -> MultiStream: return MultiStream.from_iterables(iterables) -class IterableSource(StreamSource): +class IterableSource(SourceOperator): """Creates a MultiStream from a dict of named iterables. It is a callable. @@ -105,7 +109,7 @@ class IterableSource(StreamSource): iterables: Dict[str, Iterable] - def __call__(self) -> MultiStream: + def process(self) -> MultiStream: return MultiStream.from_iterables(self.iterables) @@ -1782,3 +1786,54 @@ def signature(self, instance): if total_len < val: return i return i + 1 + + +class DownloadError(Exception): + def __init__( + self, + message, + ): + self.__super__(message) + + +class UnexpectedHttpCodeError(Exception): + def __init__(self, http_code): + self.__super__(f"unexpected http code {http_code}") + + +class DownloadOperator(SideEffectOperator): + """Operator for downloading a file from a given URL to a specified local path. + + Attributes: + source (str): URL of the file to be downloaded. + target (str): Local path where the downloaded file should be saved. + """ + + source: str + target: str + + def process(self): + try: + response = requests.get(self.source, allow_redirects=True) + except Exception as e: + raise DownloadError(f"Unabled to download {self.source}") from e + if response.status_code != 200: + raise UnexpectedHttpCodeError(response.status_code) + with open(self.target, "wb") as f: + f.write(response.content) + + +class ExtractZipFile(SideEffectOperator): + """Operator for extracting files from a zip archive. + + Attributes: + zip_file (str): Path of the zip file to be extracted. + target_dir (str): Directory where the contents of the zip file will be extracted. + """ + + zip_file: str + target_dir: str + + def process(self): + with zipfile.ZipFile(self.zip_file) as zf: + zf.extractall(self.target_dir)