I’m using torchvision c++ API. Vision::Model namespace there a way to implement polymorphism in the upper class? I can’t even feel it when I look at the code
I want to point to a parent class pointer.
vision::models::AlexNet alexnet;
vision::models::VGG19 vgg19;
vision::models::ResNet18 resnet18;
vision::models::InceptionV3 inceptionv3; // 299
vision::models::MobileNetV2 mobilenet;
vision::models::ResNext50_32x4d resnext50;
vision::models::WideResNet50_2 wide_resnet50;
tom
(Thomas V)
December 8, 2021, 7:41am
2
Would the ModuleHolder mechanism from LibTorch work?
Best regards
Thomas
I didn’t use it torch::nn::ModuleHolder
, but the torch::nn::ModuleHolder class have any templates (model dependency)? Could you write down the code you thought of?
tom
(Thomas V)
December 8, 2021, 8:10am
4
Maybe I don’t understand well enough what you want to achieve.
It is difficult to return the auto from the code below, but I would like to solve it with unique_pointer.
Or if there’s another way,
auto create_model(const std::string model_type, int64_t num_classes)
{
// std::unique_ptr<torch::nn::Module> res;
vision::models::AlexNet alexnet;
vision::models::VGG19 vgg19;
vision::models::ResNet18 resnet18;
vision::models::InceptionV3 inceptionv3; // 299
vision::models::MobileNetV2 mobilenet;
vision::models::ResNext50_32x4d resnext50;
vision::models::WideResNet50_2 wide_resnet50;
if (model_type.find("alexnet") >= 0) {
alexnet = vision::models::AlexNet();
// load imagenet weight
torch::load(alexnet, model_type);
// unregister "fc"
alexnet->unregister_module("classifier");
// https://github.com/pytorch/vision/blob/main/torchvision/csrc/models/alexnet.cpp
alexnet->classifier = torch::nn::Sequential(
torch::nn::Dropout(),
torch::nn::Linear(256 * 6 * 6, 4096),
torch::nn::Functional(torch::relu),
torch::nn::Dropout(),
torch::nn::Linear(4096, 4096),
torch::nn::Functional(torch::relu),
torch::nn::Linear(4096, num_classes)
);
// register "fc"
alexnet->register_module("classifier", alexnet->classifier);
return alexnet;
}
else if (model_type.find("vgg") >= 0) {
vgg19 = vision::models::VGG19();
// load imagenet weight
torch::load(vgg19, model_type);
// unregister "fc"
vgg19->unregister_module("classifier");
// https://github.com/pytorch/vision/blob/main/torchvision/csrc/models/vgg.cpp
vgg19->classifier = torch::nn::Sequential(
torch::nn::Linear(512 * 7 * 7, 4096),
torch::nn::Functional(vision::models::modelsimpl::relu_),
torch::nn::Dropout(),
torch::nn::Linear(4096, 4096),
torch::nn::Functional(vision::models::modelsimpl::relu_),
torch::nn::Dropout(),
torch::nn::Linear(512, num_classes)
);
// register "fc"
vgg19->register_module("classifier", vgg19->classifier);
return vgg19;
}
tom
(Thomas V)
December 8, 2021, 9:12am
6
Ah, sorry, ModuleHolder
isn’t the solution, but maybe AnyModule can help.
Best regards
Thomas
Is there a class that is supported by a similar case torch:: optim
?
tom
(Thomas V)
December 8, 2021, 11:40am
9
I’m not aware of any but wouldn’t the Optimizer
base class work there?
As far as I understand, the trouble with modules comes mainly from the fact that it is not a proper subclassing because the signatures of forward
are incompatible. I would naively expect that this should be better with optimizers.
Best regards
Thomas
1 Like