Suppose that the weight matrix for one layer is [32, 64, 4,2]. Is it possible to freeze its first filter while keeping the other 31 filters trainable?
You could freeze this particular part of your matrix by zeroing out the gradients of it.
Unfortunately you cannot freeze a certain part of the weight matrix (only the complete matrix).
Using the custom function is also the one way.
This is the definition of the custom layer zeroing out the gradients:
from torch.autograd.function import InplaceFunction
class Freeze(InplaceFunction):
@staticmethod
def forward(ctx, w):
return w
@staticmethod
def backward(ctx, grad_output):
return grad_output * 0.0,
def freeze(w):
return Freeze.apply(w)
This is the example of the code in conv2d case:
import torch
from torch.autograd import Variable
from torch.nn import Conv2d
import torch.nn.functional as F
x = Variable(torch.randn(1, 64, 10, 10))
net = Conv2d(64, 32, (4, 2))
w = torch.cat((freeze(net.weight[:1]), net.weight[1:]))
loss = F.conv2d(x, weight=w, stride=net.stride, padding=net.padding).sum()
loss.backward()
print(net.weight.grad[0])
Note: as far as I understand, the weight decay is applied even if gradients are zeroed, so please be careful about it.