RuntimeError: element 0 of variables does not require grad and does not have a grad_fn when trying to implement Gumbel-ST sampler

I currently have the following code to implement a straight-through Gumbel softmax sampler:

def gumbel_ST_sample(logits, temp=1.0):
    noise = torch.rand(logits.size())
    eps = 1e-8
    gumbel_noise = Variable(-torch.log(-torch.log(noise + eps) + eps))

    gumbel_sample = F.softmax((logits + gumbel_noise) / temp, dim=-1)

    one_hot_sample = (gumbel_sample == gumbel_sample.max(1)[0].unsqueeze(1))

    st = (one_hot_sample.float() - gumbel_sample).detach() + gumbel_sample

    return st    

Iā€™m trying to test it using the following code:

test_logits = Variable(torch.rand(4, 10))
sample = gumbel_ST_sample(test_logits)
test_embeds = Variable(torch.rand(10, 5))

out = torch.mm(sample, test_embeds).sum()

grad = torch.autograd.grad(out, test_embeds)

But this gives me the error in the title. Can anyone help?

Thanks
Kris

Stupid error:

test_embeds = Variable(torch,randn(5. 10), requires_grad=True)

solves it.

1 Like

Hi! Thanks for your sharing. I have a question, it sees that you return one-hot value after the forward of gumbel_ST_sample, but I wonder how to return the backward value of non-one-hot function (i.e., gumbel_sample )?