-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
callback_config.py
128 lines (110 loc) · 5.05 KB
/
callback_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
from abc import ABC, abstractmethod
from typing import Union, List
from pytorch_lightning.callbacks import Callback, ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.utilities.exceptions import MisconfigurationException
class TrainerCallbackConfigMixin(ABC):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
callbacks: List[Callback]
default_root_dir: str
logger: Union[LightningLoggerBase, bool]
weights_save_path: str
ckpt_path: str
checkpoint_callback: ModelCheckpoint
progress_bar_refresh_rate: int
process_position: int
@property
@abstractmethod
def slurm_job_id(self) -> int:
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def save_checkpoint(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
def configure_checkpoint_callback(self):
"""
Weight path set in this priority:
Checkpoint_callback's path (if passed in).
User provided weights_saved_path
Otherwise use os.getcwd()
"""
ckpt_path = self.default_root_dir
if self.checkpoint_callback:
# init a default one
if self.logger is not None:
save_dir = (getattr(self.logger, 'save_dir', None) or
getattr(self.logger, '_save_dir', None) or
self.default_root_dir)
# weights_save_path overrides anything
if self.weights_save_path is not None:
save_dir = self.weights_save_path
version = self.logger.version if isinstance(
self.logger.version, str) else f'version_{self.logger.version}'
ckpt_path = os.path.join(
save_dir,
self.logger.name,
version,
"checkpoints"
)
else:
ckpt_path = os.path.join(self.default_root_dir, "checkpoints")
# when no val step is defined, use 'loss' otherwise 'val_loss'
train_step_only = not self.is_overridden('validation_step')
monitor_key = 'loss' if train_step_only else 'val_loss'
if self.checkpoint_callback is True:
os.makedirs(ckpt_path, exist_ok=True)
self.checkpoint_callback = ModelCheckpoint(
filepath=ckpt_path,
monitor=monitor_key
)
# If user specified None in filepath, override with runtime default
elif isinstance(self.checkpoint_callback, ModelCheckpoint) \
and self.checkpoint_callback.dirpath is None:
self.checkpoint_callback.dirpath = ckpt_path
self.checkpoint_callback.filename = '{epoch}'
os.makedirs(self.checkpoint_callback.dirpath, exist_ok=True)
elif self.checkpoint_callback is False:
self.checkpoint_callback = None
self.ckpt_path = ckpt_path
if self.checkpoint_callback:
# set the path for the callbacks
self.checkpoint_callback.save_function = self.save_checkpoint
# if checkpoint callback used, then override the weights path
self.weights_save_path = self.checkpoint_callback.dirpath
# if weights_save_path is still none here, set to current working dir
if self.weights_save_path is None:
self.weights_save_path = self.default_root_dir
def configure_early_stopping(self, early_stop_callback):
if early_stop_callback is True or None:
self.early_stop_callback = EarlyStopping(
monitor='val_loss',
patience=3,
strict=True,
verbose=True,
mode='min'
)
self.enable_early_stop = True
elif not early_stop_callback:
self.early_stop_callback = None
self.enable_early_stop = False
else:
self.early_stop_callback = early_stop_callback
self.enable_early_stop = True
def configure_progress_bar(self):
progress_bars = [c for c in self.callbacks if isinstance(c, ProgressBarBase)]
if len(progress_bars) > 1:
raise MisconfigurationException(
'You added multiple progress bar callbacks to the Trainer, but currently only one'
' progress bar is supported.'
)
elif len(progress_bars) == 1:
self.progress_bar_callback = progress_bars[0]
elif self.progress_bar_refresh_rate > 0:
self.progress_bar_callback = ProgressBar(
refresh_rate=self.progress_bar_refresh_rate,
process_position=self.process_position,
)
self.callbacks.append(self.progress_bar_callback)
else:
self.progress_bar_callback = None