I am new to pytorch and I am implementing the paper below. The main algorithm works fine but I am struggling to implement the gradient bias correction in section 3.2. Writing a custom torch.autograd.Function and adding the running exponential moving average to the context seems to be way to do it but I am getting the following error on the backward call “RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn”.
class LogMeanExp_Unbiased(torch.autograd.Function):
@staticmethod
def forward(ctx, input, running_ema, alpha=0.1):
yt = torch.exp(input).mean()
yt.requires_grad = True
if running_ema == 0:
running_ema = yt
else:
running_ema = alpha * yt + (1-alpha)*running_ema.item()
ctx.input = input
ctx.yt = yt
ctx.running_ema = running_ema
return yt.log()
@staticmethod
def backward(ctx, grad_output):
return (ctx.input * ctx.yt).backward() / ctx.running_ema, None, None