Custom GhostBatchNorm with channels_last

class GhostBatchNorm(nn.BatchNorm2d):
    def __init__(self, num_features, num_splits, **kw):
        super().__init__(num_features, **kw)

        running_mean = torch.zeros(num_features * num_splits)
        running_var = torch.ones(num_features * num_splits)

        self.weight.requires_grad = False
        self.num_splits = num_splits
        self.register_buffer("running_mean", running_mean)
        self.register_buffer("running_var", running_var)

    def train(self, mode=True, channels_last=False):
        if (self.training is True) and (mode is False):
            # lazily collate stats when we are going to use them
            self.running_mean = torch.mean(
                self.running_mean.view(self.num_splits, self.num_features), dim=0
            ).repeat(self.num_splits)
            self.running_var = torch.mean(
                self.running_var.view(self.num_splits, self.num_features), dim=0
            ).repeat(self.num_splits)
        return super().train(mode)

    def forward(self, input):
        n, c, h, w = input.shape

        if self.training or not self.track_running_stats:
            return F.batch_norm(
                input.view(-1, c * self.num_splits, h, w),
                self.running_mean,
                self.running_var,
                self.weight.repeat(self.num_splits),
                self.bias.repeat(self.num_splits),
                True,
                self.momentum,
                self.eps,
            ).view(n, c, h, w)
        else:
            return F.batch_norm(
                input,
                self.running_mean[: self.num_features],
                self.running_var[: self.num_features],
                self.weight,
                self.bias,
                False,
                self.momentum,
                self.eps,
            )

I’m trying to implement a GhostBatchNorm class with channels_last enabled; however, I kept experiencing the following error:
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

I have checked the shape of input and all parameters (mean, var, etc…) but all of them seem to be the same with or without channels_last enabled (using model.to(memory_format=torch.channels_last)). Does anyone have a suggestion on what kind of problems I am facing? Thanks in advance.

I guess the view operation on the output of F.batch_norm might be raising the issue. If so, you should call contiguous() on the output before reshaping it to the channels-first shape.

1 Like

Thank you very much for your help. Your suggestion was correct; I have to call contiguous for the input used inside F.batch_norm in order for the codes to work.

Do you know why using channels_last would require the use of contiguous? Thank you.