Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update no_trainer scripts to include gradient accumulation #18436

Open
9 tasks
muellerzr opened this issue Aug 2, 2022 · 18 comments
Open
9 tasks

Update no_trainer scripts to include gradient accumulation #18436

muellerzr opened this issue Aug 2, 2022 · 18 comments
Labels
Examples Which is related to examples in general Good First Issue

Comments

@muellerzr
Copy link
Contributor

muellerzr commented Aug 2, 2022

Feature request

🤗 Accelerate has a gradient accumulation wrapper, and the no_trainer scripts should be updated to include it!

An example can be seen here, below is an example diff of what the integration would look like:

-     accelerator = (
-         Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator()
-     )
+     accelerator = (
+         Accelerator(log_with=args.report_to, logging_dir=args.output_dir, gradient_accumulation_steps=args.gradient_accumulation_steps) if args.with_tracking else Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
+     )

As well as:

-     num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+     num_update_steps_per_epoch = len(train_dataloader)

...


for step, batch in enumerate(train_dataloader):
+     with accelerator.accumulate(model):
-             loss = loss / args.gradient_accumulation_steps
            accelerator.backward(loss)
-             if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
                progress_bar.update(1)
                completed_steps += 1

The list of available scripts to update include:

  • examples/pytorch/image-classification/run_image_classification_no_trainer.py
  • examples/pytorch/language-modeling/run_clm_no_trainer.py
  • examples/pytorch/language-modeling/run_mlm_no_trainer.py
  • examples/pytorch/multiple-choice/run_swag_no_trainer.py
  • examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py
  • examples/pytorch/question_answering/run_qa_no_trainer.py
  • examples/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py
  • examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py
  • examples/pytorch/summarization/run_summarization_no_trainer.py

Motivation

This is a great first issue for someone who wants to learn how to use some of the latest bits in Accelerate and get an easy beginner contribution to the library 🤗

Your contribution

If you decide to pick up this issue, feel free to ping myself (@muellerzr), @sgugger, or @pacman100 to review 🤗

@muellerzr muellerzr added Examples Which is related to examples in general Good First Issue labels Aug 2, 2022
@Rasmusafj
Copy link
Contributor

Hi @muellerzr

I took a go at this (accelerate seems awesome!), and implemented the changes quickly. However, I noticed some performance degredation when using the the gradient accumulation wrapper.

After some debugging, I think it stems from the lr_scheduler implementation in accelerate updating learning rate at every step in training loop whereas the example script updates the learning rate every optimizer step.

So I think either accelerate needs to add something like

# Otherwise, first make sure the optimizer was stepped.
for opt in self.optimizers:
       if opt.step_was_skipped or not opt.gradient_state.sync_gradients:
             return

to scheduler.py implementation at line 59

Or the script should have

if accelerator.sync_gradients:
    lr_scheduler.step()

I think this should be changed in accelerate. Let me know what you think or if im totally off! I'll be happy to do issue + PR to fix in accelerate and I'll definetly fix the example scripts in transformers. :)

@sgugger
Copy link
Collaborator

sgugger commented Aug 4, 2022

No we can't do this as then the user would have to know in advance the number of optimization steps when they create their scheduler (which they don't since Accelerate handles gradient accumulation behind the scenes). That's why the learning rate scheduler should be created with the full number of training batches prior to gradient accumulation, then stepped at each batch (which is roughly equivalent to creating it with the right number of optimization batches and step at every optimization step).

@Rasmusafj
Copy link
Contributor

@sgugger Cool!

So if I understand you comment,

  • learning rate scheduler should not know anything about the actual optimization steps, but assume every batch is a step
    • Hence, num_training_steps for the lr_scheduler is num_training_steps=math.ceil(len(train_dataloader)) * args.num_train_epochs, instead of taking gradient_accumulation_steps into account
    • This means that if gradient_accumulation_steps is 5, we will take 4 steps of scheduling learning rate without actually using it for gradient updates

I've made a WIP pull request for the image examples/pytorch/image-classification/run_image_classification_no_trainer.py script (I'll update the rest of the scripts once i'm certain its the correct approach),

  • The current functionality of progress_bar / completed_steps is only increment when doing an optimization step i.e.
if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
    progress_bar.update(1)
    completed_steps += 1

So to keep the functionality, we need to know if optimization step occurred here which I think we can use

if accelerator.sync_gradients
    progress_bar.update(1)
    completed_steps += 1

but is this also something that should be kept away i.e. change logic a bit so that completed_steps == completed_batches instead of optimization_steps ?

@sgugger
Copy link
Collaborator

sgugger commented Aug 4, 2022

It's going to be easier to just have the progress bar display the number of completed steps. Also, we should multiply max_steps by the number of gradient accumulation steps for the same reason (if the user provides it).

@muellerzr
Copy link
Contributor Author

I think either option would work fine as well. The reason behind sync_gradients as part of the Accelerator is to provide this open interface to perform a check like this, so from an API design it's correct.

My $0.02 is to either explain in a comment what sync_gradients checks briefly, or to do as Sylvain recommended here.

@vedant-z
Copy link

Hi @muellerzr opened PR #18601 for second example in the list.

@Snimm
Copy link

Snimm commented Dec 29, 2022

Hi @muellerzr opened a PR for 8th example on the list. Please let me know if something is wrong. (This is my first contribution ever).

@mszsorondo
Copy link

Hi @muellerzr!
Any script to update yet?

@sameerreddy13
Copy link

sameerreddy13 commented Feb 2, 2023

Hi, I believe there is an issue with this PR (Rasmusafj:issue_18436), particularly for run_mlm_no_trainer.py. I am running BERT pretraining with this script and I run with the following arguments on 8 GPUs:
--num_warmup_steps 10000 --max_train_steps 200000 --checkpointing_steps 500 --per_device_batch_size 256 --gradient_accumulation_steps 2

When tracking the learning rate, the learning rate peaks at step 2500 (completed_steps == 2500), even though the training will stop at 200k completed_steps. My guess is the learning_rate is stepped for each of the 8 GPUs so the warmup is only actually 10k / 8 = 1.25k. Multiplied by the 2 gradient accum steps which are likely accounted for by the accumulate wrapper we end up with 2.5k warmup steps.

I saw it suggested above by @Rasmusafj that we only step the learning rate when sync_gradients is true, which I believe would solve this issue for me, and bring about the right expected behavior. I saw @sgugger recommended against this, however.

I am tracking the learning rate by printing lr_scheduler.get_last_lr()[0] every checkpointing_steps interval.
NOTE: I am using accelerate with the deepspeed plugin.

@sgugger
Copy link
Collaborator

sgugger commented Feb 2, 2023

cc @muellerzr so it's on your radar. It's True that then we use number of steps instead of number of epochs for a given training, the logic we have for the scheduler fails

@Hannibal046
Copy link

I meet the same problem as @sameerreddy13

@Hannibal046
Copy link

Maybe we should make it clear what does step mean in warmup_steps? one step fetching data from dataloader or one completed_step?

@sameerreddy13
Copy link

sameerreddy13 commented May 5, 2023

It should always be one gradient update step because that is the common assumption in literature as it is tied to the learning rate scheduler. In practice if we have batch size K and grad accum A we report the effective batch size as K * A. To fully fix this issue I did the following:

lr_scheduler = get_scheduler(
    name=args.lr_scheduler_type,
    optimizer=optimizer,
    num_warmup_steps=args.num_warmup_steps * accelerator.num_processes,
    num_training_steps=args.max_train_steps * accelerator.num_processes,
)
...
if step % args.gradient_accumulation_steps != 0:
    # Gradients only accumulate
    with accelerator.no_sync(model):
        outputs = model(**batch)
        accelerator.backward(outputs.loss)
    else:
        # Gradients finally sync
        outputs = model(**batch)
        accelerator.backward(outputs.loss)
        optimizer.step()
        optimizer.zero_grad()
        if (
            completed_steps < args.num_warmup_steps
            or lr_scheduler.get_last_lr()[0] > args.min_learning_rate
        ):
            lr_scheduler.step()

@sameerreddy13
Copy link

sameerreddy13 commented May 5, 2023

It's been a while since I made this change but I manually used no_sync. iirc there was some underlying issue with the accelerator.accumulate(model) . I believe when I did a validation loop inside the training loop (say every K batches you want to get validation loss) that this broke the gradient accumulation, and only one gradient accum step would happen irregardless of the configured argument. You can see this at a coarse grained level by putting a validation step inside the train loop, setting grad_accum to something like 4 and observing the training suddenly speed up after the first evaluation.

@Hannibal046
Copy link

@sameerreddy13 , I agree with you. I also write a snippet about this at huggingface/accelerate#1382 (comment) with two different points:

  • first, I initialize my lr_scheduler without *accelerate.num_processes and not pass it to prepare, do you think this is equivalent to yours?
  • I still use accelerator.accumulate(model) because I didn't notice the underlying issue, if that is really the case, what about only validating after certain completed steps rather than certain batches ?

@sandeepchittilla
Copy link

is this issue still open? can the relevant people mark which PRs have been are merged/w.i.p ?

I see there is #18601 from @vedant-z but it's been closed?

@lzy37ld
Copy link

lzy37ld commented Oct 28, 2023

Sry, any update or final answer here?

@Petroncini
Copy link

is this issue still open, @muellerzr? Im looking for a good first issue to tackle as a first time contributor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Examples Which is related to examples in general Good First Issue
Projects
None yet