Back-propagation on variable sampled by random distribution

I have this simple network and I want to try a very simple idea. That is, let’s suppose that the input to the network is a pair of tensors of size (batch_size, num_channels, dim, dim), that represent mean values and variances of N = batch_size * num_channels * dim * dim univariate normal distributions.

Inside the forward function, I’m interested in producing a new tensor (z) of the same size (i.e., (batch_size, num_channels, dim, dim)), for which each element is drawn by a univariate normal distribution with means and variances given by the corresponding elements of the above input tensors. After that I want to use variable z just like I would use the standard x input variable (add some convolution layers, etc.).

I’m doing this as shown below. Even though it seems to work as expected, I am not sure what will happen during back-propagation. That is, at the point where I create the variable z (should a .requires_grad_(True) be added there?), it seems that the connections with the inputs x_mean and x_var might be broken. After all, sampling from a distributions is not a differentiable operation…

I wonder if this is going to work during back-propagation, and if not, if you have any ideas/insight on how to implement it so it does.

Many thanks!

import torch
import torch.nn as nn


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)

    def forward(self, x_mean, x_var):
        z = x_var.sqrt() * torch.randn(x_mean.size()) + x_mean
        return self.conv(z)

net = Net()

x_mean = torch.randn(1, 16, 300, 300).requires_grad_(True)
x_var = torch.randn(1, 16, 300, 300).requires_grad_(True)

z = net(x_mean, x_var)

Hi,

Since the sampling is non-differentiable with respect to the parameters it will not propagate any gradients.
But what you do here is a reparametrisation trick right? And so gradients will not flow back the torch.randn op (but that’s ok) and will flow as expected to x_var and x_mean.
If x_mean or x_var requires gradient, then z will require gradients as well. If they don’t, z won’t either.

Hi @albanD, thanks for your response!

To be honest I am a bit confused by your answer. Since gradients will not flow back the torch.randn op, how it will propagate back to the inputs x_mean and x_var? Keep in mind that before the sampling step, x_mean and x_var might be passed through many layers (convolutions, pooling, non-linearities), so what I was thinking is whether gradients will propagate through these layers or not. In any case, I’m just experimenting, so I’ll probably drop this idea at all, but I’m curious whether it could work in PyTorch.

Your sampling is done as: z = x_var.sqrt() * torch.randn(x_mean.size()) + x_mean. As you can see, the torch.randn() part is independant of the way you use x_var and x_mean. So even though you cannot backprop through the randn op, you can backprop thourgh the multiply and sqrt op for x_var. This is known as the reparametrisation trick if I’m not mistaken in the litterature.

2 Likes

Hmm, I see; yes, of course you’re right… Interesting that there’s a name for this – I wasn’t aware.

Many thanks again :slight_smile:

This is quite “simple” to see for a gaussian distribution with custom mean and variance. But it can be trickier to do for more complex distributions.

1 Like

True that. Cheers :slight_smile: