Skip to content

Commit

Permalink
Simplify variables: step, epoch, max_epochs, min_epochs (#589)
Browse files Browse the repository at this point in the history
  • Loading branch information
elliotwaite authored and williamFalcon committed Dec 7, 2019
1 parent c6e0dbe commit 1051c18
Show file tree
Hide file tree
Showing 18 changed files with 90 additions and 88 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,16 @@ use something other than tensorboard).
Here are more advanced examples
```python
# train on cpu using only 10% of the data (for demo purposes)
trainer = Trainer(max_num_epochs=1, train_percent_check=0.1)
trainer = Trainer(max_epochs=1, train_percent_check=0.1)

# train on 4 gpus (lightning chooses GPUs for you)
# trainer = Trainer(max_num_epochs=1, gpus=4, distributed_backend='ddp')
# trainer = Trainer(max_epochs=1, gpus=4, distributed_backend='ddp')

# train on 4 gpus (you choose GPUs)
# trainer = Trainer(max_num_epochs=1, gpus=[0, 1, 3, 7], distributed_backend='ddp')
# trainer = Trainer(max_epochs=1, gpus=[0, 1, 3, 7], distributed_backend='ddp')

# train on 32 gpus across 4 nodes (make sure to submit appropriate SLURM job)
# trainer = Trainer(max_num_epochs=1, gpus=8, num_gpu_nodes=4, distributed_backend='ddp')
# trainer = Trainer(max_epochs=1, gpus=8, num_gpu_nodes=4, distributed_backend='ddp')

# train (1 epoch only here for demo)
trainer.fit(model)
Expand Down
2 changes: 1 addition & 1 deletion pl_examples/full_examples/imagenet/imagenet_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def main(hparams):
trainer = pl.Trainer(
default_save_path=hparams.save_path,
gpus=hparams.gpus,
max_num_epochs=hparams.epochs,
max_epochs=hparams.epochs,
distributed_backend=hparams.distributed_backend,
use_amp=hparams.use_16bit
)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,10 +694,10 @@ def configure_optimizers(self):
"""
raise NotImplementedError

def optimizer_step(self, epoch_idx, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
"""Do something instead of the standard optimizer behavior
:param int epoch_idx:
:param int epoch:
:param int batch_idx:
:param optimizer:
:param optimizer_idx:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/logging/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def log_hyperparams(self, params):
pass
@rank_zero_only
def log_metrics(self, metrics, step_idx):
def log_metrics(self, metrics, step):
# metrics is a dictionary of metric names and values
# your code to record metrics goes here
pass
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/logging/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ class LightningLoggerBase(object):
def __init__(self):
self._rank = 0

def log_metrics(self, metrics, step_idx):
def log_metrics(self, metrics, step):
"""Record metrics.
:param float metric: Dictionary with metric names as keys and measured quanties as values
:param int|None step_idx: Step number at which the metrics should be recorded
:param int|None step: Step number at which the metrics should be recorded
"""
raise NotImplementedError()

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/logging/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,13 @@ def log_hyperparams(self, params):
self.experiment.log_parameters(vars(params))

@rank_zero_only
def log_metrics(self, metrics, step_idx=None):
def log_metrics(self, metrics, step=None):
# Comet.ml expects metrics to be a dictionary of detached tensors on CPU
for key, val in metrics.items():
if is_tensor(val):
metrics[key] = val.cpu().detach()

self.experiment.log_metrics(metrics, step=step_idx)
self.experiment.log_metrics(metrics, step=step)

@rank_zero_only
def finalize(self, status):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/logging/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ def log_hyperparams(self, params):
self.experiment.log_param(self.run_id, k, v)

@rank_zero_only
def log_metrics(self, metrics, step_idx=None):
def log_metrics(self, metrics, step=None):
timestamp_ms = int(time() * 1000)
for k, v in metrics.items():
if isinstance(v, str):
logger.warning(
f"Discarding metric with string value {k}={v}"
)
continue
self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step_idx)
self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step)

def save(self):
pass
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/logging/test_tube.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ def log_hyperparams(self, params):
self.experiment.argparse(params)

@rank_zero_only
def log_metrics(self, metrics, step_idx=None):
def log_metrics(self, metrics, step=None):
# TODO: HACK figure out where this is being set to true
self.experiment.debug = self.debug
self.experiment.log(metrics, global_step=step_idx)
self.experiment.log(metrics, global_step=step)

@rank_zero_only
def save(self):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def log_metrics(self, metrics, grad_norm_dic):

# log actual metrics
if self.proc_rank == 0 and self.logger is not None:
self.logger.log_metrics(scalar_metrics, step_idx=self.global_step)
self.logger.log_metrics(scalar_metrics, step=self.global_step)
self.logger.save()

def add_tqdm_metrics(self, metrics):
Expand Down
26 changes: 13 additions & 13 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def __init__(
accumulate_grad_batches=1,
max_nb_epochs=None, # backward compatible, todo: remove in v0.8.0
min_nb_epochs=None, # backward compatible, todo: remove in v0.8.0
max_num_epochs=1000,
min_num_epochs=1,
max_epochs=1000,
min_epochs=1,
train_percent_check=1.0,
val_percent_check=1.0,
test_percent_check=1.0,
Expand Down Expand Up @@ -111,8 +111,8 @@ def __init__(
:param int check_val_every_n_epoch: check val every n train epochs
:param bool fast_dev_run: runs full iteration over everything to find bugs
:param int accumulate_grad_batches: Accumulates grads every k batches
:param int max_num_epochs:
:param int min_num_epochs:
:param int max_epochs:
:param int min_epochs:
:param int train_percent_check: How much of train set to check
:param int val_percent_check: How much of val set to check
:param int test_percent_check: How much of test set to check
Expand Down Expand Up @@ -158,17 +158,17 @@ def __init__(
self.process_position = process_position
self.weights_summary = weights_summary
if max_nb_epochs is not None: # Backward compatibility
warnings.warn("`max_nb_epochs` has renamed to `max_num_epochs` since v0.5.0"
warnings.warn("`max_nb_epochs` has renamed to `max_epochs` since v0.5.0"
" and will be removed in v0.8.0", DeprecationWarning)
if not max_num_epochs: # in case you did not set the proper value
max_num_epochs = max_nb_epochs
self.max_num_epochs = max_num_epochs
if not max_epochs: # in case you did not set the proper value
max_epochs = max_nb_epochs
self.max_epochs = max_epochs
if min_nb_epochs is not None: # Backward compatibility
warnings.warn("`min_nb_epochs` has renamed to `min_num_epochs` since v0.5.0"
warnings.warn("`min_nb_epochs` has renamed to `min_epochs` since v0.5.0"
" and will be removed in v0.8.0", DeprecationWarning)
if not min_num_epochs: # in case you did not set the proper value
min_num_epochs = min_nb_epochs
self.min_num_epochs = min_num_epochs
if not min_epochs: # in case you did not set the proper value
min_epochs = min_nb_epochs
self.min_epochs = min_epochs
if nb_sanity_val_steps is not None: # Backward compatibility
warnings.warn("`nb_sanity_val_steps` has renamed to `num_sanity_val_steps` since v0.5.0"
" and will be removed in v0.8.0", DeprecationWarning)
Expand All @@ -183,7 +183,7 @@ def __init__(
self.fast_dev_run = fast_dev_run
if self.fast_dev_run:
self.num_sanity_val_steps = 1
self.max_num_epochs = 1
self.max_epochs = 1
m = '''
Running in fast_dev_run mode: will run a full train,
val loop using a single batch
Expand Down
18 changes: 9 additions & 9 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
.. code-block:: python
# DEFAULT
trainer = Trainer(min_num_epochs=1, max_num_epochs=1000)
trainer = Trainer(min_epochs=1, max_epochs=1000)
Early stopping
--------------
Expand Down Expand Up @@ -259,17 +259,17 @@ def process_output(self, output, train):

def train(self):
# run all epochs
for epoch_idx in range(self.current_epoch, self.max_num_epochs):
for epoch in range(self.current_epoch, self.max_epochs):
# set seed for distributed sampler (enables shuffling for each epoch)
if self.use_ddp and hasattr(self.get_train_dataloader().sampler, 'set_epoch'):
self.get_train_dataloader().sampler.set_epoch(epoch_idx)
self.get_train_dataloader().sampler.set_epoch(epoch)

# get model
model = self.get_model()

# update training progress in trainer and model
model.current_epoch = epoch_idx
self.current_epoch = epoch_idx
model.current_epoch = epoch
self.current_epoch = epoch

# val can be checked multiple times in epoch
is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
Expand All @@ -294,11 +294,11 @@ def train(self):
# .reset() doesn't work on disabled progress bar so we should check
if not self.main_progress_bar.disable:
self.main_progress_bar.reset(num_iterations)
desc = f'Epoch {epoch_idx + 1}' if not self.is_iterable_train_dataloader else ''
desc = f'Epoch {epoch + 1}' if not self.is_iterable_train_dataloader else ''
self.main_progress_bar.set_description(desc)

# changing gradient according accumulation_scheduler
self.accumulation_scheduler.on_epoch_begin(epoch_idx, self)
self.accumulation_scheduler.on_epoch_begin(epoch, self)

# -----------------
# RUN TNG EPOCH
Expand All @@ -319,9 +319,9 @@ def train(self):
self.reduce_lr_on_plateau_scheduler.step(val_loss, epoch=self.current_epoch)

# early stopping
met_min_epochs = epoch_idx > self.min_num_epochs
met_min_epochs = epoch > self.min_epochs
if self.enable_early_stop and (met_min_epochs or self.fast_dev_run):
should_stop = self.early_stop_callback.on_epoch_end(epoch=epoch_idx,
should_stop = self.early_stop_callback.on_epoch_end(epoch=epoch,
logs=self.callback_metrics)
# stop training
stop = should_stop and met_min_epochs
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/utilities/arg_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ def add_default_args(parser, root_dir, rand_seed=None, possible_model_names=None
parser.opt_list('--accumulate_grad_batches', default=1, type=int, tunable=False,
help='accumulates gradients k times before applying update.'
' Simulates huge batch size')
parser.add_argument('--max_num_epochs', default=200, type=int, help='cap epochs')
parser.add_argument('--min_num_epochs', default=2, type=int, help='min epochs')
parser.add_argument('--max_epochs', default=200, type=int,
help='maximum number of epochs')
parser.add_argument('--min_epochs', default=2, type=int,
help='minimum number of epochs')
parser.add_argument('--train_percent_check', default=1.0, type=float,
help='how much of training set to check')
parser.add_argument('--val_percent_check', default=1.0, type=float,
Expand Down
12 changes: 6 additions & 6 deletions tests/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_amp_single_gpu(tmpdir):
trainer_options = dict(
default_save_path=tmpdir,
show_progress_bar=True,
max_num_epochs=1,
max_epochs=1,
gpus=1,
distributed_backend='ddp',
use_amp=True
Expand All @@ -45,7 +45,7 @@ def test_no_amp_single_gpu(tmpdir):
trainer_options = dict(
default_save_path=tmpdir,
show_progress_bar=True,
max_num_epochs=1,
max_epochs=1,
gpus=1,
distributed_backend='dp',
use_amp=True
Expand All @@ -69,7 +69,7 @@ def test_amp_gpu_ddp(tmpdir):
trainer_options = dict(
default_save_path=tmpdir,
show_progress_bar=True,
max_num_epochs=1,
max_epochs=1,
gpus=2,
distributed_backend='ddp',
use_amp=True
Expand All @@ -94,7 +94,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir):

trainer_options = dict(
show_progress_bar=True,
max_num_epochs=1,
max_epochs=1,
gpus=[0],
distributed_backend='ddp',
use_amp=True
Expand Down Expand Up @@ -153,7 +153,7 @@ def test_cpu_model_with_amp(tmpdir):
default_save_path=tmpdir,
show_progress_bar=False,
logger=tutils.get_test_tube_logger(tmpdir),
max_num_epochs=1,
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.4,
use_amp=True
Expand All @@ -175,7 +175,7 @@ def test_amp_gpu_dp(tmpdir):
model, hparams = tutils.get_model()
trainer_options = dict(
default_save_path=tmpdir,
max_num_epochs=1,
max_epochs=1,
gpus='0, 1', # test init with gpu string
distributed_backend='dp',
use_amp=True
Expand Down
18 changes: 9 additions & 9 deletions tests/test_cpu_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_lbfgs_cpu_model(tmpdir):

trainer_options = dict(
default_save_path=tmpdir,
max_num_epochs=1,
max_epochs=1,
print_nan_grads=True,
show_progress_bar=False,
weights_summary='top',
Expand All @@ -64,7 +64,7 @@ def test_default_logger_callbacks_cpu_model(tmpdir):

trainer_options = dict(
default_save_path=tmpdir,
max_num_epochs=1,
max_epochs=1,
gradient_clip_val=1.0,
overfit_pct=0.20,
print_nan_grads=True,
Expand Down Expand Up @@ -97,7 +97,7 @@ def test_running_test_after_fitting(tmpdir):
trainer_options = dict(
default_save_path=tmpdir,
show_progress_bar=False,
max_num_epochs=1,
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
test_percent_check=0.2,
Expand Down Expand Up @@ -135,7 +135,7 @@ class CurrentTestModel(LightningTestMixin, LightningTestModelBase):

trainer_options = dict(
show_progress_bar=False,
max_num_epochs=1,
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
test_percent_check=0.2,
Expand Down Expand Up @@ -209,7 +209,7 @@ def test_simple_cpu(tmpdir):
# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
max_num_epochs=1,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.1,
)
Expand All @@ -230,7 +230,7 @@ def test_cpu_model(tmpdir):
default_save_path=tmpdir,
show_progress_bar=False,
logger=tutils.get_test_tube_logger(tmpdir),
max_num_epochs=1,
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.4
)
Expand All @@ -253,7 +253,7 @@ def test_all_features_cpu_model(tmpdir):
show_progress_bar=False,
logger=tutils.get_test_tube_logger(tmpdir),
accumulate_grad_batches=2,
max_num_epochs=1,
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.4
)
Expand Down Expand Up @@ -314,7 +314,7 @@ def train_dataloader(self):

trainer_options = dict(
default_save_path=tmpdir,
max_num_epochs=1,
max_epochs=1,
truncated_bptt_steps=truncated_bptt_steps,
val_percent_check=0,
weights_summary=None,
Expand Down Expand Up @@ -348,7 +348,7 @@ def test_single_gpu_model(tmpdir):
trainer_options = dict(
default_save_path=tmpdir,
show_progress_bar=False,
max_num_epochs=1,
max_epochs=1,
train_percent_check=0.1,
val_percent_check=0.1,
gpus=1
Expand Down
Loading

0 comments on commit 1051c18

Please sign in to comment.