This is my custom transform function that I want to apply over the whole dataset.
class normalization(object):
def __call__(self, sample):
image, label = sample['image'], sample['labels']
fmin= torch.min (image)
fm = image- fmin
image = 255*fm/torch.max(fm)
return {'Normalized image': image, 'Labels':label
}
I am adding the custom transform in transforms.Compose using code below.
batch_size= 100
size= 299, 299
data_transforms = transforms.Compose([transforms.Resize(size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
transforms.Grayscale(num_output_channels=1),
normalization()
])
Data loading is done using:
data_set = torchvision.datasets.ImageFolder(root= data_dir, transform=data_transforms)
train_set_size = int(len(data_set) * 0.6)
test_set_size = len(data_set) - train_set_size
train_set, test_set = torch.utils.data.random_split(
data_set, [train_set_size, test_set_size],
generator=torch.Generator().manual_seed(72))
data_loader = DataLoader(data_set, batch_size=batch_size,shuffle=False)
train_data_loader = DataLoader(train_set, batch_size=batch_size,shuffle=True)
test_data_loader = DataLoader(test_set, batch_size=batch_size,shuffle=False)
The below code block is causing the error:
image, _ = data_set[i]
print('types:', type(image))
print(i, image.size())
ax = plt.subplot(1, 4, i+1 )
plt.tight_layout()
ax.set_title('Sample #{}'.format(i))
if i == 3:
plt.show()
break
Error:
I don’t understand what’s causing an issue in my custom transform function. Also, the ‘image’ should be tensor type but currently it’s type is ‘dict’.
Any help is appreciated!