Skip to content

Commit

Permalink
fix(flair_pipelines): Fix flair pipelines
Browse files Browse the repository at this point in the history
  • Loading branch information
ktagowski committed Apr 19, 2022
1 parent 362feb0 commit 9572f26
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 28 deletions.
4 changes: 3 additions & 1 deletion embeddings/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
DOWNSAMPLE_SPLITS_TYPE = Tuple[Optional[float], Optional[float], Optional[float]]
SAMPLE_MISSING_SPLITS_TYPE = Optional[Tuple[Optional[float], Optional[float]]]
FLAIR_DATASET_TRANSFORMATIONS_TYPE = Union[
Transformation[datasets.DatasetDict, Corpus], Transformation[Corpus, Corpus]
Transformation[datasets.DatasetDict, datasets.DatasetDict],
Transformation[datasets.DatasetDict, Corpus],
Transformation[Corpus, Corpus],
]
FLAIR_PERSISTERS_TYPE = Union[FlairConllPersister[Corpus], FlairPicklePersister[Corpus, Corpus]]
21 changes: 15 additions & 6 deletions embeddings/pipeline/flair_classification.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple

import datasets
from flair.data import Corpus
Expand All @@ -15,6 +15,7 @@
from embeddings.embedding.flair_loader import FlairDocumentPoolEmbeddingLoader
from embeddings.evaluator.text_classification_evaluator import TextClassificationEvaluator
from embeddings.model.flair_model import FlairModel
from embeddings.pipeline import FLAIR_DATASET_TRANSFORMATIONS_TYPE
from embeddings.pipeline.standard_pipeline import StandardPipeline
from embeddings.task.flair_task.text_classification import TextClassification
from embeddings.transformation.flair_transformation.classification_corpus_transformation import (
Expand All @@ -23,7 +24,10 @@
from embeddings.transformation.flair_transformation.split_sample_corpus_transformation import (
SampleSplitsFlairCorpusTransformation,
)
from embeddings.transformation.transformation import Transformation
from embeddings.transformation.hf_transformation.class_encode_column_transformation import (
ClassEncodeColumnTransformation,
)
from embeddings.transformation.transformation import DummyTransformation
from embeddings.utils.json_dict_persister import JsonPersister


Expand All @@ -45,14 +49,19 @@ def __init__(
sample_missing_splits: Optional[Tuple[Optional[float], Optional[float]]] = None,
seed: int = 441,
load_dataset_kwargs: Optional[Dict[str, Any]] = None,
encode_classes: bool = False,
):
output_path = Path(output_path)
dataset = Dataset(dataset_name, **load_dataset_kwargs if load_dataset_kwargs else {})
data_loader = HuggingFaceDataLoader()
transformation: Union[
Transformation[datasets.DatasetDict, Corpus], Transformation[Corpus, Corpus]
]
transformation = ClassificationCorpusTransformation(input_column_name, target_column_name)
transformation: FLAIR_DATASET_TRANSFORMATIONS_TYPE = DummyTransformation()
if encode_classes:
transformation = transformation.then(
ClassEncodeColumnTransformation(column=target_column_name)
)
transformation = transformation.then(
ClassificationCorpusTransformation(input_column_name, target_column_name)
)
if sample_missing_splits:
transformation = transformation.then(
SampleSplitsFlairCorpusTransformation(*sample_missing_splits, seed=seed)
Expand Down
21 changes: 14 additions & 7 deletions embeddings/pipeline/flair_pair_classification.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple

import datasets
from flair.data import Corpus
Expand All @@ -15,6 +15,7 @@
from embeddings.embedding.flair_loader import FlairDocumentPoolEmbeddingLoader
from embeddings.evaluator.text_classification_evaluator import TextClassificationEvaluator
from embeddings.model.flair_model import FlairModel
from embeddings.pipeline import FLAIR_DATASET_TRANSFORMATIONS_TYPE
from embeddings.pipeline.standard_pipeline import StandardPipeline
from embeddings.task.flair_task.text_pair_classification import TextPairClassification
from embeddings.transformation.flair_transformation.pair_classification_corpus_transformation import (
Expand All @@ -23,7 +24,10 @@
from embeddings.transformation.flair_transformation.split_sample_corpus_transformation import (
SampleSplitsFlairCorpusTransformation,
)
from embeddings.transformation.transformation import Transformation
from embeddings.transformation.hf_transformation.class_encode_column_transformation import (
ClassEncodeColumnTransformation,
)
from embeddings.transformation.transformation import DummyTransformation
from embeddings.utils.json_dict_persister import JsonPersister


Expand All @@ -45,15 +49,18 @@ def __init__(
sample_missing_splits: Optional[Tuple[Optional[float], Optional[float]]] = None,
seed: int = 441,
load_dataset_kwargs: Optional[Dict[str, Any]] = None,
encode_classes: bool = False,
):
output_path = Path(output_path)
dataset = Dataset(dataset_name, **load_dataset_kwargs if load_dataset_kwargs else {})
data_loader = HuggingFaceDataLoader()
transformation: Union[
Transformation[datasets.DatasetDict, Corpus], Transformation[Corpus, Corpus]
]
transformation = PairClassificationCorpusTransformation(
input_columns_names_pair, target_column_name
transformation: FLAIR_DATASET_TRANSFORMATIONS_TYPE = DummyTransformation()
if encode_classes:
transformation = transformation.then(
ClassEncodeColumnTransformation(column=target_column_name)
)
transformation = transformation.then(
PairClassificationCorpusTransformation(input_columns_names_pair, target_column_name)
)
if sample_missing_splits:
transformation = transformation.then(
Expand Down
28 changes: 20 additions & 8 deletions embeddings/pipeline/flair_preprocessing_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
from embeddings.transformation.flair_transformation.split_sample_corpus_transformation import (
SampleSplitsFlairCorpusTransformation,
)
from embeddings.transformation.hf_transformation.class_encode_column_transformation import (
ClassEncodeColumnTransformation,
)
from embeddings.transformation.transformation import DummyTransformation
from embeddings.utils.flair_corpus_persister import FlairConllPersister, FlairPicklePersister

Expand All @@ -62,6 +65,7 @@ class FlairPreprocessingPipeline(
ignore_test_subset: bool = False
seed: int = 441
load_dataset_kwargs: Optional[Dict[str, Any]] = None
encode_labels: bool = False

def __post_init__(self) -> None:
self.persister = self._get_persister()
Expand All @@ -71,7 +75,7 @@ def __post_init__(self) -> None:
super(FlairPreprocessingPipeline, self).__init__(dataset, data_loader, transformation)

@abc.abstractmethod
def _get_base_dataset_transformation(self) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE:
def _get_to_flair_dataset_transformation(self) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE:
pass

@abc.abstractmethod
Expand All @@ -89,17 +93,25 @@ def _get_dataloader(self, dataset: Dataset) -> FLAIR_DATALOADERS:

def _get_dataset_transformation(
self, data_loader: FLAIR_DATALOADERS
) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE:
) -> Optional[FLAIR_DATASET_TRANSFORMATIONS_TYPE]:
if isinstance(data_loader, (ConllFlairCorpusDataLoader, PickleFlairCorpusDataLoader)):
return DummyTransformation()
return None

return self._get_base_dataset_transformation()
return self._get_to_flair_dataset_transformation()

def _get_transformations(
self, data_loader: FLAIR_DATALOADERS
) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE:

transformation = self._get_dataset_transformation(data_loader)
transformation: FLAIR_DATASET_TRANSFORMATIONS_TYPE = DummyTransformation()
if self.encode_labels:
transformation = transformation.then(
ClassEncodeColumnTransformation(column=self.target_column_name)
)

to_flair_dataset_transformation = self._get_dataset_transformation(data_loader)
if to_flair_dataset_transformation:
transformation = transformation.then(to_flair_dataset_transformation)

if self.sample_missing_splits:
transformation = transformation.then(
Expand All @@ -126,7 +138,7 @@ class FlairTextClassificationPreprocessingPipeline(FlairPreprocessingPipeline):
def _get_persister(self) -> FLAIR_PERSISTERS_TYPE:
return FlairPicklePersister(self.persist_path)

def _get_base_dataset_transformation(self) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE:
def _get_to_flair_dataset_transformation(self) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE:
assert isinstance(self.input_column_name, str)
return ClassificationCorpusTransformation(
input_column_name=self.input_column_name,
Expand All @@ -138,7 +150,7 @@ def _get_base_dataset_transformation(self) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE
class FlairTextPairClassificationPreprocessingPipeline(
FlairTextClassificationPreprocessingPipeline
):
def _get_base_dataset_transformation(self) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE:
def _get_to_flair_dataset_transformation(self) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE:
assert isinstance(self.input_column_name, (tuple, list))
return PairClassificationCorpusTransformation(
input_columns_names_pair=self.input_column_name,
Expand All @@ -151,7 +163,7 @@ class FlairSequenceLabelingPreprocessingPipeline(FlairPreprocessingPipeline):
def _get_persister(self) -> FLAIR_PERSISTERS_TYPE:
return FlairConllPersister(self.persist_path)

def _get_base_dataset_transformation(self) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE:
def _get_to_flair_dataset_transformation(self) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE:
assert isinstance(self.input_column_name, str)
return ColumnCorpusTransformation(
input_column_name=self.input_column_name,
Expand Down
23 changes: 17 additions & 6 deletions embeddings/pipeline/flair_sequence_labeling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple

import datasets
from flair.data import Corpus
Expand All @@ -15,6 +15,7 @@
from embeddings.embedding.flair_loader import FlairWordEmbeddingLoader
from embeddings.evaluator.sequence_labeling_evaluator import SequenceLabelingEvaluator
from embeddings.model.flair_model import FlairModel
from embeddings.pipeline import FLAIR_DATASET_TRANSFORMATIONS_TYPE
from embeddings.pipeline.standard_pipeline import StandardPipeline
from embeddings.task.flair_task.sequence_labeling import SequenceLabeling
from embeddings.transformation.flair_transformation.column_corpus_transformation import (
Expand All @@ -23,7 +24,10 @@
from embeddings.transformation.flair_transformation.split_sample_corpus_transformation import (
SampleSplitsFlairCorpusTransformation,
)
from embeddings.transformation.transformation import Transformation
from embeddings.transformation.hf_transformation.class_encode_column_transformation import (
ClassEncodeColumnTransformation,
)
from embeddings.transformation.transformation import DummyTransformation
from embeddings.utils.json_dict_persister import JsonPersister


Expand All @@ -47,14 +51,21 @@ def __init__(
sample_missing_splits: Optional[Tuple[Optional[float], Optional[float]]] = None,
seed: int = 441,
load_dataset_kwargs: Optional[Dict[str, Any]] = None,
encode_classes: bool = True,
):
output_path = Path(output_path)
dataset = Dataset(dataset_name, **load_dataset_kwargs if load_dataset_kwargs else {})
data_loader = HuggingFaceDataLoader()
transformation: Union[
Transformation[datasets.DatasetDict, Corpus], Transformation[Corpus, Corpus]
]
transformation = ColumnCorpusTransformation(input_column_name, target_column_name)

transformation: FLAIR_DATASET_TRANSFORMATIONS_TYPE = DummyTransformation()
if encode_classes:
transformation = transformation.then(
ClassEncodeColumnTransformation(column=target_column_name)
)
transformation = transformation.then(
ColumnCorpusTransformation(input_column_name, target_column_name)
)

if sample_missing_splits:
transformation = transformation.then(
SampleSplitsFlairCorpusTransformation(*sample_missing_splits, seed=seed)
Expand Down

0 comments on commit 9572f26

Please sign in to comment.