eliabruni
(Elia Bruni)
1
Hi,
Given the soft implementation of Gumbel Softmax:
from torch.autograd import Variable
import torch.nn.functional as F
def sample_gumbel(input):
noise = torch.rand(input.size())
eps = 1e-20
noise.add_(eps).log_().neg_()
noise.add_(eps).log_().neg_()
return Variable(noise)
def gumbel_softmax_sample(input):
temperature = 1
noise = sample_gumbel(input)
x = (input + noise) / temperature
x = F.log_softmax(x)
return x.view_as(input)
I’d like to implement the hard version, as here. Any idea how to do that? Maybe this example in torch helps?
4 Likes
I did stop_gradient(x) in PyTorch as Variable(x.data). However, I’m looking for a better alternative.
apaszke
(Adam Paszke)
3
Wouldn’t .detach()
solve this? It will return a new Variable that doesn’t require gradient, and won’t propagte anything in the backward phase.
5 Likes
eliabruni
(Elia Bruni)
4
Thank you both, I’ll try it out!
I am also interested in implementing ST gumbel softmax with stop gradient.
So, I wrote the script as below.
def gumbel_softmax_sample(self, input):
temperature = 1
noise = self.sample_gumbel(input)
x = (input + noise) / temperature
x = F.softmax(x)
if hard == True:
max_val, _ = torch.max(x, x.dim()-1)
x_hard = x == max_val.expand_as(x)
tmp = (x_hard.float() - x)
tmp2 = tmp.clone()
tmp2.detach_()
x = tmp2 + x
return x.view_as(input)
But, I tried to modify it such that only one ‘1’ exists at each row in x_hard as below.
def gumbel_softmax_sample(self, input):
temperature = 1
noise = self.sample_gumbel(input)
x = (input + noise) / temperature
x = F.softmax(x)
if hard == True:
_, max_inx = torch.max(x, x.dim()-1)
x_hard = torch.cuda.FloatTensor(x.size()).zero_().scatter_(x.dim()-1, max_inx.data, 1.0)
x2 = x.clone()
tmp = Variable(x_hard-x2.data)
tmp.detach_()
x = tmp + x
return x.view_as(input)
These two scripts work without error, but I am not sure they are equivalent to the original ST gumbel softmax.
I would greatly appreciate if someone reviews them.
Is there anyone helps me to confirm that these two scripts are equivalent to the original tensorflow implementation ?
@Seungyoung_Park Have you tried A/B testing your scripts with the TF implementation to see if the outputs are any different?
Because I do not know how to make the parameters same as those of the original, I have not yet compared them.
Qi_Zhang
(Qi Zhang)
9
Hey, have you verified the correctness of your code?
Two things:
- I don’t the difference between .detach_() and .detach().
- It seems like it’s also correct even if you don’t use tmp.detach_(), because tmp was created as a new Variable as @Ilya_Kostrikov suggested
hughperkins
(Hugh Perkins)
10
From https://gist.github.com/ericjang/1001afd374c2c3b7752545ce6d9ed349#file-gumbel-softmax-py-L27
y = tf.stop_gradient(y_hard - y) + y
Whoa, that’s so clever. I had to stare at that for ages before finally figuring that out. So cool . So, the result of this is:
- y is pure one-hot, in terms of value (since we add the soft y, and then subtract it again
- the gradients are those of soft y (since all the other terms in this expression have their gradient stripped)
7 Likes
Created PR at https://github.com/pytorch/pytorch/pull/3341 , based on above implementation
1 Like