Imbalanced dataset

Hi, I try to use a method to improve my models performance on an imbalanced dataset with 5 classes. I found the ImbalancedDatasetSampler from torchsampler but when I try to use it on the dataloader:

train_loader = DataLoader(train_set, batch_size=32, sampler=ImbalancedDatasetSampler(train_set))

I get this attribute error:

AttributeError Traceback (most recent call last)
in ()
1 from torchsampler import ImbalancedDatasetSampler
2 #TRAIN LOADER
----> 3 train_loader = DataLoader(train_set, batch_size=32, sampler=ImbalancedDatasetSampler(train_set))
4 #VALID LOADER
5 valid_loader = DataLoader(valid_set, batch_size = 8, shuffle = True)

2 frames
/usr/local/lib/python3.7/dist-packages/torchsampler/imbalanced.py in init(self, dataset, indices, num_samples, callback_get_label)
28 # distribution of classes in the dataset
29 df = pd.DataFrame()
—> 30 df[“label”] = self._get_labels(dataset)
31 df.index = self.indices
32 df = df.sort_index()

/usr/local/lib/python3.7/dist-packages/torchsampler/imbalanced.py in _get_labels(self, dataset)
48 return dataset.samples[:][1]
49 elif isinstance(dataset, torch.utils.data.Subset):
—> 50 return dataset.dataset.imgs[:][1]
51 elif isinstance(dataset, torch.utils.data.Dataset):
52 return dataset.get_labels()

/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataset.py in getattr(self, attribute_name)
81 return function
82 else:
—> 83 raise AttributeError
84
85 @classmethod

AttributeError:

Does anybody know why I get this error?

My custom dataset code to get the images (DICOM files) and the labels is:

class MyCustomDataset(Dataset):
  def __init__(self, path = "/content/drive/MyDrive/subjects"):
    dataframe = pd.read_csv("/content/drive/MyDrive/KL_grade.csv", sep = ';')
    self.labels = {}
    id = list(dataframe["id"])
    grades = list(dataframe["grade"])
    for i,g in zip(id, grades):
      self.labels[str(i).zfill(5)] = g
    
    self.ids = [a.split("/")[-1] for a in sorted(glob.glob(f"/content/drive/MyDrive/subjects/" + "/*"))]

  def __len__(self):
    return len(self.ids)

  def __getitem__(self, idx):
    imgs = load_3d_dicom_images(self.ids[idx])
    label = self.labels[self.ids[idx]]

    return torch.tensor(imgs, dtype = torch.float32), torch.tensor(label, dtype = torch.long)

Based on the error message it seems torchsampler tries to call dataset.dataset.imgs, which apparently isn’t a valid attribute. I’m not familiar with the implementation of torchsampler and also cannot find a repository of it to check what might be failing.

Hi Peter, thanks for the response. Do you know any techniques I could use to overcome this issue ? I train some models but I have come up with overfitting. The validation loss and accuracy are saturating while training loss decreases and training accuracy gets to over 90%.
I tried to add weights to cross entropy loss and add a sampler when I define the train loader but neither of these methods work.
I am working with 435 DICOM files (160x384x384) and I have 5 classes.

I don’t know what’s causing the error, so would need to see the torchsampler implementation to figure our what might be causing it and how to fix it.
Generally, you could use the PyTorch WeightedRandomSampler as described e.g. here.