My model has its tensors on cuda, and its state is saved as such. When I load the model state for inference with checkpoint = torch.load(weights_fpath)
, the model state keeps the same device configuration. This adds overhead to torch.load (in my case 1-2 seconds). Calling checkpoint = torch.load(weights_fpath, "cpu")
takes around 1 millisecond however. I don’t really understand the point to all of this, because when I’ll call model.load_state_dict(checkpoint["model_state"])
after, all my tensors are going to end up on cpu, regardless of the device specified in the model state.
-
What’s the point of setting the device of the tensors in the model state, what’s the use of the map_location argument of torch.load?
-
Is there a way to directly load cuda tensors when doing
model.load_state_dict()
that would be faster than loading cpu tensors then callingmodel.cuda()
?