Image classification with PyTorch

Hi,

I’m trying to create my first image classificator for college project. I tried to follow this tutorial:

And got it working - model accuracy is ~85%. But when I try to classificate random images from the web (pretty similiar) I don’t get accurate predictions. Can You please help me to understand why all predictions with photos from the web are wrong? How to correct my code to achieve valid classifications?

Here’s my code:

import torch
import torchvision
import torch.nn as nn
import kagglehub
from torchvision.transforms import ToTensor
from torchvision.datasets import ImageFolder
import os
import torch.nn.functional as F
from tqdm.notebook import tqdm
from torchvision.utils import make_grid
from torch.utils.data import random_split
from torch.utils.data.dataloader import DataLoader
import matplotlib.pyplot as plt
# %matplotlib inline

path = kagglehub.dataset_download("moltean/fruits")

print("Path to dataset files:", path)

newPath = "/root/.cache/kagglehub/datasets/moltean/fruits/versions/11/fruits-360_dataset_100x100/fruits-360"

# newPath = "/root/.cache/kagglehub/datasets/moltean/fruits/versions/11/fruits-360_dataset_original-size/fruits-360-original-size"

print(newPath)

if os.path.exists(newPath):
    print(f"Zawartość folderu {newPath}:")
    for item in os.listdir(newPath):
        print(item)
else:
    print(f"Ścieżka {newPath} nie istnieje!")

data_dir = "/root/.cache/kagglehub/datasets/moltean/fruits/versions/11/fruits-360_dataset_100x100/fruits-360"
print('Folders :', os.listdir(data_dir))
classes = os.listdir(data_dir + "/Training")
print('131 classes :', classes)

from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor(),
])

dataset = ImageFolder(data_dir + '/Training', transform=transform)
print('Size of training dataset :', len(dataset))
test = ImageFolder(data_dir + '/Test', transform=transform)
print('Size of test dataset :', len(test))

img, label = dataset[101]
print(img.shape)

def show_image(img, label):
    print('Label: ', dataset.classes[label], "("+str(label)+")")
    plt.imshow(img.permute(1, 2, 0))

show_image(*dataset[101])

torch.manual_seed(20)
val_size = len(dataset)//10
train_size = len(dataset) - val_size

train_ds, val_ds = random_split(dataset, [train_size, val_size])
len(train_ds), len(val_ds) # train_ds length = dataset length - val_ds length

batch_size = 64
train_loader = DataLoader(train_ds, batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size*2, num_workers=2, pin_memory=True)
test_loader = DataLoader(test, batch_size*2, num_workers=2, pin_memory=True)

for images, labels in train_loader:
    fig, ax = plt.subplots(figsize=(18,10))
    ax.set_xticks([]); ax.set_yticks([])
    ax.imshow(make_grid(images, nrow=16).permute(1, 2, 0))
    break

def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')

def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device

    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl:
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)

def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

class ImageClassificationBase(nn.Module):
    def training_step(self, batch):
        images, labels = batch
        out = self(images)
        loss = F.cross_entropy(out, labels)

        return loss

    def validation_step(self, batch):
        images, labels = batch
        out = self(images)
        loss = F.cross_entropy(out, labels)
        acc = accuracy(out, labels)
        return {'val_loss': loss.detach(), 'val_acc': acc}

    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}

    def epoch_end(self, epoch, result):
        print("Epoch [{}], train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
            epoch, result['train_loss'], result['val_loss'], result['val_acc']))

class CnnModel(ImageClassificationBase):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Flatten(),
            nn.Linear(256 * 8 * 8, 512),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(512, 141)
        )

    def forward(self, xb):
        return self.network(xb)

model = CnnModel()
# model.cuda()

for images, labels in train_loader:
    print('images.shape:', images.shape)
    out = model(images)
    print('out.shape:', out.shape)
    #print('out[0]:', out[0])
    break

device = get_default_device()
device

train_dl = DeviceDataLoader(train_loader, device)
val_dl = DeviceDataLoader(val_loader, device)
to_device(model, device)

@torch.no_grad()
def evaluate(model, val_loader):
    model.eval()
    outputs = [model.validation_step(batch) for batch in val_loader]
    return model.validation_epoch_end(outputs)

def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):
    history = []
    optimizer = opt_func(model.parameters(), lr)
    for epoch in range(epochs):

        model.train()
        train_losses = []
        for batch in tqdm(train_loader):
            loss = model.training_step(batch)
            train_losses.append(loss)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        result = evaluate(model, val_loader)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        model.epoch_end(epoch, result)
        history.append(result)
    return history

model = to_device(CnnModel(), device)

history=[evaluate(model, val_loader)]
history

num_epochs = 1
opt_func = torch.optim.Adam
lr = 0.001

history+= fit(num_epochs, lr, model, train_dl, val_dl, opt_func)

history+= fit(num_epochs, lr/10, model, train_dl, val_dl, opt_func)

def plot_accuracies(history):
    accuracies = [x['val_acc'] for x in history]
    plt.plot(accuracies, '-x')
    plt.xlabel('epoch')
    plt.ylabel('accuracy')
    plt.title('Accuracy vs. No. of epochs')
    plt.show()


def plot_losses(history):
    train_losses = [x.get('train_loss') for x in history]
    val_losses = [x['val_loss'] for x in history]
    plt.plot(train_losses, '-bx')
    plt.plot(val_losses, '-rx')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.legend(['Training', 'Validation'])
    plt.title('Loss vs. No. of epochs')
    plt.show()

plot_accuracies(history)

plot_losses(history)

evaluate(model, test_loader)

torch.save(model.state_dict(), 'cnn_model.pth')

torch.save(model, 'cnn_model_full.pth')

import random
import torch
import matplotlib.pyplot as plt

def show_random_images_with_predictions(model, test_loader, classes, device, num_images=5):

    model.eval()


    images, labels = next(iter(test_loader))


    images, labels = images.to(device), labels.to(device)


    with torch.no_grad():
        outputs = model(images)
        _, preds = torch.max(outputs, dim=1)


    fig, axes = plt.subplots(1, num_images, figsize=(15, 8))
    for i in range(num_images):
        ax = axes[i]
        ax.imshow(images[i].permute(1, 2, 0).cpu())
        ax.set_title(f'Pred: {classes[preds[i]]}\nTrue: {classes[labels[i]]}')
        ax.axis('off')
    plt.show()


show_random_images_with_predictions(model, test_loader, dataset.classes, device, num_images=5)

print('131 classes :', classes)

model = torch.load('cnn_model_full.pth')
model.eval()

import os
from PIL import Image
import torch
from torchvision import transforms
import matplotlib.pyplot as plt

transform = transforms.Compose([
    transforms.Resize((64,64)),
    transforms.ToTensor()
])

model = model.to(device)

def show_predictions_for_sample_images(model, folder_path, classes, device, transform=None):
    model.eval()

    images = os.listdir(folder_path)


    fig, axes = plt.subplots(3, 4, figsize=(12, 8))
    axes = axes.flatten()

    for i, image_name in enumerate(images):
        image_path = os.path.join(folder_path, image_name)


        img = Image.open(image_path).convert('RGB')
        if transform:
            img = transform(img)

        img = img.unsqueeze(0).to(device)


        with torch.no_grad():
            output = model(img)
            _, pred = torch.max(output, dim=1)


        ax = axes[i]
        ax.imshow(img.squeeze().permute(1, 2, 0).cpu())
        ax.set_title(f'Pred: {classes[pred]}')

    plt.tight_layout()
    plt.show()


folder_path = "/content/sampleFruits"


show_predictions_for_sample_images(model, folder_path, dataset.classes, device, transform)

And here are the predictions i get:

Thanks in advance for any suggestions !

Your model might have learned specifics from your training dataset (e.g. the image encoding). You could try to add more samples into your training set to diversify the images more.

in this dataset they are all single fruits brightly illuminatd with blank background. The model will be unhappy with real varied images with varied lighting contrast and backgrounds.
One answer is to add augmentation to the dataset,

[/quote]