Skip to content

Commit

Permalink
PR #17651: Add CategoricalFocalCrossentropy to Losses API
Browse files Browse the repository at this point in the history
Imported from GitHub PR #17651

Implements the `CategoricalFocalCrossentropy()` loss based on the paper [Focal Loss for Dense Object Detection](https://arxiv.org/pdf/1708.02002.pdf) (Lin et al., 2018).

Feature request was made in #17583.
Copybara import of the project:

--
5696b5a by Kaan Bıçakcı <kaan.dvlpr@gmail.com>:

Add pure logic of CFCE

--
40e547f by Kaan Bıçakcı <kaan.dvlpr@gmail.com>:

Add support for ragged tensors

--
d3dd32f by Kaan Bıçakcı <kaan.dvlpr@gmail.com>:

Make sure output sum equals 1

--
16adf85 by Kaan Bıçakcı <kaan.dvlpr@gmail.com>:

Raise shape mismatch / update tests

--
bc38e33 by Kaan Bıçakcı <kaan.dvlpr@gmail.com>:

Add categorical_focal_loss tests

--
363baaf by Kaan Bıçakcı <kaan.dvlpr@gmail.com>:

Add documentation / minor fix.

--
c267fa0 by Kaan Bıçakcı <kaan.dvlpr@gmail.com>:

Reformatting after focal loss implementation

--
3c33117 by Kaan Bıçakcı <kaan.dvlpr@gmail.com>:

Fix linting.

--
3538622 by Kaan Bıçakcı <kaan.dvlpr@gmail.com>:

Fix docstring style.

--
6b4fa6b by Kaan Bıçakcı <kaan.dvlpr@gmail.com>:

Update the docstrings.

--
49c03a2 by Kaan Bıçakcı <kaan.dvlpr@gmail.com>:

Fix linting issues

--
f560336 by Kaan Bıçakcı <kaan.dvlpr@gmail.com>:

Address comments from code-review.

Merging this change closes #17651

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17651 from Frightera:frightera_categorical_focal_loss_v2 f560336
PiperOrigin-RevId: 518880861
  • Loading branch information
tensorflower-gardener committed Mar 27, 2023
1 parent 1a8c6c1 commit 9f54275
Show file tree
Hide file tree
Showing 4 changed files with 621 additions and 0 deletions.
97 changes: 97 additions & 0 deletions keras/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5574,6 +5574,103 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1):
return -tf.reduce_sum(target * tf.math.log(output), axis)


@keras_export("keras.backend.categorical_focal_crossentropy")
@tf.__internal__.dispatch.add_dispatch_support
@doc_controls.do_not_generate_docs
def categorical_focal_crossentropy(
target,
output,
alpha=0.25,
gamma=2.0,
from_logits=False,
axis=-1,
):
"""Computes the alpha balanced focal crossentropy loss.
According to [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf), it
helps to apply a focal factor to down-weight easy examples and focus more on
hard examples. The general formula for the focal loss (FL)
is as follows:
`FL(p_t) = (1 − p_t)^gamma * log(p_t)`
where `p_t` is defined as follows:
`p_t = output if y_true == 1, else 1 - output`
`(1 − p_t)^gamma` is the `modulating_factor`, where `gamma` is a focusing
parameter. When `gamma` = 0, there is no focal effect on the cross entropy.
`gamma` reduces the importance given to simple examples in a smooth manner.
The authors use alpha-balanced variant of focal loss (FL) in the paper:
`FL(p_t) = −alpha * (1 − p_t)^gamma * log(p_t)`
where `alpha` is the weight factor for the classes. If `alpha` = 1, the
loss won't be able to handle class imbalance properly as all
classes will have the same weight. This can be a constant or a list of
constants. If alpha is a list, it must have the same length as the number
of classes.
The formula above can be generalized to:
`FL(p_t) = alpha * (1 − p_t)^gamma * CrossEntropy(target, output)`
where minus comes from `CrossEntropy(target, output)` (CE).
Extending this to multi-class case is straightforward:
`FL(p_t) = alpha * (1 − p_t)^gamma * CategoricalCE(target, output)`
Args:
target: Ground truth values from the dataset.
output: Predictions of the model.
alpha: A weight balancing factor for all classes, default is `0.25` as
mentioned in the reference. It can be a list of floats or a scalar.
In the multi-class case, alpha may be set by inverse class
frequency by using `compute_class_weight` from `sklearn.utils`.
gamma: A focusing parameter, default is `2.0` as mentioned in the
reference. It helps to gradually reduce the importance given to
simple examples in a smooth manner.
from_logits: Whether `output` is expected to be a logits tensor. By
default, we consider that `output` encodes a probability
distribution.
axis: Int specifying the channels axis. `axis=-1` corresponds to data
format `channels_last`, and `axis=1` corresponds to data format
`channels_first`.
Returns:
A tensor.
"""
target = tf.convert_to_tensor(target)
output = tf.convert_to_tensor(output)
target.shape.assert_is_compatible_with(output.shape)

output, from_logits = _get_logits(
output, from_logits, "Softmax", "categorical_focal_crossentropy"
)

if from_logits:
output = tf.nn.softmax(output, axis=axis)

# Adjust the predictions so that the probability of
# each class for every sample adds up to 1
# This is needed to ensure that the cross entropy is
# computed correctly.
output = output / tf.reduce_sum(output, axis=axis, keepdims=True)

epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
output = tf.clip_by_value(output, epsilon_, 1.0 - epsilon_)

# Calculate cross entropy
cce = -target * tf.math.log(output)

# Calculate factors
modulating_factor = tf.pow(1.0 - output, gamma)
weighting_factor = tf.multiply(modulating_factor, alpha)

# Apply weighting factor
focal_cce = tf.multiply(weighting_factor, cce)
focal_cce = tf.reduce_sum(focal_cce, axis=axis)
return focal_cce


@keras_export("keras.backend.sparse_categorical_crossentropy")
@tf.__internal__.dispatch.add_dispatch_support
@doc_controls.do_not_generate_docs
Expand Down
47 changes: 47 additions & 0 deletions keras/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2244,6 +2244,19 @@ def test_binary_focal_crossentropy_with_sigmoid(self):
)
self.assertArrayNear(result[0], [7.995, 0.022, 0.701], 1e-3)

@test_combinations.generate(
test_combinations.combine(mode=["graph", "eager"])
)
def test_categorical_focal_crossentropy_with_softmax(self):
t = backend.constant([[0, 1, 0]])
logits = backend.constant([[8.0, 1.0, 1.0]])
p = backend.softmax(logits)
p = tf.identity(tf.identity(p))
result = self.evaluate(
backend.categorical_focal_crossentropy(t, p, gamma=2.0)
)
self.assertArrayNear(result, [1.747], 1e-3)

@test_combinations.generate(
test_combinations.combine(mode=["graph", "eager"])
)
Expand All @@ -2260,6 +2273,21 @@ def test_binary_focal_crossentropy_from_logits(self):
)
self.assertArrayNear(result[0], [7.995, 0.022, 0.701], 1e-3)

@test_combinations.generate(
test_combinations.combine(mode=["graph", "eager"])
)
def test_categorical_focal_crossentropy_from_logits(self):
t = backend.constant([[0, 1, 0]])
logits = backend.constant([[8.0, 1.0, 1.0]])
result = self.evaluate(
backend.categorical_focal_crossentropy(
target=t,
output=logits,
from_logits=True,
)
)
self.assertArrayNear(result, [1.7472], 1e-3)

@test_combinations.generate(
test_combinations.combine(mode=["graph", "eager"])
)
Expand All @@ -2279,6 +2307,25 @@ def test_binary_focal_crossentropy_no_focal_effect_with_zero_gamma(self):
non_focal_result = self.evaluate(backend.binary_crossentropy(t, p))
self.assertArrayNear(focal_result[0], non_focal_result[0], 1e-3)

@test_combinations.generate(
test_combinations.combine(mode=["graph", "eager"])
)
def test_categorical_focal_crossentropy_no_focal_effect(self):
t = backend.constant([[0, 1, 0]])
logits = backend.constant([[8.0, 1.0, 1.0]])
p = backend.softmax(logits)
p = tf.identity(tf.identity(p))
focal_result = self.evaluate(
backend.categorical_focal_crossentropy(
target=t,
output=p,
gamma=0.0,
alpha=1.0,
)
)
non_focal_result = self.evaluate(backend.categorical_crossentropy(t, p))
self.assertArrayNear(focal_result, non_focal_result, 1e-3)

@test_combinations.generate(
test_combinations.combine(mode=["graph", "eager"])
)
Expand Down
Loading

0 comments on commit 9f54275

Please sign in to comment.