I have a bit of conundrum with data movement and the assumptions inside DataLoader and Dataset that I am not sure how to solve.
My data is located in storage cluster located at X. My GPUs are in a cluster at Y. There is little to no storage at Y. For example, my training dataset is 100 TB, and I have 1 TB available at Y. I would have to incrementally transfer my data from X to Y (lets say via a wget) to train my model.
How I understand the current DataLoader and Dataset implementation, I would have to know the total number of training data examples to implement Dataset.__len__, not be able to set shuffle=True, and figure how to tell DataLoader to download a new file when it is done with the current file.
Most things I found about transferring data has been about transferring the data from RAM to VRAM. What would be the best way for me to implement the DataSet and DataLoader such that I pass it a list of URIs (https://foo/bar0.pt, https://foo/bar1.pt, etc.) and it moves through the entire list one file at a time, shuffling only within a given file, and download a new file when the current one is exhausted?
Couldn’t you save your URLs to a CSV file, and then just load and index those entries in a custom Dataset. And then call something like the following in the __getitem__:
import torch
from torch.utils.data import Dataset
import pandas as pd
from PIL import Image
import requests
from io import BytesIO
class URLImageDataset(Dataset):
"""
Custom Dataset for loading images from URLs listed in a CSV file.
CSV format assumptions:
- Must have a column named 'url' (or adjust the column name below).
- Optionally, a column named 'label' for classification tasks (int or str).
If no 'label' column exists, the dataset returns only the image (e.g., for unsupervised tasks).
"""
def __init__(self, csv_file: str, transform_images=None):
"""
Args:
csv_file (str): Path to the CSV file.
transform_images: A torchvision transforms pipeline (e.g., transforms.Compose([...])).
Applied to the PIL Image after loading.
"""
self.df = pd.read_csv(csv_file)
self.transform_images = transform_images
# Verify required column exists
if 'url' not in self.df.columns:
raise ValueError("CSV must contain a 'url' column with image URLs.")
self.has_labels = 'label' in self.df.columns
def __len__(self) -> int:
return len(self.df)
def __getitem__(self, idx: int):
if torch.is_tensor(idx):
idx = idx.tolist()
row = self.df.iloc[idx]
url = row['url']
# Download and load the image
try:
response = requests.get(url, timeout=10) # Timeout to avoid hanging
response.raise_for_status() # Raise if HTTP error
img = Image.open(BytesIO(response.content)).convert('RGB')
except Exception as e:
raise RuntimeError(f"Failed to load image from {url}: {str(e)}")
# Apply transforms if provided
if self.transform_images:
img = self.transform_images(img)
# Return label if present, else just the image
if self.has_labels:
label = row['label']
# If labels are strings, you might want to map them to ints here
return img, label
return img