How to vision::models polymorphism?

I’m using torchvision c++ API. Vision::Model namespace there a way to implement polymorphism in the upper class? I can’t even feel it when I look at the code

I want to point to a parent class pointer.

vision::models::AlexNet         alexnet;
vision::models::VGG19           vgg19;
vision::models::ResNet18        resnet18;
vision::models::InceptionV3     inceptionv3; // 299
vision::models::MobileNetV2     mobilenet;
vision::models::ResNext50_32x4d resnext50;
vision::models::WideResNet50_2  wide_resnet50;

Would the ModuleHolder mechanism from LibTorch work?

Best regards

Thomas

I didn’t use it torch::nn::ModuleHolder, but the torch::nn::ModuleHolder class have any templates (model dependency)? Could you write down the code you thought of?

Maybe I don’t understand well enough what you want to achieve.

It is difficult to return the auto from the code below, but I would like to solve it with unique_pointer.

Or if there’s another way,

auto create_model(const std::string model_type, int64_t num_classes)
{
	// std::unique_ptr<torch::nn::Module> res;
	
	vision::models::AlexNet         alexnet;
	vision::models::VGG19           vgg19;
	vision::models::ResNet18        resnet18;
	vision::models::InceptionV3     inceptionv3; // 299
	vision::models::MobileNetV2     mobilenet;
	vision::models::ResNext50_32x4d resnext50;
	vision::models::WideResNet50_2  wide_resnet50;

	if (model_type.find("alexnet") >= 0) {
		alexnet = vision::models::AlexNet();

		// load imagenet weight
		torch::load(alexnet, model_type);

		// unregister "fc"
		alexnet->unregister_module("classifier");

		// https://github.com/pytorch/vision/blob/main/torchvision/csrc/models/alexnet.cpp
		alexnet->classifier = torch::nn::Sequential(
			torch::nn::Dropout(),
			torch::nn::Linear(256 * 6 * 6, 4096),
			torch::nn::Functional(torch::relu),
			torch::nn::Dropout(),
			torch::nn::Linear(4096, 4096),
			torch::nn::Functional(torch::relu),
			torch::nn::Linear(4096, num_classes)
		);

		// register "fc"
		alexnet->register_module("classifier", alexnet->classifier);

		return alexnet;
	}
	else if (model_type.find("vgg") >= 0) {
		vgg19 = vision::models::VGG19();

		// load imagenet weight
		torch::load(vgg19, model_type);

		// unregister "fc"
		vgg19->unregister_module("classifier");

		// https://github.com/pytorch/vision/blob/main/torchvision/csrc/models/vgg.cpp
		vgg19->classifier = torch::nn::Sequential(
			torch::nn::Linear(512 * 7 * 7, 4096),
			torch::nn::Functional(vision::models::modelsimpl::relu_),
			torch::nn::Dropout(),
			torch::nn::Linear(4096, 4096),
			torch::nn::Functional(vision::models::modelsimpl::relu_),
			torch::nn::Dropout(),
			torch::nn::Linear(512, num_classes)
		);

		// register "fc"
		vgg19->register_module("classifier", vgg19->classifier);

		return vgg19;
	}

Ah, sorry, ModuleHolder isn’t the solution, but maybe AnyModule can help.

Best regards

Thomas

Thank you! I solve it

Is there a class that is supported by a similar case torch:: optim?

I’m not aware of any but wouldn’t the Optimizer base class work there?
As far as I understand, the trouble with modules comes mainly from the fact that it is not a proper subclassing because the signatures of forward are incompatible. I would naively expect that this should be better with optimizers.

Best regards

Thomas

1 Like