Skip to content

Commit

Permalink
Add side affect operators and fix operators base classes
Browse files Browse the repository at this point in the history
Signed-off-by: Elron Bandel <elron.bandel@ibm.com>
  • Loading branch information
elronbandel committed Jan 29, 2024
1 parent 549e3ee commit 8d6b5e9
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 62 deletions.
4 changes: 2 additions & 2 deletions src/unitxt/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ def register_class(cls, artifact_class):

if cls.is_registered_type(snake_case_key):
assert (
cls._class_register[snake_case_key] == artifact_class
), f"Artifact class name must be unique, '{snake_case_key}' already exists for '{cls._class_register[snake_case_key]}'"
str(cls._class_register[snake_case_key]) == str(artifact_class)
), f"Artifact class name must be unique, '{snake_case_key}' already exists for {cls._class_register[snake_case_key]}. Cannot be overriden by {artifact_class}."

return snake_case_key

Expand Down
6 changes: 3 additions & 3 deletions src/unitxt/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Generator, List, Optional

from .dataclass import NonPositionalField
from .operator import SourceOperator, StreamSource
from .operator import SourceOperator
from .random_utils import new_random_generator
from .stream import MultiStream, Stream

Expand All @@ -15,7 +15,7 @@ class BaseFusion(SourceOperator):
include_splits: List of splits to include. If None, all splits are included.
"""

origins: List[StreamSource]
origins: List[SourceOperator]
include_splits: Optional[List[str]] = NonPositionalField(default=None)

@abstractmethod
Expand Down Expand Up @@ -73,7 +73,7 @@ class WeightedFusion(BaseFusion):
max_total_examples: Total number of examples to return. If None, all examples are returned.
"""

origins: List[StreamSource] = None
origins: List[SourceOperator] = None
weights: List[float] = None
max_total_examples: int = None

Expand Down
6 changes: 3 additions & 3 deletions src/unitxt/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from datasets import DatasetDict

from .artifact import fetch_artifact
from .operator import StreamSource
from .operator import SourceOperator


def load_dataset(source: Union[StreamSource, str]) -> DatasetDict:
def load_dataset(source: Union[SourceOperator, str]) -> DatasetDict:
assert isinstance(
source, (StreamSource, str)
source, (SourceOperator, str)
), "source must be a StreamSource or a string"
if isinstance(source, str):
source, _ = fetch_artifact(source)
Expand Down
60 changes: 9 additions & 51 deletions src/unitxt/operator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import re
import zipfile
from abc import abstractmethod
from dataclasses import field
from typing import Any, Dict, Generator, List, Optional

import requests

from .artifact import Artifact
from .dataclass import NonPositionalField
from .stream import MultiStream, Stream
Expand All @@ -15,43 +12,6 @@ class Operator(Artifact):
pass


class DownloadError(Exception):
def __init__(
self,
message,
):
self.__super__(message)


class UnexpectedHttpCodeError(Exception):
def __init__(self, http_code):
self.__super__(f"unexpected http code {http_code}")


class DownloadOperator(Operator):
source: str
target: str

def __call__(self):
try:
response = requests.get(self.source, allow_redirects=True)
except Exception as e:
raise DownloadError(f"Unabled to download {self.source}") from e
if response.status_code != 200:
raise UnexpectedHttpCodeError(response.status_code)
with open(self.target, "wb") as f:
f.write(response.content)


class ZipExtractorOperator(Operator):
zip_file: str
target_dir: str

def __call__(self):
with zipfile.ZipFile(self.zip_file) as zf:
zf.extractall(self.target_dir)


class OperatorError(Exception):
def __init__(self, exception: Exception, operators: List[Operator]):
super().__init__(
Expand Down Expand Up @@ -97,21 +57,19 @@ def __call__(self, streams: Optional[MultiStream] = None) -> MultiStream:
"""


class StreamSource(StreamingOperator):
"""A class representing a stream source operator in the streaming system.
A stream source operator is a special type of `StreamingOperator` that generates a data stream without taking any input streams. It serves as the starting point in a stream processing pipeline, providing the initial data that other operators in the pipeline can process.
When called, a `StreamSource` should generate a `MultiStream`. This behavior must be implemented by any classes that inherit from `StreamSource`.
class SideEffectOperator(StreamingOperator):
"""Base class for operators that does not affect the stream."""

"""
def __call__(self, streams: Optional[MultiStream] = None) -> MultiStream:
self.process()
return streams

@abstractmethod
def __call__(self) -> MultiStream:
def process() -> None:
pass


class SourceOperator(StreamSource):
class SourceOperator(StreamingOperator):
"""A class representing a source operator in the streaming system.
A source operator is responsible for generating the data stream from some source, such as a database or a file.
Expand All @@ -126,7 +84,7 @@ class SourceOperator(StreamSource):

caching: bool = NonPositionalField(default=None)

def __call__(self) -> MultiStream:
def __call__(self, multi_stream: Optional[MultiStream] = None) -> MultiStream:
multi_stream = self.process()
if self.caching is not None:
multi_stream.set_caching(self.caching)
Expand All @@ -137,7 +95,7 @@ def process(self) -> MultiStream:
pass


class StreamInitializerOperator(StreamSource):
class StreamInitializerOperator(SourceOperator):
"""A class representing a stream initializer operator in the streaming system.
A stream initializer operator is a special type of `StreamSource` that is capable of taking parameters during the stream generation process. This can be useful in situations where the stream generation process needs to be customized or configured based on certain parameters.
Expand Down
61 changes: 58 additions & 3 deletions src/unitxt/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import operator
import os
import uuid
import zipfile
from abc import abstractmethod
from collections import Counter
from copy import deepcopy
Expand All @@ -54,6 +55,8 @@
Union,
)

import requests

from .artifact import Artifact, fetch_artifact
from .dataclass import NonPositionalField
from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
Expand All @@ -62,12 +65,13 @@
MultiStreamOperator,
PagedStreamOperator,
SequentialOperator,
SideEffectOperator,
SingleStreamOperator,
SingleStreamReducer,
SourceOperator,
StreamingOperator,
StreamInitializerOperator,
StreamInstanceOperator,
StreamSource,
)
from .random_utils import new_random_generator
from .stream import Stream
Expand All @@ -89,7 +93,7 @@ def process(self, iterables: Dict[str, Iterable]) -> MultiStream:
return MultiStream.from_iterables(iterables)


class IterableSource(StreamSource):
class IterableSource(SourceOperator):
"""Creates a MultiStream from a dict of named iterables.
It is a callable.
Expand All @@ -105,7 +109,7 @@ class IterableSource(StreamSource):

iterables: Dict[str, Iterable]

def __call__(self) -> MultiStream:
def process(self) -> MultiStream:
return MultiStream.from_iterables(self.iterables)


Expand Down Expand Up @@ -1782,3 +1786,54 @@ def signature(self, instance):
if total_len < val:
return i
return i + 1


class DownloadError(Exception):
def __init__(
self,
message,
):
self.__super__(message)


class UnexpectedHttpCodeError(Exception):
def __init__(self, http_code):
self.__super__(f"unexpected http code {http_code}")


class DownloadOperator(SideEffectOperator):
"""Operator for downloading a file from a given URL to a specified local path.
Attributes:
source (str): URL of the file to be downloaded.
target (str): Local path where the downloaded file should be saved.
"""

source: str
target: str

def process(self):
try:
response = requests.get(self.source, allow_redirects=True)
except Exception as e:
raise DownloadError(f"Unabled to download {self.source}") from e
if response.status_code != 200:
raise UnexpectedHttpCodeError(response.status_code)
with open(self.target, "wb") as f:
f.write(response.content)


class ExtractZipFile(SideEffectOperator):
"""Operator for extracting files from a zip archive.
Attributes:
zip_file (str): Path of the zip file to be extracted.
target_dir (str): Directory where the contents of the zip file will be extracted.
"""

zip_file: str
target_dir: str

def process(self):
with zipfile.ZipFile(self.zip_file) as zf:
zf.extractall(self.target_dir)

0 comments on commit 8d6b5e9

Please sign in to comment.