Skip to content

Commit

Permalink
generator seed fix for DDP mAP drop (#9545)
Browse files Browse the repository at this point in the history
* Try to fix DDP mAP drop by setting generator's seed to RANK

* Fix default activation bug

* Update dataloaders.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update dataloaders.py

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed Sep 24, 2022
1 parent c8e5230 commit f11a8a6
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 8 deletions.
4 changes: 2 additions & 2 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ def autopad(k, p=None, d=1): # kernel, padding, dilation

class Conv(nn.Module):
# Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)
act = nn.SiLU() # default activation
default_act = nn.SiLU() # default activation

def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
super().__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
self.bn = nn.BatchNorm2d(c2)
self.act = self.act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()

def forward(self, x):
return self.act(self.bn(self.conv(x)))
Expand Down
2 changes: 1 addition & 1 deletion models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def parse_model(d, ch): # model_dict, input_channels(3)
LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
anchors, nc, gd, gw, act = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')
if act:
Conv.act = eval(act) # redefine default activation, i.e. Conv.act = nn.SiLU()
Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
LOGGER.info(f"{colorstr('activation:')} {act}") # print
na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
Expand Down
5 changes: 3 additions & 2 deletions utils/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
BAR_FORMAT = '{l_bar}{bar:10}{r_bar}{bar:-10b}' # tqdm bar format
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1))
PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders

# Get orientation exif tag
Expand Down Expand Up @@ -139,7 +140,7 @@ def create_dataloader(path,
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
generator = torch.Generator()
generator.manual_seed(0)
generator.manual_seed(6148914691236517205 + RANK)
return loader(dataset,
batch_size=batch_size,
shuffle=shuffle and sampler is None,
Expand Down Expand Up @@ -1169,7 +1170,7 @@ def create_classification_dataloader(path,
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
generator = torch.Generator()
generator.manual_seed(0)
generator.manual_seed(6148914691236517205 + RANK)
return InfiniteDataLoader(dataset,
batch_size=batch_size,
shuffle=shuffle and sampler is None,
Expand Down
8 changes: 5 additions & 3 deletions utils/segment/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from ..torch_utils import torch_distributed_zero_first
from .augmentations import mixup, random_perspective

RANK = int(os.getenv('RANK', -1))


def create_dataloader(path,
imgsz,
Expand Down Expand Up @@ -61,8 +63,8 @@ def create_dataloader(path,
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
# generator = torch.Generator()
# generator.manual_seed(0)
generator = torch.Generator()
generator.manual_seed(6148914691236517205 + RANK)
return loader(
dataset,
batch_size=batch_size,
Expand All @@ -72,7 +74,7 @@ def create_dataloader(path,
pin_memory=True,
collate_fn=LoadImagesAndLabelsAndMasks.collate_fn4 if quad else LoadImagesAndLabelsAndMasks.collate_fn,
worker_init_fn=seed_worker,
# generator=generator,
generator=generator,
), dataset


Expand Down

0 comments on commit f11a8a6

Please sign in to comment.