Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Activation Function Experiments #441

Closed
glenn-jocher opened this issue Aug 10, 2019 · 16 comments
Closed

Activation Function Experiments #441

glenn-jocher opened this issue Aug 10, 2019 · 16 comments
Assignees
Labels
enhancement New feature or request question Further information is requested Stale

Comments

@glenn-jocher
Copy link
Member

glenn-jocher commented Aug 10, 2019

This issue documents studies on the YOLOV3 activation function. The PyTorch 1.2 release updated some of the BatchNorm2D weight initializations (from 0-1 uniform random to all 1s), so I thought this would be a good time to benchmark the model and test the default repo against 3 possible improvements:

  1. Default nn.LeakyReLU(0.1, inplace=True)
  2. Swish class Swish(nn.Module)
  3. PRELU nn.PReLU(num_parameters=filters)
  4. PRELU nn.PReLU(num_parameters=1)

I benchmarked 5ff6e6b with each of the above activations on the small coco_64img.data tutorial dataset:

python3 train.py --img-size 416 --data data/coco_64img.data --batch-size 16 --accumulate 4 --nosave

results_activations_416

PReLU looks promising, but we can't draw any conclusions from this small dataset. In my next post I will plot the results on the full coco dataset trained to 10% of the final epochs, which should be a much more useful comparison.

python3 train.py --img-size 320 --data data/coco.data --batch-size 32 --accumulate 2 --epochs 27 --nosave
@glenn-jocher glenn-jocher added enhancement New feature or request question Further information is requested labels Aug 10, 2019
@glenn-jocher glenn-jocher self-assigned this Aug 10, 2019
@glenn-jocher
Copy link
Member Author

glenn-jocher commented Aug 11, 2019

Experiments on 5ff6e6b below. Results are test.py mAP at conf_thres = 0.1 > conf_thres = 0.001 > --save-json at the end of 27 coco.data epochs. No multi-scale.

Swish produces the best results, with the highest mAP and lowest validation losses, across almost all epochs (not just the final epoch), but the difference is small, and the increase in GPU memory is significant. LeakyReLU is 'inplace', reducing GPU memory, whereas swish requires +50% more GPU memory (being a custom module), and PRELU requires about 30% more GPU memory.

python3 train.py --img-size 320 --epochs 27 --batch-size 64 --accumulate 1  --nosave
  1. nn.LeakyReLU(0.1, inplace=True) (old default): 44.6
  2. class Swish(nn.Module): 44.9
  3. nn.PReLU(num_parameters=filters, init=0.10): 43.4
  4. scale_xy=1.2: 44.3
  5. scale_xy=1.1: 44.0
  6. scale_xy=1.5: 44.4
  7. (1.0 - giou) ** 2).mean() # giou^2 loss: 44.2
  8. yolov3-spp-pan.cfg: 44.4
  9. Initialize cls/obj biases -5 **(new default): 44.9
  10. Adam 9E-5: 45.2
  11. Adam uFBCE 8192: 45.4

results

@okanlv
Copy link

okanlv commented Nov 25, 2019

@glenn-jocher Following lukemelas/EfficientNet-PyTorch#88, GPU memory consumption for Swish decreases, if the swish implementation inherits torch.autograd.Function (code taken from the same pull request) :

class SwishImplementation(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i):
        result = i * torch.sigmoid(i)
        ctx.save_for_backward(i)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        i = ctx.saved_variables[0]
        sigmoid_i = torch.sigmoid(i)
        return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))

class MemoryEfficientSwish(nn.Module):
    def forward(self, x):
        return SwishImplementation.apply(x)

However, this will increase the training time. If you are still using Swish for some of your experiments and getting out of memory errors, it could be useful.

@glenn-jocher
Copy link
Member Author

@okanlv nice find!! It looks like Swish is indeed improving performance in this repo, so this new class may be very useful. I will test it on 1 epoch of COCO using the command below on a V100 GCP instance.

python3 train.py --data data/coco.data --cfg cfg/yolov3s.cfg --weights '' --epochs 1

Comparing it to the default LeakyReLU(0.1) and our current Swish() implementation:

class Swish(nn.Module):
    def forward(self, x):
        return x.mul_(torch.sigmoid(x))
-- Loss mAP@0.5 Time Mem
LeakyReLU(0.1, inplace=True) 15.5 0.0309 17:20 9.6G
Swish() 15.6 0.0483 19:28 13.0G
MemoryEfficientSwish() 15.7 0.0445 19:54 10.6G

That's strange, the two Swish versions are returning different losses and mAPs, with the memory efficient version worse in both. I had expected them to produce exactly the same results.

@okanlv
Copy link

okanlv commented Nov 26, 2019

Hmm, I didn't expect that. If I find anything else, I will keep you updated.

@glenn-jocher
Copy link
Member Author

glenn-jocher commented Nov 26, 2019

To double check, I trained to 27 epochs, and got the same results. MemoryEfficientSwish() produces worse results: 49.3 mAP vs 49.7 mAP compared to default Swish() implementation. I don't exactly know why. I use Apex for mixed precision training BTW, not sure if that has any effect.

@okanlv
Copy link

okanlv commented Dec 4, 2019

Edit: Both forward and backward for both functions produces the same results as expected. I have updated code to plot gradients for both functions.

@glenn-jocher, I might have found the problem. Inplace operation torch.mul_() in Swish class also changes its input. Running the following code shows the difference. Also, changing torch.mul_() to torch.mul() fixes this problem. I have not traced loss in backward direction to see computation graph so I cannot say anything for higher mAP at the moment.

import torch
import torch.nn as nn
import matplotlib.pyplot as plt


class Swish(nn.Module):
    def __init__(self):
        super(Swish, self).__init__()

    def forward(self, x):
        return x.mul_(torch.sigmoid(x))


class SwishImplementation(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i):
        result = i * torch.sigmoid(i)
        ctx.save_for_backward(i)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        i = ctx.saved_variables[0]
        sigmoid_i = torch.sigmoid(i)
        return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))


class MemoryEfficientSwish(nn.Module):
    def __init__(self):
        super(MemoryEfficientSwish, self).__init__()

    def forward(self, x):
        return SwishImplementation.apply(x)


f1 = Swish()
f2 = MemoryEfficientSwish()

# 1st method
# returns RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.
# x = torch.linspace(-5, 5, 1000)
# x1 = x.clone().detach()
# x1.requires_grad = True
# y1 = f1(x1)

# 2nd method
x = torch.linspace(-5, 5, 1000, requires_grad=True)
x_copy= x.clone().detach()
x_copy.requires_grad = True
x1 = x.clone()
x2 = x_copy.clone()

y1 = f1(x1)
y2 = f2(x2)

print('\nDid Swish changed its input?')
print(not torch.allclose(x, x1))
print('\nDid MemoryEfficientSwish changed its input?')
print(not torch.allclose(x, x2))

plt.xlim(-6, 6)
plt.ylim(-1, 6)
plt.plot(x.detach().numpy(), y1.detach().numpy())
plt.plot(x.detach().numpy(), y2.detach().numpy())
plt.title('Swish functions')
plt.legend(['Swish', 'MemoryEfficientSwish'], loc='upper left')
plt.show()

y1.backward(torch.ones_like(x))
y2.backward(torch.ones_like(x))

assert torch.allclose(y1, y2)
assert torch.allclose(x.grad, x_copy.grad)

def getBack(var_grad_fn):
    print(var_grad_fn)
    for n in var_grad_fn.next_functions:
        if n[0]:
            try:
                tensor = getattr(n[0], 'variable')
                print('\t', n[0])
                # print('Tensor with grad found:', tensor)
                # print(' - gradient:', tensor.grad)
                print()
            except AttributeError as e:
                getBack(n[0])


print('\nTracing backward functions for Swish')
getBack(y1.grad_fn)
print('\nTracing backward functions for MemoryEfficientSwish')
getBack(y2.grad_fn)

plt.xlim(-6, 6)
plt.ylim(-1, 2)
plt.plot(x.detach().numpy(), x.grad.detach().numpy())
plt.plot(x.detach().numpy(), x_copy.grad.detach().numpy())
plt.title('Swish gradient functions')
plt.legend(['Swish', 'MemoryEfficientSwish'], loc='upper left')
plt.show()

@FranciscoReveriano
Copy link
Contributor

How much better are the results?

@okanlv
Copy link

okanlv commented Dec 4, 2019

@FranciscoReveriano I am referring to @glenn-jocher 's results in this thread. I have not trained the model myself.

@glenn-jocher
Copy link
Member Author

@okanlv ah, so you are saying that the inplace operator in Swish() is interfering with the gradient computation? That's odd, because I trained with Swish() with and without the inplace operator .mul_ and got identical results before (but the inplace operator reduced memory a small bit, so I kept it).

So do you think the better results with Swish() might be a random occurance?

@FranciscoReveriano
Copy link
Contributor

I think they are a random occurrence.

@okanlv
Copy link

okanlv commented Dec 5, 2019

@glenn-jocher @FranciscoReveriano I am not sure actually because using x.clone() before applying both swish functions produced the same output and same gradient for x. Also, you did not get any errors during training with this inplaca operator, whereas it raised RuntimeError in my case. Considering we are creating yolov3 with nn.Sequential(), the internal backward calculation might be different. That being said, we can create yolov3 models with different swish implementation and compare their parameters (values and grads) after backwards pass. I will test this in a few days.

@glenn-jocher
Copy link
Member Author

@okanlv hmm interesting, ok, keep us updated! I think Swish might be something that we'll want to integrate more in the future, as it does seem to increase mAP a bit in most circumstances.

@FranciscoReveriano
Copy link
Contributor

Yeah I am looking more into understanding Swish. Might be very beneficial.

@github-actions
Copy link

github-actions bot commented Mar 7, 2020

This issue is stale because it has been open 30 days with no activity. Remove Stale label or comment or this will be closed in 5 days.

@sudo-rm-covid19
Copy link

@glenn-jocher Hi, I wonder when you change the loss, say from SmoothL1 to GIOU or activation from ReLU to Swish, will you train the entire model from scratch or load part of the pretrained weights from former version before change as a starting point?

@glenn-jocher
Copy link
Member Author

@sudo-rm-covid19 from scratch always.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request question Further information is requested Stale
Projects
None yet
Development

No branches or pull requests

4 participants