[Solved]How do I display a grayscale image?

I have a problem with the normalization of the grayscale image (CT).
My code is…

class trainDataset(torch.utils.data.Dataset):
    def __init__(self, data, target, transform=None):
        self.data = data
        self.target = target
        self.transform = transform
        
    def __getitem__(self, index):
        x = self.data[index]
        y = self.target[index]
        
        if self.transform:
            x = self.transform(x)
            y = self.transform(y)
            
        return x, y
    
    def __len__(self):
        return len(self.data)

traindataset = trainDataset(numpy_data, numpy_target,
                       transform = transforms.Compose([transforms.ToPILImage(mode=None),
                                                       transforms.Grayscale(num_output_channels=1),
                                                       transforms.Resize(256),
                                                       transforms.ToTensor(),
                                                       transforms.Normalize((0.1307,), (0.3081,))
                                                       ]))

I encountered a problem when I used ‘imshow’.

def imshow(inp):
    inp = inp.numpy()[0]
    mean = 0.1307
    std = 0.3081
    inp = ((mean * inp) + std)
    plt.imshow(inp, cmap='gray')

imshow(traindataset[80][0])

No matter what value I put in mean and std, I get the left side of the following picture
But what I expected was the right side of the following picture.
untitled

Which part of the code should I modify?
And how can I display a target(=masked ct)?

Should’t the quoted line above be this way:

inp = inp*std + mean

Thanks for the reply.:blush:
But, the left image still appears.:sob:

I’ve solved the problem …
The problem was in ‘transforms.Grayscale (num_output_channels = 1)’
I deleted ‘transforms.Grayscale (num_output_channels = 1)’ and added ‘np.float32’ as shown below.

class trainDataset(torch.utils.data.Dataset):
    def __init__(self, data, target, transform=None):
        self.data = data.astype(np.float32)
        self.target = target.astype(np.float32)
        self.transform = transform
        
    def __getitem__(self, index):
        x = self.data[index]
        y = self.target[index]
        
        if self.transform:
            x = self.transform(x)
            y = self.transform(y)
            
        return x,y
    
    def __len__(self):
        return len(self.data)