Skip to content
This repository was archived by the owner on Mar 11, 2026. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorflow_addons/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from tensorflow_addons.optimizers.lamb import LAMB
from tensorflow_addons.optimizers.lazy_adam import LazyAdam
from tensorflow_addons.optimizers.lookahead import Lookahead
from tensorflow_addons.optimizers.gradient_accumulator import GradientAccumulator
from tensorflow_addons.optimizers.moving_average import MovingAverage
from tensorflow_addons.optimizers.novograd import NovoGrad
from tensorflow_addons.optimizers.proximal_adagrad import ProximalAdagrad
Expand Down
231 changes: 231 additions & 0 deletions tensorflow_addons/optimizers/gradient_accumulator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
from tensorflow_addons.utils import types
from typeguard import typechecked


@tf.keras.utils.register_keras_serializable(package="Addons")
class GradientAccumulator(tf.keras.optimizers.Optimizer):
"""Optimizer wrapper for gradient accumulation."""

@typechecked
def __init__(
self,
inner_optimizer: types.Optimizer,
accum_steps: types.TensorLike = 4,
reduction: str = "SUM",
name: str = "GradientAccumulator",
**kwargs,
):
r"""Construct a new GradientAccumulator optimizer.

Args:
inner_optimizer: str or `tf.keras.optimizers.Optimizer` that will be
used to compute and apply gradients.
accum_steps: int > 0. Update gradient in every accumulation steps.
reduction: str, Reduction method ['SUM', 'MEAN']
name: Optional name for the operations created when applying
gradients. Defaults to "GradientAccumulator".
**kwargs: keyword arguments. Allowed to be {`clipnorm`,
`clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
norm; `clipvalue` is clip gradients by value, `decay` is
included for backward compatibility to allow time inverse
decay of learning rate. `lr` is included for backward
compatibility, recommended to use `learning_rate` instead.
"""
super().__init__(name, **kwargs)
self._optimizer = tf.keras.optimizers.get(inner_optimizer)
self._step = None
self._accum_steps = accum_steps
self._reduction = reduction

def _accum_grad(grads_and_vars):
new_grads_and_vars = []
for grad, var in grads_and_vars:
handle = self.get_slot(var, "ga")

if isinstance(grad, tf.IndexedSlices):
handle.scatter_add(grad)

def _get_grad():
new_grad = handle.read_value()
if self._reduction == "MEAN":
new_grad /= tf.cast(self._accum_steps, new_grad.dtype)
indices = tf.squeeze(
tf.where(
tf.reduce_sum(
new_grad, axis=list(range(len(new_grad.shape))[1:])
)
!= 0
),
axis=-1,
)

values = tf.gather(new_grad, indices)
dense_shape = tf.constant(new_grad.shape.as_list())
handle.assign(
tf.zeros_like(handle),
use_locking=self._use_locking,
read_value=False,
)
return values, tf.cast(indices, grad.indices.dtype), dense_shape

values, indices, dense_shape = tf.cond(
self.step % self._accum_steps == 0,
_get_grad,
lambda: (
tf.zeros_like(grad.values),
grad.indices,
grad.dense_shape,
),
)
new_grad = tf.IndexedSlices(values, indices, dense_shape)
new_grads_and_vars.append((new_grad, var))
else:
handle.assign_add(
grad, use_locking=self._use_locking, read_value=False
)

def _get_grad():
new_grad = handle.read_value()
if self._reduction == "MEAN":
new_grad /= tf.cast(self._accum_steps, new_grad.dtype)
handle.assign(
tf.zeros_like(handle),
use_locking=self._use_locking,
read_value=False,
)
return new_grad

new_grad = tf.cond(
self.step % self._accum_steps == 0,
_get_grad,
lambda: tf.zeros_like(grad),
)
new_grads_and_vars.append((new_grad, var))
return new_grads_and_vars

self.gradient_transformers.append(_accum_grad)
self._iterations = self._optimizer.iterations

def _create_slots(self, var_list):
self._optimizer._create_slots(var_list=var_list)
for var in var_list:
self.add_slot(var, "ga")

def _resource_apply_dense(self, grad, handle, apply_state):
if "apply_state" in self._optimizer._dense_apply_args:
return self.inner_optimizer._resource_apply_dense(grad, handle, apply_state)
else:
return self.inner_optimizer._resource_apply_dense(grad, handle)

def _resource_apply_sparse(self, grad, handle, indices, apply_state):
if "apply_state" in self._optimizer._sparse_apply_args:
return self.inner_optimizer._resource_apply_sparse(
grad, handle, indices, apply_state=apply_state
)
else:
return self.inner_optimizer._resource_apply_sparse(grad, handle, indices)

def _resource_apply_sparse_duplicate_indices(
self, grad, handle, indices, apply_state=None
):
if "apply_state" in self._optimizer._sparse_apply_args:
return self.inner_optimizer._resource_apply_sparse_duplicate_indices(
grad, handle, indices, apply_state=apply_state
)
else:
return self.inner_optimizer._resource_apply_sparse_duplicate_indices(
grad, handle, indices
)

@property
def step(self):
"""Variable. The number of training steps this Optimizer has run."""
if self._step is None:
with self._distribution_strategy_scope():
self._step = self.add_weight(
"iter",
shape=[],
dtype=tf.int64,
trainable=False,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
)
self._weights.append(self._step)
return self._step

@step.setter
def step(self, variable):
if self._step is not None:
raise RuntimeError(
"Cannot set `step` to a new Variable after "
"the Optimizer weights have been created"
)
self._step = variable
self._weights.append(self._step)

def apply_gradients(self, grads_and_vars, name=None, **kwargs):
with tf.control_dependencies([self.step.assign_add(1, read_value=False)]):
train_op = super().apply_gradients(grads_and_vars, name, **kwargs)
with tf.control_dependencies([train_op]):
return self.iterations.assign_sub(
tf.cast(self.step % self._accum_steps != 0, tf.int64),
read_value=False,
)

@property
def inner_optimizer(self):
"""The optimizer that this LossScaleOptimizer is wrapping."""
return self._optimizer

@property
def iterations(self):
return self._optimizer.iterations

@iterations.setter
def iterations(self, variable):
self._optimizer.iterations = variable

@property
def lr(self):
return self._optimizer._get_hyper("learning_rate")

@lr.setter
def lr(self, lr):
self._optimizer._set_hyper("learning_rate", lr) #

@property
def learning_rate(self):
return self._optimizer._get_hyper("learning_rate")

@learning_rate.setter
def learning_rate(self, learning_rate):
self._optimizer._set_hyper("learning_rate", learning_rate)

def get_config(self):
config = {
"accum_steps": self._accum_steps,
"optimizer": tf.keras.optimizers.serialize(self._optimizer),
}
base_config = super().get_config()
return {**base_config, **config}

@classmethod
def from_config(cls, config, custom_objects=None):
optimizer = tf.keras.optimizers.deserialize(
config.pop("optimizer"), custom_objects=custom_objects
)
return cls(optimizer, **config)
159 changes: 159 additions & 0 deletions tensorflow_addons/optimizers/tests/gradient_accumulator_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for GradientAccumulator optimizers."""

import numpy as np
import pytest
import tensorflow as tf

from tensorflow_addons.optimizers import GradientAccumulator


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
@pytest.mark.with_device(["cpu", "gpu", tf.distribute.MirroredStrategy])
def test_run():
var0 = tf.Variable([1.0, 2.0])
var1 = tf.Variable([3.0, 4.0])
accum_steps = 4

grads0 = tf.constant([0.1, 0.1])
grads1 = tf.constant([0.01, 0.01])

grads_and_vars = list(zip([grads0, grads1], [var0, var1]))

opt = GradientAccumulator(tf.keras.optimizers.SGD(lr=1.0), accum_steps)

strategy = tf.distribute.get_strategy()
for _ in range(accum_steps + 1):
strategy.run(opt.apply_gradients, [grads_and_vars])

np.testing.assert_allclose(var0.read_value(), [0.6, 1.6])
np.testing.assert_allclose(var1.read_value(), [2.96, 3.96])
np.testing.assert_allclose(opt.iterations.read_value(), 1)
np.testing.assert_allclose(opt.step.read_value(), accum_steps + 1)


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
@pytest.mark.with_device(["cpu", "gpu", tf.distribute.MirroredStrategy])
def test_sparse():
var0 = tf.Variable([[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]])
var1 = tf.Variable([[3.0, 4.0, 0.0]])

grads0 = tf.IndexedSlices(
tf.constant([[0.1, 0.1, 0.0]]),
tf.constant([1]),
tf.constant([1, 3]),
)
grads1 = tf.IndexedSlices(
tf.constant([[0.01, 0.01, 0.0]]),
tf.constant([0]),
tf.constant([1, 3]),
)

grads_and_vars = list(zip([grads0, grads1], [var0, var1]))
opt = GradientAccumulator(tf.keras.optimizers.SGD(lr=1.0))
strategy = tf.distribute.get_strategy()
for _ in range(8):
strategy.run(opt.apply_gradients, [grads_and_vars])
np.testing.assert_allclose(var0.read_value(), [[1.0, 2.0, 0.0], [0.2, 1.2, 0.0]])
np.testing.assert_allclose(var1.read_value(), [[2.92, 3.92, 0.0]])


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
def test_dense():
grad = tf.Variable([[0.1]])
model = tf.keras.Sequential(
[
tf.keras.layers.Dense(
1,
kernel_initializer=tf.keras.initializers.Constant([[1.0]]),
use_bias=False,
)
]
)
model.build(input_shape=[1, 1])

opt = GradientAccumulator(tf.keras.optimizers.SGD(lr=2.0), accum_steps=2)
_ = opt.apply_gradients(list(zip([grad], model.variables)))
np.testing.assert_allclose(model.variables[0].read_value(), [[1.0]])


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
def test_optimizer_string():
_ = GradientAccumulator("adam")


def test_config():
sgd_opt = tf.keras.optimizers.SGD(lr=2.0, nesterov=True, momentum=0.3, decay=0.1)
accum_steps = 4
opt = GradientAccumulator(sgd_opt, accum_steps=accum_steps)
config = opt.get_config()

assert config["accum_steps"] == accum_steps

new_opt = GradientAccumulator.from_config(config)
old_sgd_config = opt._optimizer.get_config()
new_sgd_config = new_opt._optimizer.get_config()

for k1, k2 in zip(old_sgd_config, new_sgd_config):
assert old_sgd_config[k1] == new_sgd_config[k2]


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
@pytest.mark.with_device([tf.distribute.MirroredStrategy])
def test_fit_simple_linear_model():
seed = 0x2019
np.random.seed(seed)
tf.random.set_seed(seed)
num_examples = 5000
x = np.random.standard_normal((num_examples, 3))
w = np.random.standard_normal((3, 1))
y = np.dot(x, w) + np.random.standard_normal((num_examples, 1)) * 1e-4

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(input_shape=(3,), units=1))

opt = GradientAccumulator("sgd")
model.compile(opt, loss="mse")

model.fit(x, y, epochs=5)

x = np.random.standard_normal((100, 3))
y = np.dot(x, w)

predicted = model.predict(x)

max_abs_diff = np.max(np.abs(predicted - y))
assert max_abs_diff < 5e-3


def test_serialization():
sgd_opt = tf.keras.optimizers.SGD(lr=2.0, nesterov=True, momentum=0.3, decay=0.1)
optimizer = GradientAccumulator(sgd_opt)
config = tf.keras.optimizers.serialize(optimizer)
new_optimizer = tf.keras.optimizers.deserialize(config)
assert new_optimizer.get_config() == optimizer.get_config()


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
@pytest.mark.usefixtures("run_with_mixed_precision_policy")
def test_model_mixed_precision():
x = np.random.standard_normal((10000, 3))
w = np.random.standard_normal((3, 1))
y = np.dot(x, w) + np.random.standard_normal((10000, 1)) * 1e-4
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(input_shape=(3,), units=1))
model.compile(GradientAccumulator("sgd"), loss="mse")
model.fit(x, y, epochs=3)
Loading