How to get the correct shape of the tensor in custom dataset

Hello,

I am using custom Dataset class, but the problem is that when I get the data from the Dataloader I am left with array that has different tensor shape than I want.

shape that I get: torch.Size([1, 56, 128, 128])
shape that I want: torch.Size([1, 56, 1, 128, 128])

my approach was to:

  1. to apply numpy.expand_dims on the array and get torch.Size([1, 1, 56, 128, 128])

  2. then to np.transpose on the array to get the shape I want torch.Size([1, 56, 1, 128, 128])

after first step I am getting the error:

raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))
ValueError: pic should be 2/3 dimensional. Got 4 dimensions.

if I do the transpose first, none of the combination of np.transpose(array, axes=(1,2,0)) yields the shape torch.Size([56, 1, 128, 128])

if I convert array to Tensor first and then do torch.unsqueeze I get the error:


raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
TypeError: pic should be PIL Image or ndarray. Got <class 'torch.Tensor'>

here is my code:

class patientdataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.annotations = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        img_path = os.path.join(self.root_dir, self.annotations.iloc[index,0])
        # np_load_old = np.load
        # np.load = lambda *a, **k: np_load_old(*a, allow_pickle=True, **k)

        image= np.asarray(np.load(img_path))


        image= np.transpose(image, axes=(1,2,0))
        image = torch.Tensor (image)

        image = torch.unsqueeze(image, dim=1)



        y_label = torch.tensor(np.asarray(self.annotations.iloc[index,1]))

        if self.transform:
            image = self.transform(image)

        return (image, y_label)

your code should be fine, can you provide the call stack for the error?