Hi there, I noticed an imbalance between my classes and wanted to battle this. I decided to use WeightedRandomSampler instead of random shuffle in my data loader.
Here is an example of the code I changed. Maybe the Subset
class does not have __len__
properly implemented by default?
Before:
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True,num_workers=0)
#train_dataset is a subset of the whole dataset
After:
with open("class_weights.txt", "r") as f:
weights = [float(i) for i in f.readlines()]
sampler = WeightedRandomSampler(weights=weights, num_samples=5856, replacement=True)
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=8, sampler=sampler,num_workers=0)
#train_dataset is a subset of the whole dataset
After this change I get the following error.
Traceback (most recent call last):
File "C:/Users/Matej/PycharmProjects/pneumonia_detection/normal_neural_network.py", line 153, in <module>
train_net(net, trainloader, valloader, device, 30)
File "C:/Users/Matej/PycharmProjects/pneumonia_detection/normal_neural_network.py", line 93, in train_net
for i, data in enumerate(trainloader, 0):
File "C:\Users\Matej\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\utils\data\dataloader.py", line 345, in __next__
data = self._next_data()
File "C:\Users\Matej\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\utils\data\dataloader.py", line 385, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "C:\Users\Matej\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\utils\data\_utils\fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "C:\Users\Matej\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\utils\data\_utils\fetch.py", line 44, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "C:\Users\Matej\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\utils\data\dataset.py", line 257, in __getitem__
return self.dataset[self.indices[idx]]
IndexError: list index out of range
Thanks in advance for your answers and time.