Help with Efficient Negative Sampling

Im experimenting with negative sampling for a recommendation (ranking) task. Im hoping for some advice on how to most efficiently implement this is Pytorch.

Is the best place to create negative samples in the data loader?

My current attempt looks like below.

array : is the raw data which are interactions - user and item ids that have been observed. There are around 5 million rows. There are around 600,000 users and 1500 items.

sparse_user_item : This is a sparse matrix of the interaction data and I am using it to look up a user and get the indices for the items they have interacted with.

max_user: This is just the highest index of users which i need to add to the item ids based on how the model is actually set up (users and items are both included in the same embedding matrix).

Is there anything glaring that could be improved in how it is set up or high level, how the sampling is done?

class BPR_NegSampleDataset(Dataset):
    def __init__(self, array, min_item_id, max_item_id, check_neg = False, sparse_user_item = None, max_user = None ):
        self.x = torch.LongTensor(array) #  user and item
        self.check_neg = check_neg
        self.min_item_id = min_item_id
        self.max_item_id = max_item_id
        self.array_items = np.arange(min_item_id, max_item_id)
        self.sparse_user_item = sparse_user_item
        self.max_user = max_user
    def __len__(self):
        return len(self.x)  # number interactions
    def negsamp_vectorized_bsearch_preverif(self, pos_inds, n_samp=1):
        """ Pre-verified with binary search
            `pos_inds` is assumed to be ordered
        raw_samp = np.random.randint(self.min_item_id, self.max_item_id - len(pos_inds), size=n_samp)
        pos_inds_adj = pos_inds - np.arange(len(pos_inds))
        neg_inds = raw_samp + np.searchsorted(pos_inds_adj, raw_samp, side='right')
        return neg_inds

    def __getitem__(self, idx):
        pos_interation = self.x[idx] # userID and ItemID that was purchased
        if self.check_neg:
            pos_indxs = self.sparse_user_item[pos_interation[0]].nonzero()[1] + self.max_user +1
            neg_item = np.random.choice(np.setdiff1d(self.array_items, pos_indxs))
            neg_item =  torch.randint(self.min_item_id, self.max_item_id, size = (1,)) 
        neg_interaction = torch.hstack([pos_interation[0],torch.LongTensor(np.array([neg_item]))])

        return pos_interation, neg_interaction

Its very slow to confirm negatives are not actually positives for a user:

but even without checking its slow: