Skip to content

Commit

Permalink
fix(flair_pipelines): Fix flair pipelines
Browse files Browse the repository at this point in the history
  • Loading branch information
ktagowski committed Apr 19, 2022
1 parent 362feb0 commit b182ae7
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 13 deletions.
20 changes: 16 additions & 4 deletions embeddings/pipeline/flair_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from embeddings.transformation.flair_transformation.split_sample_corpus_transformation import (
SampleSplitsFlairCorpusTransformation,
)
from embeddings.transformation.transformation import Transformation
from embeddings.transformation.hf_transformation.class_encode_column_transformation import (
ClassEncodeColumnTransformation,
)
from embeddings.transformation.transformation import DummyTransformation, Transformation
from embeddings.utils.json_dict_persister import JsonPersister


Expand All @@ -45,14 +48,23 @@ def __init__(
sample_missing_splits: Optional[Tuple[Optional[float], Optional[float]]] = None,
seed: int = 441,
load_dataset_kwargs: Optional[Dict[str, Any]] = None,
encode_classes: bool = False,
):
output_path = Path(output_path)
dataset = Dataset(dataset_name, **load_dataset_kwargs if load_dataset_kwargs else {})
data_loader = HuggingFaceDataLoader()
transformation: Union[
Transformation[datasets.DatasetDict, Corpus], Transformation[Corpus, Corpus]
]
transformation = ClassificationCorpusTransformation(input_column_name, target_column_name)
Transformation[datasets.DatasetDict, datasets.DatasetDict],
Transformation[datasets.DatasetDict, Corpus],
Transformation[Corpus, Corpus],
] = DummyTransformation()
if encode_classes:
transformation = transformation.then(
ClassEncodeColumnTransformation(column=target_column_name)
)
transformation = transformation.then(
ClassificationCorpusTransformation(input_column_name, target_column_name)
)
if sample_missing_splits:
transformation = transformation.then(
SampleSplitsFlairCorpusTransformation(*sample_missing_splits, seed=seed)
Expand Down
20 changes: 15 additions & 5 deletions embeddings/pipeline/flair_pair_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from embeddings.transformation.flair_transformation.split_sample_corpus_transformation import (
SampleSplitsFlairCorpusTransformation,
)
from embeddings.transformation.transformation import Transformation
from embeddings.transformation.hf_transformation.class_encode_column_transformation import (
ClassEncodeColumnTransformation,
)
from embeddings.transformation.transformation import DummyTransformation, Transformation
from embeddings.utils.json_dict_persister import JsonPersister


Expand All @@ -45,15 +48,22 @@ def __init__(
sample_missing_splits: Optional[Tuple[Optional[float], Optional[float]]] = None,
seed: int = 441,
load_dataset_kwargs: Optional[Dict[str, Any]] = None,
encode_classes: bool = False,
):
output_path = Path(output_path)
dataset = Dataset(dataset_name, **load_dataset_kwargs if load_dataset_kwargs else {})
data_loader = HuggingFaceDataLoader()
transformation: Union[
Transformation[datasets.DatasetDict, Corpus], Transformation[Corpus, Corpus]
]
transformation = PairClassificationCorpusTransformation(
input_columns_names_pair, target_column_name
Transformation[datasets.DatasetDict, datasets.DatasetDict],
Transformation[datasets.DatasetDict, Corpus],
Transformation[Corpus, Corpus],
] = DummyTransformation()
if encode_classes:
transformation = transformation.then(
ClassEncodeColumnTransformation(column=target_column_name)
)
transformation = transformation.then(
PairClassificationCorpusTransformation(input_columns_names_pair, target_column_name)
)
if sample_missing_splits:
transformation = transformation.then(
Expand Down
22 changes: 18 additions & 4 deletions embeddings/pipeline/flair_sequence_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from embeddings.transformation.flair_transformation.split_sample_corpus_transformation import (
SampleSplitsFlairCorpusTransformation,
)
from embeddings.transformation.transformation import Transformation
from embeddings.transformation.hf_transformation.class_encode_column_transformation import (
ClassEncodeColumnTransformation,
)
from embeddings.transformation.transformation import DummyTransformation, Transformation
from embeddings.utils.json_dict_persister import JsonPersister


Expand All @@ -47,14 +50,25 @@ def __init__(
sample_missing_splits: Optional[Tuple[Optional[float], Optional[float]]] = None,
seed: int = 441,
load_dataset_kwargs: Optional[Dict[str, Any]] = None,
encode_classes: bool = True,
):
output_path = Path(output_path)
dataset = Dataset(dataset_name, **load_dataset_kwargs if load_dataset_kwargs else {})
data_loader = HuggingFaceDataLoader()

transformation: Union[
Transformation[datasets.DatasetDict, Corpus], Transformation[Corpus, Corpus]
]
transformation = ColumnCorpusTransformation(input_column_name, target_column_name)
Transformation[datasets.DatasetDict, datasets.DatasetDict],
Transformation[datasets.DatasetDict, Corpus],
Transformation[Corpus, Corpus],
] = DummyTransformation()
if encode_classes:
transformation = transformation.then(
ClassEncodeColumnTransformation(column=target_column_name)
)
transformation = transformation.then(
ColumnCorpusTransformation(input_column_name, target_column_name)
)

if sample_missing_splits:
transformation = transformation.then(
SampleSplitsFlairCorpusTransformation(*sample_missing_splits, seed=seed)
Expand Down

0 comments on commit b182ae7

Please sign in to comment.