Custom Layer Trainable Weights Return NaNs

I’m trying to write a custom layer for the forward pass, however the weights return NaN values (and the grad values are 0) after the first forward pass. I suspect the trainable weights are not updating on the backward pass, and are somehow getting detached from the computational graph, but I don’t know if this is true. Am I creating the weights and biases correctly?

This is my custom layer:

class MultiGraphCNN(nn.Module):

    def __init__(self,
                 input_dim,
                 output_dim,
                 num_filters,
                 activation=None,
                 use_bias=True,
                 kernel_initializer='glorot_uniform',
                 bias_initializer='zeros',
                 kernel_regularizer=None,
                 bias_regularizer=None,
                 activity_regularizer=None,
                 kernel_constraint=None,
                 bias_constraint=None,
                 **kwargs):
        super(MultiGraphCNN, self).__init__()

        self.output_dim = output_dim
        self.num_filters = num_filters
        self.activation = activation
        self.use_bias = use_bias
        self.kernel_initializer = kernel_initializer
        self.bias_initializer = bias_initializer
        self.input_dim = input_dim
        #self.input_shape = input_shape


        #if self.num_filters != int(input_shape[1][-2]/input_shape[1][-1]):
            #raise ValueError('num_filters does not match with graph_conv_filters dimensions.')

        #self.input_dim = self.input_shape[0][-1]
        # #self.input_dim = 1 #Ensures kernel shape is (2,100) as before
        kernel_shape = (self.num_filters * self.input_dim, self.output_dim)

        self.kernel = nn.Parameter(torch.empty(kernel_shape), requires_grad=True)

        # self.kernel = self.add_weight(shape=kernel_shape,
        #                               initializer=self.kernel_initializer,
        #                               name='kernel',
        #                               regularizer=self.kernel_regularizer,
        #                               constraint=self.kernel_constraint)
        if self.use_bias:
            self.bias = nn.Parameter(torch.empty(self.output_dim, ))
            #nn.init.zeros_(self.bias)
            # self.bias = self.add_weight(shape=(self.output_dim,),
            #                             initializer=self.bias_initializer,
            #                             name='bias',
            #                             regularizer=self.bias_regularizer,
            #                             constraint=self.bias_constraint)
        else:
            self.bias = None

    #     self.reset_parameters()
    #
    # def reset_parameters(self) -> None:
    #     # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
    #     # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
    #     # https://github.com/pytorch/pytorch/issues/57109
    #     nn.init.xavier_uniform_(self.kernel)
    #     #init.kaiming_uniform_(self.weight, a=math.sqrt(5))
    #     if self.bias is not None:
    #         fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.kernel)
    #         bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
    #         nn.init.uniform_(self.bias, -bound, bound)
    #
    #     self.built = True


    def forward(self, inputs):

        #print("X SHAPE", inputs[0].shape)
        output = graph_conv_op(inputs[0], self.num_filters, inputs[1], self.kernel)

        output = output + self.bias
        output = F.elu(output)
        # if self.use_bias:
        #     output = K.bias_add(output, self.bias)
        # if self.activation is not None:
        #     output = self.activation(output)
        return output

And the forward pass of my Encoder that uses this layer looks like so:

class Encoder(nn.Module):
    def __init__(self, hidden_dim, in_features, out_features,num_filters, graph_conv_filters):
        super(Encoder, self).__init__()
        self.out_features = out_features
        self.hidden_dim = hidden_dim
        self.num_filters = num_filters
        self.graph_conv_filters = graph_conv_filters
        self.in_features = in_features

        self.MultiGraphCNN_1 = MultiGraphCNN(input_dim=1, output_dim=100, num_filters=self.num_filters,
                                             activation='elu')
        #self.MultiGraphCNN_1.weight = nn.Parameter(torch.empty(2, 100))
        #torch.nn.init.xavier_uniform_(self.MultiGraphCNN_1.weight)
        #self.MultiGraphCNN_1.bias = nn.Parameter(torch.empty(100, ))
        #self.MultiGraphCNN_1.bias.data.fill_(0.01)

        self.MultiGraphCNN_2 = MultiGraphCNN(input_dim=100, output_dim=100, num_filters=self.num_filters,
                                             activation='elu')
        #self.MultiGraphCNN_2.weight = nn.Parameter(torch.empty(200, 100))
        #torch.nn.init.xavier_uniform_(self.MultiGraphCNN_2.weight)
        #self.MultiGraphCNN_2.bias = nn.Parameter(torch.empty(100, ))
        #self.MultiGraphCNN_2.bias.data.fill_(0.01)

        self.fc1 = nn.Linear(in_features=self.in_features, out_features=self.out_features[0])
        self.fc2 = nn.Linear(in_features=self.out_features[0], out_features=self.out_features[1])

        self.fc_mean = nn.Linear(in_features=self.out_features[1], out_features=self.hidden_dim)
        self.fc_var = nn.Linear(in_features=self.out_features[1], out_features=self.hidden_dim)

    def sampling(self, args):
        """Reparameterization trick by sampling fr an isotropic unit Gaussian.
        # Arguments
            args (tensor): mean and log of variance of Q(z|X)
        # Returns
            z (tensor): sampled latent vector
        """

        z_mean, z_log_var = args
        batch = z_mean.shape[0]
        dim = z_mean.shape[1]
        epsilon = torch.normal(mean=0.0, std=1.0, size=(batch, dim))
        return z_mean + z_log_var * epsilon

 def forward(self, inputs):

        #inputs = [Attr_train, graph_conv_filters]

        x = self.MultiGraphCNN_1(inputs)
        x = nn.Dropout(0.1)(x)
        x = self.MultiGraphCNN_2([x, inputs[1]])
        x = nn.Dropout(0.1)(x)
        x = Lambda(lambda x: torch.mean(x, dim=1))(
            x)  # adding a node invariant layer to make sure output does not depend upon the node order in a graph.
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)

        # z_mean = Dense(modelArgs["latent_dim"], name='z_mean')(x)
        z_mean = self.fc_mean(x)
        # z_log_var = Dense(modelArgs["latent_dim"], name='z_log_var')(x)
        z_log_var = self.fc_var(x)

        # use reparameterization trick to push the sampling out as input
        # note that "output_shape" isn't necessary with the TensorFlow backend
        z = Lambda(self.sampling)([z_mean, z_log_var])
        # z = Lambda(self.sampling, output_shape=(modelArgs["latent_dim"],), name='z')([z_mean, z_log_var])

        #latent_inputs = Input(shape=(modelArgs["latent_dim"],), name='z_sampling')
        return z, z_mean, z_log_var

The first pass works fine, but after the loss.backward() is applied, I find that self.Encoder.MultiGraphCNN_1.kernel is a tensor of NaNs, while self.Encoder.MultiGraphCNN_1.kernel.grad is a tensor of zeros. Any help would be really appreciated. Thank you :slight_smile:

Without replicating the whole thing yet, my eyes jumped to this capital L Lambda function you use. What is that / where is it defined? Feels like that might be breaking things.

       x = Lambda(lambda x: torch.mean(x, dim=1))(x)

Hi, Lambda is part of the torchivision transforms, defined liked so:

class Lambda:
    """Apply a user-defined lambda as a transform. This transform does not support torchscript.

    Args:
        lambd (function): Lambda/function to be used for transform.
    """

    def __init__(self, lambd):
        _log_api_usage_once(self)
        if not callable(lambd):
            raise TypeError(f"Argument lambd should be callable, got {repr(type(lambd).__name__)}")
        self.lambd = lambd

    def __call__(self, img):
        return self.lambd(img)

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}()"

Got it. You can run anomaly detection to get a better idea of what’s causing the issue.

If that fails, can you please provide the missing code and an example input to make your issue reproducible? e.g. definition of graph_conv_op and whatever else might be required to reproduce.

Hi Andrei, thank you for the suggestion of anomaly detection! It really helped me out. Turns out that the issue wasn’t with the encoder at all. It was with the decoder which had a Conv2D transpose layer. I did not attach my loss function in the original post (see below). It turned out to be a trivial fix - interchanging the input and target variables in the BCE and MSE loss. In fact, your response here (Conv2d.backwards always results in NaN) fixed it. Thank you :slight_smile:

def loss_func(y_true, y_pred, modelArgs, trainArgs, z_mean, z_log_var):
    y_true_attr = y_true[0]
    y_pred_attr = y_pred[0]

    y_true_a = y_true[1]
    y_true_a = y_true_a[:, y_true_a.shape[-1]:, :]
    y_pred_a = y_pred[1]

    ## ATTR RECONSTRUCTION LOSS_______________________
    ## mean squared error
    # attr_reconstruction_loss = mse(K.flatten(y_true_attr), K.flatten(y_pred_attr))
    attr_reconstruction_loss = F.mse_loss(torch.flatten(y_pred_attr), torch.flatten(y_true_attr), )
    # Scaling below by input shape (Why?)
    attr_reconstruction_loss *= modelArgs["input_shape"][0][0]

    ## A RECONSTRUCTION LOSS_______________________
    ## binary cross-entropy

   #MISTAKE WAS BELOW: TRUE AND PRED WERE INTERCHANGED
    # a_reconstruction_loss = binary_crossentropy(K.flatten(y_true_a), K.flatten(y_pred_a))
    a_reconstruction_loss = F.binary_cross_entropy(torch.flatten(y_pred_a), torch.flatten(y_true_a))
    # Scaling below by input shape (Why?)
    a_reconstruction_loss *= (modelArgs["input_shape"][1][0] * modelArgs["input_shape"][1][1])

    ## KL LOSS _____________________________________________
    kl_loss = 1 + z_log_var - torch.square(z_mean) - torch.exp(z_log_var)
    kl_loss = torch.sum(kl_loss, dim=-1)
    kl_loss *= -0.5


    #print("ADJ LOSS: ", a_reconstruction_loss)
    #print("ATTR LOSS: ", attr_reconstruction_loss)
    #print("KL LOSS: ", kl_loss)

    ## COMPLETE LOSS __________________________________________________

    loss = torch.mean(trainArgs["loss_weights"][0] * a_reconstruction_loss + trainArgs["loss_weights"][
        1] * attr_reconstruction_loss + trainArgs["loss_weights"][2] * kl_loss)

    return loss