Hi, I try to use a method to improve my models performance on an imbalanced dataset with 5 classes. I found the ImbalancedDatasetSampler
from torchsampler
but when I try to use it on the dataloader:
train_loader = DataLoader(train_set, batch_size=32, sampler=ImbalancedDatasetSampler(train_set))
I get this attribute error:
AttributeError Traceback (most recent call last)
in ()
1 from torchsampler import ImbalancedDatasetSampler
2 #TRAIN LOADER
----> 3 train_loader = DataLoader(train_set, batch_size=32, sampler=ImbalancedDatasetSampler(train_set))
4 #VALID LOADER
5 valid_loader = DataLoader(valid_set, batch_size = 8, shuffle = True)
2 frames
/usr/local/lib/python3.7/dist-packages/torchsampler/imbalanced.py in init(self, dataset, indices, num_samples, callback_get_label)
28 # distribution of classes in the dataset
29 df = pd.DataFrame()
—> 30 df[“label”] = self._get_labels(dataset)
31 df.index = self.indices
32 df = df.sort_index()
/usr/local/lib/python3.7/dist-packages/torchsampler/imbalanced.py in _get_labels(self, dataset)
48 return dataset.samples[:][1]
49 elif isinstance(dataset, torch.utils.data.Subset):
—> 50 return dataset.dataset.imgs[:][1]
51 elif isinstance(dataset, torch.utils.data.Dataset):
52 return dataset.get_labels()
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataset.py in getattr(self, attribute_name)
81 return function
82 else:
—> 83 raise AttributeError
84
85 @classmethod
AttributeError:
Does anybody know why I get this error?
My custom dataset code to get the images (DICOM files) and the labels is:
class MyCustomDataset(Dataset):
def __init__(self, path = "/content/drive/MyDrive/subjects"):
dataframe = pd.read_csv("/content/drive/MyDrive/KL_grade.csv", sep = ';')
self.labels = {}
id = list(dataframe["id"])
grades = list(dataframe["grade"])
for i,g in zip(id, grades):
self.labels[str(i).zfill(5)] = g
self.ids = [a.split("/")[-1] for a in sorted(glob.glob(f"/content/drive/MyDrive/subjects/" + "/*"))]
def __len__(self):
return len(self.ids)
def __getitem__(self, idx):
imgs = load_3d_dicom_images(self.ids[idx])
label = self.labels[self.ids[idx]]
return torch.tensor(imgs, dtype = torch.float32), torch.tensor(label, dtype = torch.long)