Batch Whitening (BW) instead of Batch Normalization,

Hi
I’m new to PyTorch and I want to test batch whitening (BW) as defined in https://github.com/roysubhankar/dwt-domain-adaptation, to train a model from scratch to do some domain adaptation and I don’t know how to define my BW layer, how I get the running_mean and running_variance?

class whitening_scale_shift(nn.Module):
	def __init__(self, planes, group_size, running_mean, running_variance, track_running_stats=True, affine=True):
		super(whitening_scale_shift, self).__init__()
		self.planes = planes
		self.group_size = group_size
		self.track_running_stats = track_running_stats
		self.affine = affine
		self.running_mean = running_mean
		self.running_variance = running_variance

		self.wh = whitening.WTransform2d(self.planes, self.group_size, 
								 running_m=self.running_mean, 
							 running_var=self.running_variance, 
							track_running_stats=self.track_running_stats)
		if self.affine:
			self.gamma = nn.Parameter(torch.ones(self.planes, 1, 1))
			self.beta = nn.Parameter(torch.zeros(self.planes, 1, 1))

	def forward(self, x):
		out = self.wh(x)
		if self.affine:
			out = out * self.gamma + self.beta
return out

Could someone help me

If you pass in None, it will instantiate the right thing:

I’m not quite sure whether the authors of the code want to achieve something specific or whether the interface of optionally passing in the tensors is only accidental (the initialization seems very unidiomatic, too) – in theory that should just go through regular state dict loading and saving.

Best regards

Thomas

P.S.: Use triple backticks ``` before and after your code to have better formatting.

thank you @tom I did like you said and it works