Thank you very much for this great code to reproduce this issue!
Indeed the memory is growing in each epoch.
After looking into the code, I think the reason is that you might track the computation graph in self.running_mean
and self.running_covar
unintentionally.
This might be the case if you assign a value with a grad_fn
to these tensors:
self.running_mean = exponential_average_factor * mean\
+ (1 - exponential_average_factor) * self.running_mean # mean is holding onto the computation graph
If you wrap the update codes into a with torch.no_grad()
guard, the memory footprint stays constant:
class ComplexBatchNorm2D(_ComplexBatchNorm):
def forward(self, input_r, input_i):
assert(input_r.size() == input_i.size())
assert(len(input_r.shape) == 4)
#self._check_input_dim(input)
exponential_average_factor = 0.0
if self.training and self.track_running_stats:
if self.num_batches_tracked is not None:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
if self.training:
# calculate mean of real and imaginary part
mean_r = input_r.mean([0, 2, 3])
mean_i = input_i.mean([0, 2, 3])
mean = torch.stack((mean_r,mean_i),dim=1)
with torch.no_grad():
# update running mean
self.running_mean = exponential_average_factor * mean\
+ (1 - exponential_average_factor) * self.running_mean
# works for 2d
input_r = input_r-mean_r[None, :, None, None]
input_i = input_i-mean_i[None, :, None, None]
# Elements of the covariance matrix (biased for train)
n = input_r.numel() / input_r.size(1)
Crr = 1./n*input_r.pow(2).sum(dim=[0,2,3])+self.eps
Cii = 1./n*input_i.pow(2).sum(dim=[0,2,3])+self.eps
Cri = (input_r.mul(input_i)).mean(dim=[0,2,3])
with torch.no_grad():
self.running_covar[:,0] = exponential_average_factor * Crr * n / (n - 1)\
+ (1 - exponential_average_factor) * self.running_covar[:,0]
self.running_covar[:,1] = exponential_average_factor * Cii * n / (n - 1)\
+ (1 - exponential_average_factor) * self.running_covar[:,1]
self.running_covar[:,2] = exponential_average_factor * Cri * n / (n - 1)\
+ (1 - exponential_average_factor) * self.running_covar[:,2]
else:
mean = self.running_mean
Crr = self.running_covar[:,0]+self.eps
Cii = self.running_covar[:,1]+self.eps
Cri = self.running_covar[:,2]#+self.eps
input_r = input_r-mean[None,:,0,None,None]
input_i = input_i-mean[None,:,1,None,None]
# caclualte the inverse square root the covariance matrix
det = Crr*Cii-Cri.pow(2)
s = torch.sqrt(det)
t = torch.sqrt(Cii+Crr + 2 * s)
inverse_st = 1.0 / (s * t)
Rrr = (Cii + s) * inverse_st
Rii = (Crr + s) * inverse_st
Rri = -Cri * inverse_st
input_r, input_i = Rrr[None,:,None,None]*input_r+Rri[None,:,None,None]*input_i, \
Rii[None,:,None,None]*input_i+Rri[None,:,None,None]*input_r
if self.affine:
input_r, input_i = self.weight[None,:,0,None,None]*input_r+self.weight[None,:,2,None,None]*input_i+\
self.bias[None,:,0,None,None], \
self.weight[None,:,2,None,None]*input_r+self.weight[None,:,1,None,None]*input_i+\
self.bias[None,:,1,None,None]
return input_r, input_i