Skip to content

Commit

Permalink
refactor: Refactor code after review
Browse files Browse the repository at this point in the history
  • Loading branch information
ktagowski committed Apr 20, 2022
1 parent 9e69742 commit d902d26
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 18 deletions.
26 changes: 10 additions & 16 deletions embeddings/pipeline/flair_hps_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC
from copy import deepcopy
from dataclasses import dataclass, field
from pathlib import Path
from tempfile import TemporaryDirectory
Expand Down Expand Up @@ -73,6 +74,15 @@ class _OptimizedFlairPipelineDefaultsBase(_HuggingFaceOptimizedPipelineDefaultsB
init=False, default_factory=TemporaryDirectory
)

@staticmethod
def _revert_default_hps_task_train_kwargs(
task_train_kwargs: Dict[str, ParameterValues]
) -> Dict[str, ParameterValues]:
out = deepcopy(task_train_kwargs)
out["param_selection_mode"] = False
out["save_final_model"] = True
return out


# Mypy currently properly don't handle dataclasses with abstract methods https://github.com/python/mypy/issues/5374
@dataclass # type: ignore
Expand Down Expand Up @@ -104,14 +114,6 @@ def _pop_sampled_parameters(
assert isinstance(load_model_kwargs, dict)
return embedding_name, document_embedding, task_train_kwargs, load_model_kwargs

@staticmethod
def _revert_default_hps_task_train_kwargs(
task_train_kwargs: Dict[str, ParameterValues]
) -> Dict[str, ParameterValues]:
task_train_kwargs["param_selection_mode"] = False
task_train_kwargs["save_final_model"] = True
return task_train_kwargs


@dataclass
class OptimizedFlairClassificationPipeline(
Expand Down Expand Up @@ -396,14 +398,6 @@ def _pop_sampled_parameters(
assert isinstance(task_model_kwargs, dict)
return embedding_name, hidden_size, task_train_kwargs, task_model_kwargs

@staticmethod
def _revert_default_hps_task_train_kwargs(
task_train_kwargs: Dict[str, ParameterValues]
) -> Dict[str, ParameterValues]:
task_train_kwargs["param_selection_mode"] = False
task_train_kwargs["save_final_model"] = True
return task_train_kwargs

def _get_metadata(self, parameters: SampledParameters) -> FlairSequenceLabelingPipelineMetadata:
(
embedding_name,
Expand Down
4 changes: 2 additions & 2 deletions embeddings/pipeline/flair_preprocessing_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _get_persister(self) -> FLAIR_PERSISTERS_TYPE:
def _get_dataset(self) -> Dataset:
return Dataset(
self.dataset_name_or_path,
**self.load_dataset_kwargs if self.load_dataset_kwargs else {}
**self.load_dataset_kwargs if self.load_dataset_kwargs else {},
)

def _get_dataloader(self, dataset: Dataset) -> FLAIR_DATALOADERS:
Expand Down Expand Up @@ -130,7 +130,7 @@ def _get_transformations(
DownsampleFlairCorpusTransformation(
*self.downsample_splits,
stratify=self.downsample_splits_stratification,
seed=self.seed
seed=self.seed,
)
)

Expand Down

0 comments on commit d902d26

Please sign in to comment.