autograd.Function for nested modules

Hi everyone.

I am trying to define a new class using autograd.Function (say f(x, g)).
This function takes a tensor x and another nn.Module (say g with parameters theta).
I will compute the output by applying some variation of g on x, and return the output.

I am a little bit fuzzy on defining the backward for this function.
How should I backpropagate through the function g, and what needs to be returned by the backward function?

class F(autograd.Function):

    @staticmethod
    def forward(ctx, x, g):
        with torch.enable_grad():
            x = x.clone().requires_grad_(True)
            z = g(x)
            ctx.save_for_backward(x, z)
            ctx._function = g
            
            return z
    
    @staticmethod
    def backward(ctx, output_grad):
        x, z = ctx.saved_tensors
        g = ctx._function
        
        with torch.enable_grad():
            z.backward(??)
            
            return x.grad*output_grad, ??

I know that now this current setup seems a little bit absurd (e.g. why I don’t just apply g directly on x and then backpropagate) but this can be helpful for me.

Thanks.

I would appreciate it if someone can help me out.

Was interested about that myself, below seems to work:

import torch
from torch import autograd, nn

class F(autograd.Function):
	
	@staticmethod
	def forward(ctx, x, g):
		with torch.enable_grad():
			x = x.clone().requires_grad_(True)
			z = g(x)
			ctx.save_for_backward(x, z)
			return z.detach() #infinite backprop without detach()
	
	@staticmethod
	def backward(ctx, output_grad):
		x, z = ctx.saved_tensors
		g_x, = autograd.grad(z, x, output_grad)
		return g_x, None
		
x0 = torch.ones(3, requires_grad=True)
g = nn.Softplus()
y = F.apply(x0, g)
y.mean().backward()
print(x0.grad)
x2 = x0.detach().requires_grad_()
y2 = nn.functional.softplus(x2)
y2.mean().backward()
print(x1.grad)

Thank you very much for your response.
So, as long as I understood, in your solution the function g(.) is assumed to be fixed; because you have returned None as its gradient. How can I also modify the gradient of g(.) assuming that it has some trainable parameters? (what needs to be returned instead of None, since g(.) is not a Tensor, but a module.)

Thanks again.

I forgot about that aspect. Below hacky code may work for that:

class F(autograd.Function):
	
	@staticmethod
	def forward(ctx, x, g):
		with torch.enable_grad():
			x = x.clone().requires_grad_(True)
			z = g(x)
			ctx.save_for_backward(x, z)
			ctx._module = g
			return z.detach()
	
	@staticmethod
	def backward(ctx, output_grad):
		x, z = ctx.saved_tensors
		module_pars = list(ctx._module.parameters())
		g_x, *g_pars = autograd.grad(z, [x] + module_pars, output_grad)
		for p, g_p in zip(module_pars, g_pars):
			if p.grad is None:
				p.grad = g_p
			else:
				p.grad += g_p
		
		return g_x, None

Safer alternative is to extract parameters() outside and use *args, it is a bit more work to write that.