Can you differentiate MC estimates?


(Manuel) #1

Let’s say I want to perform MC integration to estimate the integrand of a 2D function (fn) (i.e. a grayscale image). For reference: https://en.wikipedia.org/wiki/Monte_Carlo_integration. I have a set of random points with shape (number samples x 2) (the 2 comes because we are doing MC integration over a 2d function). I would like to differentiate wrt to this variable.

points = ... # random set of points
points.requires_grad=True
fn = ... # some image

In order to get Monte Carlo estimates, we have to warp the original samples to the space of fn. This can be easily done by multiplying by the shape of the 2d function.

fn_shape = (torch.tensor(fn.shape) - 1).view(1, -1)
warped_points = (points * fn_shape).long() # here, warped_points already does not require grad

Then, in order to get our estimates, we can just access the pixel given by the warped points

mc_samples = fn[warped_points[:, :, 0], warped_points[:, :, 1]]

Now, we just integrate the original function and the estimates and compute the squared error to know how good our approximation to the original function is

fn_integral = fn.sum() / fn.numel()
mc_integral = mc_samples.sum() / mc_samples.numel()
err = (mc_integral - fn_integral) ** 2

The source of the error seems clear. In the end, if you think of an image as a function. It is a discontinuous function where it has some value at pixel (x,y) and pixel (x+1, y) but between x and x+1 there could be infinite values that are just simply not defined. At the time of the warping, what we are doing is getting the pixel we want to access by calling (points * fn_shape).long() and forgetting about the decimal part of the number.

My main question is: is there a way to make this problem differentiable? One way could be to forget about images and just use functions that require two variables (i.e. x^2 + y^2). But I am particularly interested in taking images as 2d functions.

All the code togheter:

points = ... # random set of points
points.requires_grad=True
fn = ... # some image

fn_shape = (torch.tensor(fn.shape) - 1).view(1, 1, -1)

## warp_samples
warped_points = (points * fn_shape).long() # here, warped_points will not require grad due to the .long call

## get monte carlo estimates for the current function
mc_samples = fn[warped_points[:, :, 0], warped_points[:, :, 1]]

## integrate by monte carlo
mc_integral = mc_samples.sum() / mc_samples.numel()
fn_integral = fn.sum() / fn.numel()
err = (mc_integral - fn_integral) ** 2

# obviusly, here it breaks:
#  "element 0 of tensors does not require grad and does not have a grad_fn"
err.backward()