Usage of make grid

def save_attn_map(maps, imgs, path):
        # append images with maps
        img = torch.cat(imgs, maps), 0)
        # making a grid of two columns, images in one and attention in the other
        grid = make_grid(img,nrow=maps.size(0),padding=10)

        npimg = grid.detach().cpu().numpy() # to numpy array

        npimg = (npimg * 255).astype(np.uint8)

        fig, ax = plt.subplots(figsize = (8,2))
        ax.axis("off")
        # transpose numpy array to the PIL format, i.e., Channels x W x H
        ax.imshow(np.transpose(npimg, (1,2,0))) 

        fig.savefig("{}.pdf".format(path),bbox_inches='tight')
        plt.close(fig)

I have the function above, where tensors maps and images have the same size (Batch x Channel x W x H). But when I plot the array npimg, the image I get is totally messed up. What could it be? I tried make a grid with imgs just in case, but I get the same weird thing:

I guess the transformation to uint8 is messing up the visualizations.
What kind of values and shapes do your imgs and maps tensors have?

Originally imgs is [0 … 255] and maps is [0 … 1]. When I load imgs to a tensor, they become [0 … 1]. I am loading cifar-10 images, so for example, I have the batch [24, 3, 32, 32] for imgs and another for maps. I tried to do i = img.permute(0,2,3,1) and plot lets say i[0], but it doesn’t seems to be right.

class CifarDataset(Dataset):
    """
    :root_indexes:  path to indexes compressed file
    :root_data:     path to data compressed file
    :transforms_:   list of functions from torchvision.transforms
    """

    def __init__(self, root_indexes: str, root_data: str, transforms_: list = []):
        self.transform = transforms.Compose(transforms_)

        r_data = np.load(root_data)

        self.indexes = np.load(root_indexes)['all_indexes']
        self.labels = r_data['all_labels']

        # the line bellow solved my problem
        self.data = np.transpose(r_data['all_data'].reshape(-1,3,32,32), (0,2,3,1))

        self.data = torch.stack([self.transform(im) for im in self.data])

    def __getitem__(self, index) -> tuple:
        idx = self.indexes[index]
        return self.data[idx], self.labels[idx]

    def __len__(self) -> int:
        return len(self.indexes)

It seems the problem was a bit far away from the the grid as I thought so.