Mean shift function for gray image

The meanshift function given below adds or subtracts channel-wise mean (3-channel) from images. If the code needs to be re-written for 1 channel (gray images), is it necessary to apply this function? If so, how to change the code and how to calculate the mean value for gray images?

class MeanShift(nn.Module):
    def __init__(self, mean_rgb, sub):
        super(MeanShift, self).__init__()

        sign = -1 if sub else 1
        r = mean_rgb[0] * sign
        g = mean_rgb[1] * sign
        b = mean_rgb[2] * sign

        self.shifter = nn.Conv2d(3, 3, 1, 1, 0) #3 is size of output, 3 is size of input, 1 is kernel 1 is padding, 0 is group = torch.eye(3).view(3, 3, 1, 1) # view(3,3,1,1) convert a shape into (3,3,1,1) eye(3) is a 3x3 matrix and diagonal is 1.   = torch.Tensor([r, g, b])
        #in_channels, out_channels,ksize=3, stride=1, pad=1
        # Freeze the mean shift layer
        for params in self.shifter.parameters():
            params.requires_grad = False

    def forward(self, x):
        x = self.shifter(x)
class Net(nn.Module):
    def __init__(self, **kwargs):
        super(Net, self).__init__()
        self.sub_mean = MeanShift((0.4488, 0.4371, 0.4040), sub=True)
        self.add_mean = MeanShift((0.4488, 0.4371, 0.4040), sub=False)

If the dataset contains the gray images,