I am using the following code to compute reparameterization trick. But it is very slow. Can anyone suggest a better and faster way to do it?
def reparameterize(mux, muy, sx, sy, corr, nodesPresent):
# pdb.set_trace()
o_mux, o_muy, o_sx, o_sy, o_corr = mux, muy, sx, sy, corr
numNodes = mux.size()[1]
seq_len = mux.size()[0]
# pdb.set_trace()
next_values = Variable(torch.zeros(seq_len, numNodes, 2)).cuda()
# next_y = Variable(torch.zeros(numNodes))
for i in range(seq_len):
for node in range(numNodes):
mean = torch.cat((o_mux[i, node], o_muy[i, node]))
sigma_xx = o_sx[i, node]*o_sx[i, node]
sigma_yy = o_sy[i, node]*o_sy[i, node]
sigma_xy = o_corr[i, node]*o_sx[i, node]*o_sy[i, node]
sigma_yx = o_corr[i, node]*o_sy[i, node]*o_sx[i, node]
cov = torch.stack((sigma_xx, sigma_xy, sigma_yx, sigma_yy)).view(2,2)
# mean = Variable(torch.FloatTensor([o_mux[i, node].data[0], o_muy[i, node].data[0]]))
# cov = [[o_sx[i, node].data[0]*o_sx[i, node].data[0], o_corr[i, node].data[0]*o_sx[i, node].data[0]*o_sy[i, node].data[0]],
# [o_corr[i, node].data[0]*o_sx[i, node].data[0]*o_sy[i, node].data[0], o_sy[i, node].data[0]*o_sy[i, node].data[0]]]
# cov = Variable(torch.FloatTensor(cov))
eps = Variable(torch.randn(mean.size()[0])).cuda()
next_values[i, node, :] = mean[:, None] + torch.mm(cov,eps[:, None])
return next_values