From 81653870ded8f8d9e17f4f7d574ddeccdb1ef142 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 25 Feb 2022 12:33:09 +0100 Subject: [PATCH] Update EMA decay `tau` (#6769) * Update EMA * Update EMA * ratio invert * fix ratio invert * fix2 ratio invert * warmup iterations to 100 * ema_k * implement tau * implement tau --- utils/torch_utils.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index c5257c6ebfeb..c11d2a4269ef 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -32,9 +32,7 @@ @contextmanager def torch_distributed_zero_first(local_rank: int): - """ - Decorator to make all processes in distributed training wait for each local_master to do something. - """ + # Decorator to make all processes in distributed training wait for each local_master to do something if local_rank not in [-1, 0]: dist.barrier(device_ids=[local_rank]) yield @@ -43,13 +41,13 @@ def torch_distributed_zero_first(local_rank: int): def date_modified(path=__file__): - # return human-readable file modification date, i.e. '2021-3-26' + # Return human-readable file modification date, i.e. '2021-3-26' t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime) return f'{t.year}-{t.month}-{t.day}' def git_describe(path=Path(__file__).parent): # path must be a directory - # return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe + # Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe s = f'git -C {path} describe --tags --long --always' try: return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1] @@ -99,7 +97,7 @@ def select_device(device='', batch_size=0, newline=True): def time_sync(): - # pytorch-accurate time + # PyTorch-accurate time if torch.cuda.is_available(): torch.cuda.synchronize() return time.time() @@ -205,7 +203,7 @@ def prune(model, amount=0.3): def fuse_conv_and_bn(conv, bn): - # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/ + # Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/ fusedconv = nn.Conv2d(conv.in_channels, conv.out_channels, kernel_size=conv.kernel_size, @@ -214,12 +212,12 @@ def fuse_conv_and_bn(conv, bn): groups=conv.groups, bias=True).requires_grad_(False).to(conv.weight.device) - # prepare filters + # Prepare filters w_conv = conv.weight.clone().view(conv.out_channels, -1) w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape)) - # prepare spatial bias + # Prepare spatial bias b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) @@ -252,7 +250,7 @@ def model_info(model, verbose=False, img_size=640): def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416) - # scales img(bs,3,y,x) by ratio constrained to gs-multiple + # Scales img(bs,3,y,x) by ratio constrained to gs-multiple if ratio == 1.0: return img else: @@ -302,13 +300,13 @@ class ModelEMA: For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage """ - def __init__(self, model, decay=0.9999, updates=0): + def __init__(self, model, decay=0.9999, tau=2000, updates=0): # Create EMA self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA # if next(model.parameters()).device.type != 'cpu': # self.ema.half() # FP16 EMA self.updates = updates # number of EMA updates - self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs) + self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs) for p in self.ema.parameters(): p.requires_grad_(False)