I currently have the following code to implement a straight-through Gumbel softmax sampler:
def gumbel_ST_sample(logits, temp=1.0):
noise = torch.rand(logits.size())
eps = 1e-8
gumbel_noise = Variable(-torch.log(-torch.log(noise + eps) + eps))
gumbel_sample = F.softmax((logits + gumbel_noise) / temp, dim=-1)
one_hot_sample = (gumbel_sample == gumbel_sample.max(1)[0].unsqueeze(1))
st = (one_hot_sample.float() - gumbel_sample).detach() + gumbel_sample
return st
Iām trying to test it using the following code:
test_logits = Variable(torch.rand(4, 10))
sample = gumbel_ST_sample(test_logits)
test_embeds = Variable(torch.rand(10, 5))
out = torch.mm(sample, test_embeds).sum()
grad = torch.autograd.grad(out, test_embeds)
But this gives me the error in the title. Can anyone help?
Thanks
Kris