Try to remove the .cuda()
calls and replace them with .to(device)
, so that you can write device-agnistic code.
Switching device = 'cpu'
in your script should yield a CPU-only run.
Once this is done, you might get a better error message for the initial issue.