Hi, I am trying to implement Deep Quaternion Networks. I was able to implement the batch normalization technique. But it requires a lot of GPU memory. Is there any way I can optimize the code provided below?
class MyQuaternionBatchNorm2d(torch.nn.Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
super(MyQuaternionBatchNorm2d, self).__init__()
self.num_features = num_features
self.qnum_features = num_features//4
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
if self.affine:
self.weight = torch.nn.Parameter(torch.Tensor(self.qnum_features, 10))
self.bias = torch.nn.Parameter(torch.Tensor(num_features))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(self.qnum_features,4))
self.register_buffer('running_covar', torch.zeros(self.qnum_features,10))
self.running_covar[:,0] = 1/ np.sqrt(4)
self.running_covar[:,1] = 1/ np.sqrt(4)
self.running_covar[:,2] = 1/ np.sqrt(4)
self.running_covar[:,3] = 1/ np.sqrt(4)
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
else:
self.register_buffer('running_mean',None)
self.register_buffer('running_covar', None)
self.register_parameter('num_batches_tracked', None)
self.reset_parameters()
def reset_running_stats(self):
if self.track_running_stats:
self.running_mean.zero_()
self.running_covar.zero_()
self.running_covar[:,0] = 1/ np.sqrt(4)
self.running_covar[:,1] = 1/ np.sqrt(4)
self.running_covar[:,2] = 1/ np.sqrt(4)
self.running_covar[:,3] = 1/ np.sqrt(4)
self.num_batches_tracked.zero_()
def reset_parameters(self):
self.reset_running_stats()
if self.affine:
torch.nn.init.zeros_(self.weight)
torch.nn.init.constant_(self.weight[:,0], 1/ np.sqrt(4))
torch.nn.init.constant_(self.weight[:,4], 1/ np.sqrt(4))
torch.nn.init.constant_(self.weight[:,7], 1/ np.sqrt(4))
torch.nn.init.constant_(self.weight[:,9], 1/ np.sqrt(4))
torch.nn.init.zeros_(self.bias)
def _check_input_dim(self, input):
if input.dim() != 4:
raise ValueError('expected 4D input (got {}D input)'
.format(input.dim()))
@staticmethod
def _decomposition_v1(r,i,j,k,Vrr, Vri, Vrj, Vrk, Vii, Vij, Vik, Vjj, Vjk, Vkk):
Wrr = torch.sqrt(Vrr)
Wri = (1.0 / Wrr) * (Vri)
Wii = torch.sqrt((Vii - (Wri.pow(2))))
Wrj = (1.0 / Wrr) * (Vrj)
Wij = (1.0 / Wii) * (Vij - (Wri*Wrj))
Wjj = torch.sqrt((Vjj - (Wij.pow(2) + Wrj.pow(2))))
Wrk = (1.0 / Wrr) * (Vrk)
Wik = (1.0 / Wii) * (Vik - (Wri*Wrk))
Wjk = (1.0 / Wjj) * (Vjk - (Wij*Wik + Wrj*Wrk))
Wkk = torch.sqrt((Vkk - (Wjk.pow(2) + Wik.pow(2) + Wrk.pow(2))))
cat_W_1 = torch.cat([Wrr, Wri, Wrj, Wrk])
cat_W_2 = torch.cat([Wri,Wii, Wij, Wik])
cat_W_3 = torch.cat([Wrj, Wij, Wjj, Wjk])
cat_W_4 = torch.cat([Wrk, Wik, Wjk, Wkk])
output = cat_W_1[None,:,None,None] * r.repeat(1,4,1,1) + cat_W_2[None,:,None,None] * i.repeat(1,4,1,1) \
+ cat_W_3[None,:,None,None] * j.repeat(1,4,1,1) + cat_W_4[None,:,None,None] * k.repeat(1,4,1,1)
return output
def forward(self, input):
self._check_input_dim(input)
r,i,j,k = torch.chunk(input, 4, dim=1)
exponential_average_factor = 0.0
if self.training and self.track_running_stats:
if self.num_batches_tracked is not None:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
# calculate running estimates
if self.training:
mean_r, mean_i, mean_j, mean_k = r.mean([0, 2, 3]),i.mean([0, 2, 3]),j.mean([0, 2, 3]),k.mean([0, 2, 3])
n = input.numel() / input.size(1)
mean = torch.stack((mean_r, mean_i, mean_j, mean_k), dim=1)
# update running mean
with torch.no_grad():
self.running_mean = exponential_average_factor * mean + (1 - exponential_average_factor) * self.running_mean
r = r-mean_r[None, :, None, None]
i = i-mean_i[None, :, None, None]
j = j-mean_j[None, :, None, None]
k = k-mean_k[None, :, None, None]
Vrr = (r.pow(2).mean([0, 2, 3])) + self.eps
Vii = (i.pow(2).mean([0, 2, 3])) + self.eps
Vjj = (j.pow(2).mean([0, 2, 3])) + self.eps
Vkk = (k.pow(2).mean([0, 2, 3])) + self.eps
Vri = ((r*i).mean([0, 2, 3]))
Vrj = ((r*j).mean([0, 2, 3]))
Vrk = ((r*k).mean([0, 2, 3]))
Vij = ((i*j).mean([0, 2, 3]))
Vik = ((i*k).mean([0, 2, 3]))
Vjk = ((j*k).mean([0, 2, 3]))
with torch.no_grad():
self.running_covar[:,0] = exponential_average_factor * Vrr * n / (n - 1) + (1 - exponential_average_factor) * self.running_covar[:,0]
self.running_covar[:,1] = exponential_average_factor * Vii * n / (n - 1) + (1 - exponential_average_factor) * self.running_covar[:,1]
self.running_covar[:,2] = exponential_average_factor * Vjj * n / (n - 1) + (1 - exponential_average_factor) * self.running_covar[:,2]
self.running_covar[:,3] = exponential_average_factor * Vkk * n / (n - 1) + (1 - exponential_average_factor) * self.running_covar[:,3]
self.running_covar[:,4] = exponential_average_factor * Vri * n / (n - 1) + (1 - exponential_average_factor) * self.running_covar[:,4]
self.running_covar[:,5] = exponential_average_factor * Vrj * n / (n - 1) + (1 - exponential_average_factor) * self.running_covar[:,5]
self.running_covar[:,6] = exponential_average_factor * Vrk * n / (n - 1) + (1 - exponential_average_factor) * self.running_covar[:,6]
self.running_covar[:,7] = exponential_average_factor * Vij * n / (n - 1) + (1 - exponential_average_factor) * self.running_covar[:,7]
self.running_covar[:,8] = exponential_average_factor * Vik * n / (n - 1) + (1 - exponential_average_factor) * self.running_covar[:,8]
self.running_covar[:,9] = exponential_average_factor * Vjk * n / (n - 1) + (1 - exponential_average_factor) * self.running_covar[:,9]
else:
mean = self.running_mean
Vrr = self.running_covar[:,0]+self.eps
Vii = self.running_covar[:,1]+self.eps
Vjj = self.running_covar[:,2]+self.eps
Vkk = self.running_covar[:,3]+self.eps
Vri = self.running_covar[:,4]+self.eps
Vrj = self.running_covar[:,5]+self.eps
Vrk = self.running_covar[:,6]+self.eps
Vij = self.running_covar[:,7]+self.eps
Vik = self.running_covar[:,8]+self.eps
Vjk = self.running_covar[:,9]+self.eps
r = r-mean[None,:,0,None,None]
i = i-mean[None,:,1,None,None]
j = j-mean[None,:,2,None,None]
k = k-mean[None,:,3,None,None]
# standardized_output
input = self._decomposition_v1(r,i,j,k, Vrr, Vri, Vrj, Vrk, Vii, Vij, Vik, Vjj, Vjk, Vkk)
if self.affine:
r,i,j,k = torch.chunk(input, 4, dim=1)
cat_gamma_1 = torch.cat([self.weight[:,0], self.weight[:,1], self.weight[:,2], self.weight[:,3]])
cat_gamma_2 = torch.cat([self.weight[:,1], self.weight[:,4], self.weight[:,5], self.weight[:,6]])
cat_gamma_3 = torch.cat([self.weight[:,2], self.weight[:,5], self.weight[:,7], self.weight[:,8]])
cat_gamma_4 = torch.cat([self.weight[:,3], self.weight[:,6], self.weight[:,8], self.weight[:,9]])
input = cat_gamma_1[None,:,None,None] * r.repeat(1,4,1,1) \
+ cat_gamma_2[None,:,None,None] * i.repeat(1,4,1,1) \
+ cat_gamma_3[None,:,None,None] * j.repeat(1,4,1,1) \
+ cat_gamma_4[None,:,None,None] * k.repeat(1,4,1,1) \
+ self.bias[None, :, None, None]
return input
Thank you,
Best,
Shreyas Kamath