Skip to content

Commit

Permalink
[Docs] update docs for resume_from_checkpoint (#5164)
Browse files Browse the repository at this point in the history
* update docs and add pathlib support

* fix

(cherry picked from commit dd442b6)
  • Loading branch information
rohitgr7 authored and Borda committed Jan 6, 2021
1 parent 2d70a80 commit cc607d5
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
3 changes: 2 additions & 1 deletion docs/source/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1328,7 +1328,8 @@ resume_from_checkpoint
|
To resume training from a specific checkpoint pass in the path here.
To resume training from a specific checkpoint pass in the path here. If resuming from a mid-epoch
checkpoint, training will start from the beginning of the next epoch.
.. testcode::
Expand Down
11 changes: 6 additions & 5 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
import re
from pathlib import Path
from typing import Union, Optional
from typing import Optional, Union

import torch

Expand All @@ -24,8 +24,8 @@
from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType, _OMEGACONF_AVAILABLE, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS

if _APEX_AVAILABLE:
from apex import amp
Expand Down Expand Up @@ -156,9 +156,10 @@ def restore_training_state(self, checkpoint):
expected_steps = self.trainer.num_training_batches / n_accum
if self.trainer.num_training_batches != 0 and self.trainer.global_step % expected_steps > 1:
rank_zero_warn(
"You're resuming from a checkpoint that ended mid-epoch. "
"This can cause unreliable results if further training is done, "
"consider using an end of epoch checkpoint. "
"You're resuming from a checkpoint that ended mid-epoch."
" Training will start from the beginning of the next epoch."
" This can cause unreliable results if further training is done,"
" consider using an end of epoch checkpoint."
)

# restore the optimizers
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import os
import warnings
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Union

import torch
Expand Down Expand Up @@ -117,7 +118,7 @@ def __init__(
weights_save_path: Optional[str] = None,
num_sanity_val_steps: int = 2,
truncated_bptt_steps: Optional[int] = None,
resume_from_checkpoint: Optional[str] = None,
resume_from_checkpoint: Optional[Union[Path, str]] = None,
profiler: Optional[Union[BaseProfiler, bool, str]] = None,
benchmark: bool = False,
deterministic: bool = False,
Expand Down Expand Up @@ -252,7 +253,8 @@ def __init__(
you can set ``replace_sampler_ddp=False`` and add your own distributed sampler.
resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here.
This can be a URL.
This can be a URL. If resuming from mid-epoch checkpoint, training will start from
the beginning of the next epoch.
sync_batchnorm: Synchronize batch norm layers between process groups/whole world.
Expand Down

0 comments on commit cc607d5

Please sign in to comment.