Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient accumulate optimizer #2260

Closed
dathudeptrai opened this issue Dec 1, 2020 · 44 comments
Closed

Gradient accumulate optimizer #2260

dathudeptrai opened this issue Dec 1, 2020 · 44 comments

Comments

@dathudeptrai
Copy link

Describe the feature and the current behavior/state.

Hi, I think it's good if someone can support Gradient Accumulate optimizer for this repo, this feature is really helpful for those who train the large model with a low resource such as Bert, etc. The usage should be similar with tfa.optimizer.SWA:

opt = ...
accumulate_opt = tfa.optimizer.AccumulationOptimizer(opt, accumulate_steps=5)

There is an implementation of gradient accumulator but for custom training loop rather than Keras model fit here link.

Relevant information

  • Are you willing to contribute it (yes/no): no
  • Are you willing to maintain it going forward? (yes/no): no
  • Is there a relevant academic paper? (if so, where):
  • Is there already an implementation in another framework? (if so, where): here but for custom training loop.
  • Was it part of tf.contrib? (if so, where): no

Which API type would this fall under (layer, metric, optimizer, etc.)
optimizer
Who will benefit with this feature?
all tensorflow users.
Any other info.

@seanpmorgan
Copy link
Member

@tomerk please bring this up in ecosystem review, though I don't expect any conflicts. For the future do you need us to tag anyone or is the ecosystem-review label sufficient?

@bhack
Copy link
Contributor

bhack commented Dec 8, 2020

@dathudeptrai
Copy link
Author

@tomerk @bhack @seanpmorgan any update :D.

@dathudeptrai
Copy link
Author

@bhack just remind :D

@bhack
Copy link
Contributor

bhack commented Dec 17, 2020

Check #2196 (comment)

@tomerk
Copy link
Contributor

tomerk commented Dec 17, 2020

Notes from ecosystem review:

Rather than an optimizer that wraps another optimizer, we think this might actually make sense as an object to use as a gradient_transformer in optimizers (a new feature after an optimizer refactoring earlier this year):

https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Optimizer

Looping in @omalleyt12 who might have more insight / suggestions on how to do this.

Depending on how it comes out it might make sense in core, but addons seems like a good initial spot.

@dathudeptrai
Copy link
Author

@omalleyt12 can you take a look ?

@dathudeptrai
Copy link
Author

@tomerk @bhack @omalleyt12 a gentle ping in case you missed my previous comment :D

@bhack
Copy link
Contributor

bhack commented Dec 24, 2020

We have a quite duplicated ticket at tensorflow/tensorflow#32176

@dathudeptrai
Copy link
Author

@bhack it's a long time ago, any plan for this feature for TF2 ?

@chn-lee-yumi
Copy link

It would be great to have such a param:

From a user perspective, having a parameter (e.g. in .fit) like gradient_accumulate_batch_frequency=1 that updates gradients every N batches would be perfect. (1 as default to update every batch, as it is at the moment)

@innat
Copy link

innat commented Mar 4, 2021

IMO, this feature is not needed as we can implement gradient accumulation in the custom training in tf 2. Link.

@dathudeptrai
Copy link
Author

@innat we need it in case we want to use tf.keras model fit function :))).

@innat
Copy link

innat commented Mar 5, 2021

That should also possible to achieve by overriding the train_step method, customizing the .fit function.

@dathudeptrai
Copy link
Author

That should also possible to achieve by overriding the train_step method, customizing the .fit function.

overriding train_step can be considered as custom_training_loop. We need plug and play module and it should be apply for all model without custom train_step.

@innat
Copy link

innat commented Mar 5, 2021

If you want to plug and play, then try this. https://github.com/CyberZHG/keras-gradient-accumulation

@innat
Copy link

innat commented Mar 5, 2021

Overriding the train_step doesn't necessarily refer to a custom training loop, link. This is the leverage we get from new tf.keras, it should be adopted. Being too much plug-and-play stuff can bring lots of breakdown with mid or high-level packages update.

@dathudeptrai
Copy link
Author

dathudeptrai commented Mar 5, 2021

@innat i know train_step function in tf.keras. I also implement gradient accumulate in my framework (https://github.com/TensorSpeech/TensorFlowTTS). I'm not talking about how to implement it, I'm talking about that we need this module in this repo (ez maintain, stable and adapt with new TF version.). I can implement GA in both train_step and custom_training_loop but if we have a GA wrapper module for all base optimizer it would be better :D .

@innat
Copy link

innat commented Mar 5, 2021

@dathudeptrai understood. It sounds great then. However, I'm facing some issues with implementing GA by customizing the .fit. In case you're interested, please have a look here tensorflow/tensorflow#47578.

update

Solved: https://gist.github.com/innat/ba6740293e7b7b227829790686f2119c

@ajakoby
Copy link

ajakoby commented Jul 14, 2021

The above gradient accumulation implementation doesn;t work with TF2.5 with multi GPU distribution strategy
strategy = tf.distribute.MultiWorkerMirroredStrategy()

@fsx950223
Copy link
Member

fsx950223 commented Jul 15, 2021

# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
from tensorflow_addons.utils import types
from typeguard import typechecked


class GradientAccumulator(tf.keras.optimizers.Optimizer):
    """Optimizer wrapper for gradient accumulation."""
    @typechecked
    def __init__(
        self,
        optimizer: types.Optimizer,
        accum_steps: types.TensorLike = 4,
        name: str = "GradientAccumulator",
        **kwargs,
    ):
        r"""Construct a new GradientAccumulator optimizer.

        Args:
            optimizer: str or `tf.keras.optimizers.Optimizer` that will be
                used to compute and apply gradients.
            accum_steps: int > 0. Update gradient in every accumulation steps.
            name: Optional name for the operations created when applying
                gradients. Defaults to "GradientAccumulator".
            **kwargs: keyword arguments. Allowed to be {`clipnorm`,
                `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
                norm; `clipvalue` is clip gradients by value, `decay` is
                included for backward compatibility to allow time inverse
                decay of learning rate. `lr` is included for backward
                compatibility, recommended to use `learning_rate` instead.
        """
        super().__init__(name, **kwargs)
        self._optimizer = tf.keras.optimizers.get(optimizer)
        self._gradients = []
        self._accum_steps = accum_steps

    def _create_slots(self, var_list):
        self._optimizer._create_slots(var_list=var_list)
        for var in var_list:
            self.add_slot(var, "ga")

        self._gradients = [self.get_slot(var, "ga") for var in var_list]

    @property
    def gradients(self):
        """The accumulated gradients on the current replica."""
        if not self._gradients:
            raise ValueError(
                "The accumulator should be called first to initialize the gradients"
            )
        return list(
            gradient.read_value() if gradient is not None else gradient
            for gradient in self._gradients
        )

    def apply_gradients(self, grads_and_vars, name=None, **kwargs):
        self._optimizer._iterations = self.iterations
        return super().apply_gradients(grads_and_vars, name, **kwargs)

    def _resource_apply_dense(self, grad, var, apply_state=None):
        accum_gradient = self.get_slot(var, "ga")
        if accum_gradient is not None and grad is not None:
            accum_gradient.assign_add(
                grad, use_locking=self._use_locking, read_value=False
            )

        def _apply():
            if "apply_state" in self._optimizer._dense_apply_args:
                train_op = self._optimizer._resource_apply_dense(
                    accum_gradient.read_value(), var, apply_state=apply_state
                )
            else:
                train_op = self._optimizer._resource_apply_dense(
                    accum_gradient.read_value(), var
                )
            reset_op = accum_gradient.assign(
                        tf.zeros_like(accum_gradient),
                        use_locking=self._use_locking,
                        read_value=False,
                    )
            return tf.group(train_op, reset_op)

        apply_op = tf.cond(
            (self.iterations+1) % self._accum_steps == 0, _apply, lambda: tf.no_op()
        )
        return apply_op

    def _resource_apply_sparse(self, grad: types.TensorLike, var, indices, apply_state):
        accum_gradient = self.get_slot(var, "ga")
        if accum_gradient is not None and grad is not None:
            self._resource_scatter_add(accum_gradient, indices, grad)

        def _apply():
            if "apply_state" in self._optimizer._sparse_apply_args:
                train_op = self._optimizer._resource_apply_sparse(
                    accum_gradient.sparse_read(indices),
                    var,
                    indices,
                    apply_state=apply_state,
                )
            else:
                train_op = self._optimizer._resource_apply_sparse(
                    accum_gradient.sparse_read(indices), var, indices
                )
            reset_op = accum_gradient.assign(
                        tf.zeros_like(accum_gradient),
                        use_locking=self._use_locking,
                        read_value=False,
                    )
            return tf.group(train_op, reset_op)

        apply_op = tf.cond(
            (self.iterations+1) % self._accum_steps == 0, _apply, lambda: tf.no_op()
        )
        return apply_op

    def reset(self):
        """Resets the accumulated gradients on the current replica."""
        assign_ops = []
        if not self._gradients:
            return assign_ops

        for gradient in self._gradients:
            if gradient is not None:
                assign_ops.append(
                    gradient.assign(
                        tf.zeros_like(gradient),
                        use_locking=self._use_locking,
                        read_value=False,
                    )
                )

        return tf.group(assign_ops)

    @property
    def lr(self):
        return self._optimizer._get_hyper("learning_rate")

    @lr.setter
    def lr(self, lr):
        self._optimizer._set_hyper("learning_rate", lr)  #

    @property
    def learning_rate(self):
        return self._optimizer._get_hyper("learning_rate")

    @learning_rate.setter
    def learning_rate(self, learning_rate):
        self._optimizer._set_hyper("learning_rate", learning_rate)

    def get_config(self):
        config = {"accum_steps": self._accum_steps}
        base_config = super().get_config()
        return {**base_config, **config}



mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10)
])

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer=GradientAccumulator(tf.keras.optimizers.Adam(), accum_steps=4),
              loss=loss_fn,
              metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5)

Here is my implementation

@dathudeptrai
Copy link
Author

@fsx950223 Many thanks. If it's stable, could you make a pull request to support this feature for tensorflow_addon ?

@fsx950223
Copy link
Member

@fsx950223 Many thanks. If it's stable, could you make a pull request to support this feature for tensorflow_addon ?

I'm not sure, maybe you could test it.

@ajakoby
Copy link

ajakoby commented Jul 15, 2021

import tensorflow as tf
from tensorflow_addons.utils import types
from typeguard import typechecked

class GradientAccumulator(tf.keras.optimizers.Optimizer):
  """Gradient accumulation utility.
  When used with a distribution strategy, the accumulator should be called in a
  replica context. Gradients will be accumulated locally on each replica and
  without synchronization. Users should then call ``.gradients``, scale the
  gradients if required, and pass the result to ``apply_gradients``.
  """
  @typechecked
  def __init__(self, optimizer: types.Optimizer, accum_steps: types.TensorLike = 4, name: str = 'gradient_accumulator',
               **kwargs):
    """Initializes the accumulator."""
    super().__init__(name, **kwargs)
    self._optimizer = tf.keras.optimizers.get(optimizer)
    self._gradients = []
    self._accum_steps = accum_steps

  def _create_slots(self, var_list):
    self._optimizer._create_slots(var_list=var_list)
    for var in var_list:
      self.add_slot(var, "ga")

    self._gradients = [self.get_slot(var, "ga") for var in var_list]

  @property
  def gradients(self):
    """The accumulated gradients on the current replica."""
    if not self._gradients:
      raise ValueError(
        "The accumulator should be called first to initialize the gradients"
      )
    return list(
      gradient.value() if gradient is not None else gradient
      for gradient in self._gradients
    )

  def apply_gradients(self, grads_and_vars, name=None, **kwargs):
    self._optimizer._iterations = self.iterations
    return super().apply_gradients(grads_and_vars, name, **kwargs)

  def _resource_apply_dense(self, grad, var, apply_state=None):
    accum_gradient = self.get_slot(var, 'ga')
    if accum_gradient is not None and grad is not None:
      accum_gradient.assign_add(grad, use_locking=self._use_locking, read_value=False)
    def _apply():
      if "apply_state" in self._optimizer._dense_apply_args:
        train_op = self._optimizer._resource_apply_dense(
          accum_gradient, var, apply_state=apply_state
        )
      else:
        train_op = self._optimizer._resource_apply_dense(accum_gradient, var)
      reset_op = self.reset()
      return tf.group(train_op, reset_op)
    apply_op = tf.cond(self.iterations % self._accum_steps == 0, _apply,
                       lambda: tf.no_op())
    return apply_op

  def reset(self):
    """Resets the accumulated gradients on the current replica."""
    assign_ops = []
    if not self._gradients:
      return assign_ops

    for gradient in self._gradients:
      if gradient is not None:
        assign_ops.append(gradient.assign(tf.zeros_like(gradient), use_locking=self._use_locking, read_value=False))

    return tf.group(assign_ops)

  @property
  def lr(self):
    return self._optimizer._get_hyper("learning_rate")

  @lr.setter
  def lr(self, lr):
    self._optimizer._set_hyper("learning_rate", lr)  #

  @property
  def learning_rate(self):
    return self._optimizer._get_hyper("learning_rate")

  @learning_rate.setter
  def learning_rate(self, learning_rate):
    self._optimizer._set_hyper("learning_rate", learning_rate)

  def get_config(self):
    config = {
      "accum_steps": self.accum_steps
    }
    base_config = super().get_config()
    return {**base_config, **config}


mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10)
])
predictions = model(x_train[:1]).numpy()
# tf.nn.softmax(predictions).numpy()
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer=GradientAccumulator(tf.keras.optimizers.Adam()),
              loss=loss_fn,
              metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5)

Here is my implementation

When running this code on TF2.5 keras (run_eagerly=False) with MultiWorkerMirroredStrategy on 8 GPUs the time to train is ~2 times slower than running w/o gradient accumulation ( using the Class when accum_step=1). Do you know what is the reason for this x2 slowdown , the time to train should have stayed at the same.

@fsx950223
Copy link
Member

fsx950223 commented Jul 15, 2021

Could you provide tensorflow profiles?

@fsx950223
Copy link
Member

I found the default setting is faster than accum_steps=1 on a single device.

@fsx950223 fsx950223 mentioned this issue Jul 15, 2021
21 tasks
@fsx950223
Copy link
Member

You could test the code on my PR which I have fixed several bugs.

@andreped
Copy link

# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
from tensorflow_addons.utils import types
from typeguard import typechecked


class GradientAccumulator(tf.keras.optimizers.Optimizer):
    """Optimizer wrapper for gradient accumulation."""
    @typechecked
    def __init__(
        self,
        optimizer: types.Optimizer,
        accum_steps: types.TensorLike = 4,
        name: str = "GradientAccumulator",
        **kwargs,
    ):
        r"""Construct a new GradientAccumulator optimizer.

        Args:
            optimizer: str or `tf.keras.optimizers.Optimizer` that will be
                used to compute and apply gradients.
            accum_steps: int > 0. Update gradient in every accumulation steps.
            name: Optional name for the operations created when applying
                gradients. Defaults to "GradientAccumulator".
            **kwargs: keyword arguments. Allowed to be {`clipnorm`,
                `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
                norm; `clipvalue` is clip gradients by value, `decay` is
                included for backward compatibility to allow time inverse
                decay of learning rate. `lr` is included for backward
                compatibility, recommended to use `learning_rate` instead.
        """
        super().__init__(name, **kwargs)
        self._optimizer = tf.keras.optimizers.get(optimizer)
        self._gradients = []
        self._accum_steps = accum_steps

    def _create_slots(self, var_list):
        self._optimizer._create_slots(var_list=var_list)
        for var in var_list:
            self.add_slot(var, "ga")

        self._gradients = [self.get_slot(var, "ga") for var in var_list]

    @property
    def gradients(self):
        """The accumulated gradients on the current replica."""
        if not self._gradients:
            raise ValueError(
                "The accumulator should be called first to initialize the gradients"
            )
        return list(
            gradient.read_value() if gradient is not None else gradient
            for gradient in self._gradients
        )

    def apply_gradients(self, grads_and_vars, name=None, **kwargs):
        self._optimizer._iterations = self.iterations
        return super().apply_gradients(grads_and_vars, name, **kwargs)

    def _resource_apply_dense(self, grad, var, apply_state=None):
        accum_gradient = self.get_slot(var, "ga")
        if accum_gradient is not None and grad is not None:
            accum_gradient.assign_add(
                grad, use_locking=self._use_locking, read_value=False
            )

        def _apply():
            if "apply_state" in self._optimizer._dense_apply_args:
                train_op = self._optimizer._resource_apply_dense(
                    accum_gradient.read_value(), var, apply_state=apply_state
                )
            else:
                train_op = self._optimizer._resource_apply_dense(
                    accum_gradient.read_value(), var
                )
            reset_op = accum_gradient.assign(
                        tf.zeros_like(accum_gradient),
                        use_locking=self._use_locking,
                        read_value=False,
                    )
            return tf.group(train_op, reset_op)

        apply_op = tf.cond(
            (self.iterations+1) % self._accum_steps == 0, _apply, lambda: tf.no_op()
        )
        return apply_op

    def _resource_apply_sparse(self, grad: types.TensorLike, var, indices, apply_state):
        accum_gradient = self.get_slot(var, "ga")
        if accum_gradient is not None and grad is not None:
            self._resource_scatter_add(accum_gradient, indices, grad)

        def _apply():
            if "apply_state" in self._optimizer._sparse_apply_args:
                train_op = self._optimizer._resource_apply_sparse(
                    accum_gradient.sparse_read(indices),
                    var,
                    indices,
                    apply_state=apply_state,
                )
            else:
                train_op = self._optimizer._resource_apply_sparse(
                    accum_gradient.sparse_read(indices), var, indices
                )
            reset_op = accum_gradient.assign(
                        tf.zeros_like(accum_gradient),
                        use_locking=self._use_locking,
                        read_value=False,
                    )
            return tf.group(train_op, reset_op)

        apply_op = tf.cond(
            (self.iterations+1) % self._accum_steps == 0, _apply, lambda: tf.no_op()
        )
        return apply_op

    def reset(self):
        """Resets the accumulated gradients on the current replica."""
        assign_ops = []
        if not self._gradients:
            return assign_ops

        for gradient in self._gradients:
            if gradient is not None:
                assign_ops.append(
                    gradient.assign(
                        tf.zeros_like(gradient),
                        use_locking=self._use_locking,
                        read_value=False,
                    )
                )

        return tf.group(assign_ops)

    @property
    def lr(self):
        return self._optimizer._get_hyper("learning_rate")

    @lr.setter
    def lr(self, lr):
        self._optimizer._set_hyper("learning_rate", lr)  #

    @property
    def learning_rate(self):
        return self._optimizer._get_hyper("learning_rate")

    @learning_rate.setter
    def learning_rate(self, learning_rate):
        self._optimizer._set_hyper("learning_rate", learning_rate)

    def get_config(self):
        config = {"accum_steps": self._accum_steps}
        base_config = super().get_config()
        return {**base_config, **config}



mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10)
])

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer=GradientAccumulator(tf.keras.optimizers.Adam(), accum_steps=4),
              loss=loss_fn,
              metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5)

Here is my implementation

This seems to work just fine for training, but I have some strange behaviour when loading the trained model using tf.keras.models.load_model with compile=True.

After training, I just load the model like so:

model = tf.keras.models.load_model(
        model_path,
        custom_objects={
            "GradientAccumulator": GradientAccumulator(optimizer=Adam(ret.learning_rate), accum_steps=ret.accumsteps),
        },
        compile=True,
    )

This results in:

TypeError: missing a required argument: 'optimizer'

Probably something silly I am missing here? Any ideas, @fsx950223? I just trained a simple CNN classifier, nothing fancy. Works fine if regular Adam is used. Tested using TF 2.8 and Python 3.8.10.

@andreped
Copy link

I managed to get it working by doing the following modifications.

Seems like the optimizer argument is missing in the get_config. I rewrote the config variable to:

config = {"optimizer": self.optimizer, "accum_steps": self._accum_steps}

and added self.optimizer = optimizer to __init__().

Then when feeding the Adam optimizer, I got that:

TypeError: type of argument "optimizer" must be one of (keras.optimizer_v2.optimizer_v2.OptimizerV2, str); got dict instead

I believe that is because of the type hint optimizer: types.Optimizer in __init__(). If you wish to send the Adam optimizer (for instance) as input, it will fail as it will be stored as a Dict. However, that is the expected format, as you do self._optimizer=tf.keras.optimizers.get(optimizer) in the __init__(), where the get method expects either a string or Dict. Hence, I removed the type hint for the optimizer to just get optimizer. Might be a more preferable type hint, but I am not aware of any.

Lastly, in my example, the custom object should be custom_objects={"GradientAccumulator": GradientAccumulator}.

I could post the final, clean version, if of interest. If you are planning on merging this into TF-addons, I could contribute to the PR.

@dathudeptrai
Copy link
Author

Hi @andreped , I just saw your implementation here (https://github.com/andreped/GradientAccumulator). Great work! and I hope it will work fine with multi-gpu strategy :D .

@andreped
Copy link

andreped commented May 31, 2022

@dathudeptrai I figured it was time to make a solution available in TF 2 for myself and others who have been using similar solutions for accumulated gradients in TF1.

My implementation is a derived version from @fsx950223, which again is a modified version from the one mentioned above. Hence, all credit to him and the people who contributed to the PR #2525.

His PR seems to have been closed, without the current solution being available for people to test, debug, and further expand upon. I took upon the challenge to work on this project further, to see if I can get it working how I want.

Would love for this to be added to TF/TF-addons in the future (when we have a working solution). Accumulated gradients is something I use in almost all my projects.


EDIT: Currently, I am seeing some strange behaviour, where I am getting worse results with accumulated gradients (mini batch size=2, accum_steps 8 vs mini batch size 16) compared to regular batch training. Will try to solve that first, before I try to get multi-gpu training working properly. Also having a BatchNormalization layer that works with accumulated gradients would be a great contribution. Maybe it is possible to make a wrapper around the BatchNormalization similar as done here.

@innat
Copy link

innat commented May 31, 2022

Is there any known reason why we not asking for this feature directly in tensorflow or keras? It's a very useful and recognized feature to have.

It has been asked HERE before but closed without any valid reason.

@andreped
Copy link

@innat Good question. I think they concluded with that Keras is the best place to add it, however, to add it there, it should be working as intended, for all relevant use cases, and stable. I'm not sure that is the case for the current version. AFAIK, for one, it does not work in multi-GPU strategy scenarios.

However, that's why I believe TF-addons is a good alternative. Here one can place somewhat experimental solutions, which at least satisfies the largest bunch of users. I believe there is a lot of people using gradient accumulation on single-GPU scenarios - just working with large images/volumes/data in general, or working on low-end hardware.

We have created our own wrapper-like implementations since TF 1.13.1, and it has worked wonders on our use cases.

I have already asked them previously on why they closed the PR #2525 (comment). No response yet. Why that PR was closed in the end, I believe is because they were looking into a different way of solving the problem, which avoids the wrap-an-optimizer solution (even if it appears to work in single-GPU scenarios).

But at least for now, I have made the current implementation more easily available to the public here. At least until a proper solution is integrated into TF-addons/Keras.

@innat
Copy link

innat commented May 31, 2022

I still believe this technique is a good fit to have in core keras API. It would be nice to create a ticket in keras, HERE and discuss it further.

@andreped
Copy link

andreped commented Jun 1, 2022

@innat I observed that you have been asking a question regarding GA on stack overflow and was given an answer:
https://stackoverflow.com/a/66524901

I just tested this implementation, and I get identical results using the proposed overload train_step approach and the GradientAccumulator wrapping approach (which should be a lot easier to use):
https://github.com/andreped/GradientAccumulator/blob/832f769bbddae320d4975d796b5a30524e195085/GradientAccumulator/accumulator.py

However, in my benchmark experiments I am unable to get the same test performance for the same number of updated with and without accumulated gradients, so I believe something is fundamentally wrong, AFAIK.

I made a comment about this in a different thread (see here), and with a link on how to reproduce this weird issue. Perhaps you could take a look? Any idea on what is wrong?


EDIT: See this thread to track this issue. Will update you if I manage to solve it.

@andreped
Copy link

andreped commented Jun 1, 2022

@innat I have gotten the solution you were given on stack overflow working as expected. It yields expected behaviour compared to the gradient wrapping solution (but might be close to getting that one working as well). I have published a wheel you could try, if GA is still of interest to you. Will continue to expand upon this idea to make it more generic.

See here for more information: #2525 (comment) and https://github.com/andreped/GradientAccumulator

@innat
Copy link

innat commented Jun 2, 2022

I have published a wheel you could try.

@andreped I will. Thanks. (following)

if GA is still of interest to you. Will continue to expand upon this idea to make it more generic.

Let's create a ticket on keras if you want to contribute or provide a starter file for the interested contributor.

observed that you have been asking a question regarding GA on stack overflow and was given an answer:
stackoverflow.com/a/66524901

Have you tested it on a multi-gpu setup also? Is it working properly?

cc. @bhack what are your thoughts on this (GA) technique adding to core keras?

@bhack
Copy link
Contributor

bhack commented Jun 2, 2022

@bhack
Copy link
Contributor

bhack commented Jun 2, 2022

@bhack
Copy link
Contributor

bhack commented Jun 2, 2022

As Keras is having a quite important optimizer refactoring (#2706) it Is better to open a ticket about GA in Keras.

Probably Graphcore could be interested to contribute something there... who knows?

@andreped
Copy link

andreped commented Jun 2, 2022

Have you tested it on a multi-gpu setup also? Is it working properly?

The optimizer wrapper solution does not, but I have not gotten to testing the train_step overload approach. Might be. I can report what I find.

it Is better to open a ticket about GA in Keras.

Yeah, I think so too. I can open a ticket soon to track this feature request soon.

@andreped
Copy link

andreped commented Jun 2, 2022

I have opened a ticket now: keras-team/tf-keras#107

Lets move the discussion over there.

@andreped
Copy link

andreped commented Jun 3, 2022

I can implement GA in both train_step and custom_training_loop but if we have a GA wrapper module for all base optimizer it would be better

@dathudeptrai Just saw that you mentioned this, which is exactly what I did here. Did you have any problems with using it? Currently, it seems to work quite fine for my applications, but have not checked very advanced situations and edge cases.

Can also be noted that this our opened ticket at Keras has have gotten positive feedback, and it seems like they will add an API for doing gradient accumulation very soon :)
keras-team/tf-keras#107

@seanpmorgan
Copy link
Member

TensorFlow Addons is transitioning to a minimal maintenance and release mode. New features will not be added to this repository. For more information, please see our public messaging on this decision:
TensorFlow Addons Wind Down

Please consider sending feature requests / contributions to other repositories in the TF community with a similar charters to TFA:
Keras
Keras-CV
Keras-NLP

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

9 participants