Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/update pipelines #230

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions embeddings/config/flair_config_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
class FlairTextClassificationConfigSpaceMapping:
LOAD_MODEL_KEYS_MAPPING: ClassVar[Mapping[str, Set[str]]] = MappingProxyType(
{
"FlairDocumentCNNEmbeddings": {
"FlairDocumentRNNEmbeddings": {
"hidden_size",
"rnn_type",
"rnn_layers",
Expand All @@ -27,7 +27,7 @@ class FlairTextClassificationConfigSpaceMapping:
"word_dropout",
"reproject_words",
},
"FlairDocumentRNNEmbeddings": {
"FlairDocumentCNNEmbeddings": {
"cnn_pool_kernels",
"dropout",
"word_dropout",
Expand Down
4 changes: 3 additions & 1 deletion embeddings/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
DOWNSAMPLE_SPLITS_TYPE = Tuple[Optional[float], Optional[float], Optional[float]]
SAMPLE_MISSING_SPLITS_TYPE = Optional[Tuple[Optional[float], Optional[float]]]
FLAIR_DATASET_TRANSFORMATIONS_TYPE = Union[
Transformation[datasets.DatasetDict, Corpus], Transformation[Corpus, Corpus]
Transformation[datasets.DatasetDict, datasets.DatasetDict],
Transformation[datasets.DatasetDict, Corpus],
Transformation[Corpus, Corpus],
]
FLAIR_PERSISTERS_TYPE = Union[FlairConllPersister[Corpus], FlairPicklePersister[Corpus, Corpus]]
21 changes: 15 additions & 6 deletions embeddings/pipeline/flair_classification.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple

import datasets
from flair.data import Corpus
Expand All @@ -15,6 +15,7 @@
from embeddings.embedding.flair_loader import FlairDocumentPoolEmbeddingLoader
from embeddings.evaluator.text_classification_evaluator import TextClassificationEvaluator
from embeddings.model.flair_model import FlairModel
from embeddings.pipeline import FLAIR_DATASET_TRANSFORMATIONS_TYPE
from embeddings.pipeline.standard_pipeline import StandardPipeline
from embeddings.task.flair_task.text_classification import TextClassification
from embeddings.transformation.flair_transformation.classification_corpus_transformation import (
Expand All @@ -23,7 +24,10 @@
from embeddings.transformation.flair_transformation.split_sample_corpus_transformation import (
SampleSplitsFlairCorpusTransformation,
)
from embeddings.transformation.transformation import Transformation
from embeddings.transformation.hf_transformation.class_encode_column_transformation import (
ClassEncodeColumnTransformation,
)
from embeddings.transformation.transformation import DummyTransformation
from embeddings.utils.json_dict_persister import JsonPersister


Expand All @@ -45,14 +49,19 @@ def __init__(
sample_missing_splits: Optional[Tuple[Optional[float], Optional[float]]] = None,
seed: int = 441,
load_dataset_kwargs: Optional[Dict[str, Any]] = None,
encode_classes: bool = False,
):
output_path = Path(output_path)
dataset = Dataset(dataset_name, **load_dataset_kwargs if load_dataset_kwargs else {})
data_loader = HuggingFaceDataLoader()
transformation: Union[
Transformation[datasets.DatasetDict, Corpus], Transformation[Corpus, Corpus]
]
transformation = ClassificationCorpusTransformation(input_column_name, target_column_name)
transformation: FLAIR_DATASET_TRANSFORMATIONS_TYPE = DummyTransformation()
if encode_classes:
transformation = transformation.then(
ClassEncodeColumnTransformation(column=target_column_name)
)
transformation = transformation.then(
ClassificationCorpusTransformation(input_column_name, target_column_name)
)
if sample_missing_splits:
transformation = transformation.then(
SampleSplitsFlairCorpusTransformation(*sample_missing_splits, seed=seed)
Expand Down
26 changes: 26 additions & 0 deletions embeddings/pipeline/flair_hps_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC
from copy import deepcopy
from dataclasses import dataclass, field
from pathlib import Path
from tempfile import TemporaryDirectory
Expand Down Expand Up @@ -73,6 +74,15 @@ class _OptimizedFlairPipelineDefaultsBase(_HuggingFaceOptimizedPipelineDefaultsB
init=False, default_factory=TemporaryDirectory
)

@staticmethod
def _revert_default_hps_task_train_kwargs(
task_train_kwargs: Dict[str, ParameterValues]
) -> Dict[str, ParameterValues]:
out = deepcopy(task_train_kwargs)
out["param_selection_mode"] = False
out["save_final_model"] = True
return out


# Mypy currently properly don't handle dataclasses with abstract methods https://github.com/python/mypy/issues/5374
@dataclass # type: ignore
Expand Down Expand Up @@ -157,6 +167,12 @@ def _get_metadata(self, parameters: SampledParameters) -> FlairClassificationPip
task_train_kwargs,
load_model_kwargs,
) = self._pop_sampled_parameters(parameters=parameters)

task_train_kwargs = (
OptimizedFlairClassificationPipeline._revert_default_hps_task_train_kwargs(
task_train_kwargs
)
)
metadata: FlairClassificationPipelineMetadata = {
"embedding_name": embedding_name,
"dataset_name": str(self.dataset_name_or_path),
Expand Down Expand Up @@ -257,6 +273,11 @@ def _get_metadata(
task_train_kwargs,
load_model_kwargs,
) = self._pop_sampled_parameters(parameters=parameters)
task_train_kwargs = (
OptimizedFlairPairClassificationPipeline._revert_default_hps_task_train_kwargs(
task_train_kwargs
)
)
metadata: FlairPairClassificationPipelineMetadata = {
"embedding_name": embedding_name,
"dataset_name": str(self.dataset_name_or_path),
Expand Down Expand Up @@ -384,6 +405,11 @@ def _get_metadata(self, parameters: SampledParameters) -> FlairSequenceLabelingP
task_train_kwargs,
task_model_kwargs,
) = self._pop_sampled_parameters(parameters)
task_train_kwargs = (
OptimizedFlairSequenceLabelingPipeline._revert_default_hps_task_train_kwargs(
task_train_kwargs
)
)
metadata: FlairSequenceLabelingPipelineMetadata = {
"embedding_name": embedding_name,
"dataset_name": str(self.dataset_name_or_path),
Expand Down
21 changes: 14 additions & 7 deletions embeddings/pipeline/flair_pair_classification.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple

import datasets
from flair.data import Corpus
Expand All @@ -15,6 +15,7 @@
from embeddings.embedding.flair_loader import FlairDocumentPoolEmbeddingLoader
from embeddings.evaluator.text_classification_evaluator import TextClassificationEvaluator
from embeddings.model.flair_model import FlairModel
from embeddings.pipeline import FLAIR_DATASET_TRANSFORMATIONS_TYPE
from embeddings.pipeline.standard_pipeline import StandardPipeline
from embeddings.task.flair_task.text_pair_classification import TextPairClassification
from embeddings.transformation.flair_transformation.pair_classification_corpus_transformation import (
Expand All @@ -23,7 +24,10 @@
from embeddings.transformation.flair_transformation.split_sample_corpus_transformation import (
SampleSplitsFlairCorpusTransformation,
)
from embeddings.transformation.transformation import Transformation
from embeddings.transformation.hf_transformation.class_encode_column_transformation import (
ClassEncodeColumnTransformation,
)
from embeddings.transformation.transformation import DummyTransformation
from embeddings.utils.json_dict_persister import JsonPersister


Expand All @@ -45,15 +49,18 @@ def __init__(
sample_missing_splits: Optional[Tuple[Optional[float], Optional[float]]] = None,
seed: int = 441,
load_dataset_kwargs: Optional[Dict[str, Any]] = None,
encode_classes: bool = False,
):
output_path = Path(output_path)
dataset = Dataset(dataset_name, **load_dataset_kwargs if load_dataset_kwargs else {})
data_loader = HuggingFaceDataLoader()
transformation: Union[
Transformation[datasets.DatasetDict, Corpus], Transformation[Corpus, Corpus]
]
transformation = PairClassificationCorpusTransformation(
input_columns_names_pair, target_column_name
transformation: FLAIR_DATASET_TRANSFORMATIONS_TYPE = DummyTransformation()
if encode_classes:
transformation = transformation.then(
ClassEncodeColumnTransformation(column=target_column_name)
)
transformation = transformation.then(
PairClassificationCorpusTransformation(input_columns_names_pair, target_column_name)
)
if sample_missing_splits:
transformation = transformation.then(
Expand Down
36 changes: 26 additions & 10 deletions embeddings/pipeline/flair_preprocessing_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
from embeddings.transformation.flair_transformation.split_sample_corpus_transformation import (
SampleSplitsFlairCorpusTransformation,
)
from embeddings.transformation.hf_transformation.class_encode_column_transformation import (
ClassEncodeColumnTransformation,
)
from embeddings.transformation.transformation import DummyTransformation
from embeddings.utils.flair_corpus_persister import FlairConllPersister, FlairPicklePersister

Expand All @@ -62,6 +65,7 @@ class FlairPreprocessingPipeline(
ignore_test_subset: bool = False
seed: int = 441
load_dataset_kwargs: Optional[Dict[str, Any]] = None
encode_labels: bool = False

def __post_init__(self) -> None:
self.persister = self._get_persister()
Expand All @@ -71,7 +75,7 @@ def __post_init__(self) -> None:
super(FlairPreprocessingPipeline, self).__init__(dataset, data_loader, transformation)

@abc.abstractmethod
def _get_base_dataset_transformation(self) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE:
def _get_to_flair_dataset_transformation(self) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE:
pass

@abc.abstractmethod
Expand All @@ -81,25 +85,37 @@ def _get_persister(self) -> FLAIR_PERSISTERS_TYPE:
def _get_dataset(self) -> Dataset:
return Dataset(
self.dataset_name_or_path,
**self.load_dataset_kwargs if self.load_dataset_kwargs else {}
**self.load_dataset_kwargs if self.load_dataset_kwargs else {},
)

def _get_dataloader(self, dataset: Dataset) -> FLAIR_DATALOADERS:
return get_flair_dataloader(dataset)

def _get_dataset_transformation(
self, data_loader: FLAIR_DATALOADERS
) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE:
) -> Optional[FLAIR_DATASET_TRANSFORMATIONS_TYPE]:
if isinstance(data_loader, (ConllFlairCorpusDataLoader, PickleFlairCorpusDataLoader)):
return DummyTransformation()
return None

return self._get_base_dataset_transformation()
return self._get_to_flair_dataset_transformation()

def _get_transformations(
self, data_loader: FLAIR_DATALOADERS
) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE:

transformation = self._get_dataset_transformation(data_loader)
transformation: FLAIR_DATASET_TRANSFORMATIONS_TYPE = DummyTransformation()
if self.encode_labels:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider early exit, without extra nested if?
if self.encode_labels and isinstance(data_loader, (ConllFlairCorpusDataLoader, PickleFlairCorpusDataLoader)):
raise ValueError(
"ClassEncodeColumnTransformation transformation is unavailable for Flair DataLoaders. "
"Set parameter encode_labels value to True"
)
transformation= transformation.then(
ClassEncodeColumnTransformation(column=self.target_column_name)
) if self.encode_labels else transformation

if isinstance(data_loader, (ConllFlairCorpusDataLoader, PickleFlairCorpusDataLoader)):
raise ValueError(
"ClassEncodeColumnTransformation transformation is unavailable for Flair DataLoaders. "
"Set parameter `encode_labels` value to True"
)
transformation = transformation.then(
ClassEncodeColumnTransformation(column=self.target_column_name)
)

if to_flair_dataset_transformation := self._get_dataset_transformation(data_loader):
transformation = transformation.then(to_flair_dataset_transformation)

if self.sample_missing_splits:
transformation = transformation.then(
Expand All @@ -114,7 +130,7 @@ def _get_transformations(
DownsampleFlairCorpusTransformation(
*self.downsample_splits,
stratify=self.downsample_splits_stratification,
seed=self.seed
seed=self.seed,
)
)

Expand All @@ -126,7 +142,7 @@ class FlairTextClassificationPreprocessingPipeline(FlairPreprocessingPipeline):
def _get_persister(self) -> FLAIR_PERSISTERS_TYPE:
return FlairPicklePersister(self.persist_path)

def _get_base_dataset_transformation(self) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE:
def _get_to_flair_dataset_transformation(self) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE:
assert isinstance(self.input_column_name, str)
return ClassificationCorpusTransformation(
input_column_name=self.input_column_name,
Expand All @@ -138,7 +154,7 @@ def _get_base_dataset_transformation(self) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE
class FlairTextPairClassificationPreprocessingPipeline(
FlairTextClassificationPreprocessingPipeline
):
def _get_base_dataset_transformation(self) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE:
def _get_to_flair_dataset_transformation(self) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE:
assert isinstance(self.input_column_name, (tuple, list))
return PairClassificationCorpusTransformation(
input_columns_names_pair=self.input_column_name,
Expand All @@ -151,7 +167,7 @@ class FlairSequenceLabelingPreprocessingPipeline(FlairPreprocessingPipeline):
def _get_persister(self) -> FLAIR_PERSISTERS_TYPE:
return FlairConllPersister(self.persist_path)

def _get_base_dataset_transformation(self) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE:
def _get_to_flair_dataset_transformation(self) -> FLAIR_DATASET_TRANSFORMATIONS_TYPE:
assert isinstance(self.input_column_name, str)
return ColumnCorpusTransformation(
input_column_name=self.input_column_name,
Expand Down
23 changes: 17 additions & 6 deletions embeddings/pipeline/flair_sequence_labeling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple

import datasets
from flair.data import Corpus
Expand All @@ -15,6 +15,7 @@
from embeddings.embedding.flair_loader import FlairWordEmbeddingLoader
from embeddings.evaluator.sequence_labeling_evaluator import SequenceLabelingEvaluator
from embeddings.model.flair_model import FlairModel
from embeddings.pipeline import FLAIR_DATASET_TRANSFORMATIONS_TYPE
from embeddings.pipeline.standard_pipeline import StandardPipeline
from embeddings.task.flair_task.sequence_labeling import SequenceLabeling
from embeddings.transformation.flair_transformation.column_corpus_transformation import (
Expand All @@ -23,7 +24,10 @@
from embeddings.transformation.flair_transformation.split_sample_corpus_transformation import (
SampleSplitsFlairCorpusTransformation,
)
from embeddings.transformation.transformation import Transformation
from embeddings.transformation.hf_transformation.class_encode_column_transformation import (
ClassEncodeColumnTransformation,
)
from embeddings.transformation.transformation import DummyTransformation
from embeddings.utils.json_dict_persister import JsonPersister


Expand All @@ -47,14 +51,21 @@ def __init__(
sample_missing_splits: Optional[Tuple[Optional[float], Optional[float]]] = None,
seed: int = 441,
load_dataset_kwargs: Optional[Dict[str, Any]] = None,
encode_classes: bool = True,
):
output_path = Path(output_path)
dataset = Dataset(dataset_name, **load_dataset_kwargs if load_dataset_kwargs else {})
data_loader = HuggingFaceDataLoader()
transformation: Union[
Transformation[datasets.DatasetDict, Corpus], Transformation[Corpus, Corpus]
]
transformation = ColumnCorpusTransformation(input_column_name, target_column_name)

transformation: FLAIR_DATASET_TRANSFORMATIONS_TYPE = DummyTransformation()
if encode_classes:
transformation = transformation.then(
ClassEncodeColumnTransformation(column=target_column_name)
)
transformation = transformation.then(
ColumnCorpusTransformation(input_column_name, target_column_name)
)

if sample_missing_splits:
transformation = transformation.then(
SampleSplitsFlairCorpusTransformation(*sample_missing_splits, seed=seed)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import datasets

from embeddings.transformation.transformation import Transformation


class ClassEncodeColumnTransformation(Transformation[datasets.DatasetDict, datasets.DatasetDict]):
def __init__(
self,
column: str,
):
self.column = column

def transform(self, data: datasets.DatasetDict) -> datasets.DatasetDict:
return data.class_encode_column(column=self.column)