Custom Backward function using Function from torch.autograd fails on cuda but works on cpu

Hi every one, I am currently implementing a custom activation function using @staticmethod. I have implemented a forward and a backward pass. So when I use device as cpu the code accesses my custom forward and backward pass but when i use device as cuda. It doesn’t use my custom backward pass only uses my custom forward. I have inserted my model to(device) before. I am sure this question might be asked before. Please refer me to that particular page and also any help will be appreciated.


That sounds weird, could you share a code samples that reproduces this behavior?

Sure. I will use the custom leaky relu as example for ease. Also i am using conda. pytorch version 1.7.1

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.autograd import Function # import Function to create custom activations

import numpy.matlib

from torch.autograd import gradcheck

import random
from torch.autograd import Variable


# Device configuration
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device("cuda:0") 
#device = 'cpu'
#torch.backends.cuda.matmul.allow_tf32 = False 
# Hyper-parameters 
num_epochs = 5
batch_size = 32
learning_rate = 0.001
L_relu_lambda = 0.02

transform = transforms.Compose(
     transforms.Normalize((0.5), (0.5))])

train_dataset = torchvision.datasets.FashionMNIST(root='./data', train= True,
                                        download=True, transform=transform)

test_dataset = torchvision.datasets.FashionMNIST(root='./data', train= False,
                                       download=True, transform=transform)

train_loader =, batch_size=64,
                                          shuffle=True, worker_init_fn=np.random.seed(12))

test_loader =, batch_size=64,
                                         shuffle=False, worker_init_fn=np.random.seed(12))

def train_model(model,trainloader):
    Function trains the model and prints out the training log.
    #setup training
    #define loss function
    criterion = nn.NLLLoss()
    #define learning rate
    learning_rate = 0.003
    #define number of epochs
    epochs = 5
    #initialize optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    #run training and print out the loss to make sure that we are actually fitting to the training set
    print('Training the model. Make sure that loss decreases after each epoch.\n')
    for e in range(epochs):
        running_loss = 0
        for images, labels in trainloader:
            images =
            labels =
            images = images.view(images.shape[0], -1)
            log_ps = model(images)
            loss = criterion(log_ps, labels)


            running_loss += loss.item()
            # print out the loss to make sure it is decreasing
            print(f"Training loss: {running_loss}")

class Lrelu_custom(Function):
    #both forward and backward are @staticmethods
    def forward(ctx, input,alpha):
        ctx.save_for_backward(input) # save input for backward pass
        output= input.clone()
        output= Variable(input.clamp(min=0)+input.clamp(max=0)*alpha, requires_grad = True)
        ctx.slope = alpha
        return output

    def backward(ctx, grad_output):
        grad_input = None # set output to None

        input, = ctx.saved_tensors # restore input from context
        slope = ctx.slope
        grad_input = grad_output.clone()
        grad_input = grad_input * (input > 0).float() + grad_input * (input < 0).float() * slope
        return grad_input, None
class Classifierleakyrelu(nn.Module):
    Simple fully-connected classifier model to demonstrate leakyrelu activation.
    def __init__(self):
        super(Classifierleakyrelu, self).__init__()

        # initialize layers
        self.fc1 = nn.Linear(784, 256)
        # create shortcuts for leaky relu
        self.a1 = Lrelu_custom.apply
    def forward(self, x):
        # make sure the input tensor is flattened
        x = x.view(x.shape[0], -1)
        x = self.fc1(x)
        # apply leaky relu
        x = self.a1(x,0.001)
        x = F.log_softmax(x, dim=1)
        return x
model = Classifierleakyrelu().to(device)

cpu results and gpu results varied on my custom i just tested on this simple leaky relu and check if the code ever accessed the custom backward during back pass, used breakpoints in spyder for it but it never stopped at the break points. Thanks again


Running this code in a brand new colab notebook seems to work as expected: the backward is called and the loss is decreasing on a cuda device.

I tried on spyder with break points and did verify that it doesnt go in backward function and only forward function, I added the video on the drive to show the same. link →
in the link you can see that using breakpoints that it doesn’t go inside the backward function
I have also attached a link of video where device-cpu it goes inside both forward and backward .

It looks like it is using pytorch gradient calculations for backward pass rather than the custom backward introduced in the leaky relu function. Is there an option to check if it is using python gradient calculation or the one that i have in . Thanks again

Hi, Mr.alban.
I want to ask how to print the learnable parameters of custom activation function?
When I try to save the state_dict of the learned model and try to print the parameter by name, parameters in model.state_dict().items(), both of them doesn’t include the parameters of custom activation function. My custom activation function as following:(threshold is learnable parameters)

class Surrogate_BP_Function(torch.autograd.Function):
    def forward(ctx, input):
        out = torch.zeros_like(input).cuda()
        out[input > 0] = 1.0
        return out

    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad = grad_input * 0.3 * F.threshold(1.0 - torch.abs(input), 0, 0)
        return grad
class SNN_VGG9_BNTT(nn.Module):
    def __init__(self, num_steps, leak_mem=0.95, img_size=32,  num_cls=10):
        super(SNN_VGG9_BNTT, self).__init__()

        self.img_size = img_size
        self.num_cls = num_cls
        self.num_steps = num_steps
        self.spike_fn = Surrogate_BP_Function.apply
        self.leak_mem = leak_mem
        self.batch_num = self.num_steps

        print (">>>>>>>>>>>>>>>>>>> VGG 9 >>>>>>>>>>>>>>>>>>>>>>")
        print ("***** time step per batchnorm".format(self.batch_num))
        print (">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")

        affine_flag = True
        bias_flag = False

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=bias_flag)
        self.bntt1 = nn.ModuleList([nn.BatchNorm2d(64, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)])
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=bias_flag)
        self.bntt2 = nn.ModuleList([nn.BatchNorm2d(64, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)])
        self.pool1 = nn.AvgPool2d(kernel_size=2)

        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=bias_flag)
        self.bntt3 = nn.ModuleList([nn.BatchNorm2d(128, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)])
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=bias_flag)
        self.bntt4 = nn.ModuleList([nn.BatchNorm2d(128, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)])
        self.pool2 = nn.AvgPool2d(kernel_size=2)

        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=bias_flag)
        self.bntt5 = nn.ModuleList([nn.BatchNorm2d(256, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)])
        self.conv6 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=bias_flag)
        self.bntt6 = nn.ModuleList([nn.BatchNorm2d(256, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)])
        self.conv7 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=bias_flag)
        self.bntt7 = nn.ModuleList([nn.BatchNorm2d(256, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)])
        self.pool3 = nn.AvgPool2d(kernel_size=2)

        self.fc1 = nn.Linear((self.img_size//8)*(self.img_size//8)*256, 1024, bias=bias_flag)
        self.bntt_fc = nn.ModuleList([nn.BatchNorm1d(1024, eps=1e-4, momentum=0.1, affine=affine_flag) for i in range(self.batch_num)])
        self.fc2 = nn.Linear(1024, self.num_cls, bias=bias_flag)

        self.conv_list = [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7]
        self.bntt_list = [self.bntt1, self.bntt2, self.bntt3, self.bntt4, self.bntt5, self.bntt6, self.bntt7, self.bntt_fc]
        self.pool_list = [False, self.pool1, False, self.pool2, False, False, self.pool3]

        # Turn off bias of BNTT
        for bn_list in self.bntt_list:
            for bn_temp in bn_list:
                bn_temp.bias = None

        # Initialize the firing thresholds of all the layers
        for m in self.modules():
            if (isinstance(m, nn.Conv2d)):
                m.threshold = 1.0
                torch.nn.init.xavier_uniform_(m.weight, gain=2)
            elif (isinstance(m, nn.Linear)):
                m.threshold = 1.0
                torch.nn.init.xavier_uniform_(m.weight, gain=2)

    def forward(self, inp):
        batch_size = inp.size(0)
        mem_conv1 = torch.zeros(batch_size, 64, self.img_size, self.img_size).cuda()
        mem_conv2 = torch.zeros(batch_size, 64, self.img_size, self.img_size).cuda()
        mem_conv3 = torch.zeros(batch_size, 128, self.img_size//2, self.img_size//2).cuda()
        mem_conv4 = torch.zeros(batch_size, 128, self.img_size//2, self.img_size//2).cuda()
        mem_conv5 = torch.zeros(batch_size, 256, self.img_size//4, self.img_size//4).cuda()
        mem_conv6 = torch.zeros(batch_size, 256, self.img_size//4, self.img_size//4).cuda()
        mem_conv7 = torch.zeros(batch_size, 256, self.img_size//4, self.img_size//4).cuda()
        mem_conv_list = [mem_conv1, mem_conv2, mem_conv3, mem_conv4, mem_conv5, mem_conv6, mem_conv7]

        mem_fc1 = torch.zeros(batch_size, 1024).cuda()
        mem_fc2 = torch.zeros(batch_size, self.num_cls).cuda()
        for t in range(self.num_steps):

            for i in range(len(self.conv_list)):
                mem_conv_list[i] = self.leak_mem * mem_conv_list[i] + self.bntt_list[i][t](self.conv_list[i](inp))
                mem_thr = (mem_conv_list[i] / self.conv_list[i].threshold) - 1.0
                out = self.spike_fn(mem_thr)
                rst = torch.zeros_like(mem_conv_list[i]).cuda()
                rst[mem_thr > 0] = self.conv_list[i].threshold
                mem_conv_list[i] = mem_conv_list[i] - rst
                out_prev = out.clone()
                if self.pool_list[i] is not False:
                    out = self.pool_list[i](out_prev)
                    out_prev = out.clone()
            out_prev = out_prev.reshape(batch_size, -1)
            mem_fc1 = self.leak_mem * mem_fc1 + self.bntt_fc[t](self.fc1(out_prev))
            mem_thr = (mem_fc1 / self.fc1.threshold) - 1.0
            out = self.spike_fn(mem_thr)
            rst = torch.zeros_like(mem_fc1).cuda()
            rst[mem_thr > 0] = self.fc1.threshold
            mem_fc1 = mem_fc1 - rst
            out_prev = out.clone()
            mem_fc2 = mem_fc2 + self.fc2(out_prev)
        out_voltage = mem_fc2 / self.num_steps
        return out_voltage

Thank you so much


I am not sure which activation you mean here.
But for your parameter to be seen there, you want to make sure that it is in a nn.Module and is a nn.Parameter.

1 Like

Thank you for your kind reply.
The activation is self.spike_fn = Surrogate_BP_Function.apply.

        for m in self.modules():
            if (isinstance(m, nn.Conv2d)):
                m.threshold = 1.0

In my opinion, the threshold is in nn.Module. I try any method to get the threshold, but I failed.

Best regards

It is in the nn.Module, but it is not a Parameter. You want to do m.threshold = nn.Parameter(torch.tensor(1.0))