Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Multi-GPU training #306

Draft
wants to merge 28 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
018649d
feat: add/test gathering outputs in lightning predict step
djaniak Nov 27, 2023
0e54b1c
feat: restrict lightning module to only use 1 GPU for predict
djaniak Nov 27, 2023
131b64b
feat: restrict lightning module to only use 1 GPU for predict
djaniak Nov 27, 2023
3b6a5df
feat: restrict lightning module to only use 1 GPU for predict
djaniak Nov 27, 2023
91d2163
fix: restrict lightning module to only use 1 GPU for predict
djaniak Nov 27, 2023
b7b3ed3
fix: restrict lightning module to only use 1 GPU for predict v2
djaniak Nov 27, 2023
7d9cf3f
fix: restrict lightning module to only use 1 GPU for predict v3
djaniak Nov 27, 2023
e7cba38
feat/fix: use custom writer to save distributed predictions
djaniak Nov 27, 2023
90d3128
fix: custom writer
djaniak Nov 27, 2023
d388849
feat: finish custom writer
djaniak Nov 27, 2023
439553b
finish gathering preds in multi-gpu setup
djaniak Nov 30, 2023
6ace8bb
try fix wandb lightning loggers
djaniak Nov 30, 2023
1471dc8
fix reading and sorting preds
djaniak Nov 30, 2023
041d048
fix poetry
djaniak Nov 30, 2023
543306f
try fix preds gpu
djaniak Nov 30, 2023
60273ca
try fix preds gpu
djaniak Nov 30, 2023
8432792
try fix preds gpu
djaniak Nov 30, 2023
77312d1
try fix preds gpu
djaniak Nov 30, 2023
b1a28f9
try fix preds gpu
djaniak Nov 30, 2023
91d76c3
try fix preds gpu
djaniak Nov 30, 2023
bb4895a
try fix preds gpu
djaniak Nov 30, 2023
dcff7fa
try fix preds gpu
djaniak Nov 30, 2023
e6dd22f
try fix preds gpu
djaniak Nov 30, 2023
73dbbff
try fix preds gpu
djaniak Nov 30, 2023
042795f
try fix preds gpu
djaniak Nov 30, 2023
4d0f5c5
try fix preds gpu
djaniak Nov 30, 2023
9da5d26
add profiler
djaniak Nov 30, 2023
90b406e
fix single gpu training
djaniak Nov 30, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 60 additions & 13 deletions embeddings/model/lightning_module/lightning_module.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import abc
import inspect
import os
import pickle
from inspect import signature
from typing import Any, Dict, Generic, List, Literal, Optional, Tuple, TypeVar

Expand All @@ -17,6 +19,7 @@

from embeddings.data.datamodule import HuggingFaceDataset
from embeddings.utils.loggers import get_logger
from embeddings.utils.utils import flatten

Model = TypeVar("Model")

Expand Down Expand Up @@ -66,31 +69,75 @@ def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
pass

def predict_step(self, *args: Any, **kwargs: Any) -> Optional[Tuple[STEP_OUTPUT, STEP_OUTPUT]]:
def predict_step(
self, *args: Any, **kwargs: Any
) -> Optional[Tuple[STEP_OUTPUT, STEP_OUTPUT, STEP_OUTPUT]]:
batch, batch_idx = args
loss, logits, preds = self.shared_step(**batch)
return logits, preds
labels = batch.get("labels", None)
return logits, preds, labels

def predict(
self, dataloader: DataLoader[HuggingFaceDataset]
self, dataloader: DataLoader[HuggingFaceDataset], predpath: str
) -> Dict[str, nptyping.NDArray[Any]]:
predict_output = self._predict_with_trainer(dataloader)
assert predict_output
logits, predictions = zip(*predict_output)
probabilities = softmax(torch.cat(logits), dim=1).numpy()
predictions = torch.cat(predictions).numpy()
ground_truth = torch.cat([x["labels"] for x in dataloader]).numpy()
result = {"y_pred": predictions, "y_true": ground_truth, "y_probabilities": probabilities}
assert self.trainer is not None
if self.trainer.num_devices <= 1:
return_predictions = True
else:
return_predictions = False

predictions = self._predict_with_trainer(dataloader, return_predictions=return_predictions)

if return_predictions:
assert predictions is not None
logits, preds, labels = zip(*predictions)
probabilities = softmax(torch.cat(logits), dim=1)
preds = torch.cat(preds)
labels = torch.cat(labels)
# labels = torch.cat([x["labels"] for x in dataloader])
else:
files = sorted(os.listdir(predpath))
all_preds = []
all_logits = []
all_labels = []
# all_batch_indices = []
for file in files:
if "predictions" in file:
with open(os.path.join(predpath, file), "rb") as f:
predictions = pickle.load(f)
logits, preds, labels = zip(*predictions)
all_logits.append(torch.cat(logits))
all_preds.append(torch.cat(preds))
all_labels.append(torch.cat(labels))
# elif "batch_indices" in file:
# with open(os.path.join(predpath, file), "rb") as f:
# batch_indices = pickle.load(f)
# all_batch_indices.append(list(flatten(batch_indices)))
# all_batch_indices = torch.Tensor([y for x in all_batch_indices for y in x]).long()
probabilities = softmax(torch.cat(all_logits), dim=1)
preds = torch.cat(all_preds)
labels = torch.cat(all_labels)

result = {
"y_pred": preds.numpy(),
"y_true": labels.numpy(),
"y_probabilities": probabilities.numpy(),
}

assert all(isinstance(x, np.ndarray) for x in result.values())
return result

def _predict_with_trainer(
self, dataloader: DataLoader[HuggingFaceDataset]
self, dataloader: DataLoader[HuggingFaceDataset], return_predictions: bool
) -> Optional[_PREDICT_OUTPUT]:
assert self.trainer is not None

try:
return self.trainer.predict(
model=self, dataloaders=dataloader, return_predictions=True, ckpt_path="last"
model=self,
dataloaders=dataloader,
return_predictions=return_predictions,
ckpt_path="last",
)
except MisconfigurationException: # model loaded but not fitted
_logger.warning(
Expand All @@ -99,7 +146,7 @@ def _predict_with_trainer(
return self.trainer.predict(
model=self,
dataloaders=dataloader,
return_predictions=True,
return_predictions=return_predictions,
)

def on_train_epoch_end(self) -> None:
Expand Down
21 changes: 15 additions & 6 deletions embeddings/pipeline/lightning_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from embeddings.evaluator.evaluator import Evaluator
from embeddings.model.model import Model
from embeddings.pipeline.pipeline import Pipeline
from embeddings.utils.loggers import LightningLoggingConfig, WandbWrapper
from embeddings.utils.loggers import LightningLoggingConfig, LightningWandbWrapper
from embeddings.utils.utils import get_installed_packages, standardize_name

EvaluationResult = TypeVar("EvaluationResult")
Expand Down Expand Up @@ -46,25 +46,34 @@ def __init__(
self.pipeline_kwargs = pipeline_kwargs
self.pipeline_kwargs.pop("self")
self.pipeline_kwargs.pop("pipeline_kwargs")
self.result: Optional[EvaluationResult] = None

def run(self, run_name: Optional[str] = None) -> EvaluationResult:
if run_name:
run_name = standardize_name(run_name)
self._save_artifacts()
model_result = self.model.execute(data=self.datamodule, run_name=run_name)
result = self.evaluator.evaluate(model_result)
self.result = self.evaluator.evaluate(model_result)
self._save_metrics()
self._finish_logging()
return result
return self.result

def _save_artifacts(self) -> None:
srsly.write_json(self.output_path / "packages.json", get_installed_packages())
with open(self.output_path / "pipeline_config.yaml", "w") as f:
yaml.dump(self.pipeline_kwargs, stream=f)

def _save_metrics(self) -> None:
metrics = getattr(self.result, "metrics")
with open(self.output_path / "metrics.yaml", "w") as f:
yaml.dump(metrics, stream=f)

def _finish_logging(self) -> None:
if self.logging_config.use_wandb():
logger = WandbWrapper()
logger.log_output(
wrapper = LightningWandbWrapper(self.logging_config)
wrapper.log_output(
self.output_path, ignore={"wandb", "csv", "tensorboard", "checkpoints"}
)
logger.finish_logging()
metrics = getattr(self.result, "metrics")
wrapper.log_metrics(metrics)
wrapper.finish_logging()
26 changes: 20 additions & 6 deletions embeddings/task/lightning_task/lightning_task.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import abc
import os
from pathlib import Path
from typing import Any, Dict, Generic, List, Optional, Sequence, Type, TypeVar, Union

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
import torch
from pytorch_lightning.callbacks import BasePredictionWriter, Callback, ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.profilers import AdvancedProfiler, PyTorchProfiler
from torch.utils.data import DataLoader
from transformers import AutoModel, AutoTokenizer

Expand All @@ -18,6 +21,7 @@
from embeddings.task.lightning_task.hf_task import HuggingFaceTaskName
from embeddings.task.task import Output, Task
from embeddings.utils.lightning_callbacks.best_epoch_callback import BestEpochCallback
from embeddings.utils.lightning_callbacks.custom_prediction_writer import CustomPredictionWriter
from embeddings.utils.loggers import LightningLoggingConfig, get_logger
from embeddings.utils.torch_utils import cleanup_torch_model_artifacts

Expand Down Expand Up @@ -76,10 +80,12 @@ def best_validation_score(self) -> Optional[float]:
return None

def _get_callbacks(self, dataset_subsets: Sequence[str]) -> List[Callback]:
self.predpath = self.output_path.joinpath("predictions")
self.predpath.mkdir(parents=False, exist_ok=True)
dirpath = self.output_path.joinpath("checkpoints")
callbacks: List[Callback] = [
ModelCheckpoint(
dirpath=self.output_path.joinpath("checkpoints"), **self.model_checkpoint_kwargs
)
ModelCheckpoint(dirpath=dirpath, **self.model_checkpoint_kwargs),
CustomPredictionWriter(output_dir=str(self.predpath), write_interval="epoch"),
]
if "validation" in dataset_subsets:
callbacks.append(BestEpochCallback())
Expand Down Expand Up @@ -112,12 +118,20 @@ def fit(
"PyTorch 2.0 compile mode does not support inference_mode! Setting Lightning Trainer inference_mode to False!"
)
inference_mode = False

profiler_kwarg = self.task_train_kwargs.pop("profiler")
if profiler_kwarg == "pytorch":
profiler_dirpath = self.output_path / "profiler_logs"
profiler_dirpath.mkdir(exist_ok=True, parents=False)
profiler = PyTorchProfiler(dirpath=profiler_dirpath, filename="perf_logs")
else:
profiler = None
# profiler = AdvancedProfiler(dirpath=str(self.output_path), filename="perf_logs")
self.trainer = pl.Trainer(
default_root_dir=str(self.output_path),
callbacks=callbacks,
logger=self.logging_config.get_lightning_loggers(self.output_path, run_name),
logger=self.logging_config.get_lightning_loggers(run_name),
inference_mode=inference_mode,
profiler=profiler,
**self.task_train_kwargs,
)
try:
Expand Down
2 changes: 1 addition & 1 deletion embeddings/task/lightning_task/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def build_task_model(self) -> None:

def predict(self, dataloader: DataLoader[Any], return_names: bool = True) -> Predictions:
assert self.model is not None
results = self.model.predict(dataloader=dataloader)
results = self.model.predict(dataloader=dataloader, predpath=str(self.predpath))
results["names"] = np.array(self.model.target_names)
return Predictions(**results)

Expand Down
23 changes: 23 additions & 0 deletions embeddings/utils/lightning_callbacks/custom_prediction_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import os
import pickle

from pytorch_lightning.callbacks import BasePredictionWriter


class CustomPredictionWriter(BasePredictionWriter):
def __init__(self, output_dir, write_interval):
super().__init__(write_interval)
self.output_dir = output_dir

def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
# this will create N (num processes) files in `output_dir` each containing
# the predictions of its respective rank
predpath = os.path.join(self.output_dir, f"predictions_{trainer.global_rank}.pkl")
with open(predpath, "wb") as f:
pickle.dump(predictions, f)

# optionally, you can also save `batch_indices` to get the information about the data index
# from your prediction data
idxpath = os.path.join(self.output_dir, f"batch_indices_{trainer.global_rank}.pkl")
with open(idxpath, "wb") as f:
pickle.dump(batch_indices, f)
76 changes: 51 additions & 25 deletions embeddings/utils/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import wandb
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.loggers.wandb import WandbLogger
from typing_extensions import Literal

from embeddings.data.io import T_path
Expand All @@ -30,10 +31,12 @@ def get_logger(name: str, log_level: Union[str, int] = DEFAULT_LOG_LEVEL) -> log

@dataclass
class LightningLoggingConfig:
output_path: Union[Path, str] = "."
loggers_names: List[Literal["wandb", "csv", "tensorboard"]] = field(default_factory=list)
tracking_project_name: Optional[str] = None
wandb_entity: Optional[str] = None
wandb_logger_kwargs: Dict[str, Any] = field(default_factory=dict)
loggers: Optional[Dict[str, pl_loggers.Logger]] = field(init=False, default=None)

def __post_init__(self) -> None:
if "wandb" not in self.loggers_names and (
Expand Down Expand Up @@ -80,48 +83,41 @@ def use_tensorboard(self) -> bool:

def get_lightning_loggers(
self,
output_path: T_path,
run_name: Optional[str] = None,
) -> List[pl_loggers.Logger]:
"""Based on configuration, provides pytorch-lightning loggers' callbacks."""
output_path = Path(output_path)
loggers: List[pl_loggers.Logger] = []
if not self.loggers:
self.output_path = Path(self.output_path)
self.loggers = {}

if self.use_tensorboard():
loggers.append(
pl_loggers.TensorBoardLogger(
if self.use_tensorboard():
self.loggers["tensorboard"] = pl_loggers.TensorBoardLogger(
name=run_name,
save_dir=str(output_path.joinpath("tensorboard")),
save_dir=str(self.output_path / "tensorboard"),
)
)

if self.use_wandb():
if not self.tracking_project_name:
raise ValueError(
"Tracking project name is not passed. Pass tracking_project_name argument!"
)
save_dir = output_path.joinpath("wandb")
save_dir.mkdir(exist_ok=True)
loggers.append(
pl_loggers.wandb.WandbLogger(
if self.use_wandb():
if not self.tracking_project_name:
raise ValueError(
"Tracking project name is not passed. Pass tracking_project_name argument!"
)
save_dir = self.output_path / "wandb"
save_dir.mkdir(exist_ok=True, parents=True)
self.loggers["wandb"] = pl_loggers.wandb.WandbLogger(
name=run_name,
save_dir=str(save_dir),
project=self.tracking_project_name,
entity=self.wandb_entity,
reinit=True,
**self.wandb_logger_kwargs
)
)

if self.use_csv():
loggers.append(
pl_loggers.CSVLogger(
if self.use_csv():
self.loggers["csv"] = pl_loggers.CSVLogger(
name=run_name if run_name else "",
save_dir=str(output_path.joinpath("csv")),
save_dir=self.output_path / "csv",
)
)

return loggers
return list(self.loggers.values())


class ExperimentLogger(abc.ABC):
Expand Down Expand Up @@ -170,3 +166,33 @@ def log_artifact(self, paths: Iterable[T_path], artifact_name: str, artifact_typ
for path in paths:
artifact.add_file(path)
wandb.log_artifact(artifact)


class LightningWandbWrapper:
def __init__(self, logging_config: LightningLoggingConfig) -> None:
assert logging_config.use_wandb()
assert isinstance(logging_config.loggers, dict)
assert "wandb" in logging_config.loggers
assert isinstance(logging_config.loggers["wandb"], WandbLogger)
self.wandb_logger: WandbLogger = logging_config.loggers["wandb"]

def log_output(
self,
output_path: T_path,
ignore: Optional[Iterable[str]] = None,
) -> None:
for entry in os.scandir(output_path):
if not ignore or entry.name not in ignore:
self.wandb_logger.experiment.save(entry.path, output_path)

def log_metrics(self, metrics: Dict[str, Any]) -> None:
self.wandb_logger.log_metrics(metrics)

def finish_logging(self) -> None:
self.wandb_logger.experiment.finish()

def log_artifact(self, paths: Iterable[T_path], artifact_name: str, artifact_type: str) -> None:
artifact = wandb.Artifact(name=artifact_name, type=artifact_type)
for path in paths:
artifact.add_file(path)
self.wandb_logger.experiment.log_artifact(artifact)
9 changes: 9 additions & 0 deletions embeddings/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os.path
import pprint
import zipfile
from collections.abc import Iterable
from datetime import datetime
from pathlib import Path
from tempfile import NamedTemporaryFile
Expand Down Expand Up @@ -152,3 +153,11 @@ def compress_and_remove(filepath: T_path) -> None:
) as arc:
arc.write(filepath, arcname=filepath.name)
filepath.unlink()


def flatten(xs: Iterable[Any]):
for x in xs:
if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
yield from flatten(x)
else:
yield x
Loading
Loading