I have this simple network and I want to try a very simple idea. That is, let’s suppose that the input to the network is a pair of tensors of size (batch_size, num_channels, dim, dim)
, that represent mean values and variances of N = batch_size * num_channels * dim * dim
univariate normal distributions.
Inside the forward
function, I’m interested in producing a new tensor (z
) of the same size (i.e., (batch_size, num_channels, dim, dim)
), for which each element is drawn by a univariate normal distribution with means and variances given by the corresponding elements of the above input tensors. After that I want to use variable z
just like I would use the standard x
input variable (add some convolution layers, etc.).
I’m doing this as shown below. Even though it seems to work as expected, I am not sure what will happen during back-propagation. That is, at the point where I create the variable z
(should a .requires_grad_(True)
be added there?), it seems that the connections with the inputs x_mean
and x_var
might be broken. After all, sampling from a distributions is not a differentiable operation…
I wonder if this is going to work during back-propagation, and if not, if you have any ideas/insight on how to implement it so it does.
Many thanks!
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
def forward(self, x_mean, x_var):
z = x_var.sqrt() * torch.randn(x_mean.size()) + x_mean
return self.conv(z)
net = Net()
x_mean = torch.randn(1, 16, 300, 300).requires_grad_(True)
x_var = torch.randn(1, 16, 300, 300).requires_grad_(True)
z = net(x_mean, x_var)