Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't convert trained AMP model to full precision #349

Closed
jtiscione opened this issue Jun 9, 2019 · 3 comments
Closed

Can't convert trained AMP model to full precision #349

jtiscione opened this issue Jun 9, 2019 · 3 comments

Comments

@jtiscione
Copy link

I have a basic benchmark test where I train a CNN on MNIST data with and without AMP. The problem is that I can't get the f16 types out of the model or export it to a CPU. Calling float() on the model doesn't seem to do anything.

import time
import torch.optim as optim
import torch
import torch.nn as nn
import torch.utils.data
import torch.onnx
import torchvision.datasets as dsets
import torchvision.transforms as trans
from apex import amp

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))

        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.lin1 = nn.Linear(32 * 7 * 7, 10)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(out.size(0), -1)
        out = self.lin1(out)
        return out

trainSet = dsets.MNIST(root='./data', train=True, transform=trans.ToTensor(), download=True)
trainLoader = torch.utils.data.DataLoader(dataset=trainSet, batch_size=100, shuffle=True)
trainData = iter(trainLoader)

testSet = dsets.MNIST(root='./data', train=False, transform=trans.ToTensor(), download=True)
testLoader = torch.utils.data.DataLoader(dataset=testSet, batch_size=100, shuffle=True)
testData = iter(testLoader)

def accuracy(testLoader, model):
    correct, total = 0, 0
    with torch.no_grad():
        for data in testLoader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return (correct / total)

NUM_EPOCHS = 3
device = torch.device('cuda:0')
criterion = nn.CrossEntropyLoss()
model = Net()
model.__init__()
model = model.to(device)
optimizer = optim.SGD(model.parameters(), 0.01)
start = time.time()
for epoch in range(NUM_EPOCHS):
    print('Epoch {}'.format(epoch))
    for i, (images, labels) in enumerate(trainLoader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
print('f32: Accuracy: {0:.4f}'.format(accuracy(testLoader, model)))
print('f32: Training time: {0:.2f}'.format(time.time() - start))
torch.save(model.state_dict(), './f32.pth')
dummy_input = torch.randn(1, 1, 28, 28).to(device)
torch.onnx.export(model, dummy_input, './f32.onnx', verbose=True)

print('*****************************************************************************')

model = Net()
model.__init__()
model = model.to(device)
optimizer = optim.SGD(model.parameters(), 0.01)
amp_model, amp_optimizer = amp.initialize(model, optimizer, opt_level="O1")
start = time.time()
for epoch in range(NUM_EPOCHS):
    print('Epoch {}'.format(epoch))
    for i, (images, labels) in enumerate(trainLoader):
        images = images.to(device)
        labels = labels.to(device)
        amp_optimizer.zero_grad()
        outputs = amp_model(images)
        loss = criterion(outputs, labels)
        # loss.backward()
        with amp.scale_loss(loss, amp_optimizer) as scaled_loss:
            scaled_loss.backward()
        amp_optimizer.step()
print('AMP: Accuracy: {0:.4f}'.format(accuracy(testLoader, amp_model)))
print('AMP: Training time: {0:.2f}'.format(time.time() - start))

amp_model = amp_model.float()  # This line doesn't do anything

torch.save(amp_model.state_dict(), './amp.pth')
dummy_input = torch.randn(1, 1, 28, 28).to(device)
torch.onnx.export(amp_model, dummy_input, './amp.onnx', verbose=True)

This code prints the following:

Epoch 0
Epoch 1
Epoch 2
f32: Accuracy: 0.9803
f32: Training time: 13.65
graph(%0 : Float(1, 1, 28, 28),
      %layer1.0.weight : Float(16, 1, 5, 5),
      %layer1.0.bias : Float(16),
      %layer1.1.weight : Float(16),
      %layer1.1.bias : Float(16),
      %layer1.1.running_mean : Float(16),
      %layer1.1.running_var : Float(16),
      %layer1.1.num_batches_tracked : Long(),
      %layer2.0.weight : Float(32, 16, 5, 5),
      %layer2.0.bias : Float(32),
      %layer2.1.weight : Float(32),
      %layer2.1.bias : Float(32),
      %layer2.1.running_mean : Float(32),
      %layer2.1.running_var : Float(32),
      %layer2.1.num_batches_tracked : Long(),
      %lin1.weight : Float(10, 1568),
      %lin1.bias : Float(10)):
  %17 : Float(1, 16, 28, 28) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[5, 5], pads=[2, 2, 2, 2], strides=[1, 1]](%0, %layer1.0.weight, %layer1.0.bias), scope: Net/Sequential[layer1]/Conv2d[0]
  %18 : Float(1, 16, 28, 28) = onnx::BatchNormalization[epsilon=1e-05, momentum=0.9](%17, %layer1.1.weight, %layer1.1.bias, %layer1.1.running_mean, %layer1.1.running_var), scope: Net/Sequential[layer1]/BatchNorm2d[1]
  %19 : Float(1, 16, 28, 28) = onnx::Relu(%18), scope: Net/Sequential[layer1]/ReLU[2]
  %20 : Float(1, 16, 14, 14) = onnx::MaxPool[kernel_shape=[2, 2], pads=[0, 0, 0, 0], strides=[2, 2]](%19), scope: Net/Sequential[layer1]/MaxPool2d[3]
  %21 : Float(1, 32, 14, 14) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[5, 5], pads=[2, 2, 2, 2], strides=[1, 1]](%20, %layer2.0.weight, %layer2.0.bias), scope: Net/Sequential[layer2]/Conv2d[0]
  %22 : Float(1, 32, 14, 14) = onnx::BatchNormalization[epsilon=1e-05, momentum=0.9](%21, %layer2.1.weight, %layer2.1.bias, %layer2.1.running_mean, %layer2.1.running_var), scope: Net/Sequential[layer2]/BatchNorm2d[1]
  %23 : Float(1, 32, 14, 14) = onnx::Relu(%22), scope: Net/Sequential[layer2]/ReLU[2]
  %24 : Float(1, 32, 7, 7) = onnx::MaxPool[kernel_shape=[2, 2], pads=[0, 0, 0, 0], strides=[2, 2]](%23), scope: Net/Sequential[layer2]/MaxPool2d[3]
  %25 : Long() = onnx::Constant[value={0}](), scope: Net
  %26 : Tensor = onnx::Shape(%24), scope: Net
  %27 : Long() = onnx::Gather[axis=0](%26, %25), scope: Net
  %28 : Long() = onnx::Constant[value={-1}](), scope: Net
  %29 : Tensor = onnx::Unsqueeze[axes=[0]](%27)
  %30 : Tensor = onnx::Unsqueeze[axes=[0]](%28)
  %31 : Tensor = onnx::Concat[axis=0](%29, %30)
  %32 : Float(1, 1568) = onnx::Reshape(%24, %31), scope: Net
  %33 : Float(1, 10) = onnx::Gemm[alpha=1, beta=1, transB=1](%32, %lin1.weight, %lin1.bias), scope: Net/Linear[lin1]
  return (%33)

*****************************************************************************
Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Epoch 0
Epoch 1
Epoch 2
AMP: Accuracy: 0.9788
AMP: Training time: 19.55
graph(%x.3 : Float(1, 1, 28, 28),
      %layer1.0.weight : Float(16, 1, 5, 5),
      %layer1.0.bias : Float(16),
      %layer1.1.weight : Float(16),
      %layer1.1.bias : Float(16),
      %layer1.1.running_mean : Float(16),
      %layer1.1.running_var : Float(16),
      %layer1.1.num_batches_tracked : Long(),
      %layer2.0.weight : Float(32, 16, 5, 5),
      %layer2.0.bias : Float(32),
      %layer2.1.weight : Float(32),
      %layer2.1.bias : Float(32),
      %layer2.1.running_mean : Float(32),
      %layer2.1.running_var : Float(32),
      %layer2.1.num_batches_tracked : Long(),
      %lin1.weight : Float(10, 1568),
      %lin1.bias : Float(10)):
  %17 : Half(16, 1, 5, 5) = onnx::Cast[to=10](%layer1.0.weight), scope: Net/Sequential[layer1]/Conv2d[0]
  %18 : Half(16) = onnx::Cast[to=10](%layer1.0.bias), scope: Net/Sequential[layer1]/Conv2d[0]
  %19 : Half(1, 1, 28, 28) = onnx::Cast[to=10](%x.3), scope: Net/Sequential[layer1]/Conv2d[0]
  %20 : Half(1, 16, 28, 28) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[5, 5], pads=[2, 2, 2, 2], strides=[1, 1]](%19, %17, %18), scope: Net/Sequential[layer1]/Conv2d[0]
  %21 : Half(1, 16, 28, 28) = onnx::BatchNormalization[epsilon=1e-05, momentum=0.9](%20, %layer1.1.weight, %layer1.1.bias, %layer1.1.running_mean, %layer1.1.running_var), scope: Net/Sequential[layer1]/BatchNorm2d[1]
  %22 : Half(1, 16, 28, 28) = onnx::Relu(%21), scope: Net/Sequential[layer1]/ReLU[2]
  %23 : Half(1, 16, 14, 14) = onnx::MaxPool[kernel_shape=[2, 2], pads=[0, 0, 0, 0], strides=[2, 2]](%22), scope: Net/Sequential[layer1]/MaxPool2d[3]
  %24 : Half(32, 16, 5, 5) = onnx::Cast[to=10](%layer2.0.weight), scope: Net/Sequential[layer2]/Conv2d[0]
  %25 : Half(32) = onnx::Cast[to=10](%layer2.0.bias), scope: Net/Sequential[layer2]/Conv2d[0]
  %26 : Half(1, 32, 14, 14) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[5, 5], pads=[2, 2, 2, 2], strides=[1, 1]](%23, %24, %25), scope: Net/Sequential[layer2]/Conv2d[0]
  %27 : Half(1, 32, 14, 14) = onnx::BatchNormalization[epsilon=1e-05, momentum=0.9](%26, %layer2.1.weight, %layer2.1.bias, %layer2.1.running_mean, %layer2.1.running_var), scope: Net/Sequential[layer2]/BatchNorm2d[1]
  %28 : Half(1, 32, 14, 14) = onnx::Relu(%27), scope: Net/Sequential[layer2]/ReLU[2]
  %29 : Half(1, 32, 7, 7) = onnx::MaxPool[kernel_shape=[2, 2], pads=[0, 0, 0, 0], strides=[2, 2]](%28), scope: Net/Sequential[layer2]/MaxPool2d[3]
  %30 : Long() = onnx::Constant[value={0}](), scope: Net
  %31 : Tensor = onnx::Shape(%29), scope: Net
  %32 : Long() = onnx::Gather[axis=0](%31, %30), scope: Net
  %33 : Long() = onnx::Constant[value={-1}](), scope: Net
  %34 : Tensor = onnx::Unsqueeze[axes=[0]](%32)
  %35 : Tensor = onnx::Unsqueeze[axes=[0]](%33)
  %36 : Tensor = onnx::Concat[axis=0](%34, %35)
  %37 : Half(1, 1568) = onnx::Reshape(%29, %36), scope: Net
  %38 : Half(10, 1568) = onnx::Cast[to=10](%lin1.weight), scope: Net/Linear[lin1]
  %39 : Half(10) = onnx::Cast[to=10](%lin1.bias), scope: Net/Linear[lin1]
  %40 : Half(1, 10) = onnx::Gemm[alpha=1, beta=1, transB=1](%37, %38, %39), scope: Net/Linear[lin1]
  return (%40)

Using mixed precision, the training time per batch went up 43%, but I can still increase the batch size now so that's OK. (This is on an RTX-2060 with 6GB.) What's more concerning is that I seem to be stuck in f16-land.
The expected behavior for model.float() is to convert all parameters and buffers to Float, but it's still riddled with these Half types making it useless in any environment with no f16 support. How do I get them out of there if float() doesn't do anything?

@mcarilli
Copy link
Contributor

mcarilli commented Jun 10, 2019

With opt_level="O1", Amp shouldn't directly change the parameter or buffer attributes of your model at all, in other words, they should remain fp32, so amp_model.float() should be a no-op. Can you post the results of amp_model.state_dict(), or

for param in model.parameters():
    print(param.dtype)

All of the leaves should be float.

However, onnx is also recording ops and temporaries internal to the graph. These will execute in a mixture of float and half, regardless of what type the leaves are, because with opt_level="O1", torch functions are patched to cast inputs on the fly. If your intention is to create an onnx graph that's pure float end to end, you can continue to use opt_level="O1", but run it with the auto-casting disabled:

# amp_model = amp_model.float() # with O1, this line should not be necessary
with amp.disable_casts():
    dummy_input = torch.randn(1, 1, 28, 28).to(device)
    torch.onnx.export(amp_model, dummy_input, './amp.onnx', verbose=True)

Unrelated: I suspect that your training time went up with mixed precision because your model is fairly small, and therefore not fully utilizing the device, such that the overhead of mixed precision casts becomes significant relative to the actual model ops. Also, your final linear layer's output size is 10, which is not a multiple of 8, and therefore won't be able to use Tensor Cores (#221 (comment)). You could probably pad the output size to 16 (6 unused/dummy classes) and see better performance.

@jtiscione
Copy link
Author

jtiscione commented Jun 10, 2019

That line with amp.disable_casts() was exactly what I needed. Without it, all the leaf nodes were indeed f32 as you pointed out, but the ONNX export was recording the flow back and forth through f16. If I run it with disable_casts() this is the ONNX output:

graph(%input.1 : Float(1, 1, 28, 28),
      %layer1.0.weight : Float(16, 1, 5, 5),
      %layer1.0.bias : Float(16),
      %layer1.1.weight : Float(16),
      %layer1.1.bias : Float(16),
      %layer1.1.running_mean : Float(16),
      %layer1.1.running_var : Float(16),
      %layer1.1.num_batches_tracked : Long(),
      %layer2.0.weight : Float(32, 16, 5, 5),
      %layer2.0.bias : Float(32),
      %layer2.1.weight : Float(32),
      %layer2.1.bias : Float(32),
      %layer2.1.running_mean : Float(32),
      %layer2.1.running_var : Float(32),
      %layer2.1.num_batches_tracked : Long(),
      %lin1.weight : Float(16, 1568),
      %lin1.bias : Float(16)):
  %17 : Float(1, 16, 28, 28) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[5, 5], pads=[2, 2, 2, 2], strides=[1, 1]](%input.1, %layer1.0.weight, %layer1.0.bias), scope: Net/Sequential[layer1]/Conv2d[0]
  %18 : Float(1, 16, 28, 28) = onnx::BatchNormalization[epsilon=1e-05, momentum=0.9](%17, %layer1.1.weight, %layer1.1.bias, %layer1.1.running_mean, %layer1.1.running_var), scope: Net/Sequential[layer1]/BatchNorm2d[1]
  %19 : Float(1, 16, 28, 28) = onnx::Relu(%18), scope: Net/Sequential[layer1]/ReLU[2]
  %20 : Float(1, 16, 14, 14) = onnx::MaxPool[kernel_shape=[2, 2], pads=[0, 0, 0, 0], strides=[2, 2]](%19), scope: Net/Sequential[layer1]/MaxPool2d[3]
  %21 : Float(1, 32, 14, 14) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[5, 5], pads=[2, 2, 2, 2], strides=[1, 1]](%20, %layer2.0.weight, %layer2.0.bias), scope: Net/Sequential[layer2]/Conv2d[0]
  %22 : Float(1, 32, 14, 14) = onnx::BatchNormalization[epsilon=1e-05, momentum=0.9](%21, %layer2.1.weight, %layer2.1.bias, %layer2.1.running_mean, %layer2.1.running_var), scope: Net/Sequential[layer2]/BatchNorm2d[1]
  %23 : Float(1, 32, 14, 14) = onnx::Relu(%22), scope: Net/Sequential[layer2]/ReLU[2]
  %24 : Float(1, 32, 7, 7) = onnx::MaxPool[kernel_shape=[2, 2], pads=[0, 0, 0, 0], strides=[2, 2]](%23), scope: Net/Sequential[layer2]/MaxPool2d[3]
  %25 : Long() = onnx::Constant[value={0}](), scope: Net
  %26 : Tensor = onnx::Shape(%24), scope: Net
  %27 : Long() = onnx::Gather[axis=0](%26, %25), scope: Net
  %28 : Long() = onnx::Constant[value={-1}](), scope: Net
  %29 : Tensor = onnx::Unsqueeze[axes=[0]](%27)
  %30 : Tensor = onnx::Unsqueeze[axes=[0]](%28)
  %31 : Tensor = onnx::Concat[axis=0](%29, %30)
  %32 : Float(1, 1568) = onnx::Reshape(%24, %31), scope: Net
  %33 : Float(1, 16) = onnx::Gemm[alpha=1, beta=1, transB=1](%32, %lin1.weight, %lin1.bias), scope: Net/Linear[lin1]

Without those mentions of f16 in there, it runs in other ONNX environments just fine. These little numbers are so weird looking, it's no wonder they cause so much trouble.

I'm glad you mentioned that multiple of 8 thing- it's easy to miss crap like that because stuff just works anyway. (Is there a way to figure out whether tensor cores are actually getting used? Should the batch size be a multiple of 8 also?) Kicking the number of outputs up to 16 seems to have no performance impact with this little model, but the model I'm actually using is about 3000 times larger with several huge fully connected layers and 343 output classes. It barely fits on the card but I'll try it with 344.

@mcarilli
Copy link
Contributor

Yes, the batch size should also be a multiple of 8. I pinned the issue I sent earlier (#221 (comment)) but it's still easy to miss. I'm planning to augment the Amp patching so that it will check tensor sizes entering FP16 linear layers, and warn once if sizes are not a multiple of 8.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants