Hi. I want to implement a runge kutta discretization for my loss. In order to do so I have to call four times the same layers in the same forward inside a nn.module. My question is if the following code works when it comes to the backward of the loss function. If I just output the output of the function everything seems alright. However, when I add the rk4_step results seem incorrect.
def forward(self, x):
with torch.set_grad_enabled(True):
qqd = x.requires_grad_(True)
time_step = 0.01
out=self._rk4_step(qqd,time_step)
return out
def function(self,qqd):
self.n = n = qqd.shape[1]//2
L = self._lagrangian(qqd).sum()
J = grad(L, qqd, create_graph=True)[0] ;
DL_q, DL_qd = J[:,:n], J[:,n:]
DDL_qd = []
for i in range(n):
J_qd_i = DL_qd[:,i][:,None]
H_i = grad(J_qd_i.sum(), qqd, create_graph=True)[0][:,:,None]
DDL_qd.append(H_i)
DDL_qd = torch.cat(DDL_qd, 2)
DDL_qqd, DDL_qdqd = DDL_qd[:,:n,:], DDL_qd[:,n:,:]
T = torch.einsum('ijk, ij -> ik', DDL_qqd, qqd[:,n:])
qdd = torch.einsum('ijk, ij -> ik', DDL_qdqd.pinverse(), DL_q - T)
return torch.stack((qqd[:,2],qqd[:,3], qdd[:,0], qdd[:,1]),1)
def _lagrangian(self, qqd):
x = F.softplus(self.fc1(qqd))
x = F.softplus(self.fc2(x))
x = F.softplus(self.fc3(x))
L = self.fc_last(x)
return L
def _rk4_step(self, qqd, h=None):
# one step of Runge-Kutta integration
k1 = h * self.function(qqd)
k2 = h * self.function(qqd + k1/2)
k3 = h * self.function(qqd + k2/2)
k4 = h *self.function(qqd + k3)
return qqd + 1/6 * (k1 + 2 * k2 + 2 * k3 + k4)