Dataset loader not loading entire dataset into one batch

Hi, I’m trying to make a custom data loader that can apply transforms on my .npy image files. I’ve set the batch size to length of the array to get all the samples in one batch. However when I’m printing shape of the torch tensor after iterating throught the train loader, I’m getting the output as torch.Size([3, 3, 64, 64]) . I actually want it to be ‘torch.Size([61578, 3, 64, 64])’

import numpy as np
import torch
from import Dataset, TensorDataset

import torchvision
import torchvision.transforms as transforms

from torchvision import datasets, transforms
from PIL import Image

import matplotlib.pyplot as plt

# Images
X_train = np.load('/content/galaxy_zoo/galaxy_zoo_train.npy')

X_train = torch.tensor(X_train)
X_train = torch.reshape(X_train, (len(X_train),3,64, 64))

class CustomTensorDataset(Dataset):
    """TensorDataset with support of transforms.
    def __init__(self, tensors, transform=None):
        self.tensors = tensors
        self.transform = transform

    def __getitem__(self,index):
        x = self.tensors[index]

        x = self.transform(x)
        return x

    def __len__(self):
        return self.tensors[0].size(0)

# Let's add some transforms

so2_transform = transforms.Compose([
    transforms.RandomAffine(90, translate=(0., 0.)),
    #transforms.Normalize((0.1307,), (0.3081,))]  # mean and standard deviation, respectively, for normalization

se2_transform = transforms.Compose([
    transforms.RandomAffine(90, translate=(0.25, 0.25)),
    #transforms.Normalize((0.1307,), (0.3081,))]  # mean and standard deviation, respectively, for normalization
vanilla_transform = transforms.Compose([
    #transforms.Normalize((0.1307,), (0.3081,))])

# Dataset w/o any tranformations
train_dataset_normal = CustomTensorDataset(tensors=X_train, transform=se2_transform)
train_loader =, batch_size=len(X_train))

The dataset can be found here:

!tar -xf /content/galaxy_zoo.tar.gz

For your __len__ method, should it be:
return self.tensors.size(0) # 61578
instead of:
return self.tensors[0].size(0) # 3