Shuffling Concatenated Datasets

Hello.

I am training a recommender model using DDP and currently have two datasets, one that returns a positive interaction (customer id and item id) and the other returns a negative example ( customer id and a random item id from a list of candidates for that user).

The data looks like

image

The first dataset takes in the user and item IDs and outputs these with a label of 1

class IntDataset(Dataset):
    def __init__(self, user_item_inters):
        self.inters = torch.LongTensor(user_item_inters) #  user and item
        
    def __len__(self):
        return len(self.inters)  # number interactions

    def __getitem__(self, idx):
        inter = self.inters[idx]

        return inter[0], inter[1], torch.DoubleTensor(np.array([1.0])).squeeze(-1)  

The second one takes in the same customers as the first one but also the list of candidate items. It outputs a customer id, a chosen item and the label (0)

class NegDataset(Dataset):
    def __init__(self, custs, negs, len_negs):
        
        self.custs = torch.tensor(custs) #  user and negs
        self.negs = torch.tensor(negs)
        self.len_negs = len_negs
        
    def __len__(self):
        return len(self.custs)  # number users
    

    def __getitem__(self, idx):
        
        user = self.custs[idx] # userID
        neg = self.negs[idx]
        
        
        neg_item = neg[np.random.randint(low =0, high = self.len_negs, size =1)]

        return user, torch.tensor(neg_item).squeeze(-1),  torch.tensor(np.array([0.0])).squeeze(-1)

Then I concatenate them:

neg_dataset = NegDataset(custs = pdf_ints['cust'].values, negs = pdf_ints['wght_new_smpl'], len_negs = args.len_hard_samples)
  
int_dataset = IntDataset(user_item_inters = pdf_ints[['cust', 'item']].values)
    
dataset = torch.utils.data.ConcatDataset([int_dataset,neg_dataset])

If I load all the data in memory and shuffle the concat dataset, this seems to work. However, when i cant load into memory, I am not sure how to proceed. I can use two iterable datasets, reading from disk but these can’t be shuffled AFAIK, the concat dataset will be stacked with all the positives on top.