Invert embedding layer does not update weight

In this post (deep learning - How to invert a PyTorch Embedding? - Stack Overflow) I see a very simple and short solution to invert the embedding layer.
I used the inverse embedding layer, but it does not update the weights in the network. The proposed inverse embedding layer is copied from the post here (bellow):

import torch

embeddings = torch.nn.Embedding(1000, 100)
my_sample = torch.randn(1, 100)
distance = torch.norm(embeddings.weight.data - my_sample, dim=1)
nearest = torch.argmin(distance)

What i did:
I used an embedding layer that gets one input and generated 16D output. Then, I add two hidden dense layers (64->16) and one inverse embedding layer. In short

X → embedding ->dense layer(64D)->dense layer(16D) → inverse embedding → X’

X and X’ are integer numbers.

To compute the loss, I used torch.norm(X - X’). But it does not update the weights. I can not figure out the problem and why there is no update in weights.

A short implementation is shown bellow:

# lS_o = Offset, lS_i = input number
optimizer = opts['sgd'](parameters, lr=args.learning_rate)
#--------------------------------------

model forward(self, lS_o, lS_i):
   out_emb1 = self.embl_inp(lS_o, lS_i) # 16D == embedding layer
   out_dl1= self.DLyr1(out_emb1) #  64D == Dense Layer 1
   out_dl2 = self.DLyr2(out_dl1) # 16D == Dense Layer 2
   ly  = out_dl2
   distance = self.emb_out.weight.data-ly[i,None] #subtract each row of weight matrix with each row in ly
   out = torch.argmin(torch.norm(distance, dim=1), dim=0)
   return torch.stack(out)


train_ds = Dataset(...
train_ld = DataLoader(train_ds, ...

pbar = tq.tqdm(enumerate(train_ld, total=len(train_ld))
for j, inputBatch in pbar:
    lS_o, lS_i = unpack_batch(inputBatch)
    ae_out = model(lS_o, lS_i, use_gpu=True)
    loss = torch.norm(ae_out - lS_i)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

torch.argmin is not differentiable.
You can try looking at Gumbel-Softmax or REINFORCE if you want something like a “differentiable argmin”

@Varal7 Thanks for your reply. it is very helpful. I have studied the gumbel_softmax and change the code in the model forward function as the following:

distance = torch.norm(distance,dim=1)
out = 1 - torch.nn.functional.gumbel_softmax(dist)

Because I am going to find minimum, I subtract the output of gumbel_softmax by 1.

I am not sure my implementation is correct. Could you give me more hint on the issueto figure out problem of my work.

I think It would be better to take the opposite of the input of gumbel_softmax instead.

I did as you said. Thanks. I checked the distance is less than one and moved the subtraction inside the function parameter. But I don’t get the weight update again.

Sorry, I re-read your original post,
I think instead of using a loss of X - X’, you should use a loss that computes some distance between “embedding” and “inverse embedding”. This will also make more sense in terms of learning.