Skip to content Skip to sidebar Skip to footer

How To Create A Loss Function With Mse That Uses Tf.where() To Ignore Certain Elements

Here is the function currently. Here, it removes from the MSE any values where y_true is less than a threshold (here, it is 0.1). def my_loss(y_true,y_pred): loss = tf.square(y

Solution 1:

Most of the time it results in undesirable behaviours or errors if you use the python short-circuit and operator in codes that convert into graph mode because the python short-circuit and operator cannot be overloaded. To do element-wise and operation for tensors, use tf.math.logical_and.

Besides, tf.where is not necessary here and it is likely to be slower. Masking is preferred. Example codes:

@tf.functiondefbetter_loss(y_true,y_pred):
  loss = tf.square(y_true - y_pred)
  # ignore elements where BOTH y_true & y_pred < 0.1
  mask = tf.cast(tf.logical_or(y_true >= 0.1, y_pred >= 0.1) ,tf.float32)
  loss *= mask
  return tf.reduce_sum(loss) / tf.reduce_sum(mask)

Solution 2:

You seem to confused with the tf.where usage. From documentation it can be sen that tf.where should take three parameters else it will simply return None as mentioned here

tf.where(
    condition, x=None, y=None, name=None
)

That's why your loss isn't helping in learning anything, cause it returns None always no matter how.

For your questions if you want to check both condition and then imply a loss this is how you should do it.

Let's say for y_true!=0 and y_pred!=0 you want to give losses, some_loss1 and some_loss2 respectively, then the total loss can be computed by nesting tf.where as

some_loss1=tf.constant(1000.0) #saysome_loss12=tf.constant(1000.0) #sayloss = tf.where(y_pred<0.1,tf.where(y_true<0.1,tf.constant(0.0),some_loss1),some_loss2)

This shall penalize both y_pred and y_true.

Also if you want to add this loss to your MSE loss, then create different variable names as it resigned the already obtained MSE value to this mask loss.

Post a Comment for "How To Create A Loss Function With Mse That Uses Tf.where() To Ignore Certain Elements"