2

I'm trying to create a ML model in TensorFlow that takes in a tensor with shape (128,128,12) and outputs a tensor with shape (128,128,3), where the output dimensions mean (x, y, sensor_number).

With my training data I have the problem that my output is very sparse, meaning I could only take sensor measurements at very few x-y-coordinates. But if I take a measurement, I always will have all 3 sensor readings.

I have created a simple model that takes the input data and a mask as input:

from keras import Input, layers, models

input_data = Input(shape=(128,128,12), name="input_data")
input_mask = Input(shape=(128,128,3), name="input_mask")

output_layer = layers.Conv2D(filters=3, kernel_size=(3,3), padding="same", activation="sigmoid", name="output")(input_data)
output_masked = layers.Multiply(name="masked_output")([output_layer, input_mask])

print(output_masked.shape)
# return input_data, input_mask, output_layer, output_masked

model = models.Model(inputs=input_data, outputs=output_layer)
model_masked = models.Model(inputs=[input_data, input_mask], outputs=output_masked)

model_masked.compile(optimizer="adam", loss="mse")

The mask simply contains ones at the coordinates where I have taken measurements, otherwise zeros. If helpful, it would be no problem to use zeros in y_true to obtain the mask, the actual sensor readings are always greater than 0.

Now the problem is during training the loss gets tiny, and the model will usually predict zeros. To solve this issue I figured I'd need a custom loss function that only calculates the loss for the coordinates where data is available. I've tried this:

from keras import ops
def masked_mse(y_true, y_pred):
    mask_value = 0
    mask = ops.repeat(
        ops.cast(
            ops.any(
                ops.not_equal(y_true, mask_value),
                axis=-1, keepdims=True,
            ),
            "float32"
        ),
        repeats=y_true.shape[-1], axis=-1
    )
    masked_squared_error = ops.square(mask * (y_pred - y_true))
    masked_mse = ops.sum(masked_squared_error, axis=-1) / ops.sum(mask, axis=-1) # results in lots of NaNs
    # masked_mse = ops.sum(masked_squared_error, axis=-1) / ops.maximum(ops.sum(mask, axis=-1), 1) # results in lots of zeros
    return masked_mse

but I'm not really understanding how this loss function gets applied during training. If you notice I've included two variants on how to calculate masked_mse. The first one results in nans where no measurements are. The output is per x-y-coordinate. During training however TensorFlow logs the loss as a single value, which will always be nan. I don't understand how this aggregation is calculated.

With the second variant to calculate masked_mse most values will be 0, and during training TensorFlow logs some value, but then I'm back again at tiny losses and the model learning to predict mostly zeros.

How do I define a proper loss function for training with sparse output data?

As a side note I saw that there are two different MSE calculations in Keras: tf.keras.losses.MeanSquaredError and tf.keras.losses.MSE, where the former has the parameter reduction. With reduction=None this calculates the coordinate-wise MSE. With reduction="sum_over_batch_size" (default) this calculates a single value. Is this how I should define my masked loss function?

1 Answer 1

0

Training a TensorFlow model on sparse data with standard MSE loss can cause it to predict only zeros. To solve this, you need a custom loss function that focuses solely on the non-zero values in the target tensor. This approach prevents the loss from being distorted by the abundant zeros and ensure the model accurately learns from the actual sensor measurements. Please refer this gist, where i have tried implementing a custom loss function.

Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.