Built in way to only update weights corresponding to certain outputs

I’ve written a very simple CBOW model, like so:

class CBOW(torch.nn.Module):
    def __init__(self, vocab_size, embedding_dim=200):
        super(CBOW, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.lin = nn.Linear(embedding_dim, vocab_size)
        self.activation = nn.LogSoftmax(dim=1)

    def forward(self, inputs):
        embeds = self.embeddings(inputs)
        out = torch.mean(embeds, dim=1)
        out = self.activation(self.lin(out))
        return out

model = CBOW(len(dataset.word2index))
criterion = nn.NLLLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.5)

def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

if __name__ == '__main__':
    training_epochs = 10

    for epoch in range(training_epochs):
        train(epoch)

Now, I want to implement negative sampling, so I need a way to only update the weights corresponding to certain rows in the output. I know that one way to do this is I can manually zero the gradients corresponding to the rows in the output that I don’t want the weights updated, as suggested here:

However, someone mentioned to me that there is an easier way to do this that is “built-in” to PyTorch? Can anyone tell me if this is true? Is there a built-in way to choose which weights I want to update?

1 Like

Hi,

The proposed method is actually the right way to do it.
Another approach not using hooks is to no include these rows into the computation of the loss, so that the .backward() will not required to be hooked into.
Few examples below:

output = model(data)
to_keep_indices = get_indices_to_keep(output)
output = output.index_select(0, to_keep_indices)
target = target.index_select(0, to_keep_indices)
# You can also do the same with a 0/1 mask and the masked_select function
loss = criterion(output, target)

Or

# Create your criterion as:
criterion = nn.NLLLoss(reduce=False) # EDITED: wrong keyword argument

# In your training loop
loss = criterion(output, target)
to_keep_indices = get_indices_to_keep(output, loss)
final_loss = loss.index_select(0, to_keep_indices)
# You can also do the same with a 0/1 mask and the masked_select function
# Now average the loss of each of the samples you want to keep and backward.
final_loss.mean().backward()
1 Like

Wouldn’t the second proposal require the reduce=False to NLLLoss?

Ho yes my bad, wrong argument :smiley:

1 Like