I'm creating a Convolutional Variational Autoencoder with Tensorflow in Python code, with some images I created myself (64x64 pixels).
My problem is, that if I use anything else than Binary Crossentropy as my reconstruction loss, my model never converges, (the loss stays the same on all epochs), and I get all NaN predictions. Binary Crossentropy is obviously not the optimal loss function here. But nothing else works, for some bizare reason. I hope somebody can help me.
import tensorflow as tf
from tensorflow.keras import layers, models, losses, optimizers, callbacks
import numpy as np
# Define input dimensions
input_shape = (64, 64, 3) # Change to (64, 64, 1) if grayscale
# Encoder
def build_encoder(input_shape, latent_dim):
encoder_inputs = layers.Input(shape=input_shape)
x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(encoder_inputs)
x = layers.MaxPooling2D((2, 2), padding='same')(x)
x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x = layers.MaxPooling2D((2, 2), padding='same')(x)
x = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
x = layers.MaxPooling2D((2, 2), padding='same')(x)
x = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(x)
x = layers.MaxPooling2D((2, 2), padding='same')(x)
x = layers.Flatten()(x)
#x = layers.Dense(512, activation='relu')(x)
#x = layers.Dense(128, activation='relu')(x)
z_mean = layers.Dense(latent_dim, name='z_mean')(x)
z_log_var = layers.Dense(latent_dim, name='z_log_var')(x)
def sampling(args):
z_mean, z_log_var = args
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
z = layers.Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var])
encoder = models.Model(encoder_inputs, [z_mean, z_log_var, z], name='encoder')
return encoder
# Decoder
def build_decoder(latent_dim):
latent_inputs = layers.Input(shape=(latent_dim,))
#x = layers.Dense(512, activation='relu')(latent_inputs)
#x = layers.Dense(256, activation='relu')(x)
x = layers.Dense(4 * 4 * 256, activation='relu')(latent_inputs)
x = layers.Reshape((4, 4, 256))(x)
x = layers.UpSampling2D((2, 2))(x)
x = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
x = layers.UpSampling2D((2, 2))(x)
x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x = layers.UpSampling2D((2, 2))(x)
x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x = layers.UpSampling2D((2, 2))(x)
decoder_outputs = layers.Conv2D(3, (3, 3), activation='sigmoid', padding='same')(x) # 3 for color, change to 1 for grayscale
decoder = models.Model(latent_inputs, decoder_outputs, name='decoder')
return decoder
# Define the Variational Autoencoder (VAE) model
class VAE(models.Model):
def __init__(self, encoder, decoder, **kwargs):
super(VAE, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
def call(self, inputs):
z_mean, z_log_var, z = self.encoder(inputs)
reconstructed = self.decoder(z)
kl_loss = -0.5 * tf.reduce_mean(
z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1)
kl_loss *= 0.0 #Sæt vægten her
self.add_loss(kl_loss)
return reconstructed
# Parameters
latent_dim = 50 # Adjust as needed
# Build encoder and decoder
encoder = build_encoder(input_shape, latent_dim)
decoder = build_decoder(latent_dim)
# Build VAE
vae = VAE(encoder, decoder)
optimizer = optimizers.Adam(learning_rate=0.001)
#vae.compile(optimizer=optimizer, loss=losses.MeanSquaredError())
vae.compile(optimizer=optimizer, loss=losses.BinaryCrossentropy())
optimizer.SGD
# Print model summaries
encoder.summary()
decoder.summary()
# Ensure data shapes are correct before training
print(f"Shape of X_train: {X_train.shape}")
print(f"Shape of X_val: {X_val.shape}")
early_stopping = callbacks.EarlyStopping(monitor='val_loss', patience=6, restore_best_weights=True)
reduce_lr = callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, min_lr=1e-4)
vae.fit(X_train, X_train, epochs=30, batch_size=32, validation_data=(X_val, X_val), callbacks=[early_stopping, reduce_lr])
vae.summary()
I have tried MSE, MAE, KLDiv loss and many others. All other loss functions gives NaN predictions and nothing converges.
I tried changing all hyperparameters I could think off. Nothing changes.
I looked everywhere on the internet for an answer, but no luck. I really hope somebody has an idea.