Faster way to compute Reparamterization tirck

I am using the following code to compute reparameterization trick. But it is very slow. Can anyone suggest a better and faster way to do it?

def reparameterize(mux, muy, sx, sy, corr, nodesPresent):
    # pdb.set_trace()

    o_mux, o_muy, o_sx, o_sy, o_corr = mux, muy, sx, sy, corr


    numNodes = mux.size()[1]
    seq_len = mux.size()[0]
    # pdb.set_trace()

    next_values = Variable(torch.zeros(seq_len, numNodes, 2)).cuda()
    # next_y = Variable(torch.zeros(numNodes))


    for i in range(seq_len):
        for node in range(numNodes):
            mean = torch.cat((o_mux[i, node], o_muy[i, node]))
            sigma_xx = o_sx[i, node]*o_sx[i, node]
            sigma_yy = o_sy[i, node]*o_sy[i, node] 
            sigma_xy = o_corr[i, node]*o_sx[i, node]*o_sy[i, node]
            sigma_yx = o_corr[i, node]*o_sy[i, node]*o_sx[i, node]

            cov =  torch.stack((sigma_xx, sigma_xy, sigma_yx, sigma_yy)).view(2,2)

        
            # mean = Variable(torch.FloatTensor([o_mux[i, node].data[0], o_muy[i, node].data[0]]))
            # cov = [[o_sx[i, node].data[0]*o_sx[i, node].data[0], o_corr[i, node].data[0]*o_sx[i, node].data[0]*o_sy[i, node].data[0]],
            # [o_corr[i, node].data[0]*o_sx[i, node].data[0]*o_sy[i, node].data[0], o_sy[i, node].data[0]*o_sy[i, node].data[0]]]
            # cov = Variable(torch.FloatTensor(cov))
        
            eps = Variable(torch.randn(mean.size()[0])).cuda()

            next_values[i, node, :] = mean[:, None] + torch.mm(cov,eps[:, None])

    return next_values

Hi,

The obvious thing you can do to speed things up is to remove the two for loops and make all operations work directly on tensors:

for i in range(seq_len):
        for node in range(numNodes):
            mean = torch.cat((o_mux[i, node], o_muy[i, node]))
            sigma_xx = o_sx[i, node]*o_sx[i, node]
            sigma_yy = o_sy[i, node]*o_sy[i, node] 
            sigma_xy = o_corr[i, node]*o_sx[i, node]*o_sy[i, node]
            sigma_yx = o_corr[i, node]*o_sy[i, node]*o_sx[i, node]

Can just become

mean = torch.cat((o_mux, o_muy))
sigma_xx = o_sx *o_sx
sigma_yy = o_sy * o_sy 
sigma_xy = o_corr * o_sx * o_sy
sigma_yx = o_corr * o_sy * o_sx