RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [4, 80, 25]], which is output 0 of ReciprocalBackward0, is at version 1; expected version 0 instead

import torch
import torch.nn as nn
#from torch.autograd import Variable
#from torch.autograd import Function
#from torch.nn.modules.module import Module
#from torch.nn.parameter import Parameter
from torch.nn.functional import conv2d
import torch.nn.functional as F
import numpy as np
import pywt
import scipy
class Kerv2d(nn.Conv2d):
    '''
    kervolution with following options:
    kernel_type: [linear, polynomial, gaussian, etc.]
    default is convolution:
             kernel_type --> linear,
    balance, power, gamma is valid only when the kernel_type is specified
    if learnable_kernel = True,  they just be the initial value of learable parameters
    if learnable_kernel = False, they are the value of kernel_type's parameter
    the parameter [power] cannot be learned due to integer limitation
    '''
    def __init__(self, in_channels, out_channels, kernel_size, 
            stride=1, padding_mode='zeros', dilation=1, groups=1, bias=True,
            kernel_type='sigmoid', learnable_kernel=False, kernel_regularizer=False, power=3, gamma=1,kernel_fn=None, balance=1):

        super(Kerv2d, self).__init__(in_channels, out_channels, kernel_size, stride, dilation, groups, bias,padding_mode)
        self.kernel_type = kernel_type
        #self.learnable_kernel, self.kernel_regularizer = learnable_kernel, kernel_regularizer
        self.balance, self.power, self.gamma = balance, power, gamma

        # parameter for kernel type
        #if learnable_kernel == True:
            #self.balance = nn.Parameter(torch.cuda.FloatTensor([balance] * out_channels), requires_grad=True).view(-1, 1)
            #self.gamma   = nn.Parameter(torch.cuda.FloatTensor([gamma]   * out_channels), requires_grad=True).view(-1, 1)
    
    
    def forward(self, input):

        minibatch, in_channels, input_width, input_hight = input.size()
        assert(in_channels == self.in_channels)
        input_unfold = F.unfold(input, kernel_size=self.kernel_size, dilation=self.dilation, padding=self.padding, stride=self.stride)
        input_unfold = input_unfold.view(minibatch, 1, self.kernel_size[0]*self.kernel_size[1]*self.in_channels, -1)
        weight_flat  = self.weight.view(self.out_channels, -1, 1)
        output_width = (input_width - self.kernel_size[0] + 2 * self.padding[0]) // self.stride[0] + 1
        output_hight = (input_hight - self.kernel_size[1] + 2 * self.padding[1]) // self.stride[1] + 1

        if self.kernel_type == 'linear':
            output = (input_unfold * weight_flat).sum(dim=2)

        elif self.kernel_type == 'manhattan':
            output = -((input_unfold - weight_flat).abs().sum(dim=2))

        elif self.kernel_type == 'euclidean':
            output = -(((input_unfold - weight_flat)**2).sum(dim=2))

        elif self.kernel_type == 'polynomial':
            
            output = ((input_unfold * weight_flat).sum(dim=2) + self.balance)**self.power

        elif self.kernel_type == 'gaussian':
            output = (-self.gamma*((input_unfold - weight_flat)**2).sum(dim=2)).exp() + 0

        elif self.kernel_type == 'rbf':
            output = ((-self.gamma*(((input_unfold - weight_flat)**2).abs().sum(dim=2)).sqrt()).exp())

        elif self.kernel_type =='bessel':
            output = torch.special.i0(input_unfold - weight_flat).sum(dim=2)

        elif self.kernel_type =='sigmoid':
            #output = 1/(1 + ((input_unfold - weight_flat).sum(dim=2)).exp())
            
            output = (1 + ((input_unfold - weight_flat).sum(dim=2)).exp()).reciprocal()
            

       
            
            
            
            
        else:
            raise NotImplementedError(self.kernel_type+' kervolution not implemented')
        


        if self.bias is not None:
            output += self.bias.view(self.out_channels, -1)

        return output.view(minibatch, self.out_channels, output_width, output_hight)

Hello friends. Can anyone help me in resolving this issue while implementing the sigmoid Kernel for CNN.
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [4, 80, 25]], which is output 0 of ReciprocalBackward0, is at version 1; expected version 0 instead.

Could you replace that with

output = output + self.bias.view(self.out_channels, -1)

Your bias operation is being done inplace.

Thank you for your reply. I have replaced my code with

output = output + self.bias.view(self.out_channels, -1)

Now I am getting the following error

RuntimeError: Function ‘ExpBackward0’ returned nan values in its 0th output.

whatever’s computed within this is causing torch.exp() to return a NaN value. I see you’re using sigmoid as your given function? If so, you’re missing the minus sign on your exponent.

Sigmoid is defined as 1/(1 + torch.exp(-x)) whereas you have 1/(1 + exp(x)) at the moment where x is defined as ((input_unfold - weight_flat).sum(dim=2)).

You could just pass it into torch.sigmoid directly, via,

output = torch.sigmoid( ((input_unfold - weight_flat).sum(dim=2)) )

I am getting the following error after making the changes as suggested by you.

RuntimeError: Function ‘ExpBackward0’ returned nan values in its 0th output.

Could you try running with torch.autograd.set_detect_anomaly? To find where the NaN is appearing? Automatic differentiation package - torch.autograd — PyTorch 1.10.0 documentation

As per your suggestion, I have included torch.autograd.set_detect_anomaly(True) in my code. Following are the details of the error,

RuntimeError                              Traceback (most recent call last)
<ipython-input-69-599bd4b36306> in <module>()
     29         # Backward and optimize
     30         optimizer.zero_grad()
---> 31         loss.backward()
     32         optimizer.step()
     33 

1 frames
/usr/local/lib/python3.7/dist-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
    305                 create_graph=create_graph,
    306                 inputs=inputs)
--> 307         torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
    308 
    309     def register_hook(self, hook):

/usr/local/lib/python3.7/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    154     Variable._execution_engine.run_backward(
    155         tensors, grad_tensors_, retain_graph, create_graph, inputs,
--> 156         allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
    157 
    158 

RuntimeError: Function 'ExpBackward0' returned nan values in its 0th output.

So it seems that the torch.exp call is returning an invalid gradient. Could you print out the output of the torch.exp(-x) call where x is defined as ((input_unfold - weight_flat).sum(dim=2))? and possibly input_unfold - weight_flat?

I have a feeling torch.exp is return an infinity which when passed to the backward causes the NaN.

Did you check that you’re using torch.exp(-x) rather than torch.exp(x), the minus sign in the exponent is needed.

This should be output = (1 + ((input_unfold - weight_flat).sum(dim=2)).mul(-1).exp()).reciprocal(), in order for it to be a sigmoid function!

I am getting the same error for other kernels like rbf, Gaussian. Can you please suggest any modifications in code. I would also request to suggest inclusion of other useful kernels in my code for improvement of classification accuracy.

The error is emerging from the calculation of the Loss. Could you print the loss value? (Before calling .backward()) And also print out the value of,

to check it’s finite. If the loss is infinite, it will become NaN in the backward pass.

As per your suggestion I have printed the value of (input_unfold - weight_flat)**2 for an rbf kernel as given below.

elif self.kernel_type == 'rbf':
            print((input_unfold - weight_flat)**2)
            output = torch.exp((self.gamma*(((input_unfold - weight_flat)**2).abs().sum(dim=2).mul(-1)).sqrt()))

I am getting the following tensor.

tensor([[[[1.5106e-03, 1.5106e-03, 1.5106e-03, …, 1.1569e-03,
3.0368e-04, 6.0389e-03],
[2.8391e-04, 2.8391e-04, 2.8391e-04, …, 5.3499e-03,
4.8374e-04, 1.9257e-05],
[1.5675e-03, 1.5675e-03, 1.5675e-03, …, 6.1521e-03,
2.7095e-03, 1.5675e-03],
…,
[7.2640e-03, 8.6801e-01, 3.9451e-01, …, 7.2640e-03,
7.2640e-03, 7.2640e-03],
[8.6283e-01, 3.9102e-01, 7.7187e-02, …, 6.7971e-03,
6.7971e-03, 6.7971e-03],
[3.0447e-01, 1.2345e-01, 6.5897e-02, …, 7.9489e-05,
7.9489e-05, 7.9489e-05]],

     [[7.3047e-06, 7.3047e-06, 7.3047e-06,  ..., 5.7090e-05,
       3.4805e-03, 1.3062e-03],
      [2.5936e-03, 2.5936e-03, 2.5936e-03,  ..., 2.8788e-05,
       8.0589e-03, 4.0182e-03],
      [3.9495e-03, 3.9495e-03, 3.9495e-03,  ..., 5.7607e-04,
       2.5385e-03, 3.9495e-03],
      ...,
      [3.7975e-03, 6.1594e-01, 2.3160e-01,  ..., 3.7975e-03,
       3.7975e-03, 3.7975e-03],
      [6.7872e-01, 2.7069e-01, 1.4658e-01,  ..., 5.1044e-04,
       5.1044e-04, 5.1044e-04],
      [3.3737e-01, 1.0388e-01, 5.1827e-02,  ..., 1.4412e-03,
       1.4412e-03, 1.4412e-03]],

     [[8.8988e-04, 8.8988e-04, 8.8988e-04,  ..., 6.2389e-04,
       7.0026e-04, 4.7162e-03],
      [2.9093e-04, 2.9093e-04, 2.9093e-04,  ..., 5.3802e-03,
       4.7467e-04, 2.1118e-05],
      [6.8946e-03, 6.8946e-03, 6.8946e-03,  ..., 1.4854e-02,
       9.1193e-03, 6.8946e-03],
      ...,
      [1.6353e-04, 7.3827e-01, 3.0876e-01,  ..., 1.6353e-04,
       1.6353e-04, 1.6353e-04],
      [7.1580e-01, 2.9429e-01, 1.3008e-01,  ..., 1.5248e-07,
       1.5248e-07, 1.5248e-07],
      [3.6435e-01, 8.9718e-02, 4.1975e-02,  ..., 3.6894e-03,
       3.6894e-03, 3.6894e-03]],

     ...,

     [[3.3553e-03, 3.3553e-03, 3.3553e-03,  ..., 3.9410e-03,
       1.3046e-02, 3.6408e-04],
      [8.0142e-04, 8.0142e-04, 8.0142e-04,  ..., 7.1576e-03,
       1.1097e-04, 2.5116e-04],
      [6.5826e-04, 6.5826e-04, 6.5826e-04,  ..., 4.1603e-03,
       1.4530e-03, 6.5826e-04],
      ...,
      [1.3704e-03, 6.5516e-01, 2.5589e-01,  ..., 1.3704e-03,
       1.3704e-03, 1.3704e-03],
      [8.3160e-01, 3.7010e-01, 8.6902e-02,  ..., 4.2876e-03,
       4.2876e-03, 4.2876e-03],
      [3.8677e-01, 7.9096e-02, 3.4815e-02,  ..., 6.2459e-03,
       6.2459e-03, 6.2459e-03]],

     [[4.5563e-03, 4.5563e-03, 4.5563e-03,  ..., 5.2350e-03,
       1.5325e-02, 8.2121e-04],
      [4.1114e-03, 4.1114e-03, 4.1114e-03,  ..., 6.1266e-05,
       1.0602e-02, 5.8648e-03],
      [1.6336e-03, 1.6336e-03, 1.6336e-03,  ..., 6.2824e-03,
       2.7962e-03, 1.6336e-03],
      ...,
      [7.2603e-03, 8.6797e-01, 3.9449e-01,  ..., 7.2603e-03,
       7.2603e-03, 7.2603e-03],
      [7.4526e-01, 3.1328e-01, 1.1794e-01,  ..., 2.8369e-04,
       2.8369e-04, 2.8369e-04],
      [2.6591e-01, 1.5014e-01, 8.5750e-02,  ..., 7.4053e-04,
       7.4053e-04, 7.4053e-04]],

     [[4.6240e-04, 4.6240e-04, 4.6240e-04,  ..., 2.7724e-04,
       1.2103e-03, 3.6418e-03],
      [6.0864e-03, 6.0864e-03, 6.0864e-03,  ..., 1.8039e-02,
       1.5344e-03, 4.2973e-03],
      [6.5879e-03, 6.5879e-03, 6.5879e-03,  ..., 1.7912e-03,
       4.7203e-03, 6.5879e-03],
      ...,
      [7.2810e-04, 7.6287e-01, 3.2474e-01,  ..., 7.2810e-04,
       7.2810e-04, 7.2810e-04],
      [7.7364e-01, 3.3178e-01, 1.0702e-01,  ..., 1.0974e-03,
       1.0974e-03, 1.0974e-03],
      [2.4592e-01, 1.6585e-01, 9.7714e-02,  ..., 2.2065e-03,
       2.2065e-03, 2.2065e-03]]],


    [[[1.5106e-03, 1.5106e-03, 1.5106e-03,  ..., 1.6799e-04,
       3.5862e-02, 1.3757e-03],
      [2.8391e-04, 2.8391e-04, 2.8391e-04,  ..., 1.7864e-02,
       3.4694e-04, 6.5997e-05],
      [1.5675e-03, 1.5675e-03, 1.5675e-03,  ..., 1.4300e-03,
       4.1687e-03, 1.5675e-03],
      ...,
      [7.2640e-03, 1.0410e+00, 1.0757e-02,  ..., 7.2640e-03,
       7.2640e-03, 7.2640e-03],
      [1.0353e+00, 1.0187e-02, 3.2086e-01,  ..., 6.7971e-03,
       6.7971e-03, 6.7971e-03],
      [7.5077e-04, 2.4296e-01, 4.0355e-02,  ..., 7.9489e-05,
       7.9489e-05, 7.9489e-05]],

     [[7.3047e-06, 7.3047e-06, 7.3047e-06,  ..., 8.1844e-04,
       2.1846e-02, 2.0065e-05],
      [2.5936e-03, 2.5936e-03, 2.5936e-03,  ..., 4.0575e-02,
       2.4158e-03, 5.7610e-03],
      [3.9495e-03, 3.9495e-03, 3.9495e-03,  ..., 4.1760e-03,
       1.4343e-03, 3.9495e-03],
      ...,
      [3.7975e-03, 7.6289e-01, 1.8610e-03,  ..., 3.7975e-03,
       3.7975e-03, 3.7975e-03],
      [8.3260e-01, 1.6878e-05, 2.1289e-01,  ..., 5.1044e-04,
       5.1044e-04, 5.1044e-04],
      [3.1863e-03, 2.7244e-01, 2.9529e-02,  ..., 1.4412e-03,
       1.4412e-03, 1.4412e-03]],

     [[8.8988e-04, 8.8988e-04, 8.8988e-04,  ..., 1.5406e-05,
       3.2521e-02, 7.8703e-04],
      [2.9093e-04, 2.9093e-04, 2.9093e-04,  ..., 1.7808e-02,
       3.5470e-04, 6.2674e-05],
      [6.8946e-03, 6.8946e-03, 6.8946e-03,  ..., 6.6027e-03,
       1.1666e-02, 6.8946e-03],
      ...,
      [1.6353e-04, 8.9842e-01, 9.7796e-04,  ..., 1.6353e-04,
       1.6353e-04, 1.6353e-04],
      [8.7361e-01, 3.2740e-04, 2.3387e-01,  ..., 1.5248e-07,
       1.5248e-07, 1.5248e-07],
      [6.2766e-03, 2.9674e-01, 2.2219e-02,  ..., 3.6894e-03,
       3.6894e-03, 3.6894e-03]],

     ...,

     [[3.3553e-03, 3.3553e-03, 3.3553e-03,  ..., 7.0275e-03,
       8.5711e-03, 3.5642e-03],
      [8.0142e-04, 8.0142e-04, 8.0142e-04,  ..., 1.4932e-02,
       9.0518e-04, 1.1129e-05],
      [6.5826e-04, 6.5826e-04, 6.5826e-04,  ..., 5.7025e-04,
       2.5634e-03, 6.5826e-04],
      ...,
      [1.3704e-03, 8.0648e-01, 3.4350e-04,  ..., 1.3704e-03,
       1.3704e-03, 1.3704e-03],
      [1.0011e+00, 7.0500e-03, 3.0192e-01,  ..., 4.2876e-03,
       4.2876e-03, 4.2876e-03],
      [9.5092e-03, 3.1700e-01, 1.7101e-02,  ..., 6.2459e-03,
       6.2459e-03, 6.2459e-03]],

     [[4.5563e-03, 4.5563e-03, 4.5563e-03,  ..., 8.7247e-03,
       6.8898e-03, 4.7993e-03],
      [4.1114e-03, 4.1114e-03, 4.1114e-03,  ..., 4.6064e-02,
       3.8867e-03, 7.9377e-03],
      [1.6336e-03, 1.6336e-03, 1.6336e-03,  ..., 1.4932e-03,
       4.2760e-03, 1.6336e-03],
      ...,
      [7.2603e-03, 1.0409e+00, 1.0752e-02,  ..., 7.2603e-03,
       7.2603e-03, 7.2603e-03],
      [9.0612e-01, 1.2480e-03, 2.5084e-01,  ..., 2.8369e-04,
       2.8369e-04, 2.8369e-04],
      [7.6180e-05, 2.0865e-01, 5.6176e-02,  ..., 7.4053e-04,
       7.4053e-04, 7.4053e-04]],

     [[4.6240e-04, 4.6240e-04, 4.6240e-04,  ..., 1.9379e-05,
       2.9587e-02, 3.8915e-04],
      [6.0864e-03, 6.0864e-03, 6.0864e-03,  ..., 5.2548e-03,
       6.3668e-03, 2.8134e-03],
      [6.5879e-03, 6.5879e-03, 6.5879e-03,  ..., 6.8795e-03,
       3.1576e-03, 6.5879e-03],
      ...,
      [7.2810e-04, 9.2553e-01, 2.0673e-03,  ..., 7.2810e-04,
       7.2810e-04, 7.2810e-04],
      [9.3739e-01, 2.6637e-03, 2.6742e-01,  ..., 1.0974e-03,
       1.0974e-03, 1.0974e-03],
      [8.1160e-04, 1.9099e-01, 6.5933e-02,  ..., 2.2065e-03,
       2.2065e-03, 2.2065e-03]]],


    [[[1.5106e-03, 1.5106e-03, 1.5106e-03,  ..., 1.1227e-03,
       4.5607e-03, 4.0703e-03],
      [2.8391e-04, 2.8391e-04, 2.8391e-04,  ..., 1.3964e-04,
       6.5322e-05, 6.6743e-04],
      [1.5675e-03, 1.5675e-03, 1.5675e-03,  ..., 4.1633e-03,
       9.3677e-04, 1.5675e-03],
      ...,
      [7.2640e-03, 7.9553e-01, 5.7596e-02,  ..., 7.2640e-03,
       7.2640e-03, 7.2640e-03],
      [7.9057e-01, 5.6267e-02, 5.9802e-02,  ..., 6.7971e-03,
       6.7971e-03, 6.7971e-03],
      [2.6790e-02, 2.9246e-02, 7.7913e-02,  ..., 7.9489e-05,
       7.9489e-05, 7.9489e-05]],

     [[7.3047e-06, 7.3047e-06, 7.3047e-06,  ..., 6.5016e-05,
       6.7411e-04, 4.9413e-04],
      [2.5936e-03, 2.5936e-03, 2.5936e-03,  ..., 6.3352e-03,
       5.7547e-03, 1.7592e-03],
      [3.9495e-03, 3.9495e-03, 3.9495e-03,  ..., 1.4374e-03,
       5.1596e-03, 3.9495e-03],
      ...,
      [3.7975e-03, 5.5513e-01, 8.6746e-03,  ..., 3.7975e-03,
       3.7975e-03, 3.7975e-03],
      [6.1481e-01, 1.7469e-02, 1.9462e-02,  ..., 5.1044e-04,
       5.1044e-04, 5.1044e-04],
      [3.7143e-02, 4.0025e-02, 6.2541e-02,  ..., 1.4412e-03,
       1.4412e-03, 1.4412e-03]],

     [[8.8988e-04, 8.8988e-04, 8.8988e-04,  ..., 5.9879e-04,
       3.4219e-03, 2.9989e-03],
      [2.9093e-04, 2.9093e-04, 2.9093e-04,  ..., 1.3479e-04,
       6.2016e-05, 6.7818e-04],
      [6.8946e-03, 6.8946e-03, 6.8946e-03,  ..., 1.1657e-02,
       5.4832e-03, 6.8946e-03],
      ...,
      [1.6353e-04, 6.7155e-01, 2.8073e-02,  ..., 1.6353e-04,
       1.6353e-04, 1.6353e-04],
      [6.5013e-01, 2.3830e-02, 2.6150e-02,  ..., 1.5248e-07,
       1.5248e-07, 1.5248e-07],
      [4.6441e-02, 4.9657e-02, 5.1667e-02,  ..., 3.6894e-03,
       3.6894e-03, 3.6894e-03]],

     ...,

     [[3.3553e-03, 3.3553e-03, 3.3553e-03,  ..., 4.0050e-03,
       8.5604e-04, 1.0885e-03],
      [8.0142e-04, 8.0142e-04, 8.0142e-04,  ..., 1.2746e-07,
       1.1409e-05, 1.3909e-03],
      [6.5826e-04, 6.5826e-04, 6.5826e-04,  ..., 2.5592e-03,
       2.7794e-04, 6.5826e-04],
      ...,
      [1.3704e-03, 5.9240e-01, 1.3864e-02,  ..., 1.3704e-03,
       1.3704e-03, 1.3704e-03],
      [7.6069e-01, 4.8506e-02, 5.1792e-02,  ..., 4.2876e-03,
       4.2876e-03, 4.2876e-03],
      [5.4659e-02, 5.8144e-02, 4.3687e-02,  ..., 6.2459e-03,
       6.2459e-03, 6.2459e-03]],

     [[4.5563e-03, 4.5563e-03, 4.5563e-03,  ..., 5.3087e-03,
       1.5081e-03, 1.8121e-03],
      [4.1114e-03, 4.1114e-03, 4.1114e-03,  ..., 8.6094e-03,
       7.9303e-03, 3.0399e-03],
      [1.6336e-03, 1.6336e-03, 1.6336e-03,  ..., 4.2706e-03,
       9.8803e-04, 1.6336e-03],
      ...,
      [7.2603e-03, 7.9549e-01, 5.7585e-02,  ..., 7.2603e-03,
       7.2603e-03, 7.2603e-03],
      [6.7821e-01, 2.9448e-02, 3.2020e-02,  ..., 2.8369e-04,
       2.8369e-04, 2.8369e-04],
      [1.6269e-02, 1.8194e-02, 9.9387e-02,  ..., 7.4053e-04,
       7.4053e-04, 7.4053e-04]],

     [[4.6240e-04, 4.6240e-04, 4.6240e-04,  ..., 2.6060e-04,
       2.5170e-03, 2.1562e-03],
      [6.0864e-03, 6.0864e-03, 6.0864e-03,  ..., 2.4353e-03,
       2.8179e-03, 7.5691e-03],
      [6.5879e-03, 6.5879e-03, 6.5879e-03,  ..., 3.1623e-03,
       8.1272e-03, 6.5879e-03],
      ...,
      [7.2810e-04, 6.9502e-01, 3.3031e-02,  ..., 7.2810e-04,
       7.2810e-04, 7.2810e-04],
      [7.0530e-01, 3.5302e-02, 3.8113e-02,  ..., 1.0974e-03,
       1.0974e-03, 1.0974e-03],
      [1.1618e-02, 1.3254e-02, 1.1224e-01,  ..., 2.2065e-03,
       2.2065e-03, 2.2065e-03]]],


    [[[1.5106e-03, 1.5106e-03, 1.5106e-03,  ..., 4.0845e-03,
       3.8873e-02, 2.6158e-03],
      [2.8391e-04, 2.8391e-04, 2.8391e-04,  ..., 2.0007e-02,
       2.0895e-05, 6.4618e-03],
      [1.5675e-03, 1.5675e-03, 1.5675e-03,  ..., 2.6905e-03,
       5.7331e-04, 1.5675e-03],
      ...,
      [7.2640e-03, 1.3240e-01, 2.2788e-02,  ..., 7.2640e-03,
       7.2640e-03, 7.2640e-03],
      [1.3443e-01, 2.1955e-02, 1.1336e-02,  ..., 6.7971e-03,
       6.7971e-03, 6.7971e-03],
      [5.5719e-03, 3.2399e-02, 1.2244e-02,  ..., 7.9489e-05,
       7.9489e-05, 7.9489e-05]],

     [[7.3047e-06, 7.3047e-06, 7.3047e-06,  ..., 4.9911e-04,
       2.4209e-02, 9.1694e-05],
      [2.5936e-03, 2.5936e-03, 2.5936e-03,  ..., 4.3775e-02,
       3.9950e-03, 1.5896e-04],
      [3.9495e-03, 3.9495e-03, 3.9495e-03,  ..., 2.5570e-03,
       1.5972e-02, 3.9495e-03],
      ...,
      [3.7975e-03, 2.6083e-01, 1.6852e-05,  ..., 3.7975e-03,
       3.7975e-03, 3.7975e-03],
      [2.2249e-01, 1.8607e-03, 4.4735e-02,  ..., 5.1044e-04,
       5.1044e-04, 5.1044e-04],
      [1.0752e-02, 2.2786e-02, 6.6593e-03,  ..., 1.4412e-03,
       1.4412e-03, 1.4412e-03]],

     [[8.8988e-04, 8.8988e-04, 8.8988e-04,  ..., 3.0112e-03,
       3.5392e-02, 1.7732e-03],
      [2.9093e-04, 2.9093e-04, 2.9093e-04,  ..., 1.9949e-02,
       2.2832e-05, 6.4951e-03],
      [6.8946e-03, 6.8946e-03, 6.8946e-03,  ..., 9.0845e-03,
       3.8018e-04, 6.8946e-03],
      ...,
      [1.6353e-04, 1.9036e-01, 6.1649e-03,  ..., 1.6353e-04,
       1.6353e-04, 1.6353e-04],
      [2.0204e-01, 4.2691e-03, 3.5836e-02,  ..., 1.5248e-07,
       1.5248e-07, 1.5248e-07],
      [1.5995e-02, 1.6428e-02, 3.4607e-03,  ..., 3.6894e-03,
       3.6894e-03, 3.6894e-03]],

     ...,

     [[3.3553e-03, 3.3553e-03, 3.3553e-03,  ..., 1.0812e-03,
       1.0074e-02, 2.0836e-03],
      [8.0142e-04, 8.0142e-04, 8.0142e-04,  ..., 1.6897e-02,
       2.5699e-04, 8.4355e-03],
      [6.5826e-04, 6.5826e-04, 6.5826e-04,  ..., 1.4391e-03,
       1.4348e-03, 6.5826e-04],
      ...,
      [1.3704e-03, 2.3630e-01, 8.2432e-04,  ..., 1.3704e-03,
       1.3704e-03, 1.3704e-03],
      [1.4716e-01, 1.7216e-02, 1.5236e-02,  ..., 4.2876e-03,
       4.2876e-03, 4.2876e-03],
      [2.0955e-02, 1.2074e-02, 1.6432e-03,  ..., 6.2459e-03,
       6.2459e-03, 6.2459e-03]],

     [[4.5563e-03, 4.5563e-03, 4.5563e-03,  ..., 1.8026e-03,
       8.2438e-03, 3.0495e-03],
      [4.1114e-03, 4.1114e-03, 4.1114e-03,  ..., 4.9469e-02,
       5.8368e-03, 3.4194e-07],
      [1.6336e-03, 1.6336e-03, 1.6336e-03,  ..., 2.7769e-03,
       5.3442e-04, 1.6336e-03],
      ...,
      [7.2603e-03, 1.3241e-01, 2.2782e-02,  ..., 7.2603e-03,
       7.2603e-03, 7.2603e-03],
      [1.8684e-01, 6.8182e-03, 2.9608e-02,  ..., 2.8369e-04,
       2.8369e-04, 2.8369e-04],
      [1.4835e-03, 4.6711e-02, 2.1544e-02,  ..., 7.4053e-04,
       7.4053e-04, 7.4053e-04]],

     [[4.6240e-04, 4.6240e-04, 4.6240e-04,  ..., 2.1666e-03,
       3.2328e-02, 1.1412e-03],
      [6.0864e-03, 6.0864e-03, 6.0864e-03,  ..., 6.4450e-03,
       4.3213e-03, 2.0037e-02],
      [6.5879e-03, 6.5879e-03, 6.5879e-03,  ..., 4.7455e-03,
       2.0939e-02, 6.5879e-03],
      ...,
      [7.2810e-04, 1.7818e-01, 8.5956e-03,  ..., 7.2810e-04,
       7.2810e-04, 7.2810e-04],
      [1.7303e-01, 9.7725e-03, 2.4270e-02,  ..., 1.0974e-03,
       1.0974e-03, 1.0974e-03],
      [3.5179e-04, 5.5643e-02, 2.7736e-02,  ..., 2.2065e-03,
       2.2065e-03, 2.2065e-03]]]], device='cuda:0', grad_fn=<PowBackward0>)

tensor([[[[3.9900e-04, 3.9900e-04, 3.9900e-04, …, nan,
3.9900e-04, nan],
[2.3724e-03, 2.3724e-03, 2.3724e-03, …, 2.3724e-03,
nan, nan],
[6.4523e-05, 6.4523e-05, 6.4523e-05, …, nan,
nan, 6.4523e-05],
…,
[3.4188e-03, nan, nan, …, 3.4188e-03,
3.4188e-03, 3.4188e-03],
[ nan, nan, nan, …, 1.7611e-04,
1.7611e-04, 1.7611e-04],
[ nan, nan, nan, …, 3.1767e-03,
3.1767e-03, 3.1767e-03]],

     [[8.7806e-07, 8.7806e-07, 8.7806e-07,  ...,        nan,
       8.7806e-07,        nan],
      [4.7228e-05, 4.7228e-05, 4.7228e-05,  ..., 4.7228e-05,
              nan,        nan],
      [2.5639e-03, 2.5639e-03, 2.5639e-03,  ...,        nan,
              nan, 2.5639e-03],
      ...,
      [4.1269e-08,        nan,        nan,  ..., 4.1269e-08,
       4.1269e-08, 4.1269e-08],
      [       nan,        nan,        nan,  ..., 1.3153e-03,
       1.3153e-03, 1.3153e-03],
      [       nan,        nan,        nan,  ..., 4.0486e-04,
       4.0486e-04, 4.0486e-04]],

     [[3.1932e-03, 3.1932e-03, 3.1932e-03,  ...,        nan,
       3.1932e-03,        nan],
      [2.6909e-03, 2.6909e-03, 2.6909e-03,  ..., 2.6909e-03,
              nan,        nan],
      [9.7583e-05, 9.7583e-05, 9.7583e-05,  ...,        nan,
              nan, 9.7583e-05],
      ...,
      [1.7329e-03,        nan,        nan,  ..., 1.7329e-03,
       1.7329e-03, 1.7329e-03],
      [       nan,        nan,        nan,  ..., 2.8461e-03,
       2.8461e-03, 2.8461e-03],
      [       nan,        nan,        nan,  ..., 1.0612e-05,
       1.0612e-05, 1.0612e-05]],

     ...,

     [[2.9565e-03, 2.9565e-03, 2.9565e-03,  ...,        nan,
       2.9565e-03,        nan],
      [3.2911e-03, 3.2911e-03, 3.2911e-03,  ..., 3.2911e-03,
              nan,        nan],
      [9.1167e-06, 9.1167e-06, 9.1167e-06,  ...,        nan,
              nan, 9.1167e-06],
      ...,
      [3.4284e-03,        nan,        nan,  ..., 3.4284e-03,
       3.4284e-03, 3.4284e-03],
      [       nan,        nan,        nan,  ..., 1.7568e-03,
       1.7568e-03, 1.7568e-03],
      [       nan,        nan,        nan,  ..., 1.5258e-06,
       1.5258e-06, 1.5258e-06]],

     [[5.7720e-04, 5.7720e-04, 5.7720e-04,  ...,        nan,
       5.7720e-04,        nan],
      [1.2828e-05, 1.2828e-05, 1.2828e-05,  ..., 1.2828e-05,
              nan,        nan],
      [1.6647e-03, 1.6647e-03, 1.6647e-03,  ...,        nan,
              nan, 1.6647e-03],
      ...,
      [1.0735e-03,        nan,        nan,  ..., 1.0735e-03,
       1.0735e-03, 1.0735e-03],
      [       nan,        nan,        nan,  ..., 1.4540e-04,
       1.4540e-04, 1.4540e-04],
      [       nan,        nan,        nan,  ..., 1.1156e-03,
       1.1156e-03, 1.1156e-03]],

     [[1.3076e-04, 1.3076e-04, 1.3076e-04,  ...,        nan,
       1.3076e-04,        nan],
      [2.4632e-03, 2.4632e-03, 2.4632e-03,  ..., 2.4632e-03,
              nan,        nan],
      [3.0567e-03, 3.0567e-03, 3.0567e-03,  ...,        nan,
              nan, 3.0567e-03],
      ...,
      [3.0679e-03,        nan,        nan,  ..., 3.0679e-03,
       3.0679e-03, 3.0679e-03],
      [       nan,        nan,        nan,  ..., 2.7698e-03,
       2.7698e-03, 2.7698e-03],
      [       nan,        nan,        nan,  ..., 3.1502e-03,
       3.1502e-03, 3.1502e-03]]],


    [[[3.9900e-04, 3.9900e-04, 3.9900e-04,  ...,        nan,
       3.9900e-04,        nan],
      [2.3724e-03, 2.3724e-03, 2.3724e-03,  ..., 2.3724e-03,
              nan,        nan],
      [6.4523e-05, 6.4523e-05, 6.4523e-05,  ...,        nan,
              nan, 6.4523e-05],
      ...,
      [3.4188e-03,        nan,        nan,  ..., 3.4188e-03,
       3.4188e-03, 3.4188e-03],
      [       nan,        nan,        nan,  ..., 1.7611e-04,
       1.7611e-04, 1.7611e-04],
      [       nan,        nan,        nan,  ..., 3.1767e-03,
       3.1767e-03, 3.1767e-03]],

     [[8.7806e-07, 8.7806e-07, 8.7806e-07,  ...,        nan,
       8.7806e-07,        nan],
      [4.7228e-05, 4.7228e-05, 4.7228e-05,  ..., 4.7228e-05,
              nan,        nan],
      [2.5639e-03, 2.5639e-03, 2.5639e-03,  ...,        nan,
              nan, 2.5639e-03],
      ...,
      [4.1269e-08,        nan,        nan,  ..., 4.1269e-08,
       4.1269e-08, 4.1269e-08],
      [       nan,        nan,        nan,  ..., 1.3153e-03,
       1.3153e-03, 1.3153e-03],
      [       nan,        nan,        nan,  ..., 4.0486e-04,
       4.0486e-04, 4.0486e-04]],

     [[3.1932e-03, 3.1932e-03, 3.1932e-03,  ...,        nan,
       3.1932e-03,        nan],
      [2.6909e-03, 2.6909e-03, 2.6909e-03,  ..., 2.6909e-03,
              nan,        nan],
      [9.7583e-05, 9.7583e-05, 9.7583e-05,  ...,        nan,
              nan, 9.7583e-05],
      ...,
      [1.7329e-03,        nan,        nan,  ..., 1.7329e-03,
       1.7329e-03, 1.7329e-03],
      [       nan,        nan,        nan,  ..., 2.8461e-03,
       2.8461e-03, 2.8461e-03],
      [       nan,        nan,        nan,  ..., 1.0612e-05,
       1.0612e-05, 1.0612e-05]],

     ...,

     [[2.9565e-03, 2.9565e-03, 2.9565e-03,  ...,        nan,
       2.9565e-03,        nan],
      [3.2911e-03, 3.2911e-03, 3.2911e-03,  ..., 3.2911e-03,
              nan,        nan],
      [9.1167e-06, 9.1167e-06, 9.1167e-06,  ...,        nan,
              nan, 9.1167e-06],
      ...,
      [3.4284e-03,        nan,        nan,  ..., 3.4284e-03,
       3.4284e-03, 3.4284e-03],
      [       nan,        nan,        nan,  ..., 1.7568e-03,
       1.7568e-03, 1.7568e-03],
      [       nan,        nan,        nan,  ..., 1.5258e-06,
       1.5258e-06, 1.5258e-06]],

     [[5.7720e-04, 5.7720e-04, 5.7720e-04,  ...,        nan,
       5.7720e-04,        nan],
      [1.2828e-05, 1.2828e-05, 1.2828e-05,  ..., 1.2828e-05,
              nan,        nan],
      [1.6647e-03, 1.6647e-03, 1.6647e-03,  ...,        nan,
              nan, 1.6647e-03],
      ...,
      [1.0735e-03,        nan,        nan,  ..., 1.0735e-03,
       1.0735e-03, 1.0735e-03],
      [       nan,        nan,        nan,  ..., 1.4540e-04,
       1.4540e-04, 1.4540e-04],
      [       nan,        nan,        nan,  ..., 1.1156e-03,
       1.1156e-03, 1.1156e-03]],

     [[1.3076e-04, 1.3076e-04, 1.3076e-04,  ...,        nan,
       1.3076e-04,        nan],
      [2.4632e-03, 2.4632e-03, 2.4632e-03,  ..., 2.4632e-03,
              nan,        nan],
      [3.0567e-03, 3.0567e-03, 3.0567e-03,  ...,        nan,
              nan, 3.0567e-03],
      ...,
      [3.0679e-03,        nan,        nan,  ..., 3.0679e-03,
       3.0679e-03, 3.0679e-03],
      [       nan,        nan,        nan,  ..., 2.7698e-03,
       2.7698e-03, 2.7698e-03],
      [       nan,        nan,        nan,  ..., 3.1502e-03,
       3.1502e-03, 3.1502e-03]]],


    [[[3.9900e-04, 3.9900e-04, 3.9900e-04,  ...,        nan,
              nan,        nan],
      [2.3724e-03, 2.3724e-03, 2.3724e-03,  ...,        nan,
              nan,        nan],
      [6.4523e-05, 6.4523e-05, 6.4523e-05,  ...,        nan,
              nan, 6.4523e-05],
      ...,
      [3.4188e-03,        nan,        nan,  ..., 3.4188e-03,
       3.4188e-03, 3.4188e-03],
      [       nan,        nan,        nan,  ..., 1.7611e-04,
       1.7611e-04, 1.7611e-04],
      [       nan,        nan,        nan,  ..., 3.1767e-03,
       3.1767e-03, 3.1767e-03]],

     [[8.7806e-07, 8.7806e-07, 8.7806e-07,  ...,        nan,
              nan,        nan],
      [4.7228e-05, 4.7228e-05, 4.7228e-05,  ...,        nan,
              nan,        nan],
      [2.5639e-03, 2.5639e-03, 2.5639e-03,  ...,        nan,
              nan, 2.5639e-03],
      ...,
      [4.1269e-08,        nan,        nan,  ..., 4.1269e-08,
       4.1269e-08, 4.1269e-08],
      [       nan,        nan,        nan,  ..., 1.3153e-03,
       1.3153e-03, 1.3153e-03],
      [       nan,        nan,        nan,  ..., 4.0486e-04,
       4.0486e-04, 4.0486e-04]],

     [[3.1932e-03, 3.1932e-03, 3.1932e-03,  ...,        nan,
              nan,        nan],
      [2.6909e-03, 2.6909e-03, 2.6909e-03,  ...,        nan,
              nan,        nan],
      [9.7583e-05, 9.7583e-05, 9.7583e-05,  ...,        nan,
              nan, 9.7583e-05],
      ...,
      [1.7329e-03,        nan,        nan,  ..., 1.7329e-03,
       1.7329e-03, 1.7329e-03],
      [       nan,        nan,        nan,  ..., 2.8461e-03,
       2.8461e-03, 2.8461e-03],
      [       nan,        nan,        nan,  ..., 1.0612e-05,
       1.0612e-05, 1.0612e-05]],

     ...,

     [[2.9565e-03, 2.9565e-03, 2.9565e-03,  ...,        nan,
              nan,        nan],
      [3.2911e-03, 3.2911e-03, 3.2911e-03,  ...,        nan,
              nan,        nan],
      [9.1167e-06, 9.1167e-06, 9.1167e-06,  ...,        nan,
              nan, 9.1167e-06],
      ...,
      [3.4284e-03,        nan,        nan,  ..., 3.4284e-03,
       3.4284e-03, 3.4284e-03],
      [       nan,        nan,        nan,  ..., 1.7568e-03,
       1.7568e-03, 1.7568e-03],
      [       nan,        nan,        nan,  ..., 1.5258e-06,
       1.5258e-06, 1.5258e-06]],

     [[5.7720e-04, 5.7720e-04, 5.7720e-04,  ...,        nan,
              nan,        nan],
      [1.2828e-05, 1.2828e-05, 1.2828e-05,  ...,        nan,
              nan,        nan],
      [1.6647e-03, 1.6647e-03, 1.6647e-03,  ...,        nan,
              nan, 1.6647e-03],
      ...,
      [1.0735e-03,        nan,        nan,  ..., 1.0735e-03,
       1.0735e-03, 1.0735e-03],
      [       nan,        nan,        nan,  ..., 1.4540e-04,
       1.4540e-04, 1.4540e-04],
      [       nan,        nan,        nan,  ..., 1.1156e-03,
       1.1156e-03, 1.1156e-03]],

     [[1.3076e-04, 1.3076e-04, 1.3076e-04,  ...,        nan,
              nan,        nan],
      [2.4632e-03, 2.4632e-03, 2.4632e-03,  ...,        nan,
              nan,        nan],
      [3.0567e-03, 3.0567e-03, 3.0567e-03,  ...,        nan,
              nan, 3.0567e-03],
      ...,
      [3.0679e-03,        nan,        nan,  ..., 3.0679e-03,
       3.0679e-03, 3.0679e-03],
      [       nan,        nan,        nan,  ..., 2.7698e-03,
       2.7698e-03, 2.7698e-03],
      [       nan,        nan,        nan,  ..., 3.1502e-03,
       3.1502e-03, 3.1502e-03]]],


    [[[3.9900e-04, 3.9900e-04, 3.9900e-04,  ...,        nan,
              nan,        nan],
      [2.3724e-03, 2.3724e-03, 2.3724e-03,  ...,        nan,
              nan,        nan],
      [6.4523e-05, 6.4523e-05, 6.4523e-05,  ...,        nan,
              nan, 6.4523e-05],
      ...,
      [3.4188e-03,        nan, 3.4188e-03,  ..., 3.4188e-03,
       3.4188e-03, 3.4188e-03],
      [       nan, 1.7611e-04,        nan,  ..., 1.7611e-04,
       1.7611e-04, 1.7611e-04],
      [3.1767e-03,        nan,        nan,  ..., 3.1767e-03,
       3.1767e-03, 3.1767e-03]],

     [[8.7806e-07, 8.7806e-07, 8.7806e-07,  ...,        nan,
              nan,        nan],
      [4.7228e-05, 4.7228e-05, 4.7228e-05,  ...,        nan,
              nan,        nan],
      [2.5639e-03, 2.5639e-03, 2.5639e-03,  ...,        nan,
              nan, 2.5639e-03],
      ...,
      [4.1269e-08,        nan, 4.1269e-08,  ..., 4.1269e-08,
       4.1269e-08, 4.1269e-08],
      [       nan, 1.3153e-03,        nan,  ..., 1.3153e-03,
       1.3153e-03, 1.3153e-03],
      [4.0486e-04,        nan,        nan,  ..., 4.0486e-04,
       4.0486e-04, 4.0486e-04]],

     [[3.1932e-03, 3.1932e-03, 3.1932e-03,  ...,        nan,
              nan,        nan],
      [2.6909e-03, 2.6909e-03, 2.6909e-03,  ...,        nan,
              nan,        nan],
      [9.7583e-05, 9.7583e-05, 9.7583e-05,  ...,        nan,
              nan, 9.7583e-05],
      ...,
      [1.7329e-03,        nan, 1.7329e-03,  ..., 1.7329e-03,
       1.7329e-03, 1.7329e-03],
      [       nan, 2.8461e-03,        nan,  ..., 2.8461e-03,
       2.8461e-03, 2.8461e-03],
      [1.0612e-05,        nan,        nan,  ..., 1.0612e-05,
       1.0612e-05, 1.0612e-05]],

     ...,

     [[2.9565e-03, 2.9565e-03, 2.9565e-03,  ...,        nan,
              nan,        nan],
      [3.2911e-03, 3.2911e-03, 3.2911e-03,  ...,        nan,
              nan,        nan],
      [9.1167e-06, 9.1167e-06, 9.1167e-06,  ...,        nan,
              nan, 9.1167e-06],
      ...,
      [3.4284e-03,        nan, 3.4284e-03,  ..., 3.4284e-03,
       3.4284e-03, 3.4284e-03],
      [       nan, 1.7568e-03,        nan,  ..., 1.7568e-03,
       1.7568e-03, 1.7568e-03],
      [1.5258e-06,        nan,        nan,  ..., 1.5258e-06,
       1.5258e-06, 1.5258e-06]],

     [[5.7720e-04, 5.7720e-04, 5.7720e-04,  ...,        nan,
              nan,        nan],
      [1.2828e-05, 1.2828e-05, 1.2828e-05,  ...,        nan,
              nan,        nan],
      [1.6647e-03, 1.6647e-03, 1.6647e-03,  ...,        nan,
              nan, 1.6647e-03],
      ...,
      [1.0735e-03,        nan, 1.0735e-03,  ..., 1.0735e-03,
       1.0735e-03, 1.0735e-03],
      [       nan, 1.4540e-04,        nan,  ..., 1.4540e-04,
       1.4540e-04, 1.4540e-04],
      [1.1156e-03,        nan,        nan,  ..., 1.1156e-03,
       1.1156e-03, 1.1156e-03]],

     [[1.3076e-04, 1.3076e-04, 1.3076e-04,  ...,        nan,
              nan,        nan],
      [2.4632e-03, 2.4632e-03, 2.4632e-03,  ...,        nan,
              nan,        nan],
      [3.0567e-03, 3.0567e-03, 3.0567e-03,  ...,        nan,
              nan, 3.0567e-03],
      ...,
      [3.0679e-03,        nan, 3.0679e-03,  ..., 3.0679e-03,
       3.0679e-03, 3.0679e-03],
      [       nan, 2.7698e-03,        nan,  ..., 2.7698e-03,
       2.7698e-03, 2.7698e-03],
      [3.1502e-03,        nan,        nan,  ..., 3.1502e-03,
       3.1502e-03, 3.1502e-03]]]]

How can I avoid ‘nan’ values?

You’ll need to check all input Tensors, so check how input_unfold and weight_flat and perhaps even self.gamma are defined/computed. I’d assume it’s coming from input_unfold but check them all.

You can do this easily by usingtorch.autograd.detect_anomaly as a context manager (see more Automatic differentiation package - torch.autograd — PyTorch 2.1 documentation)

input_unfold and weight_flat are as given below.

input_unfold =  tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.9194, 0.0503, 0.9589],
          [0.0000, 0.0000, 0.0000,  ..., 0.0503, 0.9589, 0.6398],
          [0.0000, 0.0000, 0.0000,  ..., 0.9589, 0.6398, 0.0000],
          ...,
          [0.0000, 0.7830, 0.6963,  ..., 0.0000, 0.0000, 0.0000],
          [0.7830, 0.6963, 0.6662,  ..., 0.0000, 0.0000, 0.0000],
          [0.6963, 0.6662, 0.8015,  ..., 0.0000, 0.0000, 0.0000]]],


        [[[0.0000, 0.0000, 0.0000,  ..., 0.7274, 0.0721, 0.7137],
          [0.0000, 0.0000, 0.0000,  ..., 0.0721, 0.7137, 0.8258],
          [0.0000, 0.0000, 0.0000,  ..., 0.7137, 0.8258, 0.0000],
          ...,
          [0.0000, 0.1822, 0.6842,  ..., 0.0000, 0.0000, 0.0000],
          [0.1822, 0.6842, 0.2491,  ..., 0.0000, 0.0000, 0.0000],
          [0.6842, 0.2491, 0.4927,  ..., 0.0000, 0.0000, 0.0000]]]],
       device='cuda:0')
weight_flat =  tensor([[[ 0.0389],
         [-0.0168],
         [ 0.0396],
         ...,
         [ 0.0852],
         [ 0.0824],
         [ 0.0089]],

        [[-0.0027],
         [ 0.0509],
         [-0.0628],
         ...,
         [-0.0616],
         [-0.0226],
         [ 0.0380]],

        [[ 0.0298],
         [-0.0171],
         [ 0.0830],
         ...,
         [ 0.0128],
         [-0.0004],
         [ 0.0607]],

        ...,

        [[-0.0579],
         [-0.0283],
         [ 0.0257],
         ...,
         [-0.0370],
         [ 0.0655],
         [ 0.0790]],

        [[-0.0675],
         [ 0.0641],
         [ 0.0404],
         ...,
         [ 0.0852],
         [ 0.0168],
         [-0.0272]],

        [[ 0.0215],
         [-0.0780],
         [-0.0812],
         ...,
         [ 0.0270],
         [ 0.0331],
         [-0.0470]]], device='cuda:0', grad_fn=<ViewBackward0>)

input_unfold.shape =  torch.Size([2, 1, 135, 225])
weight_flat.shape =  torch.Size([32, 135, 1])

You’re taking the inverse of a negative number here, that’s why you’re getting NaNs

Can you please suggest the way by which we can avoid negative numbers in the weights of a model?

output = torch.exp(self.gamma*(((input_unfold - weight_flat)**2).abs().sum(dim=2)).sqrt().mul(-1))