TypeError: forward() missing 1 required positional argument: 'trg'

below is my complete code. I am doing a multi-step classification problem, which is to use Transformer to predict the stock market movement.


import torch
import torch.nn as nn
import math
import numpy as np
from skorch import NeuralNetClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert self.head_dim * heads == embed_size, "Embed size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split embedding into self.heads pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        query = query.reshape(N, query_len, self.heads, self.head_dim)

        #A query represents the current token that we want to focus on. 
        #It is used to calculate the similarity between this token and other tokens in the sequence. 
        #In the scaled dot-product attention, the similarity is measured as the dot product between the query and the keys.
        
        #A key represents other tokens in the sequence with respect to the current token (the query). 
        #The similarity between the query and each key is used to compute attention weights, 
        #which determine how much attention the current token should pay to other tokens in the sequence.
        
        #Value: A value represents the content of the other tokens in the sequence. 
        #Once the attention weights are computed using the queries and keys, they are used to compute a weighted sum of the value embeddings. 
        #This weighted sum represents the output of the attention mechanism for the current token (the query) and 
        #can be thought of as a summary of the most relevant information from other tokens in the sequence.
        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(query)

        #einsum: a powerful function for performing tensor operations with a compact notation.
        # n: batch size  q: query sequence length  h: number of attention heads  d: head dimension  k:key sequence length
        # l: value sequence length
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        
        #Padding tokens are special tokens added to a sequence to make all sequences in a batch have the same length. 
        #In natural language processing, sequences (sentences or documents) often have varying lengths, 
        #which can cause issues when processing them in a batch using deep learning models. 
        #To deal with this, we pad shorter sequences with padding tokens to make all sequences in the batch have the same length as the longest sequence.
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        # The softmax function is used to convert a vector of real numbers into a probability distribution. It ensures that the sum of the output elements is 1, and each element is in the range (0, 1).
        # dim=3 means that the softmax is applied along the last dimension of the input tensor
        # softmax(x_i) = exp(x_i) / Σ(exp(x_j))
        # here the attention weights for each query token sum to 1 across all key tokens.
        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
        
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        out = self.fc_out(out)
        return out
# the transformerblock class represents a single layer of the Transformer architecture
# a multi-head self-attention mechanism followed by a postion-wise feed-forward network.
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion, is_decoder =False):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        
        # layer normalization is used to stabilize the training process and improve the model's performance.
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        
        if is_decoder:
            self.enc_dec_attention = SelfAttention(embed_size, heads)
            self.norm3 = nn.LayerNorm(embed_size)
        
        
        # Expanding the input size allows the feed-forward network to learn more complex patterns and capture higher-level features in the data.
        # When we say that the input size is expanded by a factor of forward_expansion, it means that the hidden layer of the feed-forward network 
        # has forward_expansion times more neurons than the input layer. For example, if the input size is 512 and the forward_expansion factor is 4, 
        # the hidden layer will have 512 * 4 = 2048 neurons. 
        
        # ReLU (rectified linear unit)
        # ReLU(x) = max(0, x), which replaces all negative values in the tensor with zeros.
        # introduces non_linearity into the model, allowing it to learn complex relationships between the input features
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size),
        )

        self.dropout = nn.Dropout(dropout)
        
        self.is_decoder = is_decoder
    #In a residual connection (skip connections or shortcut connections), the output of a layer (or a group of layers) is added to the input of the same layer 
    #(or group of layers) before being passed to the next layer. This allows the network to learn residual functions,
    #i.e., the difference between the input and the output, which can help alleviate the vanishing gradient problem and improve the model's performance.
    def forward(self, value, key, query, trg_mask = None, src_mask = None):
        if self.is_decoder:
            attention = self.attention(query, query, query, trg_mask)
        else:
            attention = self.attention(value, key, query, trg_mask)
            
        x = self.dropout(self.norm1(attention + query)) #Residual Connections
        
        if self.is_decoder:
            enc_dec_attention = self.enc_dec_attention(value, key, x, src_mask)
            x = self.dropout(self.norm3(enc_dec_attention + x))
        
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x)) #Residual Connections
        return out
# to capture the relative positions of the tokens in the input sequences.
class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, max_length):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_length, embed_size)
        for pos in range(max_length):
            for i in range(0, embed_size, 2):
                #The (10000 ** ((2 * i) / embed_size)) term acts as a scaling factor to ensure that the positional encodings 
                #have a smooth distribution across the range of positions and dimensions.
                pe[pos, i] = math.sin(pos / (10000 ** ((2 * i) / embed_size)))
                pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1)) / embed_size)))
        #This is done to make the positional encodings compatible with the input embeddings tensor during the forward pass.
        self.pe = pe.unsqueeze(0).requires_grad_(False)
        #Positional encodings are not learnable parameters.
    def forward(self, x):
        x = x + self.pe[:, : x.size(1), :]
        return x
# responsible for processing the input sequences and generating a continuous representation that can be used by the decoder 
# to generate the target sequences.
class Encoder(nn.Module):
    def __init__(
        self,
        src_vocab_size,
        embed_size,
        num_layers,
        heads,
        forward_expansion,
        dropout,
        max_length,
    ):
        super(Encoder, self).__init__()
        #this layer converts input tokens into continuous vectors.
        self.src_word_embedding = nn.Embedding(src_vocab_size, embed_size)
        
        #this line creates an instance of the positionalencoding class to add positional information to the
        #input embeddings.
        self.src_position_embedding = PositionalEncoding(embed_size, max_length)

        #a list of transformerblock layers
        # each layer contains a self-attention mechanism and a position-wise feed-forward network
        self.encoder_layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    heads,
                    dropout=dropout,
                    forward_expansion=forward_expansion,
                )
                for _ in range(num_layers)
            ]
        )

        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask):
        x = self.dropout(self.src_position_embedding(self.src_word_embedding(x)))

        for layer in self.encoder_layers:
            x = layer(x, x, x, mask)

        return x
#The decoder is responsible for generating the target sequences, 
#given the continuous representation produced by the encoder. 
#The decoder consists of a stack of identical layers, each containing a self-attention mechanism, 
#an encoder-decoder attention mechanism, 
#and a position-wise feed-forward network.
class Decoder(nn.Module):
    def __init__(
        self,
        trg_vocab_size,
        embed_size,
        num_layers,
        heads,
        forward_expansion,
        dropout,
        max_length,
    ):
        super(Decoder, self).__init__()
        #This creates an embedding layer for the target tokens.
        #this layer converts target tokens into continuous vectors.
        self.trg_word_embedding = nn.Embedding(trg_vocab_size, embed_size)
        
        #this line creates an instance of the positionalencoding class to add positional 
        #information to the target embeddings.
        self.trg_position_embedding = PositionalEncoding(embed_size, max_length)

        self.decoder_layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    heads,
                    dropout=dropout,
                    forward_expansion=forward_expansion,
                    is_decoder=True,
                )
                for _ in range(num_layers)
            ]
        )
        #this line creates a linear layer that projects the ouput of the final decoder layer to the
        #target vocabulary size. 
        #This layer is used to generate the probability distribution over the target vocabulary for each position in the output sequence.
        self.fc_out = nn.Linear(embed_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, src_mask, trg_mask):
        #At each layer, the self-attention mechanism computes the attention scores between the target tokens and 
        #generates a new representation of the target sequence. The encoder-decoder attention mechanism computes 
        #the attention scores between this new representation and the output of the encoder to incorporate information 
        #from the input sequence. 
        x = self.dropout(self.trg_position_embedding(self.trg_word_embedding(x)))

        for layer in self.decoder_layers:
            x = layer(enc_out, enc_out, x, trg_mask, src_mask)

        out = self.fc_out(x)
        return out

class Transformer(nn.Module):
    def __init__(
        self,
        src_vocab_size,
        trg_vocab_size,
        src_pad_idx,
        trg_pad_idx,
        embed_size=256,
        num_layers=6,
        forward_expansion=4,
        heads=8,
        dropout=0.1,
        max_length=100,
    ):
        super(Transformer, self).__init__()

        self.encoder = Encoder(
            src_vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            max_length,
        )

        self.decoder = Decoder(
            trg_vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            max_length,
        )
        
        #padding indices for the source and target sequences. these are used to create masks
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx

    def make_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        return src_mask

    #The make_trg_mask function creates a mask for the target sequence. 
    #It first retrieves the shape of trg to obtain the batch size N and target length trg_len. 
    #It then creates a lower triangular matrix of ones with dimensions (trg_len, trg_len) and expands it to have the same batch size and extra dimension for compatibility with the attention mechanism.
    def make_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            N, 1, trg_len, trg_len
        )

        return trg_mask

    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)

        enc_src = self.encoder(src, src_mask)
        out = self.decoder(trg, enc_src, src_mask, trg_mask)

        return out
#%%
class TransformerClassification(nn.Module):
    def __init__(
        self,
        num_classes,
        src_vocab_size,
        src_pad_idx,
        trg_pad_idx,
        embed_size=256,
        num_layers=6,
        forward_expansion=4,
        heads=8,
        dropout=0.1,
        max_length=100,
    ):
        super().__init__()
        self.transformer = Transformer(
            src_vocab_size,
            num_classes,
            src_pad_idx,
            trg_pad_idx,
            embed_size,
            num_layers,
            forward_expansion,
            heads,
            dropout,
            max_length,
        )
        self.fc_out = nn.Linear(embed_size, num_classes)

    def forward(self, src, trg):
        out = self.transformer(src, trg)
        out = self.fc_out(out)
        return out

num_samples = 1000
sequence_length = 50
num_features = 5
num_classes = 3
output_sequence_length = 3

X = np.random.random((num_samples, sequence_length, num_features))
y = np.random.randint(0, num_classes, (num_samples, output_sequence_length))

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Create dataset
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.long)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = TransformerClassification(
    num_classes,
    num_features,
    src_pad_idx=0,
    trg_pad_idx=0,
    embed_size=50,
    num_layers=2,
    forward_expansion=2,
    heads=5,
    dropout=0.1,
    max_length=sequence_length,
).to(device)

class Seq2SeqClassifier(NeuralNetClassifier):
    def get_loss(self, y_pred, y_true, X=None, training=False):
            y_true = y_true.view(-1)
            y_pred = y_pred.view(-1, y_pred.shape[-1])
            return super().get_loss(y_pred, y_true, X=X, training=training)

net = Seq2SeqClassifier(
    model,
    criterion=nn.CrossEntropyLoss,
    optimizer=torch.optim.Adam,
    train_split=None,
    verbose=1,
    device=device,
)
# Add the start token to the beginning of each target sequence
y_train_input = torch.cat([torch.zeros((y_train.shape[0], 1), dtype=torch.long), y_train[:, :-1]], dim=1)
net.fit(X_train, {'y_true': y_train, 'y_input': y_train_input})
# Predictions
y_test_input = torch.cat([torch.zeros((y_test.shape[0], 1), dtype=torch.long), y_test[:, :-1]], dim=1)
y_test_input = y_test_input.to(device)
y_pred = net.predict(X_test, y_input=y_test_input)

accuracy = accuracy_score(y_test.view(-1).cpu().numpy(), y_pred.view(-1).cpu().numpy())
print(f"Accuracy: {accuracy:.2f}")

But I encountered the following error, TypeError: forward() missing 1 required positional argument: ‘trg’. Below are the entire trace:

Traceback (most recent call last):

File “C:\Users\86189.spyder-py3\Transformer4_30.py”, line 395, in
net.fit(X_train, y_train_input)

File “C:\Users\86189\anaconda3\lib\site-packages\skorch\classifier.py”, line 141, in fit
return super(NeuralNetClassifier, self).fit(X, y, **fit_params)

File “C:\Users\86189\anaconda3\lib\site-packages\skorch\net.py”, line 1230, in fit
self.partial_fit(X, y, **fit_params)

File “C:\Users\86189\anaconda3\lib\site-packages\skorch\net.py”, line 1189, in partial_fit
self.fit_loop(X, y, **fit_params)

File “C:\Users\86189\anaconda3\lib\site-packages\skorch\net.py”, line 1101, in fit_loop
self.run_single_epoch(iterator_train, training=True, prefix=“train”,

File “C:\Users\86189\anaconda3\lib\site-packages\skorch\net.py”, line 1137, in run_single_epoch
step = step_fn(batch, **fit_params)

File “C:\Users\86189\anaconda3\lib\site-packages\skorch\net.py”, line 1016, in train_step
self._step_optimizer(step_fn)

File “C:\Users\86189\anaconda3\lib\site-packages\skorch\net.py”, line 972, in _step_optimizer
optimizer.step(step_fn)

File “C:\Users\86189\anaconda3\lib\site-packages\torch\optim\optimizer.py”, line 140, in wrapper
out = func(*args, **kwargs)

File “C:\Users\86189\anaconda3\lib\site-packages\torch\optim\optimizer.py”, line 23, in _use_grad
ret = func(self, *args, **kwargs)

File “C:\Users\86189\anaconda3\lib\site-packages\torch\optim\adam.py”, line 183, in step
loss = closure()

File “C:\Users\86189\anaconda3\lib\site-packages\skorch\net.py”, line 1006, in step_fn
step = self.train_step_single(batch, **fit_params)

File “C:\Users\86189\anaconda3\lib\site-packages\skorch\net.py”, line 905, in train_step_single
y_pred = self.infer(Xi, **fit_params)

File “C:\Users\86189\anaconda3\lib\site-packages\skorch\net.py”, line 1427, in infer
return self.module_(x, **fit_params)

File “C:\Users\86189\anaconda3\lib\site-packages\torch\nn\modules\module.py”, line 1190, in _call_impl
return forward_call(*input, **kwargs)

TypeError: forward() missing 1 required positional argument: ‘trg’

Please post code snippets by wrapping them into here backticks ``` instead of posting screenshots as we won’t be able to copy/paste the code to debug it and neither will your code be indexed by the forum.

Sorry about the inconvenience, I have edited agian.