Skip to content

Commit

Permalink
Update EMA decay tau (ultralytics#6769)
Browse files Browse the repository at this point in the history
* Update EMA

* Update EMA

* ratio invert

* fix ratio invert

* fix2 ratio invert

* warmup iterations to 100

* ema_k

* implement tau

* implement tau
  • Loading branch information
glenn-jocher committed Feb 25, 2022
1 parent 4b4db08 commit 8165387
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@

@contextmanager
def torch_distributed_zero_first(local_rank: int):
"""
Decorator to make all processes in distributed training wait for each local_master to do something.
"""
# Decorator to make all processes in distributed training wait for each local_master to do something
if local_rank not in [-1, 0]:
dist.barrier(device_ids=[local_rank])
yield
Expand All @@ -43,13 +41,13 @@ def torch_distributed_zero_first(local_rank: int):


def date_modified(path=__file__):
# return human-readable file modification date, i.e. '2021-3-26'
# Return human-readable file modification date, i.e. '2021-3-26'
t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime)
return f'{t.year}-{t.month}-{t.day}'


def git_describe(path=Path(__file__).parent): # path must be a directory
# return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
# Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
s = f'git -C {path} describe --tags --long --always'
try:
return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1]
Expand Down Expand Up @@ -99,7 +97,7 @@ def select_device(device='', batch_size=0, newline=True):


def time_sync():
# pytorch-accurate time
# PyTorch-accurate time
if torch.cuda.is_available():
torch.cuda.synchronize()
return time.time()
Expand Down Expand Up @@ -205,7 +203,7 @@ def prune(model, amount=0.3):


def fuse_conv_and_bn(conv, bn):
# Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
# Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
fusedconv = nn.Conv2d(conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
Expand All @@ -214,12 +212,12 @@ def fuse_conv_and_bn(conv, bn):
groups=conv.groups,
bias=True).requires_grad_(False).to(conv.weight.device)

# prepare filters
# Prepare filters
w_conv = conv.weight.clone().view(conv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))

# prepare spatial bias
# Prepare spatial bias
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
Expand Down Expand Up @@ -252,7 +250,7 @@ def model_info(model, verbose=False, img_size=640):


def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
# scales img(bs,3,y,x) by ratio constrained to gs-multiple
# Scales img(bs,3,y,x) by ratio constrained to gs-multiple
if ratio == 1.0:
return img
else:
Expand Down Expand Up @@ -302,13 +300,13 @@ class ModelEMA:
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
"""

def __init__(self, model, decay=0.9999, updates=0):
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
# Create EMA
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
# if next(model.parameters()).device.type != 'cpu':
# self.ema.half() # FP16 EMA
self.updates = updates # number of EMA updates
self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
for p in self.ema.parameters():
p.requires_grad_(False)

Expand Down

0 comments on commit 8165387

Please sign in to comment.