I have a network that trains for a bit. Then I call requires_grad_(False) to some layers then later call requires_grad_() to those layers and requires_grad_(False) to the other layers. At this point training completely stops. Looking at Gradients exist but weights not updating I have the code:
if(epoch == test):
before = list(cnn.parameters())
for i in range(len(before)):
before[i] = before[i].clone()
loss.backward()
optimizer.step()
if(epoch == test):
after = list(cnn.parameters())
for i in range(len(before)):
print(torch.equal(before[i].data, after[i].data))
if(type(after[i].grad) != type(None)):
print(after[i].grad.data[:].max())
On epoch 0 before the requires grad it works.
False
tensor(7.1678e-08, device=‘cuda:0’, dtype=torch.float64)
False
tensor(3.6259e-06, device=‘cuda:0’, dtype=torch.float64)
True
True
False
tensor(1.1820e-05, device=‘cuda:0’, dtype=torch.float64)
False
tensor(2.0888e-05, device=‘cuda:0’, dtype=torch.float64)
True
True
False
tensor(8.6558e-05, device=‘cuda:0’, dtype=torch.float64)
False
tensor(0.0002, device=‘cuda:0’, dtype=torch.float64)
True
True
False
tensor(0.0020, device=‘cuda:0’, dtype=torch.float64)
False
tensor(0.0041, device=‘cuda:0’, dtype=torch.float64)
True
True
But then on epoch 2 after the requires grad happens the grad is still calculated but the weights dont get updated
True
tensor(-6.2089e-07, device=‘cuda:0’, dtype=torch.float64)
True
tensor(1.1862e-06, device=‘cuda:0’, dtype=torch.float64)
True
tensor(-1.1782e-06, device=‘cuda:0’, dtype=torch.float64)
True
True
True
tensor(2.5829e-06, device=‘cuda:0’, dtype=torch.float64)
True
tensor(4.0316e-06, device=‘cuda:0’, dtype=torch.float64)
True
tensor(5.7105e-06, device=‘cuda:0’, dtype=torch.float64)
True
True
True
tensor(3.4637e-05, device=‘cuda:0’, dtype=torch.float64)
True
tensor(5.0992e-05, device=‘cuda:0’, dtype=torch.float64)
True
tensor(8.0335e-05, device=‘cuda:0’, dtype=torch.float64)
True
True
True
tensor(0.0016, device=‘cuda:0’, dtype=torch.float64)
True
tensor(0.0009, device=‘cuda:0’, dtype=torch.float64)
True
tensor(0.0019, device=‘cuda:0’, dtype=torch.float64)
True
True