Can I use torchvision Dataset and Dataloader with AWS S3?

So basically I would like to use ImageFolder dataset with path data_path = s3://bucketname/image_folder and a DataLoader. Should be examples of something like that lying around, but just can’t find anything very useful…

Were you able to solve this?

Not possible with ImageFolder, but you can basically do anything you want in the getitem method, including fetching images from S3. The problem is the network speed though, so it’s really hard to make it fast enough, even with several dataloader workers/ number of preprocessing threads.

You can use alluxio and fuse to mount a s3 directory locally and access it as a local image folder at the path /mtn/fuse/bucket-name/path-in-s3.

alluxio-start.sh local NoMount

alluxio fs mount /bucket-name s3a://bucket-name --readonly

# run at every reboot
sudo umount /mnt/fuse || /bin/true
alluxio-fuse mount /mnt/fuse /
1 Like

Hi :slight_smile:

I wanted to ask if there is a feasible (in terms of speed) solution to use datasets.ImageFolder with a path to S3 directory containing all the images. For my project, I wanted to use the PyTorch wrapper from sagemaker.pytorch to create a training job to train a custom CNN.

I saw one potential solution here, but not sure if it is performant.

Would be grateful for any tips :slight_smile:

1 Like

FYI I’m now doing basically this on GCS and before that, on Azure. If the data is in the same region, typically I’m getting about 50% speed compared to if the data is on a local (fast) SSD. I’m actually caching the dataset as it’s downloaded, so only the 1st epoch will be slower.

EDIT: here https://gist.github.com/harpone/1ce4c775ff63e22bc5228c4c77b48604

1 Like

hi, I have recently read Announcing the Amazon S3 plugin for PyTorch | AWS Machine Learning Blog which is your answer. however, if you are not comfortable with it, you can assume your data as stream and use Handling streaming data

This code may help you.

from __future__ import print_function, division
from awsio.python.lib.io.s3.s3dataset import S3Dataset
from torch.utils.data import DataLoader
from PIL import Image
import torch
import io
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
classes_name = {"akiec":0 , "bcc":1 , "bkl":2 , "df":3 , "mel":4 , "nv":5,"vasc":6 } #it is classes in for your #dataset you should change it
class S3ImageSet(S3Dataset):
    def __init__(self, urls, transform=None):
        super().__init__(urls)
        self.transform = transform

    def __getitem__(self, idx):
        img_name, img = super(S3ImageSet, self).__getitem__(idx)
        # Convert bytes object to image
        img = Image.open(io.BytesIO(img)).convert('RGB')
        for k in classes_name.keys():
            if k in img_name:
                lbl = classes_name[k]
                break
        # Apply preprocessing functions on data
        if self.transform is not None:
            img = self.transform(img)
        return img , lbl

batch_size = 4

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}
urls = ['s3://ham1000/Cancer/train',"s3://ham1000/Cancer/val"]
key_name = ["train","test"]
dataloaders = {}
dataset_sizes = {}
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
for i in range(len(urls)):
    dataset = S3ImageSet(urls[i], transform=data_transforms[key_name[i]])
    dataloader = DataLoader(dataset,
            batch_size=batch_size,
            num_workers=4, shuffle = True)
    dataloaders[key_name[i]] = dataloader