Per Image Normalization

How to perform similar normalization as tf.image.per_image_standardization in pytorch?

Any way to achieve this?

Could you explain, what this method is doing, as the docs don’t seem to give much information:

Linearly scales each image in  `image`  to have mean 0 and variance 1.

For each 3-D image  `x`  in  `image` , computes  `(x - mean) / adjusted_stddev` , where
* `mean`  is the average of all values in  `x`
* `adjusted_stddev = max(stddev, 1.0/sqrt(N))`  is capped away from 0 to protect against division by 0 when handling uniform images
  * `N`  is the number of elements in  `x`
  * `stddev`  is the standard deviation of all values in  `x`

I don’t understand what the “per image” part of this normalization is, if the mean is the "average of all values in x and the stddev also seems to use all elements.

This is a replica that we created in PyTorch to use as a lambda function in our transforms. We compared these results with the tensorflow implementation and it seems to work the same.

I believe that this calculates the mean of the pixels in a single image and ‘x’ here refers to each pixel of the same image.

def per_image_standardization(image):
    """
    This function creates a custom per image standardization
    transform which is used for data augmentation.
    params:
        - image (torch Tensor): Image Tensor that needs to be standardized.
    
    returns:
        - image (torch Tensor): Image Tensor post standardization.
    """
    # get original data type
    orig_dtype = image.dtype

    # compute image mean
    image_mean = torch.mean(image, dim=(-1, -2, -3))

    # compute image standard deviation
    stddev = torch.std(image, axis=(-1, -2, -3))

    # compute number of pixels
    num_pixels = torch.tensor(torch.numel(image), dtype=torch.float32)

    # compute minimum standard deviation
    min_stddev = torch.rsqrt(num_pixels)

    # compute adjusted standard deviation
    adjusted_stddev = torch.max(stddev, min_stddev)

    # normalize image
    image -= image_mean
    image = torch.div(image, adjusted_stddev)

    # make sure that image output dtype  == input dtype
    assert image.dtype == orig_dtype

    return image
2 Likes