Fix some parameters in a layer

I’m trying to train a network with some strategies and needs to fix some params in one layer during training. For example, a conv layer that is 6 channel 3x3, and only params in channel 1, or even the top 2x2 in channel 1 should be updated. Wandering how to achieve this.

I think the easiest way would be to zero out the gradients of all frozen parameters before calling optrimizer.step().

Is there any way to automate this process. I saw people mentioning backward_hooks but it seems to have a lot of issues. Any way to get it work or a simple workaround?

You could register a hook directly on the parameter, e.g. as:

model = nn.Conv2d(3, 6, 3, 1, 1)
mask = torch.randint(0, 2, (6, 3, 3, 3)).float()
model.weight.register_hook(lambda x: x * mask)

model(torch.randn(1, 3, 4, 4)).mean().backward()