Backward not called

Hi all,

I’ve been trying to write my own function because I need to make some operations that are not differentiated right now using autograd.
Then I wrap this function in a module.
The forward pass works fine but it looks like the backward pass is never called as I can never enter in it.
Here a quick explanation of my code :

class ICA_3D_MM(Function):

    def __init__(self):
        super(ICA_3D_MM, self).__init__()

    @staticmethod

    def forward(ctx, mu, sigma, landa, eps, dimX, dimY, dimZ):
         #some ops  
         return A, KL
    def backward(grad_1, grad_2):
        #some ops
        return grad_mu, grad_sigma, grad_landa, grad_eps, grad_dimX, grad_dimY, grad_dim_Z

Then I wrap my code in a Module

 class Ica_3D_layer(nn.Module):

       def __init__(self, n_comp, n_voxels, dimX, dimY, dimZ):

            super(Ica_3D_layer, self).__init__()

            self.mu = nn.Parameter(torch.zeros(n_comp, n_voxels), requires_grad=True)
            self.sigma = nn.Parameter(torch.zeros(n_comp, 1), requires_grad=True)
            self.landa = nn.Parameter(torch.zeros(n_comp, 1), requires_grad=True)

            self.A = None
            self.KL_variable = None

            self.dimX = dimX
            self.dimY = dimY
            self.dimZ = dimZ

       def forward(self, x):

           eps = torch.randn(self.dimX*self.dimY*self.dimZ, 1)

           self.A, self.KL_variable = ICA_3D_MM.apply(self.mu.data, self.sigma.data, self.landa.data, eps, self.dimX, self.dimY, self.dimZ)

    return torch.mm(x, self.A)

The forward pass works fine but when I do loss.backward() the gradients for the paramters of my Ica_3D_Layer remains None. Anyone would have an idea that could explain this behaviour.

Your Function class is not correct, I think it should be:

class ICA_3D_MM(Function):

    @staticmethod
    def forward(ctx, mu, sigma, landa, eps, dimX, dimY, dimZ):
         #some ops  
         return A, KL

    @staticmethod
    def backward(ctx, grad_1, grad_2):
        #some ops
        return grad_mu, grad_sigma, grad_landa, grad_eps, grad_dimX, grad_dimY, grad_dim_Z

I tried your modification, but I still get the same problem with my gradients being None. It’s like the backward pass is never called.

eps = torch.randn(self.dimX*self.dimY*self.dimZ, 1)
self.A, self.KL_variable = ICA_3D_MM.apply(self.mu.data, self.sigma.data, self.landa.data, eps, self.dimX, self.dimY, self.dimZ)

you should pass Variables to Function. This line should be:

eps = Variable(torch.randn(self.dimX*self.dimY*self.dimZ, 1))
self.A, self.KL_variable = ICA_3D_MM.apply(self.mu, self.sigma, self.landa, eps, self.dimX, self.dimY, self.dimZ)

Problem is when I do as you tell me I get the following error :

 lf.A, self.KL_variable = ICA_3D_MM.apply(self.mu, self.sigma, self.landa, eps, self.dimX, self.dimY, self.dimZ)
RuntimeError: save_for_backward can only save input or output tensors, but argument 3 doesn't satisfy this condition 

The only trick I found was to put .data to avoid this.

I guess you use save_for_backward in your code, this should only be called on input or output of the function.
Also I assume that dimX, dimY and dimZ are scalars right?

Yes I use save for backward

    @staticmethod
     def forward(ctx, mu, sigma, landa, eps, dimX, dimY, dimZ):

    # import pdb
    # pdb.set_trace()

    ctx.save_for_backward(mu, sigma, landa, eps, dimX, dimY, dimZ)

dimX, dimY and dimZ are integers. But as you can see mu, sigma, landa, and eps are input of my Function. So I should be allowed to save them for backward

You cannot use it on dim*.
If you need to save non-Tensor/Variable elements, you can just use the context:
ctx.dimX = dimX.

2 Likes

Ok I understand. Using the ctx and saving only Variables solved the problem. It now enters the backward method.
Thank you very much for your help.