Gradient computation in custom backward

I’m witnessing some unusual output for below code, can you have a look please!

class Custom_Convolution(torch.autograd.Function):    
    
    @staticmethod
    def forward(ctx, input, weight, bias, stride, padding):  #input(from previous layer)'s shape = ([batch_size=100, 96, 8, 8])
        output = torch.nn.functional.conv2d(input, weight, bias, stride, padding)  
        ctx.save_for_backward(input, weight, bias, output)
        return output    #output's shape = ([[batch_size= 100,128, 4, 4])

    @staticmethod
    def backward(ctx, grad_output):    # grad_output size = ([batch_size, 128,4,4])
        
        input, weight, bias, output = ctx.saved_tensors    #input size = ([batch_size, 96,8,8])  
        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = torch.nn.grad.conv2d_input(input.shape, weight, grad_output) #shape = ([batch_size,96,8,8])
              
        if ctx.needs_input_grad[1]:
            grad_weight = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output)  #shape = ([128,96,5,5])
                        
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum((0,2,3))        #shape = ([128])          
           
        with torch.enable_grad():
               feat = output.clone()   # output from forward with size = ([batch_size, 128,4,4])

               feat = feat.view(feat.shape[0], feat.shape[1], -1) # features size = ([batch_size, 128,16])
       
               cont = torch.tensor([0.]).to(dev)
               for i in range(0, feat.shape[0]):
                   for f in range(len(feat[i])):
                       Zi_unnormalized = feat[i][f]
                       Zi = torch.nn.functional.normalize(Zi_unnormalized, dim = 0)
                       # Zj and Zk are tensors made from feat[i][*] and feat[other than i][*]. Zj and Zk varies for each Zi (or f)

                       Zi_Zk = torch.Tensor([0]).to(dev)
                       for k in Zk:
                           k= torch.nn.functional.normalize(k, dim = 0)
                           zi_zk = ...
                           Zi_Zk = Zi_Zk.add(zi_zk)

                       # Similarly computing Zi_Zj
                       # Li = some algebra of Zi_Zj and Zi_Zk
                       # number of 'Li' values =  feat.shape[0] * feat.shape[1]
                       cont = cont.add(Li)   # 1 value
               print("\n Loss: ", cont_loss, cont_loss.requires_grad)

### This line of printing loss keeps on repeating with the same value of Loss!!!! 


        cont_loss_weight = torch.autograd.grad(outputs= cont_loss, inputs= weight, retain_graph=(True))
        print ("Shape:", cont_loss_weight .shape)                                   
        grad_weight += cont_loss_weight

        cont_loss_bias = torch.autograd.grad(outputs= cont_loss, inputs= bias, retain_graph=(True))
        grad_bias += cont_loss_bias

        if bias is not None:
            return grad_input, grad_weight, grad_bias, None, None
        else:
            return grad_input, grad_weight, None, None

Output :

Loss:  tensor([37.218], device='cuda:0', grad_fn=<AddBackward0>) True
Loss:  tensor([37.218], device='cuda:0', grad_fn=<AddBackward0>) True
Loss:  tensor([37.218], device='cuda:0', grad_fn=<AddBackward0>) True
Loss:  tensor([37.218], device='cuda:0', grad_fn=<AddBackward0>) True
.
.
.
RuntimeError: CUDA out of memory.

I don’t know why this line is repeating so many time and has no end. And finally CUDA goes out of memory. It should print once only for 1 batch, moreover I haven’t made any indentation mistake!

This line
cont_loss_weight = torch.autograd.grad(outputs= cont_loss, inputs= weight, retain_graph=(True))
is getting executed but it neither show any error nor it returns something because its following line: print ("Shape:", cont_loss_weight .shape) doesn’t get printed.