Change labels of Subset

Hi,

I need help to find a way to change the labels of my data (ex: change all 5 into 7).

I am using the MNIST dataset, and split it in several subsets that I save in files and load later. For my experiments I cannot directly use the Subsets, I have to save them in a file and load them after. And I cannot change the label before saving the data.

This is how I load my data:

class SubsetLoader(Dataset):
    def __init__(self, filename):
        self.data = torch.load(filename)

    def __getitem__(self, index):
        return self.data[index]
    
    def __len__(self):
        return len(self.data)

I tried this among other solutions: dataset.targets[dataset.targets == 5] = 7 after loading the data.
But subset has no attribute ‘targets’ and ‘SubsetLoader’ object has no attribute ‘targets’.

Is there a way to change the labels of my subsets then ?

Please give me a code snippet that how you saved those subsets

This is how I save the data:

mnist_dataset = datasets.MNIST(root='../data', train=True, download=True, transform=transform)

def random_split_strat(clients):
    nb_clients = clients
    num_classes = 10 

    r = []

    a = [ran.random() for _ in range(nb_clients)]
    s = sum(a)
    a = [ i/s for i in a ]

    for i in range(len(a)):
        r.append(a[i])

    mnist_subsets = random_split(mnist_dataset, r, generator=torch.Generator().manual_seed(42))

    for i, subset in enumerate(mnist_subsets):
        filename = f'mnist_subset_{i}.pt'
        torch.save(subset, path+filename)    

Have you tried the above before saving the subsets?

mnist_dataset.targets[mnist_dataset.targets == 5] = 7

Unfortunately I have to save the data without modifying the labels, and then change the labels when I load them

Here’s a sample code

f = torch.load("mnist_subset_0.pt")
f.dataset.targets[f.dataset.targets == 5] = 7

should work

1 Like

This works ! Thank you :slight_smile:

1 Like