How to implement "Shake Shake"

Hi,

I think you might want something like

x = Variable(torch.randn(5), requires_grad=True)
y = Variable(torch.bernoulli(torch.FloatTensor(5).fill_(0.5)))
z = 0.5*x+(x*(y-0.5)).detach()
z.sum().backward()

The trick is to use the mean in the normal flow (which will be used for backward) and block the backward for the product with the difference between the mean and the normal flow.

Note that implementing your own autograd.Function would likely be somewhat more efficient computationally (saves the product of x*(x-0.5)).

I saw something similar in the discussion of Gumbel-softmax by Hugh Perkins.

Best regards

Thomas

1 Like