Custom loss function yields poor results

Hi, basically I’m creating a custom loss function in the forward function of a nn.module. However when I add the function rk4_step the results are absurd. Without it they are ok. For instance, if I use one training point I get the same test loss with 100 points as I would get if I used 400 training points, which doesn’t make too much sense. So I think the problem may be related with the loss.backward() in the trainning which is not computing the gradients properly. Is everything in rk4_step compatible with backward?

def forward(self, x):
    with torch.set_grad_enabled(True):
            time_step =torch.tensor(0.01)
            out=self._rk4_step(self.function, x, 0, time_step)
    return out
def function(self,x,t):
     self.n = n = x.shape[1]//2
     qqd = x.requires_grad_(True)
     L = self._lagrangian(qqd).sum()
     J = grad(L, qqd, create_graph=True)[0] ;
     DL_q, DL_qd = J[:,:n], J[:,n:]
     DDL_qd = []
     for i in range(n):
         J_qd_i = DL_qd[:,i][:,None]
         H_i = grad(J_qd_i.sum(), qqd, create_graph=True)[0][:,:,None]
         DDL_qd.append(H_i)
     DDL_qd = torch.cat(DDL_qd, 2)
     DDL_qqd, DDL_qdqd = DDL_qd[:,:n,:], DDL_qd[:,n:,:]
     T = torch.einsum('ijk, ij -> ik', DDL_qqd, qqd[:,n:])
     qdd = torch.einsum('ijk, ij -> ik', DDL_qdqd.inverse(), DL_q - T)

     return torch.cat([qqd[:,self.n:], qdd], 1)
        
def _lagrangian(self, qqd):
    x = F.softplus(self.fc1(qqd))
    x = F.softplus(self.fc2(x))
    # x = F.softplus(self.fc3(x))
    L = self.fc_last(x)
    return L
def _rk4_step(self, f, x, t, h):
    # one step of Runge-Kutta integration
    k1 = torch.mul(f(x, t),h)
    k2 = torch.mul(f(x + k1/2, t + h/2),h)
    k3 = torch.mul(f(x + k2/2, t + h/2),h)
    k4 = torch.mul(f(x + k3, t + h),h)
    return x + 1/6 * (k1 + 2 * k2 + 2 * k3 + k4)

Just a new update. I realized that fc_last.bias is always None. Which makes me think that there is indeed an error.