Implemented negative sampling using scatter function

I want to implement negative sampling. The tensor x is very sparse, containing 1 or 0.
I try to make 100 negative samples, and calculate the loss only using positive and 100 negative samples.

I included “sampling” when I calculate loss, and it made backpropagation impossible…

but I don’t know the way to fix it.

I got this error :
RuntimeError: Function ScatterBackward0 returned an invalid gradient at index 1 - got [5, 100] but expected shape compatible with [5, 16736]

I attach my snippet below.
Thank you in advance

where_zeros = torch.zeros(x.shape).to(device)
pos_score = torch.where(x > 0, recon_x, where_zeros)
false_score = torch.where(x == 0.0, recon_x, where_zeros)
for i in range(x.shape[0]):
    negs = false_score[i].nonzero().view(1,-1)
    # print(negs)
    weight = torch.ones(negs.shape).to(device)
    neg_samples = torch.multinomial(weight, 100)
    if i == 0:
        neg_ids = torch.tensor(neg_samples, dtype=torch.long).unsqueeze(0).to(device)
        neg_ids =, torch.tensor(neg_samples, dtype=torch.long).unsqueeze(0).to(device)), 0)

neg_ids = neg_ids.squeeze()

neg_x = torch.zeros(x.shape).to(device).scatter(1, neg_ids, x)
neg_score = torch.zeros(x.shape).to(device).scatter(1, neg_ids, recon_x)
Recon_loss_pos = F.mse_loss(F.sigmoid(pos_score), x)
Recon_loss_neg = F.mse_loss(F.sigmoid(neg_score), neg_x)
Recon_loss = Recon_loss_neg + Recon_loss_pos

Same error, have you solved this problem?