Hi,
I am reviewing the gumbel_softmax implementation in PyTorch (torch.nn.functional — PyTorch 2.1 documentation). I am not able to understand, what this line of code is trying to accomplish:
The forward pass will return y_hard since - y_soft.detach() + y_soft() results in 0:
# forward
ret = y_hard - y_soft.detach() + y_soft
ret = y_hard + 0
ret = y_hard
while the backward pass will go through y_soft since y_hard is not attached to the computation graph and created via torch.zeros_like().scatter_(). The second term y_soft.detach() is explicitly detached and thus won’t pass any gradients to previous layers so only the third term y_soft will be used in the backward.