I wanted to creat an architecture that gradient flow is blocked from a certain layer beckward and the former layer will not update. I don’t want freeze that part. I want to put gradient of the intended layer be zero and block the gradient flow somehow the former layer will not be updated in this path. Is there any solution? I tried to put the grad of the intended output’s layer to be zero manually in the forward path of the model but I still have the gradient values for the layers before that.
If you know during the forward which part you want to block the gradients from, you can use
.detach() on the output of this block to exclude it from the backward.
If you only know after the forward which part you want to block, you will need to add a hook to the Tensor that is the output of your block as:
out.register_hook(lambda grad: torch.zeros_like(grad))
Thanks for your reply. Yes, I know which part should block it.
I tried with the following code, but when I check the weights, all of them does not have the gradient.
def __init__(self, input_channels=1, output_channels=32):
self.layer1 = nn.Conv3d(in_channels=input_channels, out_channels=output_channels, kernel_size=(3, 3, 3), padding=1)
self.layer2 = nn.Conv3d(in_channels=output_channels, out_channels=2*output_channels, kernel_size=(3, 3, 3), padding=1)
self.layer3 = nn.Conv3d(in_channels=2 * output_channels + 1, out_channels=1, kernel_size=(3, 3, 3), padding=1)
self.layer4 = nn.ReLU(inplace=False)
def forward(self, x):
o1 = self.layer1(x)
o2 = self.layer2(o1)
x1 = o2.detach()
x2 = torch.cat((x1,x),1)
o3 = self.layer3(x2)
ceriterion = nn.MSELoss()
model = simm().to(device)
w1 = model.layer1.weight.data
w2 = model.layer2.weight.data
w3 = model.layer3.weight.data
optimizer = optim.Adam(model.parameters(), lr=0.1)
data = torch.rand(5,1,10,12,13).to(device)
output_b = model(data)
lablel = 10*torch.ones_like(output_b)
loss = ceriterion(output_b,lablel)
print(torch.mean(abs(model.layer1.weight.data - w1)))
print(torch.mean(abs(model.layer2.weight.data - w2)))
print(torch.mean(abs(model.layer3.weight.data - w3)))
I think the weights at the last layer should change, am I right?
Thanks for your help. You are right.