Nn.interpolate function unexpected behaviour

Hi folks, a while ago I built myself a translation of the torchvision.transforms library that operates over tensors.
To substitute PIL (or accimage) resize() i use nn.interpolate(), which at first sight works (in -linear) mode but with slight different results on high contrast pixel differences, no difference in mean, but different std.
On -cubic mode its a bit stranger, min max go beyond original value range., e.g instead of 0-1 range I will get something like -0.1 to 1.15

Is this particularly important, no, data can still be trained. But benchmarking results are different.
Why I did this, utilizing tensor augmentation is of value, as demonstrated recently in Nvidia’s ADAStyleGAN

I am curious why the differences. I dug into the code and while PIL is easy to read, torch requires a bit more time devoted to the issue.

I got same results in cpu and cuda., both equally different from pil and accimage.

example
differences appear minor, but they change benchmarking results

import io
import requests
import numpy as np
import PIL
import matplotlib.pyplot as plt
import torch
from torch.nn.functional import interpolate
![difpytorch_pil|690x268](upload://lFVUHsLWLC5JzwxOvPmZznqUsum.jpeg) 
def test_dif(dtype="float64", mode="bilinear", show=False):

    sig = lambda x, msg="": print("%s\tmin: %.4f, max: %.4f mean: %.4f, std: %.4f \tshape: %s, dtype: %s"%(msg, x.min(), x.max(), x.mean(), x.std(), str(tuple(x.shape)), str(x.dtype)))
    url = ("https://ichef.bbci.co.uk/news/976/cpsprodpb/10207/production/_116155066_campercats.png")
    pimg = PIL.Image.open(io.BytesIO(requests.get(url).content))
    tensor = torch.from_numpy((np.array(pimg)/255).astype(dtype)).permute(2, 0, 1).contiguous()
    tensor = tensor.view(1, *tensor.shape)
    print(dtype, mode, "Resize test vs interpolate, ",  pimg.size)

    new_size = [512, (512*pimg.size[0])//pimg.size[1]]
    resample = {"bilinear":PIL.Image.BILINEAR, "bicubic":PIL.Image.BICUBIC}

    pimg_sz = (np.array(pimg.resize(size=new_size[::-1], resample=resample[mode]))/255).astype(dtype)
    ptensor = interpolate(tensor, size=new_size, mode=mode, align_corners=False)
    
sig(ptensor, msg="  interpolate() ")
    sig(pimg_sz, msg="  Image.resize()")
    if show:
        ntensor = ptensor[0].numpy().transpose(1, 2, 0)
        diff = ntensor - pimg_sz
        diff = (diff - diff.min())/(diff.max() - diff.min())
        plt.figure(figsize=(18, 7))
        # plt.subplot(131)
        # plt.imshow(pimg_sz)
        # plt.subplot(132)
        # plt.imshow(ntensor)
        # plt.subplot(133)
        plt.imshow(diff)
        plt.tight_layout()
        plt.show()
if __name__ == "__main__":
    test_dif("float64", "bicubic")
    test_dif("float32", "bicubic")
    test_dif("float64", "bilinear")
    test_dif("float32", "bilinear", show=True)

Result

float64 bicubic Resize test vs interpolate,  (976, 549)
  interpolate()         min: -0.0877, max: 1.0536 mean: 0.3292, std: 0.2270     shape: (1, 3, 512, 910), dtype: torch.float64
  Image.resize()        min: 0.0000, max: 1.0000 mean: 0.3292, std: 0.2266      shape: (512, 910, 3), dtype: float64
float32 bicubic Resize test vs interpolate,  (976, 549)
  interpolate()         min: -0.0877, max: 1.0536 mean: 0.3292, std: 0.2270     shape: (1, 3, 512, 910), dtype: torch.float32
  Image.resize()        min: 0.0000, max: 1.0000 mean: 0.3292, std: 0.2266      shape: (512, 910, 3), dtype: float32
float64 bilinear Resize test vs interpolate,  (976, 549)
  interpolate()         min: 0.0000, max: 1.0000 mean: 0.3292, std: 0.2261      shape: (1, 3, 512, 910), dtype: torch.float64
  Image.resize()        min: 0.0000, max: 1.0000 mean: 0.3292, std: 0.2260      shape: (512, 910, 3), dtype: float64
float32 bilinear Resize test vs interpolate,  (976, 549)
  interpolate()         min: 0.0000, max: 1.0000 mean: 0.3292, std: 0.2261      shape: (1, 3, 512, 910), dtype: torch.float32
  Image.resize()        min: 0.0000, max: 1.0000 mean: 0.3292, std: 0.2260      shape: (512, 910, 3), dtype: float32

ignore the comment on bicubic, interpolate. I notice, I have to clamp it.

That sounds cool, but note that torchvision implements the transformations now for tensors step by step, so you might consider using it directly or contribute instead to it.

E.g. resize seems to work on tensors already:

import torchvision.transforms.functional as TF

out = TF.resize(torch.randn(3, 24, 24), (12, 12))
print(out.shape)
> torch.Size([3, 12, 12])
2 Likes

thank you ptrblk! Im not activelly developing my transformations - I just happen to use them, when I noticed slightly odd benchmarking numbers
But Ill look at the stack and see if theres any I can contribute - to the project. Appreciated!

1 Like

@ptrblck, for reference, just to see if anything changed, I installed current torchvision, mod’d my script to

from torchvision.transforms.functional import resize
ttensor = resize(img=tensor, size=new_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR)
diff = ttensor.numpy().transpose(1, 2, 0) - pimg_sz
#...

it does seem that resize uses nn.interpolate just like i did, or same resampling algo - so the cat still has the shadow. It isnt important as these bilinear mode isnt lossless.

Nice to see functional.affine() !

Just to close the discussion with a bit of info.
i tried, torchvision.resize(), torch.nn.interpolate(), PIL.Image.resize(), cv2.resize(), and tensorflow.image.resize()

torchvision.resize() and torch.nn.interpolate() return image diff of 0.
torch 1.5 and tensorflow 2.2 have a slight differetnce(image attachhed)
torch 1.8 and tensorflow 2.2 are identical -

all others are different from each other! curiously, since models tend to learn surface, they will all return slightly different benchmarks. Which is what led me to look at this.

gist: cv2 PIL.Image torch and tensorflow have sligthly different interpolation algorithms; curious · GitHub

1 Like