Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keras 3 converts loss dtype in __init__ where keras 2 did not #20060

Closed
markomitos opened this issue Jul 29, 2024 · 4 comments
Closed

Keras 3 converts loss dtype in __init__ where keras 2 did not #20060

markomitos opened this issue Jul 29, 2024 · 4 comments

Comments

@markomitos
Copy link

When trying to create a custom loss with int64 tensor it is automatically converted to float32 because of this conversion:

self._dtype_policy = dtype_policies.get(dtype or backend.floatx())

In Keras 2 there is no such conversion:

def __init__(self, reduction=losses_utils.ReductionV2.AUTO, name=None):

This is my custom class:

class _CustomLossRequiringLabelBeInteger(keras.losses.Loss):

      def __init__(self):
        super().__init__(name='custom_loss_requiring_label_be_integer')

      def call(self, y_true, y_pred):
        # Note that this TF function requires that the label `y_true` be of an
        # integer dtype; a TypeError is thrown if `y_true` isn't int32 or int64.
        return tf.nn.sparse_softmax_cross_entropy_with_logits(y_true, y_pred)

This is the batch passed to the forward_pass:

    batch = collections.OrderedDict(
        x=tf.convert_to_tensor(np.ones((1, 2)), dtype=tf.float32),
        y=tf.convert_to_tensor([0], dtype=tf.int64),
    )
@ghsanti
Copy link
Contributor

ghsanti commented Jul 29, 2024

just another user here but doesn't that imply you can set the dtype to int64 in the super().__init__(name="...", dtype=np.int64) when you subclass the loss?

@james77777778
Copy link
Contributor

james77777778 commented Jul 29, 2024

@markomitos
This should meet your need:

import keras


class CustomLoss(keras.losses.Loss):
    def __init__(self, **kwargs):
        super().__init__(name="custom_loss", **kwargs)

    def call(self, y_true, y_pred):
        # `sparse_categorical_crossentropy` will cast `target` to `int64`
        # internally.
        return keras.ops.nn.sparse_categorical_crossentropy(
            y_true, y_pred, from_logits=True
        )


x = keras.ops.ones([1, 2])
y = keras.ops.convert_to_tensor([0], dtype="int64")

loss = CustomLoss()(y, x)
print(loss)  # tf.Tensor(0.6931472, shape=(), dtype=float32)

# Or use built-in loss
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)(y, x)
print(loss)  # tf.Tensor(0.6931472, shape=(), dtype=float32)

Indeed, Keras 3 converts y_true and y_pred using dtype, and it is not straightforward for the loss when it expects different dtypes for the inputs.

If you really want to have full control over dtype, you can override __call__ like this:

import keras
from keras.src import ops
from keras.src.losses.loss import reduce_weighted_values


class CustomFullControlLoss(keras.losses.Loss):
    def __init__(self, **kwargs):
        super().__init__(name="custom_loss", **kwargs)

    def __call__(self, y_true, y_pred, sample_weight=None):
        with ops.name_scope(self.name):
            y_pred = keras.tree.map_structure(
                lambda x: ops.convert_to_tensor(x, dtype=self.dtype), y_pred
            )
            # Cast `y_true` to int64 regardless of `self.dtype`.
            y_true = keras.tree.map_structure(
                lambda x: ops.convert_to_tensor(x, dtype="int64"), y_true
            )
            losses = self.call(y_true, y_pred)
            return reduce_weighted_values(
                losses,
                sample_weight=sample_weight,
                reduction=self.reduction,
                dtype=self.dtype,
            )

    def call(self, y_true, y_pred):
        return keras.ops.nn.sparse_categorical_crossentropy(
            y_true, y_pred, from_logits=True
        )


x = keras.ops.ones([1, 2])
y = keras.ops.convert_to_tensor([0], dtype="int64")

loss = CustomFullControlLoss()(y, x)
print(loss)  # tf.Tensor(0.6931472, shape=(), dtype=float32)

@markomitos
Copy link
Author

@james77777778 Thank you, this resolved the issue.

Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants