Hi. The indexes I forgot to reset were in my df.
I was splitting my data like this:
if folds == 0:
# Train and validation
train, val = train_test_split(train_val, test_size=0.2, random_state=seed, shuffle=True)
train, val = train.reset_index(drop=True), val.reset_index(drop=True)
And then I created this custom dataloader to retrieve the index.
# Custom pytorch dataloader for this dataset
class Derm(Dataset):
"""
Read a pandas dataframe with
images paths and labels
"""
def __init__(self, df, transform=None):
self.df = df
self.transform = transform
def __len__(self):
return len(self.df)
def __getitem__(self, index):
# Load image data and get label
try:
X = Image.open(self.df['filenames'][index]).convert('RGB')
#y = torch.tensor(self.df.iloc[index,2:])
y = torch.tensor(self.df['label_code'][index])
except IOError as err:
pass
if self.transform:
X = self.transform(X)
# Sanity check
print('id:', self.df['id'][index], 'label', y)
return index, X, y
And then on the train/eval loop:
for index, inputs, labels in dataloader:
...