Custom Layer Trainable Weights Return NaNs

Hi Andrei, thank you for the suggestion of anomaly detection! It really helped me out. Turns out that the issue wasn’t with the encoder at all. It was with the decoder which had a Conv2D transpose layer. I did not attach my loss function in the original post (see below). It turned out to be a trivial fix - interchanging the input and target variables in the BCE and MSE loss. In fact, your response here (Conv2d.backwards always results in NaN) fixed it. Thank you :slight_smile:

def loss_func(y_true, y_pred, modelArgs, trainArgs, z_mean, z_log_var):
    y_true_attr = y_true[0]
    y_pred_attr = y_pred[0]

    y_true_a = y_true[1]
    y_true_a = y_true_a[:, y_true_a.shape[-1]:, :]
    y_pred_a = y_pred[1]

    ## ATTR RECONSTRUCTION LOSS_______________________
    ## mean squared error
    # attr_reconstruction_loss = mse(K.flatten(y_true_attr), K.flatten(y_pred_attr))
    attr_reconstruction_loss = F.mse_loss(torch.flatten(y_pred_attr), torch.flatten(y_true_attr), )
    # Scaling below by input shape (Why?)
    attr_reconstruction_loss *= modelArgs["input_shape"][0][0]

    ## A RECONSTRUCTION LOSS_______________________
    ## binary cross-entropy

   #MISTAKE WAS BELOW: TRUE AND PRED WERE INTERCHANGED
    # a_reconstruction_loss = binary_crossentropy(K.flatten(y_true_a), K.flatten(y_pred_a))
    a_reconstruction_loss = F.binary_cross_entropy(torch.flatten(y_pred_a), torch.flatten(y_true_a))
    # Scaling below by input shape (Why?)
    a_reconstruction_loss *= (modelArgs["input_shape"][1][0] * modelArgs["input_shape"][1][1])

    ## KL LOSS _____________________________________________
    kl_loss = 1 + z_log_var - torch.square(z_mean) - torch.exp(z_log_var)
    kl_loss = torch.sum(kl_loss, dim=-1)
    kl_loss *= -0.5


    #print("ADJ LOSS: ", a_reconstruction_loss)
    #print("ATTR LOSS: ", attr_reconstruction_loss)
    #print("KL LOSS: ", kl_loss)

    ## COMPLETE LOSS __________________________________________________

    loss = torch.mean(trainArgs["loss_weights"][0] * a_reconstruction_loss + trainArgs["loss_weights"][
        1] * attr_reconstruction_loss + trainArgs["loss_weights"][2] * kl_loss)

    return loss