I have a sequence of images, so with batch size 1, my getitem method output would be of size [1,9,3,256,256] with sequence length 9.
The method currently looks like this:
def __getitem__(self, idx):
seq = self.data[idx]
ims = np.empty((self.seq_length, 3, 256, 256), dtype=np.uint8)
labels = np.empty((self.seq_length), dtype=int)
for i, case in enumerate(seq.cases):
ims[i] = imread(case.path)
labels[i] = case.cls
# transform to tensors
ims = torch.from_numpy(ims)
labels = torch.from_numpy(labels)
if self.transform:
ims = self.transform(ims)
return ims, labels
Since my data is highly imbalanced, I’d like to perform transformations only on the minority class (labels==1). I know that I have to check for labels[i]==1 but the problem is that in one sequence, the labels can be 0 and 1. I am not 100% sure how torchvision transformations work. Will they change the length of my sequence or can I do it like this:
for i in range(len(ims)):
if (labels[i] == 1) and self.transform: # check for minority class
ims[i] = self.transform(ims[i])