Calculate gradients when network is set to .eval()

I have 2 networks (call them preprocess and final). I’m trying to feed an output from preprocess into final, where the preprocess network is being trained based on a loss from final, and the final network is being evaluated but not trained (e.g. eval(), where dropout is evaluated correctly and weights are frozen etc.).

Something like this:

# Freeze final
for p in final.parameters():
	p.requires_grad = False

final.eval()
preprocess.train()

output = final(preprocess(input))
loss(output)

preprocess.backward()
optimizer.step()

I’m getting the error “RuntimeError: cudnn RNN backward can only be called in training mode”. What’s the idiomatic way of using train/eval networks in concert/simultaneously?

Notes:

  • Both networks are actually being trained in an alternating fashion
  • final actually has two parts; the input to preprocess is actually an intermediate output of final, which is later fed back into a different part of final

This is a limitation of the CuDNN implementation (combined with the layering of autograd over ATen, which does not allow the CuDNN implementation in ATen to see whether gradients are required). An easy way is to fix this is to set the entire network to eval but then put the RNNs into train (this re-enables dropout you might have in them, though). Another way is to disable the CuDNN backend (with torch.backends.cudnn.flags(enabled=False): ...) or use JITed RNNs.

Best regards

Thomas

2 Likes