I have a sequence of images, so with batch size 1, my getitem method output would be a sequence of images and a sequence of labels of size [9] with sequence length 9. I am trying to use the WeightedRandomSampler. However, I am running into an error in the getitem method, which I didn’t when not using the WeightedRandomSampler.
I first used this approach:
class_sample_count = [6935, 805]
weights = 1 / torch.Tensor(class_sample_count)
# load labels once to later get weights
targets = []
for _, labels in train_set:
targets.append(np.array([weights[t] for t in labels]))
# transform to numpy array
targets = np.asarray(targets)
samples_weight = torch.from_numpy(targets)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight ))
samples_weigth had shape torch.Size([860, 9]) and I got the following error:
list indices must be integers or slices, not list
I then did the same but first flattened the 2D list to a 1D list before transforming it into a numpary array:
class_sample_count = [6935, 805]
weights = 1 / torch.Tensor(class_sample_count)
# load labels once to later get weights
targets = []
for _, labels in train_set:
targets.append(np.array([weights[t] for t in labels]))
# flatten list (2D to 1D)
targets = list(chain.from_iterable(targets))
# transform to numpy array
targets = np.asarray(targets)
samples_weight = torch.from_numpy(targets)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
but then got the following error:
list index out of range
In my getitem method, I am getting a sequence of images with seq = self.data[idx].
How can I handle sequences with the WeightedRandomSampler?