Hi, I’m trying to make a custom data loader that can apply transforms on my .npy image files. I’ve set the batch size to length of the array to get all the samples in one batch. However when I’m printing shape of the torch tensor after iterating throught the train loader, I’m getting the output as torch.Size([3, 3, 64, 64])
. I actually want it to be ‘torch.Size([61578, 3, 64, 64])’
import numpy as np
import torch
from torch.utils.data import Dataset, TensorDataset
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets, transforms
from PIL import Image
import matplotlib.pyplot as plt
# Images
X_train = np.load('/content/galaxy_zoo/galaxy_zoo_train.npy')
X_train = torch.tensor(X_train)
X_train = torch.reshape(X_train, (len(X_train),3,64, 64))
class CustomTensorDataset(Dataset):
"""TensorDataset with support of transforms.
"""
def __init__(self, tensors, transform=None):
self.tensors = tensors
self.transform = transform
def __getitem__(self,index):
x = self.tensors[index]
x = self.transform(x)
return x
def __len__(self):
return self.tensors[0].size(0)
# Let's add some transforms
so2_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.RandomAffine(90, translate=(0., 0.)),
transforms.ToTensor(),
#transforms.Normalize((0.1307,), (0.3081,))] # mean and standard deviation, respectively, for normalization
])
se2_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.RandomAffine(90, translate=(0.25, 0.25)),
transforms.ToTensor(),
#transforms.Normalize((0.1307,), (0.3081,))] # mean and standard deviation, respectively, for normalization
])
vanilla_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.ToTensor(),
#transforms.Normalize((0.1307,), (0.3081,))])
])
# Dataset w/o any tranformations
train_dataset_normal = CustomTensorDataset(tensors=X_train, transform=se2_transform)
train_loader = torch.utils.data.DataLoader(train_dataset_normal, batch_size=len(X_train))
The dataset can be found here:
!wget http://bergerlab-downloads.csail.mit.edu/spatial-vae/galaxy_zoo.tar.gz
!tar -xf /content/galaxy_zoo.tar.gz