Argmax function is discrete and nondifferentiable, and it break the back-propagation path during training.
Therefore, I want to implement gumbel-softmax to instead of argmax. However, my pytorch version is 0.3, which has not packed gumbel-softmax function . So I have to reference the github-pytorch’s code and reproduce in my code. I am not sure the code reproduced by me is absolutely correct. Dose anyone can give me some suggests ? My code is as follow:
import torch
import torch.nn.functional as F
from torch.autograd import Variable
def _sample_gumbel(shape, eps=1e-10, out=None):
"""
Sample from Gumbel(0, 1)
based on
https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb ,
(MIT license)
"""
U = out.resize_(shape).uniform_() if out is not None else torch.rand(shape)
return - torch.log(eps - torch.log(U + eps))
def _gumbel_softmax_sample(logits, tau=1, eps=1e-10):
"""
Draw a sample from the Gumbel-Softmax distribution
based on
https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb
(MIT license)
"""
dims = logits.dim()
gumbel_noise = Variable(_sample_gumbel(logits.size(), eps=eps, out=logits.data.new()))
y = logits + gumbel_noise
return F.softmax(y / tau, dims - 1)
def gumbel_softmax(logits, tau=0.8, hard=False, eps=1e-10):
"""
Sample from the Gumbel-Softmax distribution and optionally discretize.
Args:
logits: `[batch_size, n_class]` unnormalized log-probs
tau: non-negative scalar temperature
hard: if ``True``, take `argmax`, but differentiate w.r.t. soft sample y
Returns:
[batch_size, n_class] sample from the Gumbel-Softmax distribution.
If hard=True, then the returned sample will be one-hot, otherwise it will
be a probability distribution that sums to 1 across classes
Constraints:
- this implementation only works on batch_size x num_features tensor for now
based on
https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb ,
(MIT license)
"""
shape = logits.size()
assert len(shape) == 2
y_soft = _gumbel_softmax_sample(logits, tau=tau, eps=eps)
if hard:
_, k = y_soft.max(-1)
# this bit is based on
# https://discuss.pytorch.org/t/stop-gradients-for-st-gumbel-softmax/530/5
y_hard = torch.zeros_like(logits).scatter_(-1, k.view(-1, 1), 1.0)
# this cool bit of code achieves two things:
# - makes the output value exactly one-hot (since we add then
# subtract y_soft value)
# - makes the gradient equal to y_soft gradient (since we strip
# all other gradients)
y = y_hard - y_soft.detach() + y_soft
else:
y = y_soft
return y
if __name__ == "__main__":
logits = F.softmax(Variable(torch.randn(10,
print(logits)
y_draw = gumbel_softmax(logits, hard=False)
print(y_draw)