Saving weights using pickle or state_dict

Hi, I have a model that I need to save the weights during training (for example from 20 to 50 samples of weights in each 50 epochs) and in test time load the model and make inference using these weights. Then taking an average of these predictions. But when I use pickle I have this problem that when I load the weights, if I am in a different gpu then I would get the following error:

Attempting to deserialize object on CUDA device 2 but torch.cuda.device_count() is 2. Please use torch.load with map_location to map your storages to an existing device.

I save the weights using the following command:

weight_set_samples = []
weight_set_samples.append(copy.deepcopy(model.state_dict()))

and when training is finished I save the weights using the following command:

pickle.dump(net.weight_set_samples, model_dir+'/state_dicts.pkl', pickle.HIGHEST_PROTOCOL)

But when I use the following code to upload the weights:

 with open(model_dir+'/state_dicts.pkl','rb') as weights:
        
        weight_set_samples = pickle.load(weights)

I would run into a problem when the gpu during testing is not the same as the gpu during training.
I read the documentation of pickle there wasn’t anything like map-location in pickle object that I can solve the problem.
I would appreciate it if someone has any idea to solve this problem since the processes were long-term processes and it is really time consuming repeating the process.
Second question is that how can I do that without pickle that I do not to run this problem. I want to save the weights as an array or list that it would be easy to load the model on these weights and making inference quickly.

The error is raised by PyTorch in its serialization module and the map_location argument can be specified in torch.load. I don’t know how pickle can be directly used to avoid this error, but would recommend to stick to torch.save/load and specify the pickle_module, if necessary.

Ok, thank you for your answer.