diff --git a/keras/backend.py b/keras/backend.py index 071e2e9cbc5..63e7bcd20bf 100644 --- a/keras/backend.py +++ b/keras/backend.py @@ -5574,6 +5574,103 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1): return -tf.reduce_sum(target * tf.math.log(output), axis) +@keras_export("keras.backend.categorical_focal_crossentropy") +@tf.__internal__.dispatch.add_dispatch_support +@doc_controls.do_not_generate_docs +def categorical_focal_crossentropy( + target, + output, + alpha=0.25, + gamma=2.0, + from_logits=False, + axis=-1, +): + """Computes the alpha balanced focal crossentropy loss. + + According to [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf), it + helps to apply a focal factor to down-weight easy examples and focus more on + hard examples. The general formula for the focal loss (FL) + is as follows: + + `FL(p_t) = (1 − p_t)^gamma * log(p_t)` + + where `p_t` is defined as follows: + `p_t = output if y_true == 1, else 1 - output` + + `(1 − p_t)^gamma` is the `modulating_factor`, where `gamma` is a focusing + parameter. When `gamma` = 0, there is no focal effect on the cross entropy. + `gamma` reduces the importance given to simple examples in a smooth manner. + + The authors use alpha-balanced variant of focal loss (FL) in the paper: + `FL(p_t) = −alpha * (1 − p_t)^gamma * log(p_t)` + + where `alpha` is the weight factor for the classes. If `alpha` = 1, the + loss won't be able to handle class imbalance properly as all + classes will have the same weight. This can be a constant or a list of + constants. If alpha is a list, it must have the same length as the number + of classes. + + The formula above can be generalized to: + `FL(p_t) = alpha * (1 − p_t)^gamma * CrossEntropy(target, output)` + + where minus comes from `CrossEntropy(target, output)` (CE). + + Extending this to multi-class case is straightforward: + `FL(p_t) = alpha * (1 − p_t)^gamma * CategoricalCE(target, output)` + + Args: + target: Ground truth values from the dataset. + output: Predictions of the model. + alpha: A weight balancing factor for all classes, default is `0.25` as + mentioned in the reference. It can be a list of floats or a scalar. + In the multi-class case, alpha may be set by inverse class + frequency by using `compute_class_weight` from `sklearn.utils`. + gamma: A focusing parameter, default is `2.0` as mentioned in the + reference. It helps to gradually reduce the importance given to + simple examples in a smooth manner. + from_logits: Whether `output` is expected to be a logits tensor. By + default, we consider that `output` encodes a probability + distribution. + axis: Int specifying the channels axis. `axis=-1` corresponds to data + format `channels_last`, and `axis=1` corresponds to data format + `channels_first`. + + Returns: + A tensor. + """ + target = tf.convert_to_tensor(target) + output = tf.convert_to_tensor(output) + target.shape.assert_is_compatible_with(output.shape) + + output, from_logits = _get_logits( + output, from_logits, "Softmax", "categorical_focal_crossentropy" + ) + + if from_logits: + output = tf.nn.softmax(output, axis=axis) + + # Adjust the predictions so that the probability of + # each class for every sample adds up to 1 + # This is needed to ensure that the cross entropy is + # computed correctly. + output = output / tf.reduce_sum(output, axis=axis, keepdims=True) + + epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype) + output = tf.clip_by_value(output, epsilon_, 1.0 - epsilon_) + + # Calculate cross entropy + cce = -target * tf.math.log(output) + + # Calculate factors + modulating_factor = tf.pow(1.0 - output, gamma) + weighting_factor = tf.multiply(modulating_factor, alpha) + + # Apply weighting factor + focal_cce = tf.multiply(weighting_factor, cce) + focal_cce = tf.reduce_sum(focal_cce, axis=axis) + return focal_cce + + @keras_export("keras.backend.sparse_categorical_crossentropy") @tf.__internal__.dispatch.add_dispatch_support @doc_controls.do_not_generate_docs diff --git a/keras/backend_test.py b/keras/backend_test.py index 89497676244..28384bc21de 100644 --- a/keras/backend_test.py +++ b/keras/backend_test.py @@ -2244,6 +2244,19 @@ def test_binary_focal_crossentropy_with_sigmoid(self): ) self.assertArrayNear(result[0], [7.995, 0.022, 0.701], 1e-3) + @test_combinations.generate( + test_combinations.combine(mode=["graph", "eager"]) + ) + def test_categorical_focal_crossentropy_with_softmax(self): + t = backend.constant([[0, 1, 0]]) + logits = backend.constant([[8.0, 1.0, 1.0]]) + p = backend.softmax(logits) + p = tf.identity(tf.identity(p)) + result = self.evaluate( + backend.categorical_focal_crossentropy(t, p, gamma=2.0) + ) + self.assertArrayNear(result, [1.747], 1e-3) + @test_combinations.generate( test_combinations.combine(mode=["graph", "eager"]) ) @@ -2260,6 +2273,21 @@ def test_binary_focal_crossentropy_from_logits(self): ) self.assertArrayNear(result[0], [7.995, 0.022, 0.701], 1e-3) + @test_combinations.generate( + test_combinations.combine(mode=["graph", "eager"]) + ) + def test_categorical_focal_crossentropy_from_logits(self): + t = backend.constant([[0, 1, 0]]) + logits = backend.constant([[8.0, 1.0, 1.0]]) + result = self.evaluate( + backend.categorical_focal_crossentropy( + target=t, + output=logits, + from_logits=True, + ) + ) + self.assertArrayNear(result, [1.7472], 1e-3) + @test_combinations.generate( test_combinations.combine(mode=["graph", "eager"]) ) @@ -2279,6 +2307,25 @@ def test_binary_focal_crossentropy_no_focal_effect_with_zero_gamma(self): non_focal_result = self.evaluate(backend.binary_crossentropy(t, p)) self.assertArrayNear(focal_result[0], non_focal_result[0], 1e-3) + @test_combinations.generate( + test_combinations.combine(mode=["graph", "eager"]) + ) + def test_categorical_focal_crossentropy_no_focal_effect(self): + t = backend.constant([[0, 1, 0]]) + logits = backend.constant([[8.0, 1.0, 1.0]]) + p = backend.softmax(logits) + p = tf.identity(tf.identity(p)) + focal_result = self.evaluate( + backend.categorical_focal_crossentropy( + target=t, + output=p, + gamma=0.0, + alpha=1.0, + ) + ) + non_focal_result = self.evaluate(backend.categorical_crossentropy(t, p)) + self.assertArrayNear(focal_result, non_focal_result, 1e-3) + @test_combinations.generate( test_combinations.combine(mode=["graph", "eager"]) ) diff --git a/keras/losses.py b/keras/losses.py index ebb850c4a4a..adf918a5102 100644 --- a/keras/losses.py +++ b/keras/losses.py @@ -922,6 +922,144 @@ def __init__( ) +@keras_export("keras.losses.CategoricalFocalCrossentropy") +class CategoricalFocalCrossentropy(LossFunctionWrapper): + """Computes the alpha balanced focal crossentropy loss. + + Use this crossentropy loss function when there are two or more label + classes and if you want to handle class imbalance without using + `class_weights`. We expect labels to be provided in a `one_hot` + representation. + + According to [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf), it + helps to apply a focal factor to down-weight easy examples and focus more on + hard examples. The general formula for the focal loss (FL) + is as follows: + + `FL(p_t) = (1 − p_t)^gamma * log(p_t)` + + where `p_t` is defined as follows: + `p_t = output if y_true == 1, else 1 - output` + + `(1 − p_t)^gamma` is the `modulating_factor`, where `gamma` is a focusing + parameter. When `gamma` = 0, there is no focal effect on the cross entropy. + `gamma` reduces the importance given to simple examples in a smooth manner. + + The authors use alpha-balanced variant of focal loss (FL) in the paper: + `FL(p_t) = −alpha * (1 − p_t)^gamma * log(p_t)` + + where `alpha` is the weight factor for the classes. If `alpha` = 1, the + loss won't be able to handle class imbalance properly as all + classes will have the same weight. This can be a constant or a list of + constants. If alpha is a list, it must have the same length as the number + of classes. + + The formula above can be generalized to: + `FL(p_t) = alpha * (1 − p_t)^gamma * CrossEntropy(y_true, y_pred)` + + where minus comes from `CrossEntropy(y_true, y_pred)` (CE). + + Extending this to multi-class case is straightforward: + `FL(p_t) = alpha * (1 − p_t)^gamma * CategoricalCE(y_true, y_pred)` + + In the snippet below, there is `# classes` floating pointing values per + example. The shape of both `y_pred` and `y_true` are + `[batch_size, num_classes]`. + + Standalone usage: + + >>> y_true = [[0., 1., 0.], [0., 0., 1.]] + >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]] + >>> # Using 'auto'/'sum_over_batch_size' reduction type. + >>> cce = tf.keras.losses.CategoricalFocalCrossentropy() + >>> cce(y_true, y_pred).numpy() + 0.23315276 + + >>> # Calling with 'sample_weight'. + >>> cce(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy() + 0.1632 + + >>> # Using 'sum' reduction type. + >>> cce = tf.keras.losses.CategoricalFocalCrossentropy( + ... reduction=tf.keras.losses.Reduction.SUM) + >>> cce(y_true, y_pred).numpy() + 0.46631 + + >>> # Using 'none' reduction type. + >>> cce = tf.keras.losses.CategoricalFocalCrossentropy( + ... reduction=tf.keras.losses.Reduction.NONE) + >>> cce(y_true, y_pred).numpy() + array([3.2058331e-05, 4.6627346e-01], dtype=float32) + + Usage with the `compile()` API: + ```python + model.compile(optimizer='adam', + loss=tf.keras.losses.CategoricalFocalCrossentropy()) + ``` + Args: + alpha: A weight balancing factor for all classes, default is `0.25` as + mentioned in the reference. It can be a list of floats or a scalar. + In the multi-class case, alpha may be set by inverse class + frequency by using `compute_class_weight` from `sklearn.utils`. + gamma: A focusing parameter, default is `2.0` as mentioned in the + reference. It helps to gradually reduce the importance given to + simple (easy) examples in a smooth manner. + from_logits: Whether `output` is expected to be a logits tensor. By + default, we consider that `output` encodes a probability + distribution. + label_smoothing: Float in [0, 1]. When > 0, label values are smoothed, + meaning the confidence on label values are relaxed. For example, if + `0.1`, use `0.1 / num_classes` for non-target labels and + `0.9 + 0.1 / num_classes` for target labels. + axis: The axis along which to compute crossentropy (the features + axis). Defaults to -1. + reduction: Type of `tf.keras.losses.Reduction` to apply to + loss. Default value is `AUTO`. `AUTO` indicates that the reduction + option will be determined by the usage context. For almost all cases + this defaults to `SUM_OVER_BATCH_SIZE`. When used under a + `tf.distribute.Strategy`, except via `Model.compile()` and + `Model.fit()`, using `AUTO` or `SUM_OVER_BATCH_SIZE` + will raise an error. Please see this custom training [tutorial]( + https://www.tensorflow.org/tutorials/distribute/custom_training) + for more details. + name: Optional name for the instance. + Defaults to 'categorical_focal_crossentropy'. + """ + + def __init__( + self, + alpha=0.25, + gamma=2.0, + from_logits=False, + label_smoothing=0.0, + axis=-1, + reduction=losses_utils.ReductionV2.AUTO, + name="categorical_focal_crossentropy", + ): + """Initializes `CategoricalFocalCrossentropy` instance.""" + super().__init__( + categorical_focal_crossentropy, + alpha=alpha, + gamma=gamma, + name=name, + reduction=reduction, + from_logits=from_logits, + label_smoothing=label_smoothing, + axis=axis, + ) + self.from_logits = from_logits + self.alpha = alpha + self.gamma = gamma + + def get_config(self): + config = { + "alpha": self.alpha, + "gamma": self.gamma, + } + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) + + @keras_export("keras.losses.SparseCategoricalCrossentropy") class SparseCategoricalCrossentropy(LossFunctionWrapper): """Computes the crossentropy loss between the labels and predictions. @@ -2025,6 +2163,144 @@ def _ragged_tensor_categorical_crossentropy( return _ragged_tensor_apply_loss(fn, y_true, y_pred) +@keras_export( + "keras.metrics.categorical_focal_crossentropy", + "keras.losses.categorical_focal_crossentropy", +) +@tf.__internal__.dispatch.add_dispatch_support +def categorical_focal_crossentropy( + y_true, + y_pred, + alpha=0.25, + gamma=2.0, + from_logits=False, + label_smoothing=0.0, + axis=-1, +): + """Computes the categorical focal crossentropy loss. + + Standalone usage: + >>> y_true = [[0, 1, 0], [0, 0, 1]] + >>> y_pred = [[0.05, 0.9, 0.05], [0.1, 0.85, 0.05]] + >>> loss = tf.keras.losses.categorical_focal_crossentropy(y_true, y_pred) + >>> assert loss.shape == (2,) + >>> loss.numpy() + array([2.63401289e-04, 6.75912094e-01], dtype=float32) + + Args: + y_true: Tensor of one-hot true targets. + y_pred: Tensor of predicted targets. + alpha: A weight balancing factor for all classes, default is `0.25` as + mentioned in the reference. It can be a list of floats or a scalar. + In the multi-class case, alpha may be set by inverse class + frequency by using `compute_class_weight` from `sklearn.utils`. + gamma: A focusing parameter, default is `2.0` as mentioned in the + reference. It helps to gradually reduce the importance given to + simple examples in a smooth manner. When `gamma` = 0, there is + no focal effect on the categorical crossentropy. + from_logits: Whether `y_pred` is expected to be a logits tensor. By + default, we assume that `y_pred` encodes a probability + distribution. + label_smoothing: Float in [0, 1]. If > `0` then smooth the labels. For + example, if `0.1`, use `0.1 / num_classes` for non-target labels + and `0.9 + 0.1 / num_classes` for target labels. + axis: Defaults to -1. The dimension along which the entropy is + computed. + + Returns: + Categorical focal crossentropy loss value. + """ + if isinstance(axis, bool): + raise ValueError( + "`axis` must be of type `int`. " + f"Received: axis={axis} of type {type(axis)}" + ) + y_pred = tf.convert_to_tensor(y_pred) + y_true = tf.cast(y_true, y_pred.dtype) + label_smoothing = tf.convert_to_tensor(label_smoothing, dtype=y_pred.dtype) + + if y_pred.shape[-1] == 1: + warnings.warn( + "In loss categorical_focal_crossentropy, expected " + "y_pred.shape to be (batch_size, num_classes) " + f"with num_classes > 1. Received: y_pred.shape={y_pred.shape}. " + "Consider using 'binary_crossentropy' if you only have 2 classes.", + SyntaxWarning, + stacklevel=2, + ) + + def _smooth_labels(): + num_classes = tf.cast(tf.shape(y_true)[-1], y_pred.dtype) + return y_true * (1.0 - label_smoothing) + ( + label_smoothing / num_classes + ) + + y_true = tf.__internal__.smart_cond.smart_cond( + label_smoothing, _smooth_labels, lambda: y_true + ) + + return backend.categorical_focal_crossentropy( + target=y_true, + output=y_pred, + alpha=alpha, + gamma=gamma, + from_logits=from_logits, + axis=axis, + ) + + +@dispatch.dispatch_for_types(categorical_focal_crossentropy, tf.RaggedTensor) +def _ragged_tensor_categorical_focal_crossentropy( + y_true, + y_pred, + alpha=0.25, + gamma=2.0, + from_logits=False, + label_smoothing=0.0, + axis=-1, +): + """Implements support for handling RaggedTensors. + + Expected shape: (batch, sequence_len, n_classes) with sequence_len + being variable per batch. + Return shape: (batch, sequence_len). + When used by CategoricalFocalCrossentropy() with the default reduction + (SUM_OVER_BATCH_SIZE), the reduction averages the loss over the + number of elements independent of the batch. E.g. if the RaggedTensor + has 2 batches with [2, 1] values respectively the resulting loss is + the sum of the individual loss values divided by 3. + + Args: + alpha: A weight balancing factor for all classes, default is `0.25` as + mentioned in the reference. It can be a list of floats or a scalar. + In the multi-class case, alpha may be set by inverse class + frequency by using `compute_class_weight` from `sklearn.utils`. + gamma: A focusing parameter, default is `2.0` as mentioned in the + reference. It helps to gradually reduce the importance given to + simple examples in a smooth manner. When `gamma` = 0, there is + no focal effect on the categorical crossentropy. + from_logits: Whether `y_pred` is expected to be a logits tensor. By + default, we assume that `y_pred` encodes a probability distribution. + label_smoothing: Float in [0, 1]. If > `0` then smooth the labels. For + example, if `0.1`, use `0.1 / num_classes` for non-target labels + and `0.9 + 0.1 / num_classes` for target labels. + axis: Defaults to -1. The dimension along which the entropy is + computed. + + Returns: + Categorical focal crossentropy loss value. + """ + fn = functools.partial( + categorical_focal_crossentropy, + alpha=alpha, + gamma=gamma, + from_logits=from_logits, + label_smoothing=label_smoothing, + axis=axis, + ) + return _ragged_tensor_apply_loss(fn, y_true, y_pred) + + @keras_export( "keras.metrics.sparse_categorical_crossentropy", "keras.losses.sparse_categorical_crossentropy", diff --git a/keras/losses_test.py b/keras/losses_test.py index b7e1b523b5b..9700f1ed280 100644 --- a/keras/losses_test.py +++ b/keras/losses_test.py @@ -1810,6 +1810,207 @@ def test_binary_labels(self): ) +@test_combinations.generate(test_combinations.combine(mode=["graph", "eager"])) +class CategoricalFocalCrossentropyTest(tf.test.TestCase): + def test_config(self): + + cce_obj = losses.CategoricalFocalCrossentropy( + name="focal_cce", + reduction=losses_utils.ReductionV2.SUM, + alpha=0.25, + gamma=2.0, + ) + self.assertEqual(cce_obj.name, "focal_cce") + self.assertEqual(cce_obj.reduction, losses_utils.ReductionV2.SUM) + self.assertEqual(cce_obj.alpha, 0.25) + self.assertEqual(cce_obj.gamma, 2.0) + + # Test alpha as a list + cce_obj = losses.CategoricalFocalCrossentropy(alpha=[0.25, 0.5, 0.75]) + self.assertEqual(cce_obj.alpha, [0.25, 0.5, 0.75]) + + def test_all_correct_unweighted(self): + y_true = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=tf.int64) + y_pred = tf.constant( + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], + dtype=tf.float32, + ) + cce_obj = losses.CategoricalFocalCrossentropy(alpha=0.25, gamma=2.0) + loss = cce_obj(y_true, y_pred) + self.assertAlmostEqual(self.evaluate(loss), 0.0, 3) + + # Test with logits. + logits = tf.constant( + [[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]] + ) + cce_obj = losses.CategoricalFocalCrossentropy(from_logits=True) + loss = cce_obj(y_true, logits) + self.assertAlmostEqual(self.evaluate(loss), 0.0, 3) + + def test_unweighted(self): + cce_obj = losses.CategoricalFocalCrossentropy() + y_true = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + y_pred = tf.constant( + [[0.9, 0.05, 0.05], [0.5, 0.89, 0.6], [0.05, 0.01, 0.94]], + dtype=tf.float32, + ) + loss = cce_obj(y_true, y_pred) + self.assertAlmostEqual(self.evaluate(loss), 0.02059, 3) + + # Test with logits. + logits = tf.constant( + [[8.0, 1.0, 1.0], [0.0, 9.0, 1.0], [2.0, 3.0, 5.0]] + ) + cce_obj = losses.CategoricalFocalCrossentropy(from_logits=True) + loss = cce_obj(y_true, logits) + self.assertAlmostEqual(self.evaluate(loss), 0.000345, 3) + + def test_scalar_weighted(self): + cce_obj = losses.CategoricalFocalCrossentropy() + y_true = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + y_pred = tf.constant( + [[0.9, 0.05, 0.05], [0.5, 0.89, 0.6], [0.05, 0.01, 0.94]], + dtype=tf.float32, + ) + loss = cce_obj(y_true, y_pred, sample_weight=2.3) + self.assertAlmostEqual(self.evaluate(loss), 0.047368, 3) + + # Test with logits. + logits = tf.constant( + [[8.0, 1.0, 1.0], [0.0, 9.0, 1.0], [2.0, 3.0, 5.0]] + ) + cce_obj = losses.CategoricalFocalCrossentropy(from_logits=True) + loss = cce_obj(y_true, logits, sample_weight=2.3) + self.assertAlmostEqual(self.evaluate(loss), 0.000794, 4) + + def test_sample_weighted(self): + cce_obj = losses.CategoricalFocalCrossentropy() + y_true = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + y_pred = tf.constant( + [[0.9, 0.05, 0.05], [0.5, 0.89, 0.6], [0.05, 0.01, 0.94]], + dtype=tf.float32, + ) + sample_weight = tf.constant([[1.2], [3.4], [5.6]], shape=(3, 1)) + loss = cce_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAlmostEqual(self.evaluate(loss), 0.06987, 3) + + # Test with logits. + logits = tf.constant( + [[8.0, 1.0, 1.0], [0.0, 9.0, 1.0], [2.0, 3.0, 5.0]] + ) + cce_obj = losses.CategoricalFocalCrossentropy(from_logits=True) + loss = cce_obj(y_true, logits, sample_weight=sample_weight) + self.assertAlmostEqual(self.evaluate(loss), 0.001933, 3) + + def test_no_reduction(self): + y_true = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + logits = tf.constant( + [[8.0, 1.0, 1.0], [0.0, 9.0, 1.0], [2.0, 3.0, 5.0]] + ) + cce_obj = losses.CategoricalFocalCrossentropy( + from_logits=True, reduction=losses_utils.ReductionV2.NONE + ) + loss = cce_obj(y_true, logits) + self.assertAllClose( + (1.5096224e-09, 2.4136547e-11, 1.0360638e-03), + self.evaluate(loss), + 3, + ) + + def test_label_smoothing(self): + logits = tf.constant([[4.9, -0.5, 2.05]]) + y_true = tf.constant([[1, 0, 0]]) + label_smoothing = 0.1 + + cce_obj = losses.CategoricalFocalCrossentropy( + from_logits=True, label_smoothing=label_smoothing + ) + loss = cce_obj(y_true, logits) + + expected_value = 0.06685 + self.assertAlmostEqual(self.evaluate(loss), expected_value, 3) + + def test_label_smoothing_ndarray(self): + logits = np.asarray([[4.9, -0.5, 2.05]]) + y_true = np.asarray([[1, 0, 0]]) + label_smoothing = 0.1 + + cce_obj = losses.CategoricalFocalCrossentropy( + from_logits=True, label_smoothing=label_smoothing + ) + loss = cce_obj(y_true, logits) + + expected_value = 0.06685 + self.assertAlmostEqual(self.evaluate(loss), expected_value, 3) + + def test_shape_mismatch(self): + y_true = tf.constant([[0], [1], [2]]) + y_pred = tf.constant( + [[0.9, 0.05, 0.05], [0.5, 0.89, 0.6], [0.05, 0.01, 0.94]] + ) + + cce_obj = losses.CategoricalFocalCrossentropy() + with self.assertRaisesRegex(ValueError, "Shapes .+ are incompatible"): + cce_obj(y_true, y_pred) + + def test_ragged_tensors(self): + cce_obj = losses.CategoricalFocalCrossentropy() + y_true = tf.ragged.constant([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1]]]) + y_pred = tf.ragged.constant( + [[[0.9, 0.05, 0.05], [0.5, 0.89, 0.6]], [[0.05, 0.01, 0.94]]], + dtype=tf.float32, + ) + # batch losses [[0.1054, 0.8047], [0.0619]] + sample_weight = tf.constant([[1.2], [3.4]], shape=(2, 1)) + loss = cce_obj(y_true, y_pred, sample_weight=sample_weight) + + self.assertAlmostEqual(self.evaluate(loss), 0.024754, 3) + + # Test with logits. + logits = tf.ragged.constant( + [[[8.0, 1.0, 1.0], [0.0, 9.0, 1.0]], [[2.0, 3.0, 5.0]]] + ) + cce_obj = losses.CategoricalFocalCrossentropy(from_logits=True) + + loss = cce_obj(y_true, logits, sample_weight=sample_weight) + self.assertAlmostEqual(self.evaluate(loss), 0.00117, 3) + + def test_ragged_tensors_ragged_sample_weights(self): + cce_obj = losses.CategoricalFocalCrossentropy() + y_true = tf.ragged.constant([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1]]]) + y_pred = tf.ragged.constant( + [[[0.9, 0.05, 0.05], [0.05, 0.89, 0.06]], [[0.05, 0.01, 0.94]]], + dtype=tf.float32, + ) + sample_weight = tf.ragged.constant( + [[1.2, 3.4], [5.6]], dtype=tf.float32 + ) + loss = cce_obj(y_true, y_pred, sample_weight=sample_weight) + self.assertAlmostEqual(self.evaluate(loss), 0.0006088, 4) + + # Test with logits. + logits = tf.ragged.constant( + [[[8.0, 1.0, 1.0], [0.0, 9.0, 1.0]], [[2.0, 3.0, 5.0]]] + ) + cce_obj = losses.CategoricalFocalCrossentropy(from_logits=True) + + loss = cce_obj(y_true, logits, sample_weight=sample_weight) + self.assertAlmostEqual(self.evaluate(loss), 0.001933, 3) + + def test_binary_labels(self): + # raise a warning if the shape of y_true and y_pred are all (None, 1). + # categorical_crossentropy shouldn't be used with binary labels. + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + cce_obj = losses.CategoricalFocalCrossentropy() + cce_obj(tf.constant([[1.0], [0.0]]), tf.constant([[1.0], [1.0]])) + self.assertIs(w[-1].category, SyntaxWarning) + self.assertIn( + "In loss categorical_focal_crossentropy, expected ", + str(w[-1].message), + ) + + @test_combinations.generate(test_combinations.combine(mode=["graph", "eager"])) class SparseCategoricalCrossentropyTest(tf.test.TestCase): def test_config(self):