Using GPU to classify an image with a pretrained model

After training a model using a timm library and get a model, I wrote my own script to classify 1 picture. However, it took 7 seconds to classify 1 224x224 image. I thought I had used GPU to predict but it Ross mentioned that I was still using CPU to calculate gradient descent, thats why it took so long.

Isnt tensor.cuda() the solution?

My code can be found below:

Pastebin: Classify picture - Pastebin.com


Full code: 

import timm
import torch
import pandas as pd
import torch.nn.functional as F
from torch.autograd import Variable
from PIL import Image
from torchvision.transforms import transforms
import time
 
model = timm.create_model('efficientnet_v2s', num_classes=7)
checkpoint = torch.load("models/model_best.pth-d043d179.pth")
model.load_state_dict(checkpoint)
model.eval()
 
 
def predict_image(path):
    print("Prediction in progress")
    image = Image.open(path)
 
    transformation = transforms.Compose([
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
 
    img_tensor = transformation(image).float()
    img_tensor = img_tensor.unsqueeze_(0)
 
    if torch.cuda.is_available():
        img_tensor.cuda()
 
    input = Variable(img_tensor)
    output = model(input)
    index = output.data.numpy().argmax()
    # prob = F.softmax(output, dim=1)
    return index
 
 
if __name__ == "__main__":
    imgpath = "data/2.jpg"
    start = time.time()
    index = predict_image(imgpath)
    df = pd.read_csv("file_index.csv")
    pred = df.iloc[index][1]
    print("Predicted Class ", pred)
    end = time.time()
 
    print("[INFO] Prediction took {:.5f} seconds".format(
        end - start))

You would have to reassign the tensor when calling the to() operator:

    if torch.cuda.is_available():
        img_tensor = img_tensor.cuda()

Also, you need to call cuda() or to('cuda') on the model as well.

Note that Variables are deprecated since PyTorch 0.4, so you can use tensors in newer versions.