Your explanation is correct and you can also verify it using this simple example:
conv = nn.Conv2d(1, 1, 3)
optimizer = torch.optim.Adam(conv.parameters())
x = torch.randn(1, 1, 24, 24)
print(conv.weight)
Parameter containing:
tensor([[[[-0.1785, 0.2719, -0.1714],
[-0.3003, 0.1127, -0.2276],
[-0.0389, 0.2258, 0.2369]]]], requires_grad=True)
mask = torch.randint(0, 2, (1, 1, 3, 3))
print(mask)
tensor([[[[1, 1, 0],
[0, 0, 0],
[1, 1, 1]]]])
with torch.no_grad():
conv.weight.mul_(mask)
Parameter containing:
tensor([[[[-0.1785, 0.2719, -0.0000],
[-0.0000, 0.0000, -0.0000],
[-0.0389, 0.2258, 0.2369]]]], requires_grad=True)
out = conv(x)
out.mean().backward()
print(conv.weight.grad)
tensor([[[[-0.0447, -0.0345, -0.0218],
[-0.0624, -0.0594, -0.0353],
[-0.0708, -0.0702, -0.0538]]]])
optimizer.step()
print(conv.weight)
Parameter containing:
tensor([[[[-0.1775, 0.2729, 0.0010],
[ 0.0010, 0.0010, 0.0010],
[-0.0379, 0.2268, 0.2379]]]], requires_grad=True)
To entirely remove some elements from the computation graph you could try to split the tensor into a trainable parameter and a static tensor. In the forward
method you could then concatenate or stack the different parts and use the functional API via F.conv2d
.
This post gives you an example using a linear layer.
Im sure there might be a more elegant approach now, but would need to play around with it a bit more.