I’ve trained a model using PyTorch and saved a state dict file. I have loaded the pre-trained model using the code below. I am getting an error message regarding invalid syntax. Do i need to replace **kwargs with something else? Many Thanks in advance
File "load_model_delete.py", line 63
model=VGG((*args, **kwargs))
^
SyntaxError: invalid syntax
I am following instruction available at this site: https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-across-devices
Many Thanks
import argparse
import datetime
import glob
import os
import random
import shutil
import time
from os.path import join
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import ToTensor
from tqdm import tqdm
import torch.optim as optim
from convnet3 import Convnet
from dataset2 import CellsDataset
from convnet3 import Convnet
from VGG import VGG
from dataset2 import CellsDataset
from torchvision import models
from Conv import Conv2d
parser = argparse.ArgumentParser('Predicting hits from pixels')
parser.add_argument('name',type=str,help='Name of experiment')
parser.add_argument('data_dir',type=str,help='Path to data directory containing images and gt.csv')
parser.add_argument('--weight_decay',type=float,default=0.0,help='Weight decay coefficient (something like 10^-5)')
parser.add_argument('--lr',type=float,default=0.0001,help='Learning rate')
args = parser.parse_args()
metadata = pd.read_csv(join(args.data_dir,'gt.csv'))
metadata.set_index('filename', inplace=True)
# create datasets:
dataset = CellsDataset(args.data_dir,transform=ToTensor(),return_filenames=True)
dataset = DataLoader(dataset,num_workers=4,pin_memory=True)
model_path = '/Users/nubstech/Documents/GitHub/CellCountingDirectCount/VGG_model_V1/checkpoints/checkpoint.pth'
class VGG(nn.Module):
def __init__(self, pretrained=True):
super(VGG, self).__init__()
vgg = models.vgg16(pretrained=True)
# if pretrained:
vgg.load_state_dict(torch.load(model_path))
features = list(vgg.features.children())
self.features4 = nn.Sequential(*features[0:23])
self.de_pred = nn.Sequential(Conv2d(512, 128, 1, same_padding=True, NL='relu'),
Conv2d(128, 1, 1, same_padding=True, NL='relu'))
def forward(self, x):
x = self.features4(x)
x = self.de_pred(x)
return x
model=VGG()
#model.load_state_dict(torch.load(model_path),strict=False)
model.eval()
#optimizer = torch.optim.Adam(model.parameters(),lr=args.lr,weight_decay=args.weight_decay)
for images, paths in tqdm(dataset):
targets = torch.tensor([metadata['count'][os.path.split(path)[-1]] for path in paths]) # B
targets = targets.float()
# code to print training data to a csv file
#filename=CellsDataset(args.data_dir,transform=ToTensor(),return_filenames=True)
output = model(images) # B x 1 x 9 x 9 (analogous to a heatmap)
preds = output.sum(dim=[1,2,3]) # predicted cell counts (vector of length B)
print(preds)
paths_test = np.array([paths])
names_preds = np.hstack(paths)
print(names_preds)
df=pd.DataFrame({'Image_Name':names_preds, 'Target':targets.detach(), 'Prediction':preds.detach()})
print(df)
# save image name, targets, and predictions
df.to_csv(r'model.csv', index=False, mode='a')