Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New FusedAdam (since #424 commit 8 Aug 2019) Issues #475

Open
vince62s opened this issue Sep 5, 2019 · 35 comments
Open

New FusedAdam (since #424 commit 8 Aug 2019) Issues #475

vince62s opened this issue Sep 5, 2019 · 35 comments

Comments

@vince62s
Copy link

vince62s commented Sep 5, 2019

Team,

This huge PR #424 was not squashed and commits in a long period of time trigger lots of side effects.

As an example, if I use FusedAdam as of this commit 4a8c4ac then the memory footprint in GPUs enables to train with a certain batch size (eg 3072 tokens per batch).

On master, I have to reduce the batch size not to get a CUDA OOM.
(eg tokens 2944 per batch use case is OpenNMT-py)

More concerning, during training, using the same config (batch size 2944 tokens) and all other params being equal:

MASTER of Sept 5 2019
[2019-09-05 12:29:50,751 INFO] Step 100/50000; acc: 3.55; ppl: 6812.21; xent: 8.83; lr: 0.00002; 23557/28574 tok/s; 148 sec
[2019-09-05 12:31:02,317 INFO] Step 200/50000; acc: 6.69; ppl: 2112.41; xent: 7.66; lr: 0.00003; 48483/58933 tok/s; 220 sec
[2019-09-05 12:32:14,252 INFO] Step 300/50000; acc: 10.53; ppl: 597.04; xent: 6.39; lr: 0.00005; 48349/58938 tok/s; 292 sec
[2019-09-05 12:33:26,245 INFO] Step 400/50000; acc: 13.50; ppl: 339.95; xent: 5.83; lr: 0.00006; 48524/58874 tok/s; 364 sec
[2019-09-05 12:34:41,417 INFO] Step 500/50000; acc: 15.14; ppl: 235.72; xent: 5.46; lr: 0.00008; 46461/56365 tok/s; 439 sec
[2019-09-05 12:35:53,654 INFO] Step 600/50000; acc: 17.41; ppl: 176.58; xent: 5.17; lr: 0.00009; 48243/58731 tok/s; 511 sec
[2019-09-05 12:37:06,072 INFO] Step 700/50000; acc: 19.30; ppl: 139.99; xent: 4.94; lr: 0.00011; 48149/58553 tok/s; 584 sec

Commit 4a8c4ac
[2019-09-05 12:42:05,446 INFO] Step 100/50000; acc: 3.61; ppl: 6631.33; xent: 8.80; lr: 0.00002; 23708/28735 tok/s; 147 sec
[2019-09-05 12:43:16,428 INFO] Step 200/50000; acc: 7.97; ppl: 1824.98; xent: 7.51; lr: 0.00003; 48888/59853 tok/s; 218 sec
[2019-09-05 12:44:27,514 INFO] Step 300/50000; acc: 11.75; ppl: 524.17; xent: 6.26; lr: 0.00005; 48952/59220 tok/s; 289 sec
[2019-09-05 12:45:38,700 INFO] Step 400/50000; acc: 14.74; ppl: 278.04; xent: 5.63; lr: 0.00006; 49226/59553 tok/s; 360 sec
[2019-09-05 12:46:53,512 INFO] Step 500/50000; acc: 17.45; ppl: 181.13; xent: 5.20; lr: 0.00008; 46560/56793 tok/s; 435 sec
[2019-09-05 12:48:05,109 INFO] Step 600/50000; acc: 19.77; ppl: 132.15; xent: 4.88; lr: 0.00009; 48644/59191 tok/s; 507 sec
[2019-09-05 12:49:16,393 INFO] Step 700/50000; acc: 22.03; ppl: 101.56; xent: 4.62; lr: 0.00011; 48692/59330 tok/s; 578 sec

The accuracy / ppl seem much better on the old FusedAdam.

Any clue ?

@FDecaYed @mcarilli

@vince62s vince62s closed this as completed Sep 5, 2019
@vince62s vince62s reopened this Sep 5, 2019
@FDecaYed
Copy link
Contributor

FDecaYed commented Sep 5, 2019

In terms of accuracy, updated fusedadam incoporated this upstream fix
pytorch/pytorch@fed5ca1
Result will be slightly different compare to before, but should be correct and same as upstream adamW now. Do you observe any regression in term of final accuracy? This fix does affect initial steps more.

In terms of memory
updated fused adam no longer fuses grad unscaling and update, thus now fp32 grad is instantiated. This could be the main source of difference if you are comparing against old fusedadam.

There could also be other change affect how memory get reused, but those should not be strict more use. With the dynamic nature of pytorch's allocator, any change on order and life cycle of tensor could cause OOM if you are very close to limit, even though you are not using more memory in total.

@vince62s
Copy link
Author

vince62s commented Sep 5, 2019

I'll go for a full run and check if there is some regression, thanks for the explanation.

In terms of memory
updated fused adam no longer fuses grad unscaling and update, thus now fp32 grad is instantiated. This could be the main source of difference if you are comparing against old fusedadam.

What is the rationale for this change ?

@FDecaYed
Copy link
Contributor

FDecaYed commented Sep 5, 2019

We want fused optimizers to work with AMP O1, while not doing it by special casing in both optimizer and amp backend.
We expect speed benefit to be marginal(in your case ~1%) and memory benefit also small(most of the case optimizer is not peak of caching allocator). So trading it for expandibility and maintainability seem to be a good deal.

@mcarilli speaking of optimizer not being peak, I think new fusedadam keep more memory than before after zero_grad now.
on 4a8c4ac FusedAdam is still handled by FP16_optimizer path. Which means fp32 grad never exist and fp16 grad become None after zero_grad()
In latest master, only fp32 becomes None after zero_grad(), fp16 grad will be zeroed but keep existing through forward/backward(higher potential causing OOM compare to before)
any thought on changing this as well to improve that?
https://github.com/NVIDIA/apex/blob/master/apex/amp/_process_optimizer.py#L366

@vince62s
Copy link
Author

vince62s commented Sep 6, 2019

OK, I trained a bit longer but I am pretty sure something is not so good.
Here is training of base Tranformer for the first 15k steps (on 4 Gpu)

Old FusedAdam:
[2019-09-05 14:58:39,291 INFO] Step 14000/50000; acc: 87.11; ppl: 1.56; xent: 0.44; lr: 0.00075; 99195/112956 tok/s; 8319 sec
[2019-09-05 15:08:35,610 INFO] Step 15000/50000; acc: 87.16; ppl: 1.55; xent: 0.44; lr: 0.00072; 98597/112762 tok/s; 8915 sec
[2019-09-05 15:08:35,709 INFO] number of examples: 3000
[2019-09-05 15:08:43,581 INFO] Validation perplexity: 11.5924
[2019-09-05 15:08:43,581 INFO] Validation accuracy: 62.038

New FusedAdam:
[2019-09-06 11:14:15,308 INFO] Step 14000/50000; acc: 83.72; ppl: 1.83; xent: 0.60; lr: 0.00075; 99332/114016 tok/s; 8272 sec
[2019-09-06 11:24:08,907 INFO] Step 15000/50000; acc: 83.71; ppl: 1.83; xent: 0.60; lr: 0.00072; 99135/113397 tok/s; 8865 sec
[2019-09-06 11:24:09,016 INFO] number of examples: 3000
[2019-09-06 11:24:16,982 INFO] Validation perplexity: 11.0455
[2019-09-06 11:24:16,982 INFO] Validation accuracy: 60.5374

Losing 1 to 2 BLEU depending on the testset.

@FDecaYed
Copy link
Contributor

Hi @vince62s
Have you got any final result shows poor convergence?
Also, the best thing we can try is probably test your code with latest upstream AdamW. That is our 'golden standard' here. If both of that and new fusedadam shows worse convergence, then it will be very interesting to investigate whether the fix causes this and start a wider conversation.

I have tested the new fused adam against latest upstream pytorch adamW, with this example.
https://pytorch.org/tutorials/beginner/pytorch_with_examples.html#pytorch-optim
And new fusedadam is a lot closer to AdamW on fp32. On amp O2, I don't see any meaningful difference either, everything seems to be within numerical error range.

@vince62s
Copy link
Author

My O2 example above had already converged.
I'll try AdamW but as a matter of fact there is a difference between old and new fusedadam, not only because of the "minor fix" you mentioned. Or it is not minor.

@FDecaYed
Copy link
Contributor

FDecaYed commented Sep 10, 2019

How big is the difference on final result?

It is 'minor' per upstream. pytorch/pytorch@fed5ca1
And I definitely agree with you. We need to find out whether the change you're seeing is from this fix(meaning the fix is not minor), or there is other issue on our side.

@vince62s
Copy link
Author

@mcarilli @FDecaYed

It seems it is not related to the "minor" fix at all.
Let me explain.

At OpenNMT-py, before the integration of the new Apex API, (this commit: OpenNMT/OpenNMT-py@aaa220b)
we used to test both Adam, FusedAdam and it gave similar results (hence, with FP16_Optimizer)
Let me call this "Good results"

After the above commit, Adam in FP16 was using the new API when FusedAdam was still using the FP16_Optimizer path as mentioned by @FDecaYed in a previous post.

Since then we did all our training with FusedAdam.

Starting Aug 7/8 with the merge of #424 FusedAdam switched to the new API.
We started to have weird results (reason why I opened this issue).

We have tested several run with the same preprocessed data, and all hyperparameters being the same. Here is what we have:
Old API (as stated above): Good results for both FusedAdam and Adam
NewAPI: Wrong results for both FusedAdam and Adam
FP32 Adam without Apex: Good results
pytorch 1.1 or 1.2 has no impact (hence 'minor fix' no impact)

Bottom line: the issue is with the new API or the way we use the new API.
However using the New API with the old path for FusedAdam was fine, so we think our calls to the new API are ok.

We really need to find out what is going on because if we want no regression we would have to go back to FP16_Optimizer.

NB: I can give you some detailed numbers by email if needed. Also The dataset can be downloaded and I can give you a small script so that you can replicate.

Cheers.
Vincent

@mcarilli
Copy link
Contributor

mcarilli commented Sep 11, 2019

https://github.com/OpenNMT/OpenNMT-py/blob/97ad4c66997d6d4d80290f66ddc3f5dab6557ab5/onmt/utils/optimizers.py#L89-L90

Looks to me like your integration isn't right. You're passing a list with 2 models, but only receiving one on the LHS. Also, if model.generator is a submodule of model, you don't need to pass both. I would try changing this to

        model, optimizer = apex.amp.initialize(
            model,
            optimizer,
            opt_level=opt.apex_opt_level,
            loss_scale=loss_scale,
            # keep_batchnorm_fp32=False if opt.optim == "fusedadam" else None)
            keep_batchnorm_fp32=True)

One benefit of the latest FusedAdam API is that it can handle a mixture of fp32 and fp16 params, so you may say keep_batchnorm_fp32=True in all cases. FP16 batchnorm may well be a source of numerical degradation (and performance degradation because cudnn batchnorm requires keep_batchnorm_fp32=True).

@vince62s
Copy link
Author

vince62s commented Sep 11, 2019

Your suggestion does not work (we tested it as a first intention at the beg, and I retried it) it spits a cast error.
I just added a list on the LHS and getting the same results.

Not even talking about FusedAdam, is there anywhere a regression test between old API (FP16_optimizer) and New API, with plain Adam, on a real word dataset (not just a toy set) ?

I'll give a try with keep_batchnorm_fp32=True

@mcarilli
Copy link
Contributor

There are small unit tests but not an end to end test. However, we also use the new FusedAdam for full workloads and it seems fine. Does the latest version with keep_batchnorm_fp32=True improve accuracy?

@FDecaYed
Copy link
Contributor

FDecaYed commented Sep 11, 2019

Does amp api with pytorch adam work before #424 ?
if not
then problem is amp(or how you use it), not new fusedadam. you happend to start see this now because old fusedadam with amp fall back to fp16_optimizer in the backend and hide the problem.
if it works
then only difference between that and now should be the optimizer.(ignoring other possible changes
in O2 backend get pulled in from multi_tensor_sgd branch? @mcarilli )

This question will help us determine where we look into next.

Some other points:

  • both pytorch 1.1/2 have old adam/adamW. If we want to test effect of the optimizer fix, we need 1.3 adamW.(we may rule out this without testing since above question)
  • it seems grad clipping code in your repo is breaking for fp16, no matter what. This should not be the reason of this issue though.

@vince62s
Copy link
Author

vince62s commented Sep 11, 2019

To your first question:
it works in the sense it runs, but results are not good, significantly worse than before amp.
So yes, the problem is either amp or the way we use it.

However, when we use amp the way we use it, with old fusedadam then it works, maybe because it falls back to fp16_optimizer, but at least it shows that we call amp correctly.
right?

NB: can you be more specific for the clipping code thing? (yes sjould be unrelated to our issue since I am not using max_grad_norm in my tests)

@mcarilli
Copy link
Contributor

For reference, how does accuracy look with O1? In general O1 is always best practice, and the new FusedAdam should permit it (ie if that was the only reason you used O2, it is no longer relevant, but maybe there is some other reason).

@FDecaYed
Copy link
Contributor

For your way of using amp O2, I feel it is probably ok
you passed in a list of model and model.generator, but never caught the return. So both of them will be cast inplace and your old handle to them should now see fp16 model correctly.
I do agree with carilli that O1 is much well maintained. If O1 works, and since it can be use with fusedadam now, then switching to it might be the correct way moving forward.

For clipping:
I think 'update_master_grads' and 'clip_master_grads' are things in old fp16_utils.fp16_optimizer. So after you switch to amp api, neither real amp, nor old FusedAdam(fall back to optimizers.fp16_optimizer)have those. So clipping code won't be hit ever since then.
Plus, apex.amp.master_params(self._optimizer) should be called instead of apex.amp.master_params(self). If this code path is ever executed, you should see
'Optimizer' object has no attribute 'param_groups' already

@vince62s
Copy link
Author

Just finished the first 5000 iterations on 4 GPU with keep_batchnorm_fp32=True ==> no change.
Trying O1 now, but I fear that for our use case not keeping a master copy of weights in fp32 might be an issue.

@mcarilli
Copy link
Contributor

With O1, the model weights are left as FP32 (in other words, the model weights are the master weights). O1 is more conservative about casting to FP32 for internal functions (which is why it's recommended as an out-of-the-box approach) so it may use more memory, but hopefully not much.

@vince62s
Copy link
Author

it actually starts much much better, letting it run with Adam first, will do Fusedadam just after.
Now, it means something might be wrong with O2 then.

@mcarilli
Copy link
Contributor

O2 was our original hypothesized recipe for mixed precision, but maintaining separate master weights is more confusing for everyone. IMO the only remaining utility of O2 is to support certain internal use cases. O1 is what I hope all public codes are able to use, and O1 is the implementation that's I'm working on for Amp upstream (pytorch/pytorch#25081). The API will be different (as requested by upstream) but more flexible and powerful.

@vince62s
Copy link
Author

one thing though:
O1 Adam 61k tok/sec
O1 FusedAdam 69k tok/sec
O2 FusedAdam is 79k tok/sec
is O1 supposed to be that slower ?

@mcarilli
Copy link
Contributor

It depends. Currently O1 in Apex relies on a lot of Python side patching logic, so if any sections of your model are CPU bound, those sections may slow down. The upstream integration of O1-style casting will be entirely on the C++ side, and be faster.

@vince62s
Copy link
Author

okay an update with FusedAdam O1:
it seems to be numerically unstable

Getting this:
[2019-09-12 10:16:09,307 INFO] Step 6100/100000; acc: 56.92; ppl: 7.94; xent: 2.07; lr: 0.00113; 59343/68299 tok/s; 2599 sec
[2019-09-12 10:16:51,006 INFO] Step 6200/100000; acc: 56.59; ppl: 8.13; xent: 2.10; lr: 0.00112; 60399/68907 tok/s; 2641 sec
[2019-09-12 10:17:32,916 INFO] Step 6300/100000; acc: 56.81; ppl: 8.00; xent: 2.08; lr: 0.00111; 60540/68854 tok/s; 2683 sec
[2019-09-12 10:18:14,820 INFO] Step 6400/100000; acc: 56.72; ppl: 8.04; xent: 2.08; lr: 0.00110; 60109/67698 tok/s; 2725 sec
[2019-09-12 10:18:56,717 INFO] Step 6500/100000; acc: 57.68; ppl: 7.58; xent: 2.03; lr: 0.00110; 61251/68279 tok/s; 2767 sec
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 1048576.0
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 262144.0
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 65536.0
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 16384.0
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 4096.0
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 1024.0
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 256.0
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 64.0
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 16.0
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 4.0
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 1.0
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 0.25
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 0.0625
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 0.015625
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 0.00390625
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 0.0009765625
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 0.000244140625
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 6.103515625e-05
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 1.52587890625e-05
... many many lines
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 4.0389678347315804e-28
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 1.0097419586828951e-28
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 2.524354896707238e-29
[2019-09-12 10:19:38,376 INFO] Step 6600/100000; acc: 23.13; ppl: nan; xent: nan; lr: 0.00109; 60433/68886 tok/s; 2808 sec
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 1048576.0
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 262144.0
... again many many lines.
several times the same secnario and it ends up with:
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 7.571533991467358e-270
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 1.8928834978668395e-270
[2019-09-12 10:22:59,617 INFO] Loading dataset from exp/dataset.en-de.train.2.pt
Traceback (most recent call last):
File "/pytorchwork/OpenNMT-py-MASTER/onmt/trainer.py", line 372, in _gradient_accumulation
self.optim.backward(loss)
File "/pytorchwork/OpenNMT-py-MASTER/onmt/utils/optimizers.py", line 317, in backward
scaled_loss.backward()
File "/usr/lib/python3.6/contextlib.py", line 88, in exit
next(self.gen)
File "/usr/local/lib/python3.6/dist-packages/apex/amp/handle.py", line 120, in scale_loss
optimizer._post_amp_backward(loss_scaler)
File "/usr/local/lib/python3.6/dist-packages/apex/amp/_process_optimizer.py", line 241, in post_backward_no_master_weights
post_backward_models_are_masters(scaler, params, stashed_grads)
File "/usr/local/lib/python3.6/dist-packages/apex/amp/_process_optimizer.py", line 127, in post_backward_models_are_masters
scale_override=(grads_have_scale, stashed_have_scale, out_scale))
File "/usr/local/lib/python3.6/dist-packages/apex/amp/scaler.py", line 176, in unscale_with_stashed
out_scale/grads_have_scale, # 1./scale,
ZeroDivisionError: float division by zero

@vince62s
Copy link
Author

Since I need a working solution, I tried the following:

I copied the old FusedAdam class in my optimizers.py code, and use the current master FP16_Optimizer wrapper. It works fine, and actually a bit faster.

If I try to use the current master master FusedAdam in the same way (taht is to say FusedAdam without amp) then it does not work.
Either in theory new FusedAdam without amp should be wrapped in FP16_optimizer but then since step() exepct no argument, we have an issue here: https://github.com/NVIDIA/apex/blob/master/apex/optimizers/fp16_optimizer.py#L151
OR
if FusedAdam without amp is supposed not to be wrapped and used the same way as Adam, then does not work (does not learn).

I think I may stick to the old way until all of this gets integrated in Pytorch. But if O1 is the right way then something might be missing for FusedAdam.

@mcarilli
Copy link
Contributor

Hmm it's beginning to sound more like there's something wrong with new FusedAdam, but like I said, we have used it ourselves for several applications. Where are you implementing gradient clipping in your code, if anywhere?

@vince62s
Copy link
Author

here: https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/utils/optimizers.py#L328-L339
but the code is wrong as mentionned by @FDecaYed. I will fix it.
However, in my experiment, I am training a transformer with max_grad_norm = 0 so this code is not hit at all.

@mcarilli
Copy link
Contributor

mcarilli commented Sep 12, 2019

Yes, with the new FusedAdam+new API (no FP16_Optimizer) I don't believe you need to special-case the clipping code for FP16 at all. You can replace

        if self._fp16:
            if hasattr(self._optimizer, "update_master_grads"):
                self._optimizer.update_master_grads()
            if hasattr(self._optimizer, "clip_master_grads") and \
               self._max_grad_norm > 0:
                import apex
                torch.nn.utils.clip_grad_norm_(
                    apex.amp.master_params(self), self._max_grad_norm)
        for group in self._optimizer.param_groups:
            group['lr'] = learning_rate
            if not self._fp16 and self._max_grad_norm > 0:
                clip_grad_norm_(group['params'], self._max_grad_norm)

with simply

        for group in self._optimizer.param_groups:
            group['lr'] = learning_rate
            if not self._fp16 and self._max_grad_norm > 0:
                clip_grad_norm_(group['params'], self._max_grad_norm)

if your intention is to clip per-param-group.

However, if you're not clipping at all and still seeing problems, this is likely unrelated...Are you passing any sparse params/gradients to FusedAdam? I don't think it supports sparse gradients, but I don't think the old FusedAdam did either...

@vince62s
Copy link
Author

Are you passing any sparse params/gradients to FusedAdam? I don't think it supports sparse gradients, but I don't think the old FusedAdam did either...

no

@vince62s
Copy link
Author

@mcarilli
Copy link
Contributor

amp.master_params simply pulls from the optimizer param groups. It's recommended in the documentation to prevent people from, for example, passing model.parameters() to the clipping code, which with O2 may be different from the params present in the param groups. Pulling directly from the param groups is also always acceptable. https://github.com/OpenNMT/OpenNMT-py/pull/1560/files#diff-423be3bd1890af1b892704ba31891075R349-R364 could be condensed to

        if self._fp16 == "legacy":
            if hasattr(self._optimizer, "update_master_grads"):
                self._optimizer.update_master_grads()
            if hasattr(self._optimizer, "clip_master_grads") and \
               self._max_grad_norm > 0:
                self._optimizer.clip_master_grads(self._max_grad_norm)

        for group in self._optimizer.param_groups:
            group['lr'] = learning_rate
            if self._fp16 is None and self._max_grad_norm > 0:
                clip_grad_norm_(group['params'], self._max_grad_norm)

In the non-legacy case you are clipping per param group though, as opposed to clipping the gradients of all the params together.

@vince62s vince62s changed the title New FusedAdam (since #424 commit 8 Aug 2019) memory Footprint New FusedAdam (since #424 commit 8 Aug 2019) Issues Sep 13, 2019
@FDecaYed
Copy link
Contributor

@mcarilli Using old fusedadam thus clipping by group is also an known issue. We are working on our bert fine-tuning example to see what's the best way to address that.
@vince62s The only known different from pytorch adamW(make sure it is the adamW and post 1.2) and new fusedadam is the fix. We would really want to rule that out first.
You can try that(easist way is probably copy paste TOT adamw.py, since that's a python only change in upstream) or you can provide me a repro so I can help. A dockerfile or versions would be nice to rule out other differences.

@vince62s
Copy link
Author

@FDecaYed I am not sure what we want to prove here.
Based on my previous post, trying to clarify:

  1. New FusedAdam with AMP/O1 looks unstable but does not mean it is algorithmically wrong. Do you want me to compare with AdamW with AMP/O1 ?
  2. New FusedAdam without AMP does not work. Or maybe I am not getting the proper way to use it without AMP.

Hope this is clearer.
Anyhow I'll give a try to AdamW when I get time.

@FDecaYed
Copy link
Contributor

FDecaYed commented Sep 13, 2019

  1. I'm trying to rule out whether the unstable is caused by the upstream algo change. We know before the change with O1, it works.
    Thus running new AdamW with O1 should tell:
    if this does not work, then the upstream change is the problem.
    if it works, then new fusedadam has its own problem.
  2. 'New FusedAdam without AMP does not work', meaning you have tried it with fp32 no amp?

@vince62s
Copy link
Author

ok...

  1. Same crash with AdamW + O1
  2. I misunderstood the comment in FusedAdam. I thought "without amp" meant "old FP16_optimizer path". So yes New FusedAdam with FP32 works fine. What I meant is that with current apex master, there is no way to use FusedAdam with old API, correct?

Anyhow, we now need to find out why AdamW makes things unstable with FP16.

@FDecaYed
Copy link
Contributor

If new upstream adamw with O1 also causes the crash(and I suspect upstream adamw with o2 also possibly cause similar accuracy drop we see in new fusedadam with o2), then it does seems the algorithm change causes issue when work with fp16

I feel It’s a problem worth looking into. Let’s probably start with a repro and add back old algo to fused Adam

@FDecaYed
Copy link
Contributor

FDecaYed commented Oct 1, 2019

@mcarilli shall we look into and track this?
Latest results seems indicate new upstream adam/W unstable using with amp(not relate to fused though).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants