Create multiple dataloader with transform using for loop

Hi guys. I’m trying to create multiple dataloaders using a for loop, and each of them uses a different transform. However, I notice that they always pick the transform of the last iteration.

train_loader = np.empty((2), dtype=np.object)
for i in range(2):
    train_loader[i]= torch.utils.data.DataLoader(
      torchvision.datasets.MNIST('./mnist/train/', train=True, download=True,
                                 transform=torchvision.transforms.Compose([
                                    torchvision.transforms.Lambda(lambda x:
                                        torchvision.transforms.functional.affine(
                                            img = x,
                                            angle = 90* i,
                                            translate = (0,0),
                                            scale = 1,
                                            shear = 0
                                        )
                                    ),
                                    torchvision.transforms.ToTensor(),
                                 ])
                                ),
      batch_size=1, shuffle=True)

Here I want train_loader[0] has 0 rotation and train_loader[1] has 90 degree rotation.
But if you draw pictures from these two dataloaders

X, y = next(iter(train_loader[0]))
plt.imshow(X[0].squeeze(), vmin=0, vmax=1)
plt.show()
X, y = next(iter(train_loader[1]))
plt.imshow(X[0].squeeze(), vmin=0, vmax=1)
plt.show()

they both have 90 degree rotation. Am I doing something wrong?

This issue is created by the wrong usage of i inside the lambda definition, you would need to bind i for each created function.
This should solve the issue:

train_loader = np.empty((2), dtype=np.object)
for i in range(2):
    train_loader[i] = torch.utils.data.DataLoader(
      torchvision.datasets.MNIST('PATH', train=True, download=False,
                                 transform=torchvision.transforms.Compose([
                                    torchvision.transforms.Lambda(lambda x, i=i:
                                        torchvision.transforms.functional.affine(
                                            img = x,
                                            angle = 90* i,
                                            translate = (0,0),
                                            scale = 1,
                                            shear = 0
                                        )
                                    ),
                                    torchvision.transforms.ToTensor(),
                                 ])
                                ),
      batch_size=1, shuffle=False)
1 Like