Custom gradient bias correction based on exponential moving average

I am new to pytorch and I am implementing the paper below. The main algorithm works fine but I am struggling to implement the gradient bias correction in section 3.2. Writing a custom torch.autograd.Function and adding the running exponential moving average to the context seems to be way to do it but I am getting the following error on the backward call “RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn”.

class LogMeanExp_Unbiased(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, running_ema, alpha=0.1):
        yt = torch.exp(input).mean()
        yt.requires_grad = True
        if running_ema == 0:
            running_ema = yt
        else:
            running_ema = alpha * yt + (1-alpha)*running_ema.item()

        ctx.input = input
        ctx.yt = yt
        ctx.running_ema = running_ema
        return yt.log()

    @staticmethod
    def backward(ctx, grad_output):
        return (ctx.input * ctx.yt).backward() / ctx.running_ema, None, None

I think, you do not need .backward() call inside backward method.

    @staticmethod
    def backward(ctx, grad_output):
        return (ctx.input * ctx.yt) / ctx.running_ema, None, None

I did not read the paper. But this will resolve the error atleast

Thank you for your reply InnovArul.

Edit: In case someone else is looking to do something similar the solution seems to be:

class ExpMeanLog(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, running_ema):
        ctx.save_for_backward(input, running_ema)
        return input.exp().mean().log()

    @staticmethod
    def backward(ctx, grad_output):
        input, running_ema = ctx.saved_tensors
        return grad_output * input.exp() / running_ema / input.shape[0], None