In Pytorch
I want to use Data Augmentation to choose one class label.
Somebody tell me how to, please.
For example,
0 label class <- Use Data Augmentation
1 label class <- Not using
You could add a condition in your Dataset
's __getitem__
:
class MyDataset(Dataset):
def __init__(self, data, target, transform=None):
...
self.transform = transform
def __getitem__(self, index):
x = self.data[index]
y = self.target[index]
if y == 0:
x = self.transform(x)
return x, y
def __len__(self):
return len(self.data)
You could also pass different transformations and apply them based on the current target label.
2 Likes
Dear ptrblck
What beautiful code this is !
I see a lot.
Thanks for your help!
1 Like