Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update EMA decay tau #6769

Merged
merged 14 commits into from
Feb 25, 2022
22 changes: 10 additions & 12 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@

@contextmanager
def torch_distributed_zero_first(local_rank: int):
"""
Decorator to make all processes in distributed training wait for each local_master to do something.
"""
# Decorator to make all processes in distributed training wait for each local_master to do something
if local_rank not in [-1, 0]:
dist.barrier(device_ids=[local_rank])
yield
Expand All @@ -43,13 +41,13 @@ def torch_distributed_zero_first(local_rank: int):


def date_modified(path=__file__):
# return human-readable file modification date, i.e. '2021-3-26'
# Return human-readable file modification date, i.e. '2021-3-26'
t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime)
return f'{t.year}-{t.month}-{t.day}'


def git_describe(path=Path(__file__).parent): # path must be a directory
# return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
# Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
s = f'git -C {path} describe --tags --long --always'
try:
return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1]
Expand Down Expand Up @@ -99,7 +97,7 @@ def select_device(device='', batch_size=0, newline=True):


def time_sync():
# pytorch-accurate time
# PyTorch-accurate time
if torch.cuda.is_available():
torch.cuda.synchronize()
return time.time()
Expand Down Expand Up @@ -205,7 +203,7 @@ def prune(model, amount=0.3):


def fuse_conv_and_bn(conv, bn):
# Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
# Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
fusedconv = nn.Conv2d(conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
Expand All @@ -214,12 +212,12 @@ def fuse_conv_and_bn(conv, bn):
groups=conv.groups,
bias=True).requires_grad_(False).to(conv.weight.device)

# prepare filters
# Prepare filters
w_conv = conv.weight.clone().view(conv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))

# prepare spatial bias
# Prepare spatial bias
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
Expand Down Expand Up @@ -252,7 +250,7 @@ def model_info(model, verbose=False, img_size=640):


def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
# scales img(bs,3,y,x) by ratio constrained to gs-multiple
# Scales img(bs,3,y,x) by ratio constrained to gs-multiple
if ratio == 1.0:
return img
else:
Expand Down Expand Up @@ -302,13 +300,13 @@ class ModelEMA:
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
"""

def __init__(self, model, decay=0.9999, updates=0):
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
# Create EMA
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
# if next(model.parameters()).device.type != 'cpu':
# self.ema.half() # FP16 EMA
self.updates = updates # number of EMA updates
self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
for p in self.ema.parameters():
p.requires_grad_(False)

Expand Down