Input to transformer as linear or matrix


I am very new to transformers and I am working with FNet by Rishikksh20.

I was looking at some videos about transformers and I see that the input after embeddings is usually a matrix. However, in this implementation, there’s a Linear layer instead of 2d Convolution. So is the input usually a vector or is there a way I can input a matrix? I am doing this:

data_model = PositionalEncoding(data_model)
data_model = data_model.flatten()

Let me know if I’m doing it correctly.

The nn.Linear layer can take any input sized Tensor as can be seen in the docs (Linear — PyTorch 1.9.0 documentation), the input has to be of shape (N,∗,H) where H is the number of input nodes in the layer, and * represent other dims. Find the shape of data_model (via data_model.shape) and the final dim is the number of nodes you need. (found via data_model.shape[-1])

So I am little confused with the implementation of positional embedding by PyTorch here

Let’s say I have an embedding of dimension 10 and I have 14 such embeddings. So my matrix will be something like (14X10). Now, if I want to get positional embedding of this matrix, I create an object and find the position embedding like this:

embed = PositionalEncoding(d_model=1)
res = embed(data)

but when I look at the shape, it’s this


torch.Size([10, 10, 14])

So how do I ensure that I’m using the positional embedding correctly?

So I was able to solve the embedding dimension issue. For the training, I have written this:

for epoch in range(epochs):
    loss = 0
'''Returns the signal to construct and embeddings'''
    for encoded_data in tqdm(train_loader): 
        signal = signal.squeeze()

'''Since the embeddings are of different shape (128,10) or (499,10), I parse a batch of 10 embeddings'''

        for i in range(0,encoded_data.shape[1],10):

            signal_train = torch.reshape()
            input_data = encoded_data[:,i:i+10,:].squeeze().cuda()
            signal_train = signal_train.cuda()
            signal_recon = train_model(input_data)

            loss = criterion(signal_recon, signal_train)
            loss += loss.item()

So I wanted to know if I am doing this correctly or you recommend any changes for faster training or more accurate training?