In C++ API how to get the function type in a Functional module?

Hi all,

I am a beginner user of PyTorch. As mentioned in the title, my question is: how can I get the function type in Functional module in C++ API? For example, if there are two activation modules in my network:

torch::nn::Functional activ1(torch::relu)
torch::nn::Functional activ2(torch::tanh)

Then I pass each of them to another procedure, say void print_name(torch::nn::Module &m) which print the name of current module, then how can I know that activ1 is ReLU and activ2 is Tanh? I try to call the name() but both of them only return “torch::nn::Functional”…

BTW: Is there any way that I can override their module name? If this works then it can also indirectly resolve my question.

Thanks in advance!

Hi all,

I found a stupid but workable solution for this: just write two new activation module classes that inherit torch::nn::FunctionalImpl, and then override their name_ in the constructor. The code is like following:

struct ActivationTanhImpl : torch::nn::FunctionalImpl {
public:
	ActivationTanhImpl() :
		torch::nn::FunctionalImpl(torch::tanh),
		torch::nn::Module("torch::nn::TanhImpl") {}
};
TORCH_MODULE(ActivationTanh);

struct ActivationReLUImpl : torch::nn::FunctionalImpl {
public:
	ActivationReLUImpl() :
		torch::nn::FunctionalImpl(torch::relu),
		torch::nn::Module("torch::nn::ReLUImpl") {}
};
TORCH_MODULE(ActivationReLU);

I know this is not elegant solution, but currently that’s the only workable one I can come out. Please let me know if you have any suggestions for this question. Thanks again!