how do you go back from nn.Embedding to the original discrete values?
In case you have the feature vector, you could calculate the difference and select the index with a 0 (or close to 0) diff.
I’m not sure if I understand what you mean. Let’s say you have some discrete valued vector that you want to represent in the latent space, how would I get feature vector? Like the feature vectors from a given model, extracted and take differences?
a = torch.nn.Embedding(10, 50)
# Discreet values
a.weight
# Indexer
a(torch.LongTensor([1,1,0]))
that still gives a real-valued vector. Want to reverse this operation. Something akin to encoder decoder .
The feature vector would be the output of the embedding layer and you could calculate the difference afterwards to get the index back:
emb = torch.nn.Embedding(10, 50)
x = torch.tensor([5])
out = emb(x)
out.shape
emb.weight.shape
rev = ((out - emb.weight).abs().sum(1) < 1e-6).nonzero()
print(rev)
# > tensor([[5]])
@safin_salih Understood
You can use this
a = torch.nn.Embedding(10, 50)
b = torch.LongTensor([2,8])
results = a(b)
def get_embedding_index(x):
results = torch.where(torch.sum((a.weight==x), axis=1))
if len(results[0])==len(x):
return None
else:
return results[0][0]
indices = torch.Tensor(list(map(get_embedding_index, results)))
indices
tensor([2., 8.])
UPDATE: I did a revamp on this to make it more memory efficient via .expand()
, and consequently faster.
Just a slightly more robust definition that handles cases of x.dim>0
:
def emb2indices(output, emb_layer):
# output is size: [batch, sequence, emb_length], emb_layer is size: [num_tokens, emb_length]
emb_weights = emb_layer.weight
# get indices from embeddings:
emb_size = output.size(0), output.size(1), -1, -1
out_size = -1, -1, emb_weights.size(0), -1
out_indices = torch.argmin(torch.abs(output.unsqueeze(2).expand(out_size) -
emb_weights.unsqueeze(0).unsqueeze(0).expand(emb_size)).sum(dim=3), dim=2)
return out_indices
For a test:
emb=nn.Embedding(300000, 100)
test_indices=torch.randint(300000, (1, 100)) #batch 1, length 100
print(test_indices)
test_embed = emb(test_indices)
print(test_embed.size())
indices=emb2indices(test_embed, emb)
print(indices)
It’s more memory efficient than using stack. However, it can be modified to use less memory if iterating over the embeddings one by one.