I have created a simple example of mnist dataset for better understanding.
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt
class MyDataset(torch.utils.data.Dataset):
def __init__(self,dataset = None, transform= None):
self.MNIST = dataset
print(self.MNIST)
self.transform = transform
def __getitem__(self, index):
data, target = self.MNIST[index]
if self.transform is not None:
#data_new = TF.to_pil_image(data)
#data_new = self.transform(data_new)
#data_new = TF.to_tensor(data_new)
data2 = TF.to_pil_image(data)
data2 = self.transform(data2)
data2 = TF.to_tensor(data2)
return data, data2, target, index
def __len__(self):
return len(self.MNIST)
here is the dataset code
train_dataset = datasets.MNIST(root=’./data’, train=True, download=True,
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))]))
Applying simple horizontal flip on the dataset
transform_train = transforms.Compose([ transforms.RandomHorizontalFlip(p=1)])
now pass the the transform_train into MyDataset class.
trainset = MyDataset(dataset=train_dataset, transform= transform_train )
trainloader = torch.utils.data.DataLoader(trainset, batch_size=24, shuffle=True, num_workers=2)
data, tr_data, target, index = iter(trainloader).next()
Image plot
def imshow(inp, title=None):
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.5])
std = np.array([0.5])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
###Make a grid from batch
out = torchvision.utils.make_grid(data)
tr_out = torchvision.utils.make_grid(tr_data)
imshow(out, title=[target[x] for x in target])
imshow(tr_out, title=[target[x] for x in target])
Both plot’s colors are not matching. Am I doing something wrong? Please check it.