Error with custom batch norm

Hi everyone,

I am trying to implement a custom version of BatchNorm with backward pass, but, when I try to use it in training, I get an error “Trying to backward through the graph a second time”

Could anyone help me with this?

My Function:

from torch.autograd import Function

class _batch_norm_function(Function):

    @staticmethod
    def forward(x, mean, variance):
        EPS = 1e-12
        return (x - mean) / torch.sqrt(variance + EPS)
      
    @staticmethod
    def setup_context(ctx, inputs, output):
        x, mean, variance = inputs
        ctx.save_for_backward(x, mean, variance)

    @staticmethod
    def backward(ctx, grad_output):
      EPS = 1e-12
      inp, mean, var = ctx.saved_tensors
      grad_input = None

      if ctx.needs_input_grad[0]:
        N, D = inp.shape
        
        std = torch.sqrt(var + EPS)
      
        xmu = inp - mean
        dxmu1 =  1/std * grad_output
        dxmu2 = 2 * xmu * 1/N * 1/(2 * torch.sqrt(var + EPS)) * torch.sum(-(inp - mean)/std**2 * grad_output, axis=0)
        dx1 = dxmu1 + dxmu2
        dmu = -1 * torch.sum(dxmu1 + dxmu2, axis=0)
        dx2 = 1 /N * dmu
        grad_input = dx1 + dx2

      return grad_input, None, None
  
def bn_function(x, mean, variance):
    return _batch_norm_function.apply(x, mean, variance)

Module:

import torch
import torch.nn as nn
import torch.distributed as dist

class MySyncBatchNorm1D(nn.Module):
  def __init__(self, alpha=0.0):
    super().__init__()
    self.training = True
    self.alpha = alpha
    self.moving_mean = None
    self.moving_variance = None
  
  def sync(self):
    if self.training and torch.distributed.is_available() and torch.distributed.is_initialized():
      size = float(dist.get_world_size())
      dist.all_reduce(self.moving_mean, op=dist.ReduceOp.SUM)
      dist.all_reduce(self.moving_variance, op=dist.ReduceOp.SUM)
      self.moving_variance /= size
      self.moving_mean /= size
    else:
      raise RuntimeError("Sync is only supported during training with dist training turned on")
  
  def forward(self, x):
    n_samples, dim = x.shape
        
    mean = torch.mean(x, axis=0)
    varience = torch.var(x, axis=0)


    if self.moving_mean is None or self.moving_variance is None:
      self.moving_mean = mean
      self.moving_variance = varience

    elif self.training == True:
      self.moving_mean = self.alpha * self.moving_mean + (1 - self.alpha) * mean
      self.moving_variance = self.alpha * self.moving_variance + (1 - self.alpha) * varience

    
    return bn_function(x, self.moving_mean, self.moving_variance)

  def train(self, mode=True):
    self.training = mode
    return self

  def eval(self):
      return self.train(False)