I’m trying to “replicate” TextGAN using pytorch and I’m new to pytorch. My current concern is to replicate the L_G (eq. 7 page 3), and here’s my current code:
def JSDLoss(batch_size, f_real, f_synt):
f_num_features = f_real.data.numpy().shape[1]
identity = autograd.Variable(torch.eye(f_num_features)*0.1)
f_real_mean = torch.mean(f_real, 0, keepdim=True)
f_synt_mean = torch.mean(f_synt, 0, keepdim=True)
dev_f_real = f_real - f_real_mean.expand(batch_size,f_num_features)
dev_f_synt = f_synt - f_synt_mean.expand(batch_size,f_num_features)
f_real_xx = torch.mm(torch.t(dev_f_real), dev_f_real)
f_synt_xx = torch.mm(torch.t(dev_f_synt), dev_f_synt)
cov_mat_f_real = (f_real_xx / batch_size) + identity
cov_mat_f_synt = (f_synt_xx / batch_size) + identity
cov_mat_f_real_inv = torch.inverse(cov_mat_f_real)
cov_mat_f_synt_inv = torch.inverse(cov_mat_f_synt)
loss_g = torch.trace(torch.mm(cov_mat_f_synt_inv, cov_mat_f_real) + torch.mm(cov_mat_f_real_inv, cov_mat_f_synt)) + torch.mm(torch.mm((f_synt_mean - f_real_mean), (cov_mat_f_synt_inv + cov_mat_f_real_inv)), torch.t(f_synt_mean - f_real_mean))
return loss_g
I got an error like this:
RuntimeError: matrices expected, got 1D, 2D tensors at /Users/soumith/code/builder/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:1288
Anyone knows what to do? Any help is really appreciated. Thanks in advance!