Converting BatchNorm2d to BatchNorm3d

I have a pretrained network containing BatchNorm2d layers. I want to inflate the network to 3d, (concatenate spatial filters in temporal dimension converting 2d cnn to 3d cnn, and similarly batchnorm2d to batchnorm3d)

Basically you can assume if output of single image on model is ‘x’, then output of k similar stacked images would be ‘x’ stacked k times.

I inflated conv layers by manually copying the weights and it worked pretty well using the function below

def inflateconv(conv3d, conv):
	conv3d.weight.data = conv.weight.data[:,:,None,:,:].expand(conv3d.weight.data.size()) * (1./(conv3d.weight.data.shape[2]))
	conv3d.bias.data = conv.bias.data
	conv3d.weight.data = conv3d.weight.data.contiguous()
	conv3d.bias.data = conv3d.bias.data.contiguous()
	return

I wrote a function for Batchnorm but it doesnt work as expected

def inflatebn(bn3d, bn):
	bn3d.weight.data = bn.weight.data
	bn3d.bias.data = bn.bias.data
	bn3d.running_mean = bn.running_mean
	bn3d.running_var = bn.running_var
	bn3d.weight.data = bn3d.weight.data.contiguous()
	bn3d.weight.data = bn3d.weight.data.contiguous()
	bn3d.running_mean = bn3d.running_mean.contiguous()
	bn3d.running_var = bn3d.running_var.contiguous()
	return

But batchnorm 2d output differs from batchnorm3d output!

Just in case someone falls upon this, you may find this discussion useful.

Also I finally shifted to reshaping to 4D tensor for BatchNorm2D (the original) and reshaping back to 5D after it.

class myBatchNorm3D(nn.Module):
	"""docstring for myBatchNorm3D"""
	def __init__(self, inChannels):
		super(myBatchNorm3D, self).__init__()
		self.inChannels = inChannels
		self.bn = nn.BatchNorm2d(self.inChannels)

	def forward(self, input):
		out = input
		N,C,D,H,W = out.size()
		out = out.squeeze(0).t().reshape(D,C,H,W)
		out = self.bn(out.contiguous())
		out = out.t().reshape(C,D,H,W).unsqueeze(0)
		return out

If it is part of some ‘model’ as member bn, just do model.bn.bn = ‘old bn to inflate’ (assign the batchnorm 2d layer, when I only copy weights and biases and running estimates for 2D layer, values were off by margin)