Skip to content

Commit

Permalink
fix(qa_datamodule): Fix qa_datamodule
Browse files Browse the repository at this point in the history
  • Loading branch information
ktagowski committed Apr 29, 2023
1 parent da15444 commit d144ccf
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 38 deletions.
44 changes: 32 additions & 12 deletions embeddings/data/qa_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import pickle
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple
Expand Down Expand Up @@ -157,17 +158,34 @@ def __init__(
self.splits = ["train", "validation"]
self.processed_data_cache_path = None
if use_cache:
datasets.disable_caching()
self.processed_data_cache_path = (
QuestionAnsweringDataModule.CACHE_DEFAULT_DIR
/ f"{standardize_name(str(dataset_name_or_path))}__{standardize_name(str(tokenizer_name_or_path))}"
)
self.processed_data_cache_path.mkdir(parents=True, exist_ok=True)
_logger.warning(
f"Using datamodule caching. Cache path={self.processed_data_cache_path}"
f"Using embeddingsdatamodule caching. Cache path={self.processed_data_cache_path}"
)

self.process_data_with_cache(stage="fit")

def cache_datamodule(self, path: Path, stage: str) -> None:
self.dataset.save_to_disk(str(path))
if stage != "fit":
with open(path / "overflow_to_sample_mapping", "wb") as f:
pickle.dump(obj=self.overflow_to_sample_mapping, file=f)
with open(path / "offset_mapping", "wb") as f:
pickle.dump(obj=self.offset_mapping, file=f)

def load_cached_datamodule(self, path: Path, stage: str) -> None:
self.dataset = datasets.load_from_disk(dataset_path=str(path))
if stage != "fit":
with open(path / "overflow_to_sample_mapping", "rb") as f:
self.overflow_to_sample_mapping = pickle.load(f)
with open(path / "offset_mapping", "rb") as f:
self.offset_mapping = pickle.load(f)

def process_data_with_cache(self, stage: Optional[str] = None) -> None:
if stage is None:
return
Expand All @@ -176,17 +194,17 @@ def process_data_with_cache(self, stage: Optional[str] = None) -> None:
data_cache_path = self.processed_data_cache_path / stage
if data_cache_path.exists():
_logger.warning(f"Loading cached datamodule from path {data_cache_path}")
self.dataset = datasets.load_from_disk(dataset_path=str(data_cache_path))
self.load_cached_datamodule(data_cache_path, stage=stage)
_logger.warning("Load completed!")
else:
_logger.warning(
f"Cached datamodule not found. Processing datamodule {data_cache_path}"
)
self.process_data(stage=stage)
self.dataset.set_format(type="torch")
self.dataset.save_to_disk(str(data_cache_path))
_logger.warning(f"Saving cached datamodule at path {data_cache_path}")
self.cache_datamodule(data_cache_path, stage=stage)
else:
self.process_data(stage=stage)
self.dataset.set_format(type="torch")

def process_data(self, stage: Optional[str] = None) -> None:
assert isinstance(self.dataset_raw, datasets.DatasetDict)
Expand All @@ -213,14 +231,16 @@ def process_data(self, stage: Optional[str] = None) -> None:
batch_size=self.processing_batch_size,
remove_columns=columns,
)
if stage != "fit":
self.overflow_to_sample_mapping[split] = self.dataset[split][
"overflow_to_sample_mapping"
]
self.offset_mapping[split] = self.dataset[split]["offset_mapping"]
self.dataset[split] = self.dataset[split].remove_columns(
["offset_mapping", "overflow_to_sample_mapping"]
)

self.overflow_to_sample_mapping[split] = self.dataset[split][
"overflow_to_sample_mapping"
]
self.offset_mapping[split] = self.dataset[split]["offset_mapping"]
self.dataset[split] = self.dataset[split].remove_columns(
["offset_mapping", "overflow_to_sample_mapping"]
)
self.dataset.set_format(type="torch")

def prepare_data(self) -> None:
AutoTokenizer.from_pretrained(self.tokenizer_name_or_path)
Expand Down
70 changes: 45 additions & 25 deletions embeddings/task/lightning_task/lightning_task.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import abc
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, Generic, List, Optional, Sequence, Type, TypeVar, Union

import pytorch_lightning as pl
from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -52,6 +54,22 @@ def __init__(
self.trainer: Optional[pl.Trainer] = None
self.logging_config = logging_config
self.tokenizer: Optional[AutoTokenizer] = None
self.callbacks: List[Callback] = None

self.inference_mode = (
self.task_train_kwargs.pop("inference_mode")
if "inference_mode" in self.task_train_kwargs.keys()
else None
)
if isinstance(self.compile_model_kwargs, dict):
_logger.warning(
"PyTorch 2.0 compile mode is turned on! Pass None to compile_model_kwargs if the behavior is unintended."
)
if self.inference_mode or self.inference_mode is None:
_logger.warning(
"PyTorch 2.0 compile mode does not support inference_mode! Setting Lightning Trainer inference_mode to False!"
)
self.inference_mode = False

@property
def best_epoch(self) -> Optional[float]:
Expand Down Expand Up @@ -87,6 +105,27 @@ def _get_callbacks(self, dataset_subsets: Sequence[str]) -> List[Callback]:
callbacks.append(EarlyStopping(**self.early_stopping_kwargs))
return callbacks

def setup_trainer(
self,
run_name: str,
accelerator: Optional[Union[str, Accelerator]] = None,
devices: Optional[Union[List[int], str, int]] = None,
) -> None:
accelerator = accelerator if accelerator else self.task_train_kwargs["accelerator"]
devices = devices if devices else self.task_train_kwargs["devices"]
task_train_kwargs = {
k: v for k, v in self.task_train_kwargs.items() if k not in ("accelerator", "devices")
}
self.trainer = pl.Trainer(
default_root_dir=str(self.output_path),
callbacks=self.callbacks,
logger=self.logging_config.get_lightning_loggers(self.output_path, run_name),
inference_mode=self.inference_mode,
accelerator=accelerator,
devices=devices,
**task_train_kwargs,
)

def fit(
self,
data: LightningDataModule,
Expand All @@ -95,31 +134,8 @@ def fit(
if not self.model:
raise self.MODEL_UNDEFINED_EXCEPTION
self.tokenizer = data.tokenizer

callbacks = self._get_callbacks(dataset_subsets=list(data.load_dataset().keys()))

inference_mode = (
self.task_train_kwargs.pop("inference_mode")
if "inference_mode" in self.task_train_kwargs.keys()
else None
)
if isinstance(self.compile_model_kwargs, dict):
_logger.warning(
"PyTorch 2.0 compile mode is turned on! Pass None to compile_model_kwargs if the behavior is unintended."
)
if inference_mode or inference_mode is None:
_logger.warning(
"PyTorch 2.0 compile mode does not support inference_mode! Setting Lightning Trainer inference_mode to False!"
)
inference_mode = False

self.trainer = pl.Trainer(
default_root_dir=str(self.output_path),
callbacks=callbacks,
logger=self.logging_config.get_lightning_loggers(self.output_path, run_name),
inference_mode=inference_mode,
**self.task_train_kwargs,
)
self.callbacks = self._get_callbacks(dataset_subsets=list(data.load_dataset().keys()))
self.setup_trainer(run_name=run_name)
try:
self.trainer.fit(self.model, data)
except Exception as e:
Expand Down Expand Up @@ -200,6 +216,10 @@ def fit_predict(
self.fit(data, run_name=run_name)
dataloader = data.get_subset(subset=predict_subset)
assert isinstance(dataloader, DataLoader)
if isinstance(self.trainer.strategy, pl.strategies.ddp.DDPStrategy):
self.setup_trainer(
run_name=run_name, accelerator="gpu", devices=0 # made predict only on single gpu,
)
result = self.predict(dataloader=dataloader)
return result

Expand Down
8 changes: 7 additions & 1 deletion embeddings/task/lightning_task/question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def build_task_model(self) -> None:
def predict(self, dataloader: Any, return_names: bool = True) -> Any:
assert self.model is not None
assert self.trainer is not None
return self.trainer.predict(model=self.model, dataloaders=dataloader)
return self.trainer.predict(
model=self.model, dataloaders=dataloader, return_predictions=True
)

@staticmethod
def postprocess_outputs(
Expand Down Expand Up @@ -92,6 +94,10 @@ def fit_predict(

dataloader = data.get_subset(subset=predict_subset)
assert isinstance(dataloader, DataLoader)
if isinstance(self.trainer.strategy, pl.strategies.ddp.DDPStrategy):
self.setup_trainer(
run_name=run_name, accelerator="gpu", devices=0 # made predict only on single gpu,
)
model_outputs = self.predict(dataloader=dataloader)
result = self.postprocess_outputs(
model_outputs=model_outputs, data=data, predict_subset=predict_subset
Expand Down

0 comments on commit d144ccf

Please sign in to comment.