Stop gradients (for ST gumbel softmax)

Hi,

Given the soft implementation of Gumbel Softmax:

from torch.autograd import Variable
import torch.nn.functional as F

def sample_gumbel(input):
    noise = torch.rand(input.size())
    eps = 1e-20
    noise.add_(eps).log_().neg_()
    noise.add_(eps).log_().neg_()
    return Variable(noise)

def gumbel_softmax_sample(input):
    temperature = 1
    noise = sample_gumbel(input)
    x = (input + noise) / temperature
    x = F.log_softmax(x)
    return x.view_as(input)

I’d like to implement the hard version, as here. Any idea how to do that? Maybe this example in torch helps?

4 Likes

I did stop_gradient(x) in PyTorch as Variable(x.data). However, I’m looking for a better alternative.

Wouldn’t .detach() solve this? It will return a new Variable that doesn’t require gradient, and won’t propagte anything in the backward phase.

5 Likes

Thank you both, I’ll try it out!

I am also interested in implementing ST gumbel softmax with stop gradient.
So, I wrote the script as below.

def gumbel_softmax_sample(self, input):
        temperature = 1
        noise = self.sample_gumbel(input)
        x = (input + noise) / temperature
        x = F.softmax(x)

        if hard == True:
            max_val, _ = torch.max(x, x.dim()-1)
            x_hard = x == max_val.expand_as(x)
            tmp = (x_hard.float() - x)
            tmp2 = tmp.clone()
            tmp2.detach_()
            x = tmp2 + x

        return x.view_as(input)

But, I tried to modify it such that only one ‘1’ exists at each row in x_hard as below.

def gumbel_softmax_sample(self, input):
    temperature = 1
    noise = self.sample_gumbel(input)
    x = (input + noise) / temperature
    x = F.softmax(x)

    if hard == True:
        _, max_inx = torch.max(x, x.dim()-1)
        x_hard = torch.cuda.FloatTensor(x.size()).zero_().scatter_(x.dim()-1, max_inx.data, 1.0)
        x2 = x.clone()
        tmp = Variable(x_hard-x2.data)
        tmp.detach_()

        x = tmp + x
        
    return x.view_as(input)

These two scripts work without error, but I am not sure they are equivalent to the original ST gumbel softmax.

I would greatly appreciate if someone reviews them.

Is there anyone helps me to confirm that these two scripts are equivalent to the original tensorflow implementation ?

@Seungyoung_Park Have you tried A/B testing your scripts with the TF implementation to see if the outputs are any different?

Because I do not know how to make the parameters same as those of the original, I have not yet compared them.

Hey, have you verified the correctness of your code?
Two things:

  1. I don’t the difference between .detach_() and .detach().
  2. It seems like it’s also correct even if you don’t use tmp.detach_(), because tmp was created as a new Variable as @Ilya_Kostrikov suggested

From https://gist.github.com/ericjang/1001afd374c2c3b7752545ce6d9ed349#file-gumbel-softmax-py-L27

y = tf.stop_gradient(y_hard - y) + y

Whoa, that’s so clever. I had to stare at that for ages before finally figuring that out. So cool :slight_smile: . So, the result of this is:

  • y is pure one-hot, in terms of value (since we add the soft y, and then subtract it again
  • the gradients are those of soft y (since all the other terms in this expression have their gradient stripped)
7 Likes

Created PR at https://github.com/pytorch/pytorch/pull/3341 , based on above implementation

1 Like