How to implement a function whose input and output vectors have the same gradient variance?

I want to implement a PyTorch Function that takes in another function and a tensor as arguments, applies it, and on the backward pass matches the gradient variances between the input and the output. Something like the following code.

import torch
from torch.autograd import Function

# Is meant to be arbitrary.
def f(x): 
    return x*2

class VarianceMatch(Function):
    @staticmethod
    def forward(ctx, x):
        y = f(x)
        y_grad_fn = y.grad_fn # Results in None
        ctx.save_for_backward(x,y,y_grad_fn)
        return y

    @staticmethod
    def backward(ctx, grad_y):
        x,y,y_grad_fn = ctx.saved_tensors
        y_grad_fn(ctx,grad_y)
        x.grad *= torch.sqrt(torch.square(grad_y).sum(1) / torch.square(x.grad).sum(1))
        return None

i = torch.scalar_tensor(2,requires_grad=True)
x = VarianceMatch.apply(i)
x.backward(torch.scalar_tensor(1))
print(i.grad)

The trouble is that during the forward pass, the backwards function does not get instantiated so I cannot extract and save it for the backwards pass. Would it be possible to enable this somehow? Also what would be the ideal way to implement something like this?

This question is related to the SO one I asked recently

import torch
from torch.autograd import Function

class VarianceMatch(Function):
    @staticmethod
    def forward(ctx, x):
        y : torch.Tensor = x * 2
        y.requires_grad_()
        y.retain_grad() # Does not work.
        def h(grad): # y.grad is None so this gives me an exception.
            return grad * torch.sqrt(torch.square(y.grad).sum(1) / torch.square(grad).sum(1))
        x.register_hook(h)
        return y

    @staticmethod
    def backward(ctx, grad_y):
        return grad_y * 2

i = torch.scalar_tensor(2,requires_grad=True)
x = VarianceMatch.apply(i)
x.backward(torch.scalar_tensor(1))
print(i.grad)

I am trying to hack it using hooks, but that is not working for me either since the output y does not have the gradient. At this point I am out of ideas.