Hi
I’m new to PyTorch and I want to test batch whitening (BW) as defined in https://github.com/roysubhankar/dwt-domain-adaptation, to train a model from scratch to do some domain adaptation and I don’t know how to define my BW layer, how I get the running_mean and running_variance?
class whitening_scale_shift(nn.Module):
def __init__(self, planes, group_size, running_mean, running_variance, track_running_stats=True, affine=True):
super(whitening_scale_shift, self).__init__()
self.planes = planes
self.group_size = group_size
self.track_running_stats = track_running_stats
self.affine = affine
self.running_mean = running_mean
self.running_variance = running_variance
self.wh = whitening.WTransform2d(self.planes, self.group_size,
running_m=self.running_mean,
running_var=self.running_variance,
track_running_stats=self.track_running_stats)
if self.affine:
self.gamma = nn.Parameter(torch.ones(self.planes, 1, 1))
self.beta = nn.Parameter(torch.zeros(self.planes, 1, 1))
def forward(self, x):
out = self.wh(x)
if self.affine:
out = out * self.gamma + self.beta
return out
If you pass in None, it will instantiate the right thing:
I’m not quite sure whether the authors of the code want to achieve something specific or whether the interface of optionally passing in the tensors is only accidental (the initialization seems very unidiomatic, too) – in theory that should just go through regular state dict loading and saving.
Best regards
Thomas
P.S.: Use triple backticks ``` before and after your code to have better formatting.