The code is running faster on CPU

Due to the iterations for the update steps, the following code is running faster on CPU.
Also I realized for the update step and inference it is running faster with numpy than pytorch. Is there a way to optimize it?

import torch
device = torch.device('cpu')
import pdb

class MesNet(torch.nn.Module):
        def __init__(self):
            super(MesNet, self).__init__()

            self.cov_lin = torch.nn.Sequential(torch.nn.Linear(6, 5)).double()

        def forward(self, u):
            z_cov = self.cov_lin(u.transpose(0, 2).squeeze(-1))
            return z_cov 

class UpdateModel(torch.nn.Module):

    def __init__(self):
        self.P_dim = 18
        self.Id3 = torch.eye(3).double()
    def run_KF(self):
        N = 10
        u = torch.randn(N, 6).double()
        v = [torch.zeros(3).double()]
        model = MesNet()
        measurements_covs_l = model(u.t().unsqueeze(0))
        # remember to remove this afterwards
        for i in range(1, N):
            v_new = self.update_pos(v[i-1].detach(), measurements_covs_l[i-1])

        criterion = torch.nn.MSELoss(reduction="sum")
        targ = torch.rand(10, 3).double()
        loss = criterion(v, targ)
        loss = torch.mean(loss)
        return v, p

    def update_pos(self, v, measurement_cov):
        Omega = torch.eye(3).double() 
        H = torch.ones((5, self.P_dim)).double()
        R = torch.diag(measurement_cov)
        Kt = H.t().mm(torch.inverse(R))
        # it is indicating inplace error even with this: 
        # Kt = H.t().mm(R)
        dx =
        dR = self.trans(dx[:9].clone())
        v_up =
        return v_up

    def trans(self, xi):
        phi = xi[:3].clone()
        angle = torch.norm(phi.clone())

        if angle.abs().lt(1e-10):

            skew_phi = torch.eye(3).double()
            J = self.Id3 + 0.5 * skew_phi
            Rot = self.Id3 + skew_phi
            axis = phi / angle
            skew_axis = torch.eye(3).double()
            s = torch.sin(angle)
            c = torch.cos(angle)

            Rot = c * self.Id3
        return Rot

net =  UpdateModel()