Guidance on Training Model with PyTorch's ImageNet Dataset

Hello PyTorch community,

I’m seeking guidance on utilizing PyTorch’s torchvision.datasets.ImageNet class for training my model. Specifically, I’m interested in understanding how to effectively leverage the functionalities provided by this class for training purposes. My goal is to train a CNN model on the ImageNet dataset.

Your insights and guidance would be highly appreciated. Thank you for your assistance!

Best,
Hassan Bin Haroon

@ptrblck @albanD Please provide your insights. Thanks!

I don’t understand the question, and which functionalities you are asking about. The dataset provides input arguments as:

  • root (string) – Root directory of the ImageNet Dataset.
  • split (string , optional) – The dataset split, supports train, or val.
  • transform (callable , optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable , optional) – A function/transform that takes in the target and transforms it.
  • loader – A function to load an image given its path.

which are used toe properly load and transform the samples.

@ptrblck Let me specify the functionality.

  1. How to use torchvision.datasets.ImageNet to access the images and corresponding labels for PyTorch network training loop. I have the ILSVRC 2012 dataset downloaded. But I want a simple example resource that exhibits the correct utilization of torchvision.datasets.ImageNet effectively.

In other words, how to leverage PyTorch for training the model on ILSVRC 2012. Given that I have the dataset ILSVRC 2012 downloaded. But replying just on PyTorch code for parsing and other pre-processing.

As PyTorch provides some of the ImageNet parsing code, my intention is not to create a custom dataset class from scratch. Instead, I aim to utilize PyTorch for the same purpose.

Let’s elaborate the question further,
imagenet_data = torchvision.datasets.ImageNet(‘path/to/imagenet_root/’)
data_loader = torch.utils.data.DataLoader(imagenet_data, batch_size=4, shuffle=True, num_workers=args.nThreads)

How can we get the images and labels from data_loader? Is it the same as in the other dataset cases?

I hope this clarifies my request for the help.

This tutorial shows an example of training a model using the dataset class with CIFAR10. Shouldn’t be too different for any other dataset.

https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

1 Like

@J_Johnson Thanks for the CIFAR10 example. You’re right; it shouldn’t be too different. Since the ImageNet dataset isn’t publicly available, users need to download it before instantiating the dataset class. This differs from datasets like CIFAR10, which can be downloaded and preprocessed entirely by PyTorch.

For any additional details, please check this PyTorch example here using ImageNet:

1 Like

Considerable Resource. I am looking at this implementation. Thanks @J_Johnson!