I want to implement a PyTorch `Function`

that takes in another function and a tensor as arguments, applies it, and on the backward pass matches the gradient variances between the input and the output. Something like the following code.

```
import torch
from torch.autograd import Function
# Is meant to be arbitrary.
def f(x):
return x*2
class VarianceMatch(Function):
@staticmethod
def forward(ctx, x):
y = f(x)
y_grad_fn = y.grad_fn # Results in None
ctx.save_for_backward(x,y,y_grad_fn)
return y
@staticmethod
def backward(ctx, grad_y):
x,y,y_grad_fn = ctx.saved_tensors
y_grad_fn(ctx,grad_y)
x.grad *= torch.sqrt(torch.square(grad_y).sum(1) / torch.square(x.grad).sum(1))
return None
i = torch.scalar_tensor(2,requires_grad=True)
x = VarianceMatch.apply(i)
x.backward(torch.scalar_tensor(1))
print(i.grad)
```

The trouble is that during the forward pass, the backwards function does not get instantiated so I cannot extract and save it for the backwards pass. Would it be possible to enable this somehow? Also what would be the ideal way to implement something like this?

This question is related to the SO one I asked recently