ConcatDataset behavior on internal lists

I have two data loaders that I combine using a ConcatDataset operation. However, I am performing a clustering operation on the dataset wherein I run some clustering on every epoch and update the clustered array that each dataset belongs

My question: How does the expected behavior of the list work when using concatenate. If I use the external update as so concat_data.clusters = new cluster

I have a minimum complete example to explain the problem:

# --- minimum code to replicate the problem:
import torch.utils.data as data_utils
import torch
import random
class Data(data_utils.Dataset):
    def __init__(self,cluster_val):
        i = random.randint(0, 10)
        self.dataset = data_utils.TensorDataset(torch.arange(i*10, (i+1)*10))
        # Initialize with zero clusters
        self.clusters = np.ones(len(self.dataset)) * cluster_val
    def __len__(self): 
        return len(self.dataset)
    def __getitem__(self, idx):
        return self.dataset[idx], self.clusters[idx]

data_concat = data_utils.ConcatDataset([Data(0), Data(5)])

# -- Modify the cluster values externally. The below line does not work? 
data_concat.clusters = [random.randint(0, 20) for _ in range(0, len(data_concat))]

for i, (data, clusters) in enumerate(data_concat): 
    print(data, clusters)

I don’t get the expected behavior. Any suggestions on how to solve this problem?

ConcatDataset is intended for cases where you want to have the same thing per item as in the individual datas ets. You could either modify the individual data sets or go through data_concat.datasets (a list of datasets).

If you need to do operations across datasets, you’re better off creating a custom wrapper rather than using ConcatDataset.

Best regards

Thomas

Understood. Thanks. I will do the custom wrapper.