Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add gradient accumulator #2525

Closed
wants to merge 24 commits into from
Closed

add gradient accumulator #2525

wants to merge 24 commits into from

Conversation

fsx950223
Copy link
Member

@fsx950223 fsx950223 commented Jul 15, 2021

Description

Brief Description of the PR:

Fixes keras-team/keras#2260

Type of change

Checklist:

  • I've properly formatted my code according to the guidelines
    • By running Black + Flake8
    • By running pre-commit hooks
  • This PR addresses an already submitted issue for TensorFlow Addons
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • This PR contains modifications to C++ custom-ops

How Has This Been Tested?

If you're adding a bugfix or new feature please describe the tests that you ran to verify your changes:
*

I meet OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature. when I use python control flow, so I use tf.cond instead.

@stefan-falk
Copy link

Hi! I tried my own implementation for GA (see issue keras-team/keras#14829) but I wasn't able to make it work in combination with mixed-precision training.

Did you test your implementation for mixed-precision?

@fsx950223
Copy link
Member Author

Hi! I tried my own implementation for GA (see issue keras-team/keras#14829) but I wasn't able to make it work in combination with mixed-precision training.

Did you test your implementation for mixed-precision?

Yes, I have tested your example.

@stefan-falk
Copy link

@fsx950223 Just saw it! Very nice work. Thanks a lot :)

@bhack
Copy link
Contributor

bhack commented Jul 15, 2021

@stefan-falk
Copy link

stefan-falk commented Jul 15, 2021

I just tried to run the MNIST example again on multiple GPUs. Am I doing it wrong or is there a potential issue?

The only change I made was using the MirroredStrategy and more than one GPU:

strategy = tf.distribute.MirroredStrategy()

with strategy.scope():
    model = get_ffn_model(input_size=input_size, output_size=output_size, hidden_size=8)
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
    optimizer = GradientAccumulator(optimizer)
Full code (click me)
import tensorflow as tf
from tensorflow_addons.utils import types
from typeguard import typechecked
import numpy as np


@tf.keras.utils.register_keras_serializable(package="Addons")
class GradientAccumulator(tf.keras.optimizers.Optimizer):
    """Optimizer wrapper for gradient accumulation."""

    @typechecked
    def __init__(
        self,
        optimizer: types.Optimizer,
        accum_steps: types.TensorLike = 4,
        name: str = "GradientAccumulator",
        **kwargs,
    ):
        r"""Construct a new GradientAccumulator optimizer.

        Args:
            optimizer: str or `tf.keras.optimizers.Optimizer` that will be
                used to compute and apply gradients.
            accum_steps: int > 0. Update gradient in every accumulation steps.
            name: Optional name for the operations created when applying
                gradients. Defaults to "GradientAccumulator".
            **kwargs: keyword arguments. Allowed to be {`clipnorm`,
                `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
                norm; `clipvalue` is clip gradients by value, `decay` is
                included for backward compatibility to allow time inverse
                decay of learning rate. `lr` is included for backward
                compatibility, recommended to use `learning_rate` instead.
        """
        super().__init__(name, **kwargs)
        self._optimizer = tf.keras.optimizers.get(optimizer)
        self._gradients = []
        self._accum_steps = accum_steps

    def _create_slots(self, var_list):
        self._optimizer._create_slots(var_list=var_list)
        for var in var_list:
            self.add_slot(var, "ga")

        self._gradients = [self.get_slot(var, "ga") for var in var_list]

    @property
    def gradients(self):
        """The accumulated gradients on the current replica."""
        if not self._gradients:
            raise ValueError(
                "The accumulator should be called first to initialize the gradients"
            )
        return list(
            gradient.read_value() if gradient is not None else gradient
            for gradient in self._gradients
        )

    def apply_gradients(self, grads_and_vars, name=None, **kwargs):
        self._optimizer._iterations = self.iterations
        return super().apply_gradients(grads_and_vars, name, **kwargs)

    def _resource_apply_dense(self, grad, var, apply_state=None):
        accum_gradient = self.get_slot(var, "ga")
        if accum_gradient is not None and grad is not None:
            accum_gradient.assign_add(
                grad, use_locking=self._use_locking, read_value=False
            )

        def _apply():
            if "apply_state" in self._optimizer._dense_apply_args:
                train_op = self._optimizer._resource_apply_dense(
                    accum_gradient.read_value(), var, apply_state=apply_state
                )
            else:
                train_op = self._optimizer._resource_apply_dense(
                    accum_gradient.read_value(), var
                )
            reset_op = accum_gradient.assign(
                tf.zeros_like(accum_gradient),
                use_locking=self._use_locking,
                read_value=False,
            )
            return tf.group(train_op, reset_op)

        apply_op = tf.cond(
            (self.iterations + 1) % self._accum_steps == 0, _apply, lambda: tf.no_op()
        )
        return apply_op

    def _resource_apply_sparse(self, grad: types.TensorLike, var, indices, apply_state):
        accum_gradient = self.get_slot(var, "ga")
        if accum_gradient is not None and grad is not None:
            self._resource_scatter_add(accum_gradient, indices, grad)

        def _apply():
            if "apply_state" in self._optimizer._sparse_apply_args:
                train_op = self._optimizer._resource_apply_sparse(
                    accum_gradient.sparse_read(indices),
                    var,
                    indices,
                    apply_state=apply_state,
                )
            else:
                train_op = self._optimizer._resource_apply_sparse(
                    accum_gradient.sparse_read(indices), var, indices
                )
            reset_op = accum_gradient.assign(
                tf.zeros_like(accum_gradient),
                use_locking=self._use_locking,
                read_value=False,
            )
            return tf.group(train_op, reset_op)

        apply_op = tf.cond(
            (self.iterations + 1) % self._accum_steps == 0, _apply, lambda: tf.no_op()
        )
        return apply_op

    def reset(self):
        """Resets the accumulated gradients on the current replica."""
        assign_ops = []
        if not self._gradients:
            return assign_ops

        for gradient in self._gradients:
            if gradient is not None:
                assign_ops.append(
                    gradient.assign(
                        tf.zeros_like(gradient),
                        use_locking=self._use_locking,
                        read_value=False,
                    )
                )

        return tf.group(assign_ops)

    @property
    def lr(self):
        return self._optimizer._get_hyper("learning_rate")

    @lr.setter
    def lr(self, lr):
        self._optimizer._set_hyper("learning_rate", lr)  #

    @property
    def learning_rate(self):
        return self._optimizer._get_hyper("learning_rate")

    @learning_rate.setter
    def learning_rate(self, learning_rate):
        self._optimizer._set_hyper("learning_rate", learning_rate)

    def get_config(self):
        config = {
            "accum_steps": self._accum_steps,
            "optimizer": tf.keras.optimizers.serialize(self._optimizer),
        }
        base_config = super().get_config()
        return {**base_config, **config}

    @classmethod
    def from_config(cls, config, custom_objects=None):
        optimizer = tf.keras.optimizers.deserialize(
            config.pop("optimizer"), custom_objects=custom_objects
        )
        return cls(optimizer, **config)


def get_ffn_model(input_size: int, output_size: int, hidden_size: int = 64) -> tf.keras.Model:
    inputs = tf.keras.layers.Input(shape=(input_size,))
    x = inputs
    x = tf.keras.layers.Dense(units=hidden_size, activation='tanh')(x)
    x = tf.keras.layers.Dense(units=hidden_size, activation='tanh')(x)
    x = tf.keras.layers.Dense(units=output_size, activation='softmax')(x)
    return tf.keras.Model(inputs=inputs, outputs=x)


def make_dataset(inputs, targets, batch_size: int, split: str, limit: int = None):
    def sample_generator_():
        while True:
            idx = np.random.randint(0, len(inputs))
            yield inputs[idx].flatten(), tf.one_hot(targets[idx], depth=num_classes)

    assert split in ('train', 'test', 'dev'), \
        f'Split must be one of "train", "test" or "dev". Got: {split}'

    inputs = inputs.astype(np.float32) / 255.0
    inputs = np.expand_dims(inputs, axis=-1)
    num_classes = len(set(targets))

    input_shape = (np.prod(inputs[0].shape),)
    target_shape = (num_classes,)

    dataset = tf.data.Dataset.from_generator(
        lambda: sample_generator_(),
        output_types=(tf.float32, tf.float32),
        output_shapes=(input_shape, target_shape)
    )

    is_training = split == 'train'

    if is_training:
        dataset = dataset.repeat()

    if limit:
        dataset = dataset.take(limit)

    return dataset.padded_batch(batch_size)


def main():
    train_batch_size = 1
    valid_batch_size = 10
    grad_acc_n = 4
    steps_per_epoch = 1000 * grad_acc_n  # Make sure we have the same number of updates

    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    train_data = make_dataset(x_train, y_train, batch_size=train_batch_size, split='train')
    valid_data = make_dataset(x_test, y_test, batch_size=valid_batch_size, split='dev', limit=500)

    input_size = train_data.element_spec[0].shape[-1]
    output_size = train_data.element_spec[1].shape[-1]

    epochs = 2

    for precision_policy in ['float32', 'mixed_float16']:
        print('#' * 72)
        print(f'Setting precision-policy to "{precision_policy}"')

        tf.keras.mixed_precision.set_global_policy(precision_policy)

        strategy = tf.distribute.MirroredStrategy()

        with strategy.scope():
            model = get_ffn_model(input_size=input_size, output_size=output_size, hidden_size=8)
            optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
            optimizer = GradientAccumulator(optimizer)

        model.compile(
            optimizer=optimizer,
            loss='categorical_crossentropy',
            metrics=['accuracy']
        )

        model.fit(
            train_data,
            epochs=epochs,
            steps_per_epoch=steps_per_epoch // train_batch_size,
            validation_data=valid_data,
            validation_steps=10
        )

        loss, accuracy = model.evaluate(valid_data)

        print(f'Evaluation')
        print(f'  - Loss:     {loss:.4f}')
        print(f'  - Accuracy: {accuracy:.4f}')


if __name__ == '__main__':
    main()

Error log

2021-07-15 12:29:21.499653: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublas.so.11
2021-07-15 12:29:21.682201: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublasLt.so.11
   2/4000 [..............................] - ETA: 15:38 - loss: 2.4943 - accuracy: 0.0000e+00  Traceback (most recent call last):
  File "/home/sfalk/tmp/speech-v2/asr/bin/configs/mnist.py", line 275, in <module>
    main()
  File "/home/sfalk/tmp/speech-v2/asr/bin/configs/mnist.py", line 259, in main
    model.fit(
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/keras/engine/training.py", line 1183, in fit
    tmp_logs = self.train_function(iterator)
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py", line 889, in __call__
    result = self._call(*args, **kwds)
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py", line 917, in _call
    return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/eager/function.py", line 3023, in __call__
    return graph_function._call_flat(
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/eager/function.py", line 1960, in _call_flat
    return self._build_call_outputs(self._inference_function.call(
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/eager/function.py", line 591, in call
    outputs = execute.execute(
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/eager/execute.py", line 59, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.FailedPreconditionError: 2 root error(s) found.
  (0) Failed precondition:  Could not find variable _AnonymousVar66. This could mean that the variable has been deleted. In TF1, it can also mean the variable is uninitialized. Debug info: container=localhost, status=Not found: Resource localhost/_AnonymousVar66/N10tensorflow3VarE does not exist.
	 [[{{node GradientAccumulator/GradientAccumulator/update_1/update_0/cond/then/_32/GradientAccumulator/GradientAccumulator/update_1/update_0/cond/Cast_3/ReadVariableOp}}]]
	 [[GradientAccumulator/GradientAccumulator/group_deps/NoOp/_131]]
  (1) Failed precondition:  Could not find variable _AnonymousVar66. This could mean that the variable has been deleted. In TF1, it can also mean the variable is uninitialized. Debug info: container=localhost, status=Not found: Resource localhost/_AnonymousVar66/N10tensorflow3VarE does not exist.
	 [[{{node GradientAccumulator/GradientAccumulator/update_1/update_0/cond/then/_32/GradientAccumulator/GradientAccumulator/update_1/update_0/cond/Cast_3/ReadVariableOp}}]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_3410]

Function call stack:
train_function -> train_function

2021-07-15 12:29:22.029065: W tensorflow/core/kernels/data/generator_dataset_op.cc:107] Error occurred when finalizing GeneratorDataset iterator: Failed precondition: Python interpreter state is not initialized. The process may be terminated.
	 [[{{node PyFunc}}]]

@fsx950223
Copy link
Member Author

@fsx950223 Do you know what are these in TPU c++ and if they are related https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc#L189:L200

It seems TPU has implemented gradient accumulation. But it doesn't work for GPUs.

@fsx950223
Copy link
Member Author

strategy = tf.distribute.MirroredStrategy()

with strategy.scope():
    model = get_ffn_model(input_size=input_size, output_size=output_size, hidden_size=8)
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
    optimizer = GradientAccumulator(optimizer)

move compile and fit into strategy.scope()

@bhack
Copy link
Contributor

bhack commented Jul 15, 2021

@qlzh727 Do you like to have this in keras directly?

@stefan-falk
Copy link

@fsx950223 This should not be necessary (see) but I get the same error in both cases.

Does it work for you on multiple GPUs?

@fsx950223
Copy link
Member Author

@fsx950223 This should not be necessary (see) but I get the same error in both cases.

Does it work for you on multiple GPUs?

FIxed

@stefan-falk
Copy link

stefan-falk commented Jul 15, 2021

Nice done!

I am able to train the MNIST example but if I use GradientAccumulator on my actual task I am getting this error:

/home/sfalk/tmp/speech-v2/asr/training/gradient_accumulation.py:112 _resource_apply_sparse
    apply_op = tf.cond(
/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
    return target(*args, **kwargs)
/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/ops/control_flow_ops.py:1438 cond_for_tf_v2
    return cond(pred, true_fn=true_fn, false_fn=false_fn, strict=True, name=name)
/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
    return target(*args, **kwargs)
/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/util/deprecation.py:535 new_func
    return func(*args, **kwargs)
/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/ops/control_flow_ops.py:1254 cond
    return cond_v2.cond_v2(pred, true_fn, false_fn, name)
/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/ops/cond_v2.py:83 cond_v2
    true_graph = func_graph_module.func_graph_from_py_func(
/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py:999 func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
/home/sfalk/tmp/speech-v2/asr/training/gradient_accumulation.py:96 _apply
    accum_gradient.sparse_read(indices),
/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/ops/variables.py:945 sparse_read
    raise AttributeError

Any idea why this might be?

stefan-falk
stefan-falk previously approved these changes Jul 15, 2021
def _apply():
if "apply_state" in self._optimizer._sparse_apply_args:
train_op = self._optimizer._resource_apply_sparse(
accum_gradient.sparse_read(indices),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if it's on my side but if I run this in my own project, accum_gradient calls the unimplemented method tf.Variable#sparse_read which in turn throws an AttributeError.

@@ -108,14 +108,14 @@ def _resource_apply_sparse(self, grad: types.TensorLike, var, indices, apply_sta
def _apply():
if "apply_state" in self._optimizer._sparse_apply_args:
train_op = self._optimizer._resource_apply_sparse(
accum_gradient.sparse_read(indices),
accum_gradient.read_value(),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems one does not have to call read_value() here. Just passing in accum_gradient works for me as well.

def _resource_apply_sparse(self, grad: types.TensorLike, var, indices, apply_state):
accum_gradient = self.get_slot(var, "ga")
if accum_gradient is not None and grad is not None:
self._resource_scatter_add(accum_gradient, indices, grad)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The MNIST example is still working but in my main project I am getting:

  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/eager/execute.py", line 59, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found.
  (0) Invalid argument:  shape of indices ([187]) is not compatible with the shape of updates ([1004,256])
	 [[{{node cond_1/then/_631/cond_1/GradientAccumulator/GradientAccumulator/update_298/update_0/ResourceScatterAdd}}]]
	 [[cond_1/then/_631/cond_1/GradientAccumulator/GradientAccumulator/update_133/update_0/ReadVariableOp/_514]]
  (1) Invalid argument:  shape of indices ([187]) is not compatible with the shape of updates ([1004,256])
	 [[{{node cond_1/then/_631/cond_1/GradientAccumulator/GradientAccumulator/update_298/update_0/ResourceScatterAdd}}]]

I think it might be the call to self._resource_scatter_add(accum_gradient, indices, grad)?

Comment on lines 75 to 128
def _resource_apply_dense(self, grad, var, apply_state=None):
accum_gradient = self.get_slot(var, "ga")
if accum_gradient is not None and grad is not None:
accum_gradient.assign_add(
grad, use_locking=self._use_locking, read_value=False
)

def _apply():
if "apply_state" in self._optimizer._dense_apply_args:
train_op = self._optimizer._resource_apply_dense(
accum_gradient, var, apply_state=apply_state
)
else:
train_op = self._optimizer._resource_apply_dense(accum_gradient, var)
reset_op = accum_gradient.assign(
tf.zeros_like(accum_gradient),
use_locking=self._use_locking,
read_value=False,
)
return tf.group(train_op, reset_op)

apply_op = tf.cond(
self.iterations % self._accum_steps == 0, _apply, lambda: tf.no_op()
)
return apply_op

def _resource_apply_sparse(self, grad: types.TensorLike, var, indices, apply_state):
accum_gradient = self.get_slot(var, "ga")
if accum_gradient is not None and grad is not None:
self._resource_scatter_add(accum_gradient, indices, grad)

def _apply():
if "apply_state" in self._optimizer._sparse_apply_args:
train_op = self._optimizer._resource_apply_sparse(
accum_gradient,
var,
indices,
apply_state=apply_state,
)
else:
train_op = self._optimizer._resource_apply_sparse(
accum_gradient, var, indices
)
reset_op = accum_gradient.assign(
tf.zeros_like(accum_gradient),
use_locking=self._use_locking,
read_value=False,
)
return tf.group(train_op, reset_op)

apply_op = tf.cond(
self.iterations % self._accum_steps == 0, _apply, lambda: tf.no_op()
)
return apply_op
Copy link

@stefan-falk stefan-falk Jul 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can simplify this a bit:

Suggested change
def _resource_apply_dense(self, grad, var, apply_state=None):
accum_gradient = self.get_slot(var, "ga")
if accum_gradient is not None and grad is not None:
accum_gradient.assign_add(
grad, use_locking=self._use_locking, read_value=False
)
def _apply():
if "apply_state" in self._optimizer._dense_apply_args:
train_op = self._optimizer._resource_apply_dense(
accum_gradient, var, apply_state=apply_state
)
else:
train_op = self._optimizer._resource_apply_dense(accum_gradient, var)
reset_op = accum_gradient.assign(
tf.zeros_like(accum_gradient),
use_locking=self._use_locking,
read_value=False,
)
return tf.group(train_op, reset_op)
apply_op = tf.cond(
self.iterations % self._accum_steps == 0, _apply, lambda: tf.no_op()
)
return apply_op
def _resource_apply_sparse(self, grad: types.TensorLike, var, indices, apply_state):
accum_gradient = self.get_slot(var, "ga")
if accum_gradient is not None and grad is not None:
self._resource_scatter_add(accum_gradient, indices, grad)
def _apply():
if "apply_state" in self._optimizer._sparse_apply_args:
train_op = self._optimizer._resource_apply_sparse(
accum_gradient,
var,
indices,
apply_state=apply_state,
)
else:
train_op = self._optimizer._resource_apply_sparse(
accum_gradient, var, indices
)
reset_op = accum_gradient.assign(
tf.zeros_like(accum_gradient),
use_locking=self._use_locking,
read_value=False,
)
return tf.group(train_op, reset_op)
apply_op = tf.cond(
self.iterations % self._accum_steps == 0, _apply, lambda: tf.no_op()
)
return apply_op
def _resource_apply_dense(self, grad, var, apply_state=None):
accum_gradient = self.get_slot(var, "ga")
if accum_gradient is not None and grad is not None:
accum_gradient.assign_add(
grad, use_locking=self._use_locking, read_value=False
)
def _apply():
train_op = self._optimizer._resource_apply_dense(
accum_gradient,
var,
apply_state=apply_state
)
reset_op = accum_gradient.assign(
tf.zeros_like(accum_gradient),
use_locking=self._use_locking,
read_value=False,
)
return tf.group(train_op, reset_op)
return self._apply_op(_apply)
def _resource_apply_sparse(self, grad: types.TensorLike, var, indices, apply_state):
accum_gradient = self.get_slot(var, "ga")
if accum_gradient is not None and grad is not None:
self._resource_scatter_add(accum_gradient, indices, grad)
def _apply():
train_op = self._optimizer._resource_apply_sparse(
accum_gradient,
var,
indices,
apply_state=apply_state,
)
reset_op = accum_gradient.assign(
tf.zeros_like(accum_gradient),
use_locking=self._use_locking,
read_value=False,
)
return tf.group(train_op, reset_op)
return self._apply_op(_apply)
def _apply_op(self, apply_fn):
return tf.cond(
self.iterations % self._accum_steps == 0,
true_fn=apply_fn,
false_fn=lambda: tf.no_op()
)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

accum_grad must be reseted after apply grad

Copy link

@stefan-falk stefan-falk Jul 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, indeed. 👍 I re-added the reset_op.

@stefan-falk
Copy link

stefan-falk commented Jul 15, 2021

@fsx950223 I was able to reproduce and track down the issue. It appears that it is coming from the tf.keras.layers.Embedding layer - or at least something in there does not seem to work with the current version of the GradientAccumulator.

Could you try this on your side?

Example

Code (click me)
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow_addons.utils import types
from typeguard import typechecked


class GradientAccumulator(tf.keras.optimizers.Optimizer):
    """Optimizer wrapper for gradient accumulation."""

    @typechecked
    def __init__(
        self,
        optimizer: types.Optimizer,
        accum_steps: types.TensorLike = 4,
        name: str = "GradientAccumulator",
        **kwargs,
    ):
        r"""Construct a new GradientAccumulator optimizer.

        Args:
            optimizer: str or `tf.keras.optimizers.Optimizer` that will be
                used to compute and apply gradients.
            accum_steps: int > 0. Update gradient in every accumulation steps.
            name: Optional name for the operations created when applying
                gradients. Defaults to "GradientAccumulator".
            **kwargs: keyword arguments. Allowed to be {`clipnorm`,
                `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
                norm; `clipvalue` is clip gradients by value, `decay` is
                included for backward compatibility to allow time inverse
                decay of learning rate. `lr` is included for backward
                compatibility, recommended to use `learning_rate` instead.
        """
        super().__init__(name, **kwargs)
        self._optimizer = tf.keras.optimizers.get(optimizer)
        self._gradients = []
        self._accum_steps = accum_steps

    def _create_slots(self, var_list):
        self._optimizer._create_slots(var_list=var_list)
        for var in var_list:
            self.add_slot(var, "ga")

        self._gradients = [self.get_slot(var, "ga") for var in var_list]

    @property
    def gradients(self):
        """The accumulated gradients on the current replica."""
        if not self._gradients:
            raise ValueError(
                "The accumulator should be called first to initialize the gradients"
            )
        return list(
            gradient.read_value() if gradient is not None else gradient
            for gradient in self._gradients
        )

    def apply_gradients(self, grads_and_vars, name=None, **kwargs):
        self._optimizer._iterations = self.iterations
        return super().apply_gradients(grads_and_vars, name, **kwargs)

    def _resource_apply_dense(self, grad, var, apply_state=None):
        accum_gradient = self.get_slot(var, "ga")
        if accum_gradient is not None and grad is not None:
            accum_gradient.assign_add(
                grad, use_locking=self._use_locking, read_value=False
            )

        def _apply():
            train_op = self._optimizer._resource_apply_dense(
                accum_gradient,
                var,
                apply_state=apply_state
            )
            reset_op = accum_gradient.assign(
                tf.zeros_like(accum_gradient),
                use_locking=self._use_locking,
                read_value=False,
            )
            return tf.group(train_op, reset_op)

        return self._apply_op(_apply)

    def _resource_apply_sparse(self, grad: types.TensorLike, var, indices, apply_state):
        accum_gradient = self.get_slot(var, "ga")
        if accum_gradient is not None and grad is not None:
            self._resource_scatter_add(accum_gradient, indices, grad)

        def _apply():
            train_op = self._optimizer._resource_apply_sparse(
                accum_gradient,
                var,
                indices,
                apply_state=apply_state,
            )
            reset_op = accum_gradient.assign(
                tf.zeros_like(accum_gradient),
                use_locking=self._use_locking,
                read_value=False,
            )
            return tf.group(train_op, reset_op)

        return self._apply_op(_apply)

    def _apply_op(self, apply_fn):
        return tf.cond(
            self.iterations % self._accum_steps == 0,
            true_fn=apply_fn,
            false_fn=lambda: tf.no_op()
        )

    def reset(self):
        """Resets the accumulated gradients on the current replica."""
        assign_ops = []
        if not self._gradients:
            return assign_ops

        for gradient in self._gradients:
            if gradient is not None:
                assign_ops.append(
                    gradient.assign(
                        tf.zeros_like(gradient),
                        use_locking=self._use_locking,
                        read_value=False,
                    )
                )

        return tf.group(assign_ops)

    @property
    def lr(self):
        return self._optimizer._get_hyper("learning_rate")

    @lr.setter
    def lr(self, lr):
        self._optimizer._set_hyper("learning_rate", lr)  #

    @property
    def learning_rate(self):
        return self._optimizer._get_hyper("learning_rate")

    @learning_rate.setter
    def learning_rate(self, learning_rate):
        self._optimizer._set_hyper("learning_rate", learning_rate)

    def get_config(self):
        config = {
            "accum_steps": self._accum_steps,
            "optimizer": tf.keras.optimizers.serialize(self._optimizer),
        }
        base_config = super().get_config()
        return {**base_config, **config}

    @classmethod
    def from_config(cls, config, custom_objects=None):
        optimizer = tf.keras.optimizers.deserialize(
            config.pop("optimizer"), custom_objects=custom_objects
        )
        return cls(optimizer, **config)


def get_dataset(vocab_size: int, batch_size: int = 10):
    def generator():
        size = np.random.randint(5, 10)
        x = np.random.randint(low=0, high=vocab_size, size=size)
        y = np.asarray([np.random.rand()])
        yield x, y

    return tf.data.Dataset.from_generator(
        generator=generator,
        output_types=(tf.int32, tf.float32),
        output_shapes=((None,), (1,))
    ).padded_batch(batch_size)


def main():
    vocab_size = 10

    inputs = layers.Input(shape=(None,), dtype=tf.int32)
    x = inputs
    x = layers.Embedding(input_dim=vocab_size, output_dim=8)(x)
    x = layers.Dense(1)(x)
    model = tf.keras.models.Model(inputs=inputs, outputs=x)

    optimizer = GradientAccumulator(optimizer='adam')

    model.compile(
        optimizer=optimizer,
        loss='mse'
    )

    train_data = get_dataset(vocab_size).repeat()

    model.fit(
        train_data,
        epochs=100,
        steps_per_epoch=1000
    )


if __name__ == '__main__':
    main()

Error log

    model.fit(
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/keras/engine/training.py", line 1183, in fit
    tmp_logs = self.train_function(iterator)
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py", line 889, in __call__
    result = self._call(*args, **kwds)
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py", line 950, in _call
    return self._stateless_fn(*args, **kwds)
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/eager/function.py", line 3023, in __call__
    return graph_function._call_flat(
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/eager/function.py", line 1960, in _call_flat
    return self._build_call_outputs(self._inference_function.call(
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/eager/function.py", line 591, in call
    outputs = execute.execute(
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/eager/execute.py", line 59, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError:  shape of indices ([6]) is not compatible with the shape of updates ([10,8])
	 [[node GradientAccumulator/GradientAccumulator/update/ResourceScatterAdd (defined at /tmp/speech-v2/asr/bin/tmp.py:87) ]] [Op:__inference_train_function_1050]

Errors may have originated from an input operation.
Input Source operations connected to node GradientAccumulator/GradientAccumulator/update/ResourceScatterAdd:
 GradientAccumulator/GradientAccumulator/update/UnsortedSegmentSum (defined at /tmp/speech-v2/asr/bin/tmp.py:60)

Function call stack:
train_function

@stefan-falk
Copy link

Awesome. With 2760fad I am able to use the GradientAccumulator in my project! 👍

@stefan-falk
Copy link

I've set CUDA_VISIBLE_DEVICES=1 but it writes:

Process finished with exit code 0
SKIPPED [ 50%]
Skipped: The gpu is not available.
SKIPPED [100%]
Skipped: The gpu is not available.

I use virtual devices

Ah, that's probably an option too. How can I do that?

@fsx950223
Copy link
Member Author

I've set CUDA_VISIBLE_DEVICES=1 but it writes:

Process finished with exit code 0
SKIPPED [ 50%]
Skipped: The gpu is not available.
SKIPPED [100%]
Skipped: The gpu is not available.

I use virtual devices

Ah, that's probably an option too. How can I do that?

Take a look at tensorflow_addons/utils/test_utils.py

@fsx950223
Copy link
Member Author

I've set CUDA_VISIBLE_DEVICES=1 but it writes:

Process finished with exit code 0
SKIPPED [ 50%]
Skipped: The gpu is not available.
SKIPPED [100%]
Skipped: The gpu is not available.

I use virtual devices

Ah, that's probably an option too. How can I do that?

For test rnn layer with multi GPUs, you can't use virtual devices, it seems hasn't been supported.

@stefan-falk
Copy link

Ah okay, I see. They rely on one physical devices and split that into multiple virtual GPUs. For some reason it does not find my GPUs though - I guess the problem is on my side here.

@mlyg
Copy link

mlyg commented Jul 25, 2021

@fsx950223 Thanks for creating this.

Does this gradient accumulator also work for SGD? I tested it using Adam with no issues, but when I switch to SGD it does not appear to work.

@fsx950223
Copy link
Member Author

There are some models use SGD as optimizer in my test cases.

@mlyg
Copy link

mlyg commented Jul 25, 2021

@fsx950223 Thanks for the quick reply.

I tested on the MNIST dataset with a simple example:

# Set a seed value
seed_value= 12321 
# 1. Set `PYTHONHASHSEED` environment variable at a fixed value
import os
os.environ['PYTHONHASHSEED']=str(seed_value)
# 2. Set `python` built-in pseudo-random generator at a fixed value
import random
random.seed(seed_value)
# 3. Set `numpy` pseudo-random generator at a fixed value
import numpy as np
np.random.seed(seed_value)
# 4. Set `tensorflow` pseudo-random generator at a fixed value
import tensorflow as tf
tf.random.set_seed(seed_value)

# Model 
input = tf.keras.Input(shape=(28, 28))
base_maps = tf.keras.layers.Flatten(input_shape=(28, 28))(input)
base_maps = tf.keras.layers.Dense(128, activation='relu')(base_maps)
base_maps = tf.keras.layers.Dense(units=10, activation='softmax', name='primary')(base_maps) 
model = tf.keras.Model(inputs=[input], outputs=[base_maps])

# bind all
model.compile(
    loss = tf.keras.losses.CategoricalCrossentropy(),
    metrics = ['accuracy'],
    optimizer = GradientAccumulator(tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9, nesterov=True), accum_steps=1))

# data 
(x_train, y_train), (_, _) = tf.keras.datasets.mnist.load_data()
x_train = tf.divide(x_train, 255)
y_train = tf.one_hot(y_train , depth=10) 
    
# customized fit 
model.fit(x_train, y_train, batch_size=512, epochs=3, verbose = 1)

I expected that if I change the batch size e.g. to 512 and accum_steps to 1, it should have the same results as batch size 256 and accum_steps to 2. This is the case when I use Adam but I get different results with SGD. Any ideas?

@fsx950223
Copy link
Member Author

fsx950223 commented Jul 25, 2021 via email

@mlyg
Copy link

mlyg commented Jul 25, 2021

@fsx950223 I think the cause is the momentum term in SGD. If I set momentum to zero I can get the same results (I think your test cases have momentum = 0). Is there anything that can be done to take into account updating with momentum?

@stefan-falk
Copy link

stefan-falk commented Jul 26, 2021

I just tested the current implementation in my ASR project. Unfortunately, I am still getting errors. 😞

    self.model.fit(
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/keras/engine/training.py", line 1183, in fit
    tmp_logs = self.train_function(iterator)
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py", line 889, in __call__
    result = self._call(*args, **kwds)
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py", line 917, in _call
    return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/eager/function.py", line 3023, in __call__
    return graph_function._call_flat(
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/eager/function.py", line 1960, in _call_flat
    return self._build_call_outputs(self._inference_function.call(
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/eager/function.py", line 591, in call
    outputs = execute.execute(
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/eager/execute.py", line 59, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.NotFoundError: 5 root error(s) found.
  (0) Not found:  Resource localhost/_AnonymousVar1328/N10tensorflow3VarE does not exist.
	 [[{{node cond_1/then/_2500/cond_1/update_0/AssignAddVariableOp}}]]
  (1) Not found:  Resource localhost/_AnonymousVar1328/N10tensorflow3VarE does not exist.
	 [[{{node cond_1/then/_2500/cond_1/update_0/AssignAddVariableOp}}]]
	 [[cond_1/then/_2500/cond_1/GradientAccumulator/ReadVariableOp_17/_3060]]
  (2) Not found:  Resource localhost/_AnonymousVar1328/N10tensorflow3VarE does not exist.
	 [[{{node cond_1/then/_2500/cond_1/update_0/AssignAddVariableOp}}]]
	 [[cond_1/then/_2500/cond_1/replica_3/GradientAccumulator/mod_220/_6276]]
  (3) Not found:  Resource localhost/_AnonymousVar1330/N10tensorflow3VarE does not exist.
	 [[{{node cond_1/then/_2500/cond_1/update_2/AssignAddVariableOp}}]]
  (4) Not found:  Resource localhost/_AnonymousVar1328/N10tensorflow3VarE does not exist.
	 [[{{node cond_1/then/_2500/cond_1/update_0/AssignAddVariableOp}}]]
	 [[Func/cond_1/then/_2500/input/_27328/_1669]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_387453]

Function call stack:
train_function -> train_function -> train_function -> train_function -> train_function

@danyang-rainbow
Copy link

danyang-rainbow commented Jul 30, 2021

When I use new latest version of this module, I got such a warning. Is there a loop that should be replaced with tf loop?

Large unrolled loop detected. Did you mean to use a TF loop? The following ops were created after iteration 3002: (<tf.Operation 'VarIsInitializedOp_3000/resource' type=Placeholder>, <tf.Operation 'VarIsInitializedOp_3000' type=VarIsInitializedOp>, <tf.Operation 'LogicalAnd_3000' type=LogicalAnd>)
See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/common_errors.md#warning-large-unrolled-loop-detected

@john32979
Copy link

@fsx950223 I think you may have an error in _accum_grad(). In my runs the "ga" slot is always 0 until I added control dependencies on handle.scatter_add() and handle.assign_add(). It looks like the add to the gradient accumulators gets lost in the graph optimization. Please see below for changes I made to make your code work.

def _accum_grad(grads_and_vars):
    new_grads_and_vars = []
    for grad, var in grads_and_vars:
        handle = self.get_slot(var, "ga")
    
        if isinstance(grad, tf.IndexedSlices):
            accum_op = handle.scatter_add(grad)
    
            def _get_grad():
                new_grad = handle.read_value()
                if self._reduction == "MEAN":
                    new_grad /= tf.cast(self._accum_steps, new_grad.dtype)
                indices = tf.squeeze(
                    tf.where(
                        tf.reduce_sum(
                            new_grad, axis=list(range(len(new_grad.shape))[1:])
                        )
                        != 0
                    ),
                    axis=-1,
                )
    
                values = tf.gather(new_grad, indices)
                dense_shape = tf.constant(new_grad.shape.as_list())
                handle.assign(
                    tf.zeros_like(handle),
                    use_locking=self._use_locking,
                    read_value=False,
                )
                return values, tf.cast(indices, grad.indices.dtype), dense_shape
    
            with tf.control_dependencies([accum_op]):
                values, indices, dense_shape = tf.cond(
                    tf.equal(tf.experimental.numpy.remainder(self.step, self._accum_steps), tf.constant(0, dtype=tf.int64)),
                    _get_grad,
                    lambda: (
                        tf.zeros_like(grad.values),
                        grad.indices,
                        grad.dense_shape,
                    ),
                )
                new_grad = tf.IndexedSlices(values, indices, dense_shape)
                new_grads_and_vars.append((new_grad, var))
        else:
            accum_op = handle.assign_add(
                grad, use_locking=self._use_locking, read_value=False
            )
    
            def _get_grad():
                new_grad = handle.read_value()
                if self._reduction == "MEAN":
                    new_grad /= tf.cast(self._accum_steps, new_grad.dtype)
                handle.assign(
                    tf.zeros_like(handle),
                    use_locking=self._use_locking,
                    read_value=False,
                )
                return new_grad
    
            with tf.control_dependencies([accum_op]):
                new_grad = tf.cond(
                    tf.equal(tf.experimental.numpy.remainder(self.step, self._accum_steps), tf.constant(0, dtype=tf.int64)),
                    _get_grad,
                    lambda: tf.zeros_like(grad),
                )
                new_grads_and_vars.append((new_grad, var))
    return new_grads_and_vars

@danyang-rainbow
Copy link

danyang-rainbow commented Aug 2, 2021

@fsx950223 I think I am able to reproduce this. The problem arises when I am using the tf.keras.layers.LSTM layer like so:

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers

# Import GradientAccumulator here ..


def get_dataset(vocab_size: int, batch_size: int = 10):
    def generator():
        size = np.random.randint(5, 10)
        x = np.random.randint(low=0, high=vocab_size, size=size)
        y = np.asarray([np.random.rand()])
        yield x, y

    return tf.data.Dataset.from_generator(
        generator=generator,
        output_types=(tf.int32, tf.float32),
        output_shapes=((None,), (1,))
    ).padded_batch(batch_size)


def main():
    vocab_size = 10

    strategy = tf.distribute.MirroredStrategy()

    with strategy.scope():
        tf.keras.mixed_precision.set_global_policy('mixed_float16')
        inputs = layers.Input(shape=(None,), dtype=tf.int32)
        x = inputs
        x = layers.Embedding(input_dim=vocab_size, output_dim=8)(x)
        x = layers.LSTM(8)(x)
        x = layers.Dense(1)(x)
        model = tf.keras.models.Model(inputs=inputs, outputs=x)

        optimizer = GradientAccumulator(optimizer='adam')

    model.compile(
        optimizer=optimizer,
        loss='mse'
    )

    train_data = get_dataset(vocab_size).repeat()

    model.fit(
        train_data,
        epochs=100,
        steps_per_epoch=1000
    )


if __name__ == '__main__':
    main()

Error log

    model.fit(
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/keras/engine/training.py", line 1183, in fit
    tmp_logs = self.train_function(iterator)
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py", line 889, in __call__
    result = self._call(*args, **kwds)
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py", line 1017, in _call
    return function_lib.defun(fn_with_cond)(canon_args, canon_kwds,
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/eager/function.py", line 3023, in __call__
    return graph_function._call_flat(
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/eager/function.py", line 1960, in _call_flat
    return self._build_call_outputs(self._inference_function.call(
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/eager/function.py", line 591, in call
    outputs = execute.execute(
  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/eager/execute.py", line 59, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.CancelledError:  [_Derived_]RecvAsync is cancelled.
	 [[{{node cond/else/_1/cond/StatefulPartitionedCall/cond_1/then/_507/cond_1/GradientAccumulator/GradientAccumulator/update/update_0/mod/_400}}]]
	 [[cond/else/_1/cond/StatefulPartitionedCall/div_no_nan/ReadVariableOp_2/_174]] [Op:__inference_fn_with_cond_8652]

Function call stack:
fn_with_cond

I think this is indeed the same issue I had before I implemented it the way I did in keras-team/keras#14829

try GradientAccumulatorv2:

class GradientAccumulatorV2:
    def __init__(self, var_list, accum_steps=4):
        self._refs = {var.ref(): tf.Variable(tf.zeros_like(var),
                                             trainable=False,
                                             synchronization=tf.VariableSynchronization.ON_READ) for var in var_list}
        self._step = tf.Variable(0, dtype=tf.int64, trainable=False, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, name='ga_steps')
        self._accum_steps = accum_steps

    def __call__(self, grads_and_vars):
        grads, vars = zip(*grads_and_vars)
        self._step.assign_add(1)
        for grad, var in grads_and_vars:
            self._refs[var.ref()].assign_add(grad)

        def _true_fn():
            return [self._refs[v.ref()] for v in vars]

        def _false_fn():
            return [tf.zeros_like(v) for v in vars]

        new_grads = tf.cond(self._step.value() % self._accum_steps == 0, _true_fn, _false_fn)
        return list(zip(new_grads, vars))

optimizer=tf.keras.optimizers.Adam(gradient_transformers=[GradientAccumulatorV2(model.trainable_weights, 4)]

Hi, I think there should be a grad reset op In V2?

@fsx950223 fsx950223 closed this Mar 2, 2022
@andreped
Copy link

andreped commented May 16, 2022

Was there a reason why this PR was closed? Having a generic method for this in TF would be of great interest. I have a working solution in TF1. At least it works fine for my applications, but I have not found a stable solution for this in TF2.

Anyway, I was testing the current implementation in TF2 and got issues using a TimeDistributed layer. Anyone experienced the same?


EDIT: Nevermind. I got the same issue whether I used TimeDistributed or not. It runs fine if the GPU is disabled, but whenever enabled, I get all sorts of fun prints. I'm using TensorFlow 2.8 and Python 3.8.10. It crashes at the start of training, using TF Keras and model.fit(). A snipped of the error log I get:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot assign a device for operation model/Conv1/Conv2D/ReadVariableOp: Could not satisfy explicit device specification '' because the node {{colocation_node model/Conv1/Conv2D/ReadVariableOp}} was colocated with a group of nodes that required incompatible device '/job:localhost/replica:0/task:0/device:GPU:0'. All available devices [/job:localhost/replica:0/task:0/device:CPU:0, /job:localhost/replica:0/task:0/device:GPU:0].
Colocation Debug Info:
Colocation group had the following types and supported devices:
Root Member(assigned_device_name_index_=2 requested_device_name_='/job:localhost/replica:0/task:0/device:GPU:0' assigned_device_name_='/job:localhost/replica:0/task:0/device:GPU:0' resource_device_name_='/job:localhost/replica:0/task:0/device:GPU:0' supported_device_types_=[CPU] possible_devices_=[]
NoOp: GPU CPU
AssignVariableOp: GPU CPU
AssignAddVariableOp: GPU CPU
Merge: GPU CPU
RealDiv: GPU CPU
_Arg: GPU CPU
FloorMod: CPU
ReadVariableOp: GPU CPU
Const: GPU CPU
Sub: GPU CPU
Equal: GPU CPU
Identity: GPU CPU
AddV2: GPU CPU
Pow: GPU CPU
Cast: GPU CPU
Switch: GPU CPU
Sqrt: GPU CPU
ResourceApplyAdam: GPU CPU
Mul: GPU CPU

Colocation members, user-requested devices, and framework assigned devices, if any:
  model_conv1_conv2d_readvariableop_resource (_Arg)  framework assigned device=/job:localhost/replica:0/task:0/device:GPU:0
  gradientaccumulator_gradientaccumulator_update_assignaddvariableop_resource (_Arg)  framework assigned device=/job:localhost/replica:0/task:0/device:GPU:0
  gradientaccumulator_gradientaccumulator_update_cond_input_1 (_Arg)  framework assigned device=/job:localhost/replica:0/task:0/device:GPU:0
  gradientaccumulator_gradientaccumulator_update_cond_input_3 (_Arg)  framework assigned device=/job:localhost/replica:0/task:0/device:GPU:0
  gradientaccumulator_gradientaccumulator_update_cond_input_4 (_Arg)  framework assigned device=/job:localhost/replica:0/task:0/device:GPU:0
  gradientaccumulator_gradientaccumulator_update_cond_input_6 (_Arg)  framework assigned device=/job:localhost/replica:0/task:0/device:GPU:0
  gradientaccumulator_gradientaccumulator_update_cond_input_7 (_Arg)  framework assigned device=/job:localhost/replica:0/task:0/device:GPU:0
  model/Conv1/Conv2D/ReadVariableOp (ReadVariableOp)
  GradientAccumulator/GradientAccumulator/update/AssignAddVariableOp (AssignAddVariableOp) /job:localhost/replica:0/task:0/device:GPU:0
  GradientAccumulator/GradientAccumulator/update/add/y (Const) /job:localhost/replica:0/task:0/device:GPU:0
  GradientAccumulator/GradientAccumulator/update/add (AddV2) /job:localhost/replica:0/task:0/device:GPU:0
  GradientAccumulator/GradientAccumulator/update/mod/y (Const) /job:localhost/replica:0/task:0/device:GPU:0
  GradientAccumulator/GradientAccumulator/update/mod (FloorMod) /job:localhost/replica:0/task:0/device:GPU:0
  GradientAccumulator/GradientAccumulator/update/Equal/y (Const) /job:localhost/replica:0/task:0/device:GPU:0
[...]

@andreped
Copy link

Managed to get it working by compiling the model with run_eagerly=True. Suboptimal to have eager mode enabled during training, but at least it runs.

@andreped
Copy link

@john32979 have you used the most recent implementation on your own data and found that it is working as intended? I am not considering multi-GPU strategy yet. Just want to know if it is working as intended in the single-GPU case.

@john32979
Copy link

john32979 commented May 31, 2022 via email

@andreped
Copy link

@john32979 OK, no worries. I will run some benchmarks myself. At least it seems to work well in the single-GPU scenario.

@andreped
Copy link

andreped commented Jun 1, 2022

@fsx950223, @john32979 Just run a benchmark to see if accumulated gradients actually work as intended, and I got some surprising results. I get different results, even with setting the global seed. From looking at the training log, it seems like the weights might have been updated for every single batch, independent of the optimizer wrapper and accum_steps. Perhaps we can only enforce an update after accum_steps, but updates will still happen as usually? Not sure...

To reproduce my experiments, you can use this script:
https://github.com/andreped/GradientAccumulator/blob/main/benchmark.py


EDIT: Made some progress, but not quite ready/bug-free yet. See this thread to track the issue. Will update you if I manage to get it working properly.

@andreped
Copy link

andreped commented Jun 1, 2022

Update: I managed to get it working. I swapped the optimizer wrapping solution with a different Model train_step overloading solution, which seems to yield expected results!

I have released a version you could test, if you'd like:
https://github.com/andreped/GradientAccumulator

I will continue to try to debug the optimizer wrapper, as it was a cool idea, but if it does not lead anywhere, I will rather put my energy on the other solution, which is likely more in-line with TF2-esque. Lots of places that can be improved and generalized, but at least now it is actually simulating batch training as I wanted.

I have a longer discussion with results and whatnot here, if anyone is interested in discussing this further: andreped/GradientAccumulator#2

@Mddct
Copy link

Mddct commented Oct 11, 2022

any update on this recently?

@andreped
Copy link

andreped commented Oct 11, 2022

any update on this recently?

@Mddct If you are talking about the PR, it has been closed and there will likely not be made a new effort to add such a method to tf-addons anytime soon. However, I believe someone in Keras is currently working on adding it there, but not sure. You can track the issue here: keras-team/tf-keras#107

As it seems to be going quite slow, I went further and made my tool stable. It has been benchmarked to yield the same result as regular batch training, with the expected linear cost of degraded training runtime. As it is working, I want others to test it and get started using gradient accumulation, as it is an extremely useful strategy and have helped me in various projects to get better results with limited resources. I have therefore released it to PyPI: https://pypi.org/project/gradient-accumulator/

To use, simply:

pip install gradient-accumulator

and then:

from gradient_accumulator.GAModelWrapper import GAModelWrapper
from tensorflow.keras.models import Model

model = Model(...)
model = GAModelWrapper(accum_steps=4, inputs=model.input, outputs=model.output)

That's it! You have enabled gradient accumulation. And have artificially increased your batch size by 4 without increasing GPU memory usage!

I have also added mixed precision support, both for GPUs and TPUs, and added an alternative to batch normalisation called adaptive gradient clipping. For more information, see here: https://github.com/andreped/GradientAccumulator#usage

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

categorical accuracy doesn't account for mask (w/ proposed solution)
8 participants