Using pretrained model

Hi All
just wondering how to code if I would like to use the pretrained model based on argument. E.g I code a function as train(model), when user input the model name, the function will download the pretrained model accordingly. I know I can use ‘if’, but as there are many pre-trainded model, so is there any elegant way to do so?

I tried as below:
m = ‘vgg16’
model = models.m(pretrained = True)
but it will pop up the errors.

Yeah, the following should work:

from torchvision import models
m = 'vgg16'
my_model = getattr(models, m)(pretrained=True)
print(my_model)

the issue with your version was, that vgg16 actually is a function-name and yours is a string, so, what you were trying to do is something like models.'vgg16' and getattr resolves the function’s name.

1 Like

got it, thanks a lot. @justusschock