IndexError: too many indices for tensor of dimension 3. Unable to apply a custom transform in transforms.Compose

This is my custom transform function that I want to apply over the whole dataset.

class normalization(object):
    def __call__(self, sample):
          image, label = sample['image'], sample['labels']
          fmin= torch.min (image)
          fm = image- fmin
          image = 255*fm/torch.max(fm)  
          return {'Normalized image': image, 'Labels':label
                  }

I am adding the custom transform in transforms.Compose using code below.

batch_size= 100
size= 299, 299
data_transforms = transforms.Compose([transforms.Resize(size),
                                       transforms.ToTensor(),                                       
                                       transforms.Normalize([0.485, 0.456, 0.406],
                                                            [0.229, 0.224, 0.225]),
                                       transforms.Grayscale(num_output_channels=1),
                                       normalization()
])

Data loading is done using:


data_set = torchvision.datasets.ImageFolder(root= data_dir, transform=data_transforms)

train_set_size = int(len(data_set) * 0.6)
test_set_size = len(data_set) - train_set_size

train_set, test_set = torch.utils.data.random_split(
    data_set, [train_set_size, test_set_size], 
    generator=torch.Generator().manual_seed(72))

data_loader = DataLoader(data_set, batch_size=batch_size,shuffle=False)

train_data_loader = DataLoader(train_set, batch_size=batch_size,shuffle=True)

test_data_loader = DataLoader(test_set, batch_size=batch_size,shuffle=False)

The below code block is causing the error:

    
    image, _ = data_set[i]
    print('types:', type(image))
    print(i, image.size())
    ax = plt.subplot(1, 4, i+1 )
    plt.tight_layout()
    ax.set_title('Sample #{}'.format(i))

    if i == 3:
        plt.show()
        break

Error:
pytorch 1

I don’t understand what’s causing an issue in my custom transform function. Also, the ‘image’ should be tensor type but currently it’s type is ‘dict’.

Any help is appreciated!