My model has its tensors on cuda, and its state is saved as such. When I load the model state for inference with
checkpoint = torch.load(weights_fpath), the model state keeps the same device configuration. This adds overhead to torch.load (in my case 1-2 seconds). Calling
checkpoint = torch.load(weights_fpath, "cpu") takes around 1 millisecond however. I don’t really understand the point to all of this, because when I’ll call
model.load_state_dict(checkpoint["model_state"]) after, all my tensors are going to end up on cpu, regardless of the device specified in the model state.
What’s the point of setting the device of the tensors in the model state, what’s the use of the map_location argument of torch.load?
Is there a way to directly load cuda tensors when doing
model.load_state_dict()that would be faster than loading cpu tensors then calling