Would be great if torch.load() support torch.device()

It seems that currently map_location in torch.load() didn’t support torch.device(). To be device-agnostic my code is now looked like this

if device.type == 'cuda':
    checkpoint = torch.load(model_file)
else:
    checkpoint = torch.load(model_file, map_location=lambda storage, loc: storage)

or maybe I could do this

checkpoint = torch.load(model_file, map_location=device.type)

Nevertheless, it is neater to do this

checkpoint = torch.load(model_file, map_location=device)