I need to define a generic class as it might take different activation and convolution layers, but when I’m trying to define the template above the class definition in the header file, I got an error from TORCH_MODULE
indicating “argument list for the class template is missing”.
template<typename ActivationType, typename ConvType, typename... ConvArgs>
class ResDecoderBlockImpl: public nn::Module {
private:
BasicConv2d skip_connection;
nn::Upsample up_sample;
nn::Sequential seq{nullptr};
public:
ResDecoderBlockImpl(int, int, int, int, ConvType, ActivationType, ConvArgs...);
at::Tensor forward(const at::Tensor&, const at::Tensor&);
};
TORCH_MODULE(ResDecoderBlock);