eliabruni
(Elia Bruni)
#1
Hi,

Given the soft implementation of Gumbel Softmax:

```
from torch.autograd import Variable
import torch.nn.functional as F
def sample_gumbel(input):
noise = torch.rand(input.size())
eps = 1e-20
noise.add_(eps).log_().neg_()
noise.add_(eps).log_().neg_()
return Variable(noise)
def gumbel_softmax_sample(input):
temperature = 1
noise = sample_gumbel(input)
x = (input + noise) / temperature
x = F.log_softmax(x)
return x.view_as(input)
```

I’d like to implement the hard version, as here. Any idea how to do that? Maybe this example in torch helps?

4 Likes

I did stop_gradient(x) in PyTorch as Variable(x.data). However, I’m looking for a better alternative.

apaszke
(Adam Paszke)
#3
Wouldn’t `.detach()`

solve this? It will return a new Variable that doesn’t require gradient, and won’t propagte anything in the backward phase.

5 Likes

eliabruni
(Elia Bruni)
#4
Thank you both, I’ll try it out!

I am also interested in implementing ST gumbel softmax with stop gradient.

So, I wrote the script as below.

```
def gumbel_softmax_sample(self, input):
temperature = 1
noise = self.sample_gumbel(input)
x = (input + noise) / temperature
x = F.softmax(x)
if hard == True:
max_val, _ = torch.max(x, x.dim()-1)
x_hard = x == max_val.expand_as(x)
tmp = (x_hard.float() - x)
tmp2 = tmp.clone()
tmp2.detach_()
x = tmp2 + x
return x.view_as(input)
```

But, I tried to modify it such that only one ‘1’ exists at each row in x_hard as below.

```
def gumbel_softmax_sample(self, input):
temperature = 1
noise = self.sample_gumbel(input)
x = (input + noise) / temperature
x = F.softmax(x)
if hard == True:
_, max_inx = torch.max(x, x.dim()-1)
x_hard = torch.cuda.FloatTensor(x.size()).zero_().scatter_(x.dim()-1, max_inx.data, 1.0)
x2 = x.clone()
tmp = Variable(x_hard-x2.data)
tmp.detach_()
x = tmp + x
return x.view_as(input)
```

These two scripts work without error, but I am not sure they are equivalent to the original ST gumbel softmax.

I would greatly appreciate if someone reviews them.

Is there anyone helps me to confirm that these two scripts are equivalent to the original tensorflow implementation ?

@Seungyoung_Park Have you tried A/B testing your scripts with the TF implementation to see if the outputs are any different?

Because I do not know how to make the parameters same as those of the original, I have not yet compared them.

Qi_Zhang
(Qi Zhang)
#9
Hey, have you verified the correctness of your code?

Two things:

- I don’t the difference between .detach_() and .detach().
- It seems like it’s also correct even if you don’t use tmp.detach_(), because tmp was created as a new Variable as @Ilya_Kostrikov suggested

hughperkins
(Hugh Perkins)
#10
From https://gist.github.com/ericjang/1001afd374c2c3b7752545ce6d9ed349#file-gumbel-softmax-py-L27

```
y = tf.stop_gradient(y_hard - y) + y
```

Whoa, that’s so clever. I had to stare at that for ages before finally figuring that out. So cool . So, the result of this is:

- y is pure one-hot, in terms of value (since we add the soft y, and then subtract it again
- the gradients are those of soft y (since all the other terms in this expression have their gradient stripped)

7 Likes

hughperkins
(Hugh Perkins)
#11
Created PR at https://github.com/pytorch/pytorch/pull/3341 , based on above implementation

1 Like