How do we pack the data as input for enumerate?

I have a cifar-10 input data of size x_train : (50000,3072), y_train : (50000,). I wanted to use the x_train, y_train and pack them into the trainloader function which uses a batch size of 100, so that when I call as follows:

for batch_idx, (inputs, targets) in enumerate(trainloader):


I wasted a lot of time in doing so, Can some one help me with this?

You could write a basic custom dataset and use that with a dataloader.

class DS(
    def __init__(this, X=None, y=None, mode="train"):
        this.mode = mode
        this.X = X #Maybe do reshaping here
        if mode == "train":
            this.y = y

    def __len__(this):
        return this.X.shape[0]

    def __getitem__(this, idx):
        if this.mode == "train":
            return torch.FloatTensor(this.X[idx]), torch.LongTensor(this.y[idx]) #or torch.FloatTensor(this.y[idx]) depending on use case
            return torch.FloatTensor(this.X[idx])

tr_data_setup = DS(X_train, y_train.reshape(-1,1)) #
trainloader =, batch_size=100, ......)

You could also expand this to perform augmentations on the image if necessary.

1 Like

Thank you that solved it.