Fastest way to process images for Resnet

I have a dataset of (N, H, W, C) images with values in [0, 255]. Some are saved as np.uint8, some as torch.uint8.
I need to process them to pass them to Resnet, and I need to do it efficiently.
What I currently do is

import torchvision.transforms as T

transforms = nn.Sequential(
    T.Resize(256, interpolation=3, antialias=None),
    T.CenterCrop(224),
    T.ConvertImageDtype(torch.float), # also divides by 255 if input is uint8
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
)

img4resnet = transforms(img.permute(0, 3, 1, 2).contiguous())

This is slow, especially if I need to pass many images at once.
There is ToTensor which is faster than permute + contiguous and accepts np.array, but it works with only 1 image at the time (I get an error saying I tried to pass 4D input and it expects 3D). If I have many images I am not sure if looping would be faster…

I have found other old issues asking something similar, but they are all years old.

Finally, I see that the latest version of torchvision now gets the preprocessing transformation from the weights. Is it faster? Can I just replace my current transform?

I’m not sure that the recent changes to torchvision actually improve the speed of the transforms rather than making them more convenient, but it could be worthwhile to update and see if the implementation(s) have become more efficient anyway.

However, for recent generations of GPUs with tensor cores, you may find that a channels-last memory layout (NHWC) actually yields faster execution in convolutional models such as ResNet, so I would also check if leaving the images in channels-last speeds things up in addition to removing a preprocessing step.

Thanks for your reply, but torchvision Resnet must get channel first because of conv2d. Is there a way to pass it with channel last?

from torchvision import models
import torch

model = models.resnet18(pretrained=True, progress=False)
x = torch.rand(1,32,32,3)
model(x)

[...]
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [64, 3, 7, 7], expected input[1, 32, 32, 3] to have 3 channels, but got 32 channels instead

Does converting the model to torch.channels_last work for you?

>>> import torch
>>> import torchvision
>>> model = torchvision.models.resnet50().to(memory_format=torch.channels_last)
>>> input = torch.randn(4, 3, 224, 224).to(memory_format=torch.channels_last)
>>> model(input)
tensor([[ 1.0137,  0.3668, -0.4023,  ..., -0.4069,  0.9567, -0.3115],
        [ 0.9884,  0.3564, -0.6024,  ..., -0.5426,  0.8428, -0.3755],
        [ 1.0849,  0.2691, -0.3835,  ..., -0.3646,  0.9449, -0.1472],
        [ 0.8947,  0.1875, -0.3154,  ..., -0.4034,  0.7230, -0.4413]],
       grad_fn=<AddmmBackward0>)

Also note that channels-last inputs would still have shape [N, C, H, W]:

>>> input.shape
torch.Size([4, 3, 224, 224])
>>> input.is_contiguous()
False
>>> input.is_contiguous(memory_format=torch.channels_last)
True

You can verify this by checking e.g., img.permute(0, 3, 1, 2).is_contiguous(memory_format=torch.channels_last) without calling .contiguous() (which would convert it to channels-first format).

I didn’t know about memory_format=torch.channels_last, thanks!
I tried it but the speed seems exactly the same.

What GPU are you running on?

You can also look at torch.compile as a means to increasing performance:
https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html

NIT: channels-last should be used with amp as I would not expect to see any major perf. improvements in FP32.