How to reimplement regular linear layer using custom module and function in Pytorch C++ (libtorch)

I’m trying to implement a custom module that calls a custom function.

I’m trying to just reimplement a simple linear layer in C++ to get it all working like in this python pytorch example:

https://pytorch.org/docs/master/notes/extending.html

However I’m having issues declaring things such as wether I should declare
torch::nn::ModuleHolder fc1{nullptr};
or
TORCH_MODULE(LinearModule)

Since the code I modified below compiles and runs but it doesn’t output the right output (outputs nan).

/* Author: Kushashwa, krshrimali (Kushashwa Ravi Shrimali) · GitHub

  • Reference:
    */

#include <torch/torch.h>

using namespace torch::autograd;

// Inherit from Function
class LinearFunction : public Function {
public:
// Note that both forward and backward are static functions

// bias is an optional argument
static torch::Tensor forward(
AutogradContext *ctx,
torch::Tensor input,
torch::Tensor weight,
torch::Tensor bias = torch::Tensor()) {

ctx->save_for_backward({input, weight, bias});

//*
std::cout << "Input" << std::endl;
std::cout << input << std::endl;
std::cout << "weight" << std::endl;
std::cout << weight << std::endl;
std::cout << "DIM input" << std::endl;
std::cout << input.dim() << std::endl;
std::cout << "DIM weight" << std::endl;
std::cout << weight.dim() << std::endl;

std::cout << "in dim 0: " << input.sizes()[0] << std::endl;
std::cout << "in dim 1: " << input.sizes()[1] << std::endl;

std::cout << "weight dim 0: " << weight.sizes()[0] << std::endl;
std::cout << "weight dim 1: " << weight.sizes()[1] << std::endl;
//*/

getchar();

auto output = input.t().mm(weight.t());
if (bias.defined()) {

    std::cout << "bias.unsqueeze(0).expand_as(output)" << std::endl;
    std::cout << bias.unsqueeze(0).expand_as(output) << std::endl;
    getchar();

  output += bias.unsqueeze(0).expand_as(output);
}
return output;

}

static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto input = saved[0];
auto weight = saved[1];
auto bias = saved[2];

auto grad_output = grad_outputs[0];

//std::cout << grad_output << std::endl;
//std::cout << grad_outputs[0] << std::endl;
//std::cout << grad_outputs[0][0] << std::endl;
//std::cout << "-------" << std::endl;

auto grad_input = grad_output.mm(weight);
auto grad_weight = grad_output.t().mm(input);
auto grad_bias = torch::Tensor();
if (bias.defined()) {
  grad_bias = grad_output.sum(0);
}

return {grad_input, grad_weight, grad_bias};

}
};

/* Sample code for training a FCN on MNIST dataset using PyTorch C++ API */
struct LinearModule: torch::nn::Module {
LinearModule(){}
LinearModule(int64_t in_features, int64_t out_features) {
//weight = register_parameter(“weight”,torch::empty({in_features,out_features}));
weight = register_parameter(“weight”,torch::empty({in_features,out_features}));
bias = register_parameter(“bias”,torch::empty({out_features}));
}
// Implement Algorithm
torch::Tensor forward(torch::Tensor input) {
// std::cout << x.size(0) << ", " << 784 << std::endl;
auto x = LinearFunction::apply(
input,
weight
//, bias
);
return x;
}

torch::Tensor weight;
torch::Tensor bias;

};

/* Sample code for training a FCN on MNIST dataset using PyTorch C++ API */
struct Net: torch::nn::Module {

//std::shared_ptr<LinearModule> fc1;
//std::shared_ptr<LinearModule> fc2;

torch::nn::ModuleHolder<LinearModule> fc1{nullptr};
torch::nn::ModuleHolder<LinearModule> fc2{nullptr};

Net() {
	// Initialize CNN
	conv1 = register_module("conv1", torch::nn::Conv2d(torch::nn::Conv2dOptions(1, 10, 5)));
	conv2 = register_module("conv2", torch::nn::Conv2d(torch::nn::Conv2dOptions(10, 20, 5)));
	conv2_drop = register_module("conv2_drop", torch::nn::Dropout());
	//fc1 = register_module("fc1", torch::nn::Linear(320, 50));
	//fc2 = register_module("fc2", torch::nn::Linear(50, 10));

	//fc1 = register_module("fc1", LinearModule(320, 50));
	//fc2 = register_module("fc2", LinearModule(50, 10));
	fc1 = register_module<LinearModule>(
		"fc1", std::make_shared<LinearModule>(320,50));

	fc2 = register_module<LinearModule>(
		"fc2", std::make_shared<LinearModule>(50,10));

	//auto x = torch::randn({320, 50}).requires_grad_();
	//auto weight = torch::randn({320, 50}).requires_grad_();
	//fc1 = LinearFunction::apply(x, weight);
	

	//auto x2 = torch::randn({50,10}).requires_grad_();
	//auto weight2 = torch::randn({50, 10}).requires_grad_();
	//fc2 = LinearFunction::apply(x2, weight2);
}

// Implement Algorithm
torch::Tensor forward(torch::Tensor x) {
	// std::cout << x.size(0) << ", " << 784 << std::endl;
	x = torch::relu(torch::max_pool2d(conv1->forward(x), 2));
	x = torch::relu(torch::max_pool2d(conv2_drop->forward(conv2->forward(x)), 2));
	x = x.view({-1, 320});
	x = torch::relu(fc1->forward(x));
	x = torch::dropout(x, 0.5, is_training());
	x = fc2->forward(x);

	return torch::log_softmax(x, 1);
}
torch::nn::Conv2d conv1{nullptr};
torch::nn::Conv2d conv2{nullptr};
torch::nn::Dropout conv2_drop{nullptr};
//torch::nn::Linear fc1{nullptr}, fc2{nullptr};

//LinearModule fc1():weight(torch::Tensor()),bias(torch::Tensor()){};
//LinearModule fc2():weight(torch::Tensor()),bias(torch::Tensor()){};
//LinearModule fc1():weight(torch::Tensor()),bias(torch::Tensor()){};

//LinearModule fc1{torch::Tensor(),torch::Tensor()};
//LinearModule fc2{torch::Tensor(),torch::Tensor()};
//LinearModule *fc1;
//LinearModule *fc2;

};

int main() {
if(0){
auto x = torch::randn({2, 3}).requires_grad_();
auto weight = torch::randn({4, 3}).requires_grad_();
auto y = LinearFunction::apply(x, weight);
y.sum().backward();

	std::cout << x.grad() << std::endl;
	std::cout << weight.grad() << std::endl;


	return 0;
}

auto net = std::make_shared<Net>();

// Create multi-threaded data loader for MNIST data
auto data_loader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
		std::move(torch::data::datasets::MNIST("../data").map(torch::data::transforms::Normalize<>(0.13707, 0.3081)).map(
			torch::data::transforms::Stack<>())), 64);
torch::optim::SGD optimizer(net->parameters(), 0.01); // Learning Rate 0.01

// net.train();

for(size_t epoch=1; epoch<=10; ++epoch) {
	size_t batch_index = 0;
	// Iterate data loader to yield batches from the dataset
	for (auto& batch: *data_loader) {
		// Reset gradients
		optimizer.zero_grad();
		// Execute the model
		torch::Tensor prediction = net->forward(batch.data);
		// Compute loss value
		torch::Tensor loss = torch::nll_loss(prediction, batch.target);
		// Compute gradients
		loss.backward();
		// Update the parameters
		optimizer.step();

		// Output the loss and checkpoint every 100 batches
		if (++batch_index % 100 == 0) {
			std::cout << "Epoch: " << epoch << " | Batch: " << batch_index 
				<< " | Loss: " << loss.item<float>() << std::endl;
			torch::save(net, "net.pt");
		}
	}
}

}