Hi,
I’ve been trying to use the new torchvision.io
+ torch.jit
to speedup my data loading when I noticed some weird behavior: 1) transforms.Resize
is slow (for upsampling) when combined with transforms.ConvertImageDtype
, 2) io.read_image
+ transforms.ConvertImageDtype
does not return the same output as read_with_pil
and transforms.ToTensor
.
Below is the code to test this for reproduction (excuse the formatting, it was done on a notebook). Also, I have a 64x64 RGB image called “1.jpg” in the same directory as this file:
import torch
import torchvision.transforms as T
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
import torchvision.io as io
def pil_loader(path):
with open(path, 'rb') as f:
with Image.open(f) as img:
return img.convert('RGB')
def torch_loader(path):
mode = io.ImageReadMode.RGB
return io.read_image(path, mode)
class myDataset(Dataset):
def __init__(self, num_images=10_000, aug=None, loader=pil_loader):
self.aug = aug
self.num_images = num_images
self.img_path = "1.jpg"
self.loader = loader
def __len__(self):
return self.num_images
def __getitem__(self, idx):
img = self.loader(self.img_path)
if self.aug is not None:
return self.aug(img)
return img
img_size = 224
mean = (0.1, 0.1, 0.1)
std = (0.2, 0.2, 0.2)
aug1 = T.Compose([T.Resize((img_size, img_size)), T.ToTensor(), T.Normalize(mean, std)])
aug2 = T.Compose([T.ToTensor(), T.Resize((img_size, img_size)), T.Normalize(mean, std)])
aug3 = nn.Sequential(T.ConvertImageDtype(torch.float), T.Resize((img_size, img_size)), T.Normalize(mean, std))
aug4 = torch.jit.script(aug3)
d1 = myDataset(aug=aug1)
d2 = myDataset(aug=aug2)
d3 = myDataset(aug=aug3, loader=torch_loader)
d4 = myDataset(aug=aug4, loader=torch_loader)
# Test loop speed
for i in tqdm(range(10_000)):
img = d1[i]
for i in tqdm(range(10_000)):
img = d2[i]
for i in tqdm(range(10_000)):
img = d3[i]
for i in tqdm(range(10_000)):
img = d4[i]
The results are shown below:
Method | Iteration/s |
---|---|
PIL + Resize + ToTensor + Normalize | 761 |
PIL + ToTensor + Resize + Normalize | 1161 |
IO + ConvertDtype + Resize + Normalize | 555 |
IO + ConvertDtype + Resize + Normalize + JIT | 556 |
As you can see, using PIL → ToTensor
→ Resize
is the fastest. Essentially, this means that reading an image, then converting it to the range [0, 1] in float type, THEN Resizing (upsample) is faster than resizing from PIL itself.
Then how come using torchvision.io
→ ConvertDtype
is much slower? (I’ve tried resizing before type conversion, but it is even slower than 555 it/s). Maybe it is because torchvision.io + ConvertDtype
is slower than PIL + ToTensor
?
Spoiler: It’s not.
Add the following lines of code:
aug5 = T.ToTensor()
aug6 = T.ConvertImageDtype(torch.float32)
d5 = myDataset(aug=aug5)
d6 = myDataset(aug=aug6, loader=torch_loader)
for i in tqdm(range(10_000)):
img1 = d5[i]
for i in tqdm(range(10_000)):
img2 = d6[i]
The result we get now is:
Method | Iteration/s |
---|---|
PIL + ToTensor | 3410 |
IO + ConvertDtype | 5967 |
So something doesn’t add up. IO + ConvertDtype
takes an image as input, and returns a torch.float
tensor, and is faster than PIL + ToTensor
, which takes the same input and returns the same type of tensor. But if you add transforms.Resize
, suddenly using PIL + ToTensor
is faster than IO + ConvertDtype
. Does anyone know why? It seems that this issue mostly occurs when we are upsampling with Resize
, but still, given the same input type (and source image size → target image size), it should be a constant factor, shouldn’t it?
My guess is that it may be related to the fact that PIL + ToTensor
doesn’t necessarily return the same exact value as IO + ConvertDtype
:
print(img1.equal(img2))
# False
print(f"Same/All: {img1.eq(img2).sum()}/{img1.numel()}")
# Same/All: 10267/12288
So that’s another weird behavior that I don’t fully understand. Transforming the two tensors back to PIL seemingly returns the same image (as far as my eyes can tell) but it is weird that two different image read methods will return different tensors and have an effect on the speed of the subsequent Resize operation.
Edit: For any who are wondering, I’m using torch==1.10.0
and torchvision==0.11.1