-
Notifications
You must be signed in to change notification settings - Fork 19.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Keras 3 converts loss dtype in __init__ where keras 2 did not #20060
Comments
just another user here but doesn't that imply you can set the |
@markomitos import keras
class CustomLoss(keras.losses.Loss):
def __init__(self, **kwargs):
super().__init__(name="custom_loss", **kwargs)
def call(self, y_true, y_pred):
# `sparse_categorical_crossentropy` will cast `target` to `int64`
# internally.
return keras.ops.nn.sparse_categorical_crossentropy(
y_true, y_pred, from_logits=True
)
x = keras.ops.ones([1, 2])
y = keras.ops.convert_to_tensor([0], dtype="int64")
loss = CustomLoss()(y, x)
print(loss) # tf.Tensor(0.6931472, shape=(), dtype=float32)
# Or use built-in loss
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)(y, x)
print(loss) # tf.Tensor(0.6931472, shape=(), dtype=float32) Indeed, Keras 3 converts If you really want to have full control over import keras
from keras.src import ops
from keras.src.losses.loss import reduce_weighted_values
class CustomFullControlLoss(keras.losses.Loss):
def __init__(self, **kwargs):
super().__init__(name="custom_loss", **kwargs)
def __call__(self, y_true, y_pred, sample_weight=None):
with ops.name_scope(self.name):
y_pred = keras.tree.map_structure(
lambda x: ops.convert_to_tensor(x, dtype=self.dtype), y_pred
)
# Cast `y_true` to int64 regardless of `self.dtype`.
y_true = keras.tree.map_structure(
lambda x: ops.convert_to_tensor(x, dtype="int64"), y_true
)
losses = self.call(y_true, y_pred)
return reduce_weighted_values(
losses,
sample_weight=sample_weight,
reduction=self.reduction,
dtype=self.dtype,
)
def call(self, y_true, y_pred):
return keras.ops.nn.sparse_categorical_crossentropy(
y_true, y_pred, from_logits=True
)
x = keras.ops.ones([1, 2])
y = keras.ops.convert_to_tensor([0], dtype="int64")
loss = CustomFullControlLoss()(y, x)
print(loss) # tf.Tensor(0.6931472, shape=(), dtype=float32) |
@james77777778 Thank you, this resolved the issue. |
When trying to create a custom loss with int64 tensor it is automatically converted to float32 because of this conversion:
keras/keras/src/losses/loss.py
Line 42 in 25391fe
In Keras 2 there is no such conversion:
keras/keras/losses.py
Line 68 in 601488f
This is my custom class:
This is the batch passed to the forward_pass:
The text was updated successfully, but these errors were encountered: