יש שאלה? התחבר לקהילה בפורום הביקור של TensorFlow

tf.keras.losses.Loss

Loss base class.

Used in the notebooks

Used in the guide Used in the tutorials

To be implemented by subclasses:

 • call(): Contains the logic for loss calculation using y_true, y_pred.

Example subclass implementation:

class MeanSquaredError(Loss):

 def call(self, y_true, y_pred):
  y_pred = tf.convert_to_tensor_v2(y_pred)
  y_true = tf.cast(y_true, y_pred.dtype)
  return tf.reduce_mean(math_ops.square(y_pred - y_true), axis=-1)

When used with tf.distribute.Strategy, outside of built-in training loops such as tf.keras compile and fit, please use 'SUM' or 'NONE' reduction types, and reduce losses explicitly in your training loop. Using 'AUTO' or 'SUM_OVER_BATCH_SIZE' will raise an error.

Please see this custom training tutorial for more details on this.

You can implement 'SUM_OVER_BATCH_SIZE' using global batch size like:

with strategy.scope():
 loss_obj = tf.keras.losses.CategoricalCrossentropy(
   reduction=tf.keras.losses.Reduction.NONE)
 ....
 loss = (tf.reduce_sum(loss_obj(labels, predictions)) *
     (1. / global_batch_size))

reduction (Optional) Type of tf.keras.losses.Reduction to apply to loss. Default value is AUTO. AUTO indicates that the reduction option will be determined by the usage context. For almost all cases this defaults to SUM_OVER_BATCH_SIZE. When used with tf.distribute.Strategy, outside of built-in training loops such as tf.keras compile and fit, using AUTO or SUM_OVER_BATCH_SIZE will raise an error. Please see this custom training tutorial for more details.
name Optional name for the op.

Methods

call

View source

Invokes the Loss instance.

Args
y_true Ground truth values. shape = [batch_size, d0, .. dN], except sparse loss functions such as sparse categorical crossentropy where shape = [batch_size, d0, .. dN-1]
y_pred The predicted values. shape = [batch_size, d0, .. dN]

Returns
Loss values with the shape [batch_size, d0, .. dN-1].

from_config

View source

Instantiates a Loss from its config (output of get_config()).

Args
config Output of get_config().

Returns
A Loss instance.

get_config

View source