Batch normalization code optimization?

Hi, I am trying to implement Deep Quaternion Networks. I was able to implement the batch normalization technique. But it requires a lot of GPU memory. Is there any way I can optimize the code provided below?

class MyQuaternionBatchNorm2d(torch.nn.Module):
  def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
      super(MyQuaternionBatchNorm2d, self).__init__()
      self.num_features = num_features
      self.qnum_features = num_features//4
      self.eps = eps
      self.momentum = momentum
      self.affine = affine
      self.track_running_stats = track_running_stats
      
      if self.affine:
          self.weight = torch.nn.Parameter(torch.Tensor(self.qnum_features, 10))
          self.bias = torch.nn.Parameter(torch.Tensor(num_features))
      else:
          self.register_parameter('weight', None)
          self.register_parameter('bias', None)
          
      if self.track_running_stats:
          self.register_buffer('running_mean', torch.zeros(self.qnum_features,4))
          self.register_buffer('running_covar', torch.zeros(self.qnum_features,10))
          self.running_covar[:,0] = 1/ np.sqrt(4)
          self.running_covar[:,1] = 1/ np.sqrt(4)
          self.running_covar[:,2] = 1/ np.sqrt(4)
          self.running_covar[:,3] = 1/ np.sqrt(4)
          self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
      else:
          self.register_buffer('running_mean',None)
          self.register_buffer('running_covar', None)
          self.register_parameter('num_batches_tracked', None)
      self.reset_parameters()
      
  def reset_running_stats(self):
      if self.track_running_stats:
          self.running_mean.zero_()
          self.running_covar.zero_()
          self.running_covar[:,0] = 1/ np.sqrt(4)
          self.running_covar[:,1] = 1/ np.sqrt(4)
          self.running_covar[:,2] = 1/ np.sqrt(4)
          self.running_covar[:,3] = 1/ np.sqrt(4)
          self.num_batches_tracked.zero_()

  def reset_parameters(self):
      self.reset_running_stats()
      
      if self.affine:
          torch.nn.init.zeros_(self.weight)
          torch.nn.init.constant_(self.weight[:,0], 1/ np.sqrt(4))
          torch.nn.init.constant_(self.weight[:,4], 1/ np.sqrt(4))
          torch.nn.init.constant_(self.weight[:,7], 1/ np.sqrt(4))
          torch.nn.init.constant_(self.weight[:,9], 1/ np.sqrt(4))
          torch.nn.init.zeros_(self.bias)
          
  def _check_input_dim(self, input):
      if input.dim() != 4:
          raise ValueError('expected 4D input (got {}D input)'
                           .format(input.dim()))
  
  @staticmethod
  def _decomposition_v1(r,i,j,k,Vrr, Vri, Vrj, Vrk, Vii, Vij, Vik, Vjj, Vjk, Vkk):
      Wrr = torch.sqrt(Vrr)
      Wri = (1.0 / Wrr) * (Vri)
      Wii = torch.sqrt((Vii - (Wri.pow(2))))
      Wrj = (1.0 / Wrr) * (Vrj)
      Wij = (1.0 / Wii) * (Vij - (Wri*Wrj))
      Wjj = torch.sqrt((Vjj - (Wij.pow(2) + Wrj.pow(2))))
      Wrk = (1.0 / Wrr) * (Vrk)
      Wik = (1.0 / Wii) * (Vik - (Wri*Wrk))
      Wjk = (1.0 / Wjj) * (Vjk - (Wij*Wik + Wrj*Wrk))
      Wkk = torch.sqrt((Vkk - (Wjk.pow(2) + Wik.pow(2) + Wrk.pow(2))))
      
      cat_W_1 = torch.cat([Wrr, Wri, Wrj, Wrk])
      cat_W_2 = torch.cat([Wri,Wii, Wij, Wik])
      cat_W_3 = torch.cat([Wrj, Wij, Wjj, Wjk])
      cat_W_4 = torch.cat([Wrk, Wik, Wjk, Wkk])
      
      output =  cat_W_1[None,:,None,None]  *  r.repeat(1,4,1,1) + cat_W_2[None,:,None,None] *   i.repeat(1,4,1,1) \
                  + cat_W_3[None,:,None,None]  *   j.repeat(1,4,1,1) +  cat_W_4[None,:,None,None]  *  k.repeat(1,4,1,1)

      return output
  
  def forward(self, input):
      self._check_input_dim(input)
      r,i,j,k = torch.chunk(input, 4, dim=1)
      
      exponential_average_factor = 0.0

      if self.training and self.track_running_stats:
          if self.num_batches_tracked is not None:
              self.num_batches_tracked += 1
              if self.momentum is None:  # use cumulative moving average
                  exponential_average_factor = 1.0 / float(self.num_batches_tracked)
              else:  # use exponential moving average
                  exponential_average_factor = self.momentum

      # calculate running estimates
      if self.training:
          mean_r, mean_i, mean_j, mean_k = r.mean([0, 2, 3]),i.mean([0, 2, 3]),j.mean([0, 2, 3]),k.mean([0, 2, 3])
          n = input.numel() / input.size(1)
          mean = torch.stack((mean_r, mean_i, mean_j, mean_k), dim=1)
          # update running mean
          with torch.no_grad():
              self.running_mean = exponential_average_factor * mean + (1 - exponential_average_factor) * self.running_mean
                  
          r = r-mean_r[None, :, None, None]
          i = i-mean_i[None, :, None, None]
          j = j-mean_j[None, :, None, None]
          k = k-mean_k[None, :, None, None]
          
          Vrr = (r.pow(2).mean([0, 2, 3])) + self.eps
          Vii = (i.pow(2).mean([0, 2, 3])) + self.eps
          Vjj = (j.pow(2).mean([0, 2, 3])) + self.eps
          Vkk = (k.pow(2).mean([0, 2, 3])) + self.eps
          Vri = ((r*i).mean([0, 2, 3]))
          Vrj = ((r*j).mean([0, 2, 3]))
          Vrk = ((r*k).mean([0, 2, 3]))
          Vij = ((i*j).mean([0, 2, 3]))
          Vik = ((i*k).mean([0, 2, 3]))
          Vjk = ((j*k).mean([0, 2, 3])) 

          with torch.no_grad():
              self.running_covar[:,0] = exponential_average_factor * Vrr * n / (n - 1) + (1 - exponential_average_factor) * self.running_covar[:,0]
              self.running_covar[:,1] = exponential_average_factor * Vii * n / (n - 1) + (1 - exponential_average_factor) * self.running_covar[:,1]
              self.running_covar[:,2] = exponential_average_factor * Vjj * n / (n - 1) + (1 - exponential_average_factor) * self.running_covar[:,2]
              self.running_covar[:,3] = exponential_average_factor * Vkk * n / (n - 1) + (1 - exponential_average_factor) * self.running_covar[:,3]
              
              self.running_covar[:,4] = exponential_average_factor * Vri * n / (n - 1) + (1 - exponential_average_factor) * self.running_covar[:,4]
              self.running_covar[:,5] = exponential_average_factor * Vrj * n / (n - 1) + (1 - exponential_average_factor) * self.running_covar[:,5]
              self.running_covar[:,6] = exponential_average_factor * Vrk * n / (n - 1) + (1 - exponential_average_factor) * self.running_covar[:,6]
              self.running_covar[:,7] = exponential_average_factor * Vij * n / (n - 1) + (1 - exponential_average_factor) * self.running_covar[:,7]
              self.running_covar[:,8] = exponential_average_factor * Vik * n / (n - 1) + (1 - exponential_average_factor) * self.running_covar[:,8]
              self.running_covar[:,9] = exponential_average_factor * Vjk * n / (n - 1) + (1 - exponential_average_factor) * self.running_covar[:,9]
      else:
          mean = self.running_mean
          Vrr = self.running_covar[:,0]+self.eps
          Vii = self.running_covar[:,1]+self.eps
          Vjj = self.running_covar[:,2]+self.eps
          Vkk = self.running_covar[:,3]+self.eps
          
          Vri = self.running_covar[:,4]+self.eps
          Vrj = self.running_covar[:,5]+self.eps
          Vrk = self.running_covar[:,6]+self.eps
          Vij = self.running_covar[:,7]+self.eps
          Vik = self.running_covar[:,8]+self.eps
          Vjk = self.running_covar[:,9]+self.eps
         
          r = r-mean[None,:,0,None,None]
          i = i-mean[None,:,1,None,None]
          j = j-mean[None,:,2,None,None]
          k = k-mean[None,:,3,None,None]
          
      # standardized_output
      input = self._decomposition_v1(r,i,j,k, Vrr, Vri, Vrj, Vrk, Vii, Vij, Vik, Vjj, Vjk, Vkk)
      
      if self.affine:
          r,i,j,k = torch.chunk(input, 4, dim=1)
          
          cat_gamma_1 = torch.cat([self.weight[:,0], self.weight[:,1], self.weight[:,2], self.weight[:,3]])
          cat_gamma_2 = torch.cat([self.weight[:,1], self.weight[:,4], self.weight[:,5], self.weight[:,6]])
          cat_gamma_3 = torch.cat([self.weight[:,2], self.weight[:,5], self.weight[:,7], self.weight[:,8]])
          cat_gamma_4 = torch.cat([self.weight[:,3], self.weight[:,6], self.weight[:,8], self.weight[:,9]])


          input =  cat_gamma_1[None,:,None,None] * r.repeat(1,4,1,1) \
                  + cat_gamma_2[None,:,None,None] * i.repeat(1,4,1,1) \
                  + cat_gamma_3[None,:,None,None] * j.repeat(1,4,1,1) \
                  + cat_gamma_4[None,:,None,None] * k.repeat(1,4,1,1) \
                  + self.bias[None, :, None, None]
      return input
  

Thank you,
Best,
Shreyas Kamath