Backprop through functional

I’m trying to backprop through a higher-order function (a function that takes a function as argument), specifically a functional (a higher-order function that returns a scalar). Here is a simple example:

import torch

class Functional(torch.autograd.Function):
    @staticmethod
    def forward(ctx, f):
        value = f(2)**2 - f(1)
        ctx.save_for_backward(value)
        return value
    @staticmethod
    def backward(ctx, grad_output):
        value = ctx.saved_tensors
        value.backward(grad_output)
        return None

class Function(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.a = torch.tensor(0.)
    def forward(self, x):
        return self.a * x

function = Function()
functional = Functional.apply

value = functional(function)
value.backward() # RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Unfortunately, this throws an error. How can I fix this?

Here is another example, a Monte Carlo estimate of the inner product between two functions on the unit hypercube:

import torch

dim = 2
samples = 100

class InnerProduct(torch.autograd.Function):
    @staticmethod
    def forward(ctx, f, g):
        x = torch.rand(samples, dim)
        f_x = f(x)
        g_x = g(x)
        ctx.save_for_backward(f_x, g_x)
        return (f_x * g_x).mean(-1)
    @staticmethod
    def backward(ctx, grad_output):
        f_x, g_x = ctx.saved_tensors
        for f_x_i, g_x_i in zip(f_x, g_x):
            f_x_i.backward(g_x_i / samples * grad_output)
            g_x_i.backward(f_x_i / samples * grad_output)
        return None, None
inner_product = InnerProduct.apply

class ScalarFunction(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(dim, 10),
            torch.nn.ELU(),
            torch.nn.Linear(10, 10),
            torch.nn.ELU(),
            torch.nn.Linear(10, 1),
            torch.nn.Flatten(-2, -1)
        )
    def forward(self, x):
        return self.net(x)

f = ScalarFunction()
g = ScalarFunction()
a = inner_product(f, g)
print(a)
a.backward() # RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

There are a few things between your current code and something that works:

  • the autograd engine actually sees when none of the inputs require a gradient and decides that the entire gradient business can thus be skipped, you could add a dummy argument to avoid that,
  • torch.autograd.Function.forward you are implicitly in a torch.no_grad context. You need to use with torch.enable_grad(): to re-enable gradient tracking.
  • you didn’t have any gradient requirement in your function, without that, the exercise seems futile (probably you’d want to check and raise an exception or not run backward or something).
  • but you want to detach the Function.forward return value to prevent bad things to happen™,
  • I will admit that I have reservations about using .backward in autograd.Functions, but this is just a feeling and I don’t have an alternative.

So with these ideas we might produce the following and it runs and puts a gradient into a.grad.

import torch

class Functional(torch.autograd.Function):
    @staticmethod
    def forward(ctx, f, dummy):
        with torch.enable_grad():
            value = f(2)**2 - f(1)
            ctx.save_for_backward(value)
        return value.detach()
    @staticmethod
    def backward(ctx, grad_output):
        value, = ctx.saved_tensors
        value.backward(grad_output)
        return None, None

class Function(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.a = torch.tensor(0., requires_grad=True)
    def forward(self, x):
        return self.a * x

dummy = torch.empty((), requires_grad=True)
function = Function()
functional = lambda f: Functional.apply(f, dummy)

value = functional(function)
value.backward() 

The more important question, of course, is why you would want an autograd.Function if you don’t need the gradients of any inputs. The same result could be obtained if you just used a plain Python function instead of torch.autograd.Functions and without all the complications:

def functional(f):
  return f(2)**2 - f(1)

would be a more correct, concise and less indirect way to do things.

(lambda: print("Best regards"))()

Thomas

@tom The reason I’m doing this is that the forward pass takes a parametrized distribution as input and estimates a property of the distribution by sampling from it. This sampling means I can’t backpropagate directly to the parameters of the distribution. However, I know another estimator for the gradient of the property with respect to the distribution, and this is what I use in the backward pass.

@tom Is there any way to avoid the dummy input? Perhaps some PyTorch setting to force computation of the backward pass?

You didn’t say what doesn’t work with a plain Python function.

@tom How would the backpropagation phase work through the plain function when the forward pass isn’t differentiable because it requires sampling?

Well, so

  • if you have inputs that are differentiable (and I take it you have), you can use these instead of the dummy parameter,
  • if you don’t have inputs that are differentiable (which is why you need the dummy), you don’t need an autograd.Function.

What I think might make the examples easier is move from passing in a module to passing in a (python) function that doesn’t use global variables instead.
Then you can have the parametrized function and the parameters as part of the inputs in functional and it’ll work with the recipe above. I’m doing something like that (but not with sampling and taking derivatives of the density or so) in my Implicit Function notebook. (Nowadays you can do the backward that I’m doing in the forward there in the backward, but I did that a looong time ago.)

Best regards

Thomas