Will the backward work for multiple neural network calls in the same loss?

Hi. I want to implement a runge kutta discretization for my loss. In order to do so I have to call four times the same layers in the same forward inside a nn.module. My question is if the following code works when it comes to the backward of the loss function. If I just output the output of the function everything seems alright. However, when I add the rk4_step results seem incorrect.

def forward(self, x):
    with torch.set_grad_enabled(True):
        qqd = x.requires_grad_(True)
        time_step = 0.01
        out=self._rk4_step(qqd,time_step)
        return out
    
def function(self,qqd):
    self.n = n = qqd.shape[1]//2
    L = self._lagrangian(qqd).sum()
    J = grad(L, qqd, create_graph=True)[0] ;
    DL_q, DL_qd = J[:,:n], J[:,n:]
    DDL_qd = []
    for i in range(n):
        J_qd_i = DL_qd[:,i][:,None]
        H_i = grad(J_qd_i.sum(), qqd, create_graph=True)[0][:,:,None]
        DDL_qd.append(H_i)
    DDL_qd = torch.cat(DDL_qd, 2)
    DDL_qqd, DDL_qdqd = DDL_qd[:,:n,:], DDL_qd[:,n:,:]
    T = torch.einsum('ijk, ij -> ik', DDL_qqd, qqd[:,n:])
    qdd = torch.einsum('ijk, ij -> ik', DDL_qdqd.pinverse(), DL_q - T)
                
    return torch.stack((qqd[:,2],qqd[:,3], qdd[:,0], qdd[:,1]),1)
def _lagrangian(self, qqd):
    x = F.softplus(self.fc1(qqd))
    x = F.softplus(self.fc2(x))
    x = F.softplus(self.fc3(x))
    L = self.fc_last(x)
    return L
def _rk4_step(self, qqd, h=None):
    # one step of Runge-Kutta integration
    k1 = h * self.function(qqd)
    k2 = h * self.function(qqd + k1/2)
    k3 = h * self.function(qqd + k2/2)
    k4 = h *self.function(qqd + k3)
    return qqd + 1/6 * (k1 + 2 * k2 + 2 * k3 + k4)

Hi,

If you use the same layer that has some parameters multiple times. The gradient for these weights will be the sum of the gradients for each use of it (as you would expect mathematically).

Your function look ok from the outside. What “seems incorrect”? Do you have a sample input and expected output that shows the issue?

Hi,

Thank you for your fast answer. I could send snippets of the code where I’m creating the dataset. However that’s too large. Butfor instance I get the same test loss independetly of the number of training points. Using 100 or 1000 gives almost exactly same test loss and it converges always after the first epoch. I also realized printing the gradients that the gradient of the last layer fc_last.bias is always None, which I think it is not correct.

You might want to check that you get proper gradients for all the parameters indeed (for a single step).
And make sure your model is not getting stuck in a bad local minima as well.

What could be the reason to get gradient None just the biases of the last layer? I’ve also checked that butI can’t find an explanation.

None in the autograd can mean that it is independent or the gradient is just all 0s.
So in this case, it is either because this bias is not used or because it is completely ignored.

Also to avoid all issues, make sure that all your computations only deal with Tensors and use pytorch’s ops. Objects that are not Tensors cannot have gradients computed for them or flowing through them.

Just to make sure is this operation differentiable in pytorch? Is the only one I have serious doubts.

Yes this is just indexing. Nothing wrong with that :slight_smile:
If you want to be extra sure, you can check that if the input has requires_grad=True, then the output also has requires_grad=True

But to be an error getting the gradients wouldn’t all gradients for all layers be None and not only the grad for the last layer biases tensor? Honestly I can’t find any reason for it to be None, except the independency and always zero cases.

I’m not sure to understand your question here.

Yeah you are right. It was not well explained sorry. So if I had somewhere a non differentiable op wouldn’t the gradients before that be all None since it is a continuous chain? In my case I get None only for the gradients of the bias from the last layer.

Ho in this case yes.
I was just mentioning that in general when you see this kind of issues.

1 Like