Download model .pth files saved in Google Drive to a different environment

I used torch.save() with the model and the model weights to save a .pth for each object. I have those .pth files saved in a Google Drive. I want to read those files in for another environment. How should I do this?

I’m trying to use PyDrive2 to download the files to another environment. This is my code for downloading from the Google Drive where the files are saved:

from pydrive2.auth import GoogleAuth
from pydrive2.drive import GoogleDrive

gauth = GoogleAuth()
gauth.LocalWebserverAuth()

drive = GoogleDrive(gauth)

model = drive.CreateFile({'id': '1MbDv8hwtmWS_oV90El_4C9IKypG1yGLS'})
model.GetContentFile('MaskInstanceModel.pth') # Download file as 'MaskInstanceModel.pth'.

modelweights = drive.CreateFile({'id': '1EKaPhcaMmCQVhzCaH0jlaqWscf-JbSKO'})
model.GetContentFile('MaskModelParams.pth') 

The files seem to download fine. They’re the right size and there is no message or warning about an error.

In the script that uses the model and weights, this is the snippet of code:

import torch
import torch.utils.data
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.io import read_image
from torchvision.transforms.functional import convert_image_dtype
import torchvision.transforms.functional as F

...
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# Load the model from the saved file:
model = torch.load('MaskInstanceModel.pth', map_location=device)

model.load_state_dict(torch.load('MaskModelParams.pth', map_location=device))

But I get this error:>

bash-4.3$ python3 model.py
Traceback (most recent call last):
File “/home/john/imseg/src/model.py”, line 92, in
model.load_state_dict(torch.load(‘MaskModelParams.pth’, map_location=device))
File “/home/john/.local/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1379, in load_state_dict
state_dict = state_dict.copy()
File “/home/john/.local/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1130, in getattr
raise AttributeError("’{}’ object has no attribute ‘{}’".format(
AttributeError: ‘MaskRCNN’ object has no attribute ‘copy’

I guess this has something to do with how PyDrive downloaded the file from Google Drive. I’m not married to PyDrive, it just seemed like the easiest way to download files from a Google Drive. Is there a better way to move the model files around?

Based on the error message it seems that

torch.load('MaskModelParams.pth', map_location=device)

returns a MaskRCNN object, not the state_dict and would thus fail, so you could double check that MaskModelParams.pth is indeed containing the state_dict.