Update only sub-elements of weights


(Jan Olle) #1

Hi everyone.

So I’m trying to implement a deep neural network composed of a few linear layers. Every layer is composed of a 2x2 weight matrix and no biases. The peculiarity of my network is that I would only like to train one of the elements of each weight and leave the others untouched.

To make things more concrete, every weight looks like:

W = torch.tensor([[number1, number2],[number3, part I want to train]])

Things I’ve thought is proceed as usual (i.e. calculate the loss function and then use some optimizer) and then somehow zero the gradients for the parts of the weights I don’t want to touch but I haven’t managed to succeed. Any ideas? I can share parts of the code if that helps :smile:
Thanks!


#2

You could use a hook to zero out all other gradients.
Here is a small example for a simple model:

model = nn.Sequential(
    nn.Linear(2, 2),
    nn.Sigmoid(),
    nn.Linear(2, 2)
)
# Create Gradient mask
gradient_mask = torch.zeros(2, 2)
gradient_mask[0, 0] = 1.0
model[0].weight.register_hook(lambda grad: grad.mul_(gradient_mask))

optimizer = optim.SGD(model.parameters(), lr=1.0)
criterion = nn.CrossEntropyLoss()

batch_size = 10
x = torch.randn(batch_size, 2)
target = torch.randint(0, 2, (batch_size,))

# Get weight before training
w0 = model[0].weight.detach().clone()

# Single training iteration
optimizer.zero_grad()
output = model(x)
loss = criterion(output, target)
loss.backward()
print('Gradient: ', model[0].weight.grad)
optimizer.step()

# Compare weight update
w1 = model[0].weight.detach().clone()
print('Weights updated ', w0!=w1)