diff --git a/utils/general.py b/utils/general.py index dc9a10fe8617..561602323ab2 100755 --- a/utils/general.py +++ b/utils/general.py @@ -29,7 +29,6 @@ from utils.downloads import gsutil_getsize from utils.metrics import box_iou, fitness -from utils.torch_utils import init_torch_seeds # Settings torch.set_printoptions(linewidth=320, precision=5, profile='long') @@ -91,10 +90,13 @@ def set_logging(rank=-1, verbose=True): def init_seeds(seed=0): - # Initialize random number generator (RNG) seeds + # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html + # cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible + import torch.backends.cudnn as cudnn random.seed(seed) np.random.seed(seed) - init_torch_seeds(seed) + torch.manual_seed(seed) + cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False) def get_latest_run(search_dir='.'): diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 04e1446bb908..352ecf572c9f 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -15,7 +15,6 @@ from pathlib import Path import torch -import torch.backends.cudnn as cudnn import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F @@ -41,15 +40,6 @@ def torch_distributed_zero_first(local_rank: int): dist.barrier(device_ids=[0]) -def init_torch_seeds(seed=0): - # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html - torch.manual_seed(seed) - if seed == 0: # slower, more reproducible - cudnn.benchmark, cudnn.deterministic = False, True - else: # faster, less reproducible - cudnn.benchmark, cudnn.deterministic = True, False - - def date_modified(path=__file__): # return human-readable file modification date, i.e. '2021-3-26' t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime)