Reverse nn.Embedding

how do you go back from nn.Embedding to the original discrete values?

1 Like

In case you have the feature vector, you could calculate the difference and select the index with a 0 (or close to 0) diff.

I’m not sure if I understand what you mean. Let’s say you have some discrete valued vector that you want to represent in the latent space, how would I get feature vector? Like the feature vectors from a given model, extracted and take differences?

a = torch.nn.Embedding(10, 50)
# Discreet values
a.weight

# Indexer
a(torch.LongTensor([1,1,0]))

that still gives a real-valued vector. Want to reverse this operation. Something akin to encoder decoder .

The feature vector would be the output of the embedding layer and you could calculate the difference afterwards to get the index back:

emb = torch.nn.Embedding(10, 50)
x = torch.tensor([5])

out = emb(x)
out.shape
emb.weight.shape

rev = ((out - emb.weight).abs().sum(1) < 1e-6).nonzero()
print(rev)
# > tensor([[5]])
1 Like

@safin_salih Understood

You can use this

a = torch.nn.Embedding(10, 50)
b = torch.LongTensor([2,8])
results = a(b)

def get_embedding_index(x):
    results = torch.where(torch.sum((a.weight==x), axis=1))
    if len(results[0])==len(x):
        return None
    else:
        return results[0][0]

indices = torch.Tensor(list(map(get_embedding_index, results)))
indices
tensor([2., 8.])
1 Like

UPDATE: I did a revamp on this to make it more memory efficient via .expand(), and consequently faster.

Just a slightly more robust definition that handles cases of x.dim>0:

def emb2indices(output, emb_layer):
    # output is size: [batch, sequence, emb_length], emb_layer is size: [num_tokens, emb_length]
    emb_weights = emb_layer.weight

    # get indices from embeddings:
    emb_size = output.size(0), output.size(1), -1, -1
    out_size = -1, -1, emb_weights.size(0), -1
    out_indices = torch.argmin(torch.abs(output.unsqueeze(2).expand(out_size) -
                                    emb_weights.unsqueeze(0).unsqueeze(0).expand(emb_size)).sum(dim=3), dim=2)
    return out_indices

For a test:

emb=nn.Embedding(300000, 100)
test_indices=torch.randint(300000, (1, 100)) #batch 1, length 100
print(test_indices)
test_embed = emb(test_indices)
print(test_embed.size())

indices=emb2indices(test_embed, emb)
print(indices)

It’s more memory efficient than using stack. However, it can be modified to use less memory if iterating over the embeddings one by one.

1 Like