Due to the iterations for the update steps, the following code is running faster on CPU.
Also I realized for the update step and inference it is running faster with numpy than pytorch. Is there a way to optimize it?
import torch
device = torch.device('cpu')
import pdb
class MesNet(torch.nn.Module):
def __init__(self):
super(MesNet, self).__init__()
self.cov_lin = torch.nn.Sequential(torch.nn.Linear(6, 5)).double()
def forward(self, u):
z_cov = self.cov_lin(u.transpose(0, 2).squeeze(-1))
return z_cov
class UpdateModel(torch.nn.Module):
def __init__(self):
torch.nn.Module.__init__(self)
self.P_dim = 18
self.Id3 = torch.eye(3).double()
def run_KF(self):
N = 10
u = torch.randn(N, 6).double()
v = [torch.zeros(3).double()]
model = MesNet()
measurements_covs_l = model(u.t().unsqueeze(0))
# remember to remove this afterwards
torch.autograd.set_detect_anomaly(True)
for i in range(1, N):
v_new = self.update_pos(v[i-1].detach(), measurements_covs_l[i-1])
v.append(v_new.clone())
criterion = torch.nn.MSELoss(reduction="sum")
targ = torch.rand(10, 3).double()
pdb.set_trace()
loss = criterion(v, targ)
loss = torch.mean(loss)
loss.backward()
return v, p
def update_pos(self, v, measurement_cov):
Omega = torch.eye(3).double()
H = torch.ones((5, self.P_dim)).double()
R = torch.diag(measurement_cov)
Kt = H.t().mm(torch.inverse(R))
# it is indicating inplace error even with this:
# Kt = H.t().mm(R)
dx = Kt.mv(torch.ones(5).double())
dR = self.trans(dx[:9].clone())
v_up = dR.mv(v)
return v_up
def trans(self, xi):
phi = xi[:3].clone()
angle = torch.norm(phi.clone())
if angle.abs().lt(1e-10):
skew_phi = torch.eye(3).double()
J = self.Id3 + 0.5 * skew_phi
Rot = self.Id3 + skew_phi
else:
axis = phi / angle
skew_axis = torch.eye(3).double()
s = torch.sin(angle)
c = torch.cos(angle)
Rot = c * self.Id3
return Rot
net = UpdateModel()
net.run_KF()