From 84bfa892365cd9d5938ea78494727783482dcad4 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 18 Sep 2021 13:28:42 +0200 Subject: [PATCH] Consolidate `init_seeds()` (#4849) --- utils/general.py | 8 +++++--- utils/torch_utils.py | 10 ---------- 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/utils/general.py b/utils/general.py index dc9a10fe8617..561602323ab2 100755 --- a/utils/general.py +++ b/utils/general.py @@ -29,7 +29,6 @@ from utils.downloads import gsutil_getsize from utils.metrics import box_iou, fitness -from utils.torch_utils import init_torch_seeds # Settings torch.set_printoptions(linewidth=320, precision=5, profile='long') @@ -91,10 +90,13 @@ def set_logging(rank=-1, verbose=True): def init_seeds(seed=0): - # Initialize random number generator (RNG) seeds + # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html + # cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible + import torch.backends.cudnn as cudnn random.seed(seed) np.random.seed(seed) - init_torch_seeds(seed) + torch.manual_seed(seed) + cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False) def get_latest_run(search_dir='.'): diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 04e1446bb908..352ecf572c9f 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -15,7 +15,6 @@ from pathlib import Path import torch -import torch.backends.cudnn as cudnn import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F @@ -41,15 +40,6 @@ def torch_distributed_zero_first(local_rank: int): dist.barrier(device_ids=[0]) -def init_torch_seeds(seed=0): - # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html - torch.manual_seed(seed) - if seed == 0: # slower, more reproducible - cudnn.benchmark, cudnn.deterministic = False, True - else: # faster, less reproducible - cudnn.benchmark, cudnn.deterministic = True, False - - def date_modified(path=__file__): # return human-readable file modification date, i.e. '2021-3-26' t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime)