Backward is not called in custom torch.autograd.Function for exponential moving average

I am working on an implementation that computes the exponential moving average (EMA) for bias correction for the MINE estimator. However, the backward function in EMALoss is not called and not gradient is computed. Any idea why?

class EMALoss(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, running_ema):
        ctx.save_for_backward(input, running_ema)
        input_log_sum_exp = input.exp().mean().log()

        return input_log_sum_exp

    @staticmethod
    def backward(ctx, grad_output):
        EPS = 1e-6
        input, running_mean = ctx.saved_tensors
        grad = grad_output * input.exp().detach() / \
            (running_mean + EPS) / input.shape[0]
        return grad, None
class MINE(nn.Module):

    def ema(self, mu, alpha, past_ema):
        return alpha * mu + (1.0 - alpha) * past_ema


    def ema_loss(self, x, running_mean, alpha):
        t_exp = torch.exp(torch.logsumexp(x, 0) - math.log(x.shape[0]))
     
        if running_mean == 0:
            running_mean = t_exp
        else:
            running_mean = self.ema(t_exp, alpha, running_mean.item())
        t_log = EMALoss.apply(x, running_mean)

        return t_log, running_mean

    def __init__(self, T, alpha=0.01):
        super().__init__()
        self.running_mean = 0
        self.alpha = alpha

    def forward(self, x, z, z_marg=None):
        if z_marg is None:
            z_marg = z[torch.randperm(x.shape[0])]

        t = self.T(x, z).mean()
        t_marg = self.T(x, z_marg)

        second_term, self.running_mean = self.ema_loss(
                t_marg, self.running_mean, self.alpha)


        return -t + second_term

Hi Tassilo!

I’m not sure how you might be calling your custom Function, so
maybe that’s the problem, but, on the surface, your code seems
to be working for me:

>>> torch.__version__
'1.9.0'
>>> _ = torch.manual_seed (2021)
>>> class EMALoss(torch.autograd.Function):
...     @staticmethod
...     def forward(ctx, input, running_ema):
...         ctx.save_for_backward(input, running_ema)
...         input_log_sum_exp = input.exp().mean().log()
...
...         return input_log_sum_exp
...
...     @staticmethod
...     def backward(ctx, grad_output):
...         print ('EMALoss::backward: called')
...         EPS = 1e-6
...         input, running_mean = ctx.saved_tensors
...         grad = grad_output * input.exp().detach() / \
...             (running_mean + EPS) / input.shape[0]
...         return grad, None
...
>>> t = torch.randn (3, requires_grad = True)
>>> l = EMALoss().apply (t, torch.zeros (1))
>>> l
tensor(1.4002, grad_fn=<EMALossBackward>)
>>> l.backward()
EMALoss::backward: called
>>> t.grad
tensor([3282087.7500,  633012.5625,  140842.0469])

Best.

K. Frank

Thanks Frank. Well, maybe I should have been more explicit. I use MINE as a loss (x,z, are inputs and T is an embedding network) - it is a mutual information estimator. In this conjunction EMAloss acts as a helper, and does not work as expected.

Hi Tassilo!

It is not clear what you are doing here. (For example, the forward()
method of your MINE class uses self.T, but self.T is never defined.)

Could you simplify your code down to the bare essentials that reproduce
your issue and then post the simplified version as a complete, runnable
script?

Best.

K. Frank

I observed that the combination of APEX 16-bit with combination of torch.autograd.Function for loss computation is creating very weird behavior, e.g., inconsistent matrix caching, no gradients. Optimization with 32-bit or disabling autocasting autocast(enabled=False) fixed the problem for me…