I’m trying to replicate this python code in C++:
class MLPEncoder(nn.Module):
def __init__(self, n_toks, emb_dim, hidden_dim, dropout, bias_offset):
super().__init__()
self.emb = nn.Embedding(n_toks, emb_dim, padding_idx=0)
torch.nn.init.normal_(self.emb.weight.data, 0, 0.01)
self.emb.weight.data[0] = 0
self.act_bn_drop_1 = nn.Sequential(
nn.ReLU(),
nn.BatchNorm1d(emb_dim),
nn.Dropout(dropout),
)
self.bottleneck = nn.Linear(emb_dim, hidden_dim)
self.bottleneck.bias.data.zero_()
self.act_bn_drop_2 = nn.Sequential(
nn.ReLU(),
nn.BatchNorm1d(hidden_dim),
nn.Dropout(dropout),
)
self.output = nn.Linear(hidden_dim, n_toks)
self.output.bias.data.zero_()
self.output.bias.data += bias_offset
def forward(self, x):
x = self.emb(x).sum(dim=1)
x = self.act_bn_drop_1(x)
x = self.bottleneck(x)
x = self.act_bn_drop_2(x)
x = self.output(x)
return x
Here is my C++ code:
struct MLPEncoder : torch::nn::Module
{
MLPEncoder(long n_toks, int64_t emb_dim, int64_t hidden_dim, float dropout, float bias_offset) :
_n_toks(n_toks), _emb_dim(emb_dim), _hidden_dim(hidden_dim),
_dropout(dropout), _bias_offset(bias_offset),
emb(register_module("embedding", torch::nn::Embedding(n_toks, emb_dim))),
act_bn_drop_1(register_module("act_bn_drop_1", torch::nn::Sequential (
torch::nn::Functional(torch::relu),
torch::nn::BatchNorm(_embDim),
torch::nn::Dropout(_dropout)
))),
bottleneck(register_module("bottleneck", torch::nn::Linear(emb_dim, hidden_dim))),
act_bn_drop_2(register_module("act_bn_drop_2", torch::nn::Sequential (
torch::nn::Functional(torch::relu),
torch::nn::BatchNorm(_hidden_dim),
torch::nn::Dropout(_dropout)
))),
output(register_module("output", torch::nn::Linear(hidden_dim, n_toks)))
{
torch::nn::init::normal_(emb.ptr()->weight, 0, 0.01);
emb.ptr()->weight[0] = 0;
torch::NoGradGuard guard; // <- Do I need this?
//ATTEMPT 1:
//torch::Tensor tmp = torch::zeros(n_toks);
//tmp += bias_offset;
//output.ptr()->bias.copy_(tmp);
//ATTEMPT 2:
output.ptr()->bias.zero_();
output.ptr()->bias += bias_offset;
//ATTEMPT 3:
//auto data = output.ptr()->bias.accessor<float,1>();
//for (int i=0; i<output.ptr()->bias.sizes().data()[0]; i++) data[i] = bias_offset;
}
torch::Tensor forward(torch::Tensor& x)
{
x = torch::embedding(emb.ptr()->weight, x, 0); //the third parameter is padding_idx
x = x.sum({1});
x = act_bn_drop_1->forward(x);
x = bottleneck->forward(x);
x = act_bn_drop_2->forward(x);
x = output(x);
return x;
}
long _n_toks;
int64_t _emb_dim, _hidden_dim;
float _dropout, _bias_offset;
torch::nn::Embedding emb;
torch::nn::Sequential act_bn_drop_1;
torch::nn::Linear bottleneck;
torch::nn::Sequential act_bn_drop_2;
torch::nn::Linear output;
};
When I run this, and calculate loss.backward() I get:
#0 __GI_raise (sig=sig@entry=6) at ../sysdeps/unix/sysv/linux/raise.c:51
#1 0x00007fffeb824801 in __GI_abort () at abort.c:79
#2 0x00007fffebe79957 in ?? () from /usr/lib/x86_64-linux-gnu/libstdc++.so.6
#3 0x00007fffebe7fab6 in ?? () from /usr/lib/x86_64-linux-gnu/libstdc++.so.6
#4 0x00007fffebe7faf1 in std::terminate() () from /usr/lib/x86_64-linux-gnu/libstdc++.so.6
#5 0x00007fffebe7faaa in std::rethrow_exception(std::__exception_ptr::exception_ptr) () from /usr/lib/x86_64-linux-gnu/libstdc++.so.6
#6 0x00007ffff6cd5b8a in torch::autograd::Engine::execute(std::vector<torch::autograd::Edge, std::allocator<torch::autograd::Edge> > const&, std::vector<torch::autograd::Variable, std::allocator<torch::autograd::Variable> > const&, bool, bool, std::vector<torch::autograd::Edge, std::allocator<torch::autograd::Edge> > const&) () from /opt/honeycomb/libtorch/lib/libtorch.so.1
#7 0x00007ffff7317254 in torch::autograd::Variable::backward(c10::optional<at::Tensor>, bool, bool) const () from /opt/honeycomb/libtorch/lib/libtorch.so.1
#8 0x00007ffff7319dc8 in torch::autograd::VariableType::backward(at::Tensor&, c10::optional<at::Tensor>, bool, bool) const () from /opt/honeycomb/libtorch/lib/libtorch.so.1
#9 0x00005555555f6204 in at::Tensor::backward (this=0x7fffffffe288, gradient=..., keep_graph=false, create_graph=false) at /opt/honeycomb/libtorch/include/ATen/core/TensorMethods.h:53
#10 0x00005555555f3bdb in main () at recsys.cpp:471