Skip to content

Commit

Permalink
Consolidate init_seeds() (ultralytics#4849)
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher authored and CesarBazanAV committed Sep 29, 2021
1 parent f6abca6 commit 3ee27ac
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 13 deletions.
8 changes: 5 additions & 3 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

from utils.downloads import gsutil_getsize
from utils.metrics import box_iou, fitness
from utils.torch_utils import init_torch_seeds

# Settings
torch.set_printoptions(linewidth=320, precision=5, profile='long')
Expand Down Expand Up @@ -91,10 +90,13 @@ def set_logging(rank=-1, verbose=True):


def init_seeds(seed=0):
# Initialize random number generator (RNG) seeds
# Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
# cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
import torch.backends.cudnn as cudnn
random.seed(seed)
np.random.seed(seed)
init_torch_seeds(seed)
torch.manual_seed(seed)
cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)


def get_latest_run(search_dir='.'):
Expand Down
10 changes: 0 additions & 10 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from pathlib import Path

import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -41,15 +40,6 @@ def torch_distributed_zero_first(local_rank: int):
dist.barrier(device_ids=[0])


def init_torch_seeds(seed=0):
# Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
torch.manual_seed(seed)
if seed == 0: # slower, more reproducible
cudnn.benchmark, cudnn.deterministic = False, True
else: # faster, less reproducible
cudnn.benchmark, cudnn.deterministic = True, False


def date_modified(path=__file__):
# return human-readable file modification date, i.e. '2021-3-26'
t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime)
Expand Down

0 comments on commit 3ee27ac

Please sign in to comment.