It seems that currently map_location
in torch.load()
didn’t support torch.device(). To be device-agnostic my code is now looked like this
if device.type == 'cuda':
checkpoint = torch.load(model_file)
else:
checkpoint = torch.load(model_file, map_location=lambda storage, loc: storage)
or maybe I could do this
checkpoint = torch.load(model_file, map_location=device.type)
Nevertheless, it is neater to do this
checkpoint = torch.load(model_file, map_location=device)