Efficient way of selectively replacing vectors from a tensor in pytorch

Given a batch of text sequences, the same is converted into a tensor with each word represented using word embeddings or vectors (of 300 dimensions). I need to selectively replace vectors for certain specific words with a new set of embeddings. Further, this replacement will occur only for not all occurrences of the specific word, but only randomly. Currently, I have the following code to achieve this. It traverses through every word using 2 for loops, check if the word is in a specified list, splIndices . Then it checks if the word needs to be replaced or not, based on T or F value in selected_ .

But could this be done in a more efficient manner?

The below code may not be an MWE, but I have tried to simplify the code by removing the specifics, so as to focus on the problem. Please ignore the semantics or purpose of the code as it may not have been appropriately represented in this snippet. The question is about improving performance.

splIndices = [45, 62, 2983, 456, 762] # vocabulary indices which needs to be replaced
splFreqs = 2000 # assuming the words in splIndices occurs 2000 times
selected_ = Torch.Tensor(2000).uniform_(0,1)>(0.2) # Tensor with 20% of the entries True
replIndexCtr = 0 # counter for selected_
diffVector = {45: Torch.Tensor(300).uniform_(0,1), ...... 762: Torch.Tensor(300).uniform_(0,1) }
 # Dictionary with vectors to be replaced. This is a dummy function. Original function depends on some property of the word

embeding = nn.Embedding.from_pretrained(embedding_matrix, freeze=False)
tempVals = x  # shape [32, 41] - batch of 32 sequences with 41 words each
x = embeding(x) # [32, 41, 300] - the sequence now has replaced vocab indices with embeddings

for i,item in enumerate(x):        # iterate through batch for sequences    
    for j,stuff in enumerate(item):   # iterate sequences for words          
        if tempVals[i][j].item() in splIndices: 
                        if self.selected_[self.replIndexCtr] == True:                   
                            x[i,j] = diffVector[tempVals[i][j].item()]
                            replIndexCtr += 1