How to sample minibatches without using for loop?

Hello, everyone! How can I define a function that takes a dataset and an index list as input and returns a tensor?

For example, given a dataset defined as below

import torch
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
dataroot = "datasets/celebA/"
image_size = 64
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))

I want to define a function

def minibatch(dataset,indices):
    """

    :param dataset:
    :param indices: an index list
    :return: a N*C*H*W tensor that contains feature vectors corresponding to samples with these indices
    """

You can order dataloader indices as you want and use typicall pytorch dataloader. You will sequentially load batches with the mixtures that you want.

I find out this is actually very easy :sweat_smile:


def batch(batch_size=batch_size, lower=0, upper=n_sample):
    indices = np.random.randint(low=lower, high=upper, size=batch_size)
    batch_data = torch.FloatTensor(batch_size, nc, image_size, image_size)
    for i in range(batch_size):
        sample, target = dataset[indices[i]]
        batch_data[i] = sample
    return batch_data

Once you load the dataset, create a dataloader, and since the dataloader and then you can do something like:

a = iter(dataloader)
b = next(a)

where b is a minibatch.

An example:

    dataset = VQA_dataset(ROOT_DIR, train = False)


    loader = torch.utils.data.DataLoader(
                    dataset,
                    batch_size= 3,
                    shuffle=True,
                            )
    data = next(iter(loader))
    v, bb, spat, obs, a, q,  q_len, item = data


    torch.save(data, DUMMY_DATA)

1 Like