Skip to content

Commit

Permalink
add profiler
Browse files Browse the repository at this point in the history
  • Loading branch information
djaniak committed Nov 30, 2023
1 parent 4d0f5c5 commit 9da5d26
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions embeddings/task/lightning_task/lightning_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from pytorch_lightning.callbacks import BasePredictionWriter, Callback, ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.profilers import AdvancedProfiler, PyTorchProfiler
from torch.utils.data import DataLoader
from transformers import AutoModel, AutoTokenizer

Expand Down Expand Up @@ -117,11 +118,20 @@ def fit(
"PyTorch 2.0 compile mode does not support inference_mode! Setting Lightning Trainer inference_mode to False!"
)
inference_mode = False
profiler_kwarg = self.task_train_kwargs.pop("profiler")
if profiler_kwarg == "pytorch":
profiler_dirpath = self.output_path / "profiler_logs"
profiler_dirpath.mkdir(exist_ok=True, parents=False)
profiler = PyTorchProfiler(dirpath=profiler_dirpath, filename="perf_logs")
else:
profiler = None
# profiler = AdvancedProfiler(dirpath=str(self.output_path), filename="perf_logs")
self.trainer = pl.Trainer(
default_root_dir=str(self.output_path),
callbacks=callbacks,
logger=self.logging_config.get_lightning_loggers(run_name),
inference_mode=inference_mode,
profiler=profiler,
**self.task_train_kwargs,
)
try:
Expand Down

0 comments on commit 9da5d26

Please sign in to comment.