Extracting embeddings from log probabilities

Hi, I am solving an image captioning related issue and eventually I have extract the embeddings of the tokens. One possible way is to extract the embeddings using the tokens. But I cannot do do that because in my case the whole process needs to be differentiable and the tokens are not differentiable. Instead, what I am currently doing is the following -
embeddings= self.model.encoder_decoder.model.tgt_embed[0].lut
embedded_sentence_pred = torch.matmul(seq_logps, embeddings.weight)
I am not entirely sure if I am doing it correctly. Can you provide any insight?