This is very helpful!
I would like to provide a tiny failing example, though I’m not fully convinced this gradient checking is not an overkill.
enc = nn.Conv2d(3, 5, 3, 1, 1)
dec = nn.Conv2d(5, 3, 3, 1, 1)
e_opt = torch.optim.Adam(enc.parameters(), lr=1e-4)
g_opt = torch.optim.Adam(dec.parameters(), lr=1e-4)
N = 1
H = W = 8
for i in range(2):
g_opt.zero_grad()
e_opt.zero_grad()
image = torch.randn(N, 3, H, W)
z = enc(image)
recon = dec(z)
c1_loss = F.l1_loss(recon, image)
# encoder
c1_loss.backward(retain_graph=True)
e_opt.step()
# decoder
g_opt.zero_grad()
# must do this line:
# c1_loss = F.l1_loss(dec(z.detach()), image)
c1_loss.backward()
g_opt.step()
It’s a simple encoder-decoder (placeholder) network. Here I have two separate optimizers for the decoder and encoder respectively. Both are optimizing c1_loss. I first backprop c1 and update encoder. Then I clear the gradient in decoder and re-backprop c1_loss. This will trigger the same runtime error. However, the resulting gradient should be right in this way?
Of course I wrote some unnecessary steps for this simple example, but it is pretty common to optimize encoder with c1_loss + enc_loss, and optimize decocer with c1_loss + dec_loss.