Hi,
I’m trying to create a simple transformer.
Why can’t this code work I get an error here
// Define the Transformer model
#ifndef TRANSFORMER_MODEL_H
#define TRANSFORMER_MODEL_H
#include <torch/torch.h>
struct TransformerModel : torch::nn::Module
{
TransformerModel(int input_size, int num_heads, int num_layers, int hidden_dim)
{
// Define the input and output embeddings
embed = register_module("embed", torch::nn::Linear(input_size, hidden_dim));
// Create a multi-head self-attention layer
self_attn = register_module("self_attn", torch::nn::MultiheadAttention(hidden_dim, num_heads));
// Create a feedforward neural network
feedforward = register_module("feedforward", torch::nn::Sequential(
torch::nn::Linear(hidden_dim, hidden_dim),
torch::nn::ReLU(),
torch::nn::Linear(hidden_dim, hidden_dim)));
// Create the transformer layers
transformer_layers = register_module("transformer_layers", torch::nn::TransformerEncoderLayer(hidden_dim, num_heads));
// Create the transformer encoder
transformer_encoder = register_module("transformer_encoder", torch::nn::TransformerEncoder(transformer_layers, num_layers));
// Define the output layer
output_layer = register_module("output_layer", torch::nn::Linear(hidden_dim, 1));
}
torch::Tensor forward(torch::Tensor x)
{
// Embed the input data
x = embed->forward(x);
// Perform multi-head self-attention
//ERROR IS HERE
///(38,9): error C2679: binary '=': no operator found which takes a right-hand operand of type 'std::tuple<at::Tensor,at::Tensor>' (or there is no acceptable conversion)
x = self_attn->forward(x, x, x);
// Apply feedforward network
x = feedforward->forward(x);
// Apply the transformer encoder
x = transformer_encoder->forward(x);
// Apply the output layer
x = output_layer->forward(x);
return x;
}
torch::nn::Linear embed{nullptr};
torch::nn::MultiheadAttention self_attn = nullptr;
torch::nn::Sequential feedforward{nullptr};
torch::nn::TransformerEncoderLayer transformer_layers{nullptr};
torch::nn::TransformerEncoder transformer_encoder{nullptr};
torch::nn::Linear output_layer{nullptr};
};
#endif
Thanks in advance