I got similar error and found out torch.cuda.set_device(0)
fixed it:
import torch
torch.cuda.set_device(0) # adding this line fixed it
model = torch.load('mod.pth')
Also it worked on CPU every time:
m = torch.load('mod.pth', map_location=torch.device('cpu'))