Hi all,
I am a beginner user of PyTorch. As mentioned in the title, my question is: how can I get the function type in Functional module in C++ API? For example, if there are two activation modules in my network:
torch::nn::Functional activ1(torch::relu)
torch::nn::Functional activ2(torch::tanh)
Then I pass each of them to another procedure, say void print_name(torch::nn::Module &m) which print the name of current module, then how can I know that activ1 is ReLU and activ2 is Tanh? I try to call the name() but both of them only return “torch::nn::Functional”…
BTW: Is there any way that I can override their module name? If this works then it can also indirectly resolve my question.
Thanks in advance!
Hi all,
I found a stupid but workable solution for this: just write two new activation module classes that inherit torch::nn::FunctionalImpl, and then override their name_ in the constructor. The code is like following:
struct ActivationTanhImpl : torch::nn::FunctionalImpl {
public:
ActivationTanhImpl() :
torch::nn::FunctionalImpl(torch::tanh),
torch::nn::Module("torch::nn::TanhImpl") {}
};
TORCH_MODULE(ActivationTanh);
struct ActivationReLUImpl : torch::nn::FunctionalImpl {
public:
ActivationReLUImpl() :
torch::nn::FunctionalImpl(torch::relu),
torch::nn::Module("torch::nn::ReLUImpl") {}
};
TORCH_MODULE(ActivationReLU);
I know this is not elegant solution, but currently that’s the only workable one I can come out. Please let me know if you have any suggestions for this question. Thanks again!