class GhostBatchNorm(nn.BatchNorm2d):
def __init__(self, num_features, num_splits, **kw):
super().__init__(num_features, **kw)
running_mean = torch.zeros(num_features * num_splits)
running_var = torch.ones(num_features * num_splits)
self.weight.requires_grad = False
self.num_splits = num_splits
self.register_buffer("running_mean", running_mean)
self.register_buffer("running_var", running_var)
def train(self, mode=True, channels_last=False):
if (self.training is True) and (mode is False):
# lazily collate stats when we are going to use them
self.running_mean = torch.mean(
self.running_mean.view(self.num_splits, self.num_features), dim=0
).repeat(self.num_splits)
self.running_var = torch.mean(
self.running_var.view(self.num_splits, self.num_features), dim=0
).repeat(self.num_splits)
return super().train(mode)
def forward(self, input):
n, c, h, w = input.shape
if self.training or not self.track_running_stats:
return F.batch_norm(
input.view(-1, c * self.num_splits, h, w),
self.running_mean,
self.running_var,
self.weight.repeat(self.num_splits),
self.bias.repeat(self.num_splits),
True,
self.momentum,
self.eps,
).view(n, c, h, w)
else:
return F.batch_norm(
input,
self.running_mean[: self.num_features],
self.running_var[: self.num_features],
self.weight,
self.bias,
False,
self.momentum,
self.eps,
)
I’m trying to implement a GhostBatchNorm class with channels_last enabled; however, I kept experiencing the following error:
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
I have checked the shape of input and all parameters (mean, var, etc…) but all of them seem to be the same with or without channels_last enabled (using model.to(memory_format=torch.channels_last)
). Does anyone have a suggestion on what kind of problems I am facing? Thanks in advance.