No you can’t. But I don’t think you need to since the number of workers primarily is for when you want to fetch data from disk or perform data augmentation (or other pre-processing).
The TensorDataset is just a dataset made from tensors, without the whole torch.data.Dataset code. I made an example that you might find useful. Note that this example probably doesn’t play well for datasets that aren’t evenly divided with the batch size, like a dataset of 14, batch size of 4 -> 4+4+4=12 and then there are 2 left which isn’t enough for a full batch. Easily fixable with some if statement
Also note that I hardcoded the number of labels to 1 on two occasions. And the more_itertools need to be pip installed
import numpy as np
import torch
from torch.utils.data import TensorDataset
import random
import more_itertools
def load_data():
# Fake data. You can also load your images and convert them into tensors.
number_images = 100
images = torch.randn(number_images, 3, 2, 2)
labels = torch.ones(number_images, 1)
return TensorDataset(images, labels)
def get_batch(dataset, batch_idx):
''' Returns the data items given batch indexes '''
# Set up the datastructures
im_size = dataset[0][0].size()
batch_size = len(batch_idx)
batch_data = torch.empty((batch_size, *im_size))
batch_labels = torch.empty((batch_size, 1))
# Add data to datastructures
for i, data_idx in enumerate(batch_idx):
data, label = dataset[data_idx]
batch_data[i] = data
batch_labels[i] = label
return batch_data, batch_labels
dataset = load_data()
data_length = len(dataset)
batch_size = 10
n_epochs = 10
for epoch in range(n_epochs):
# Create indexes, shuffles them and split them into batches
indexes = list(range(data_length))
random.shuffle(indexes)
indexes = more_itertools.chunked(indexes, batch_size)
for batch_idx in indexes:
images, labels = get_batch(dataset, batch_idx)
# You can now work with your data