Unbatched input for nn.Transformer module possible?

Hi everyone,

I am experimenting with the Transformer model of PyTorch to implement an autoencoder for multivariate time series data. I currently however struggle to feed a single unbatched input sequence into the model.
In the documentation (Transformer — PyTorch 1.12 documentation) it is written that the dimension of src should be (S, E) for an unbatched input or (N, S, E) if batch_first=True, where N = batch size, S = sequence length and E = feature number / model dimension.

While the model runs without any problems if I include the batch dimension, I get the error:
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
if I leave it out. I looked at the implementation of the forward()-function of the nn.Transformer module and realized, that there actually is no option given for an unbatched input of size (S, E). Is this actually not supported (contrastive to the documentation)?

Below, you can find a full example code and an example multivariate input sequence consisting of three different sine waves. To change the input from batched to unbatched, you may just comment out lines 106, 107, 171, 172 and 173 (every line which contains np.newaxis)

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import Transformer
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(0)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def padding(src: Tensor, maxlen: int):
    '''
    Padding is done to ensure the same time series length per batch.
    Shorter times series are padded to the max length with zeros,
    which are later masked in the attention mechanism.
    '''
    src_pad = torch.zeros([maxlen,src.shape[1]])
    src_pad[:src.shape[0],:] = src
    return src_pad

def create_masks(src: Tensor, tgt: Tensor, maxlen: int):
    '''
    src_mask:
        This is just a dummy mask, as all positions are marked with
        zero.
    src_padding_mask:
        The positions with the value of True will be ignored
        while the position with the value of False will be unchanged.
    '''
    src_mask = torch.zeros((maxlen, maxlen),device=DEVICE).type(torch.bool)
    tgt_mask = generate_square_subsequent_mask(tgt.shape[0])
   
    src_padding_mask = torch.ones(maxlen, dtype=torch.bool)
    src_padding_mask[:src.shape[0]] = False
   
    tgt_padding_mask = torch.ones(maxlen, dtype=torch.bool)
    tgt_padding_mask[:tgt.shape[0]] = False
   
    memory_key_padding_mask = src_padding_mask
   
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask

class VarSeriesEmbedding(nn.Module):
    '''
    This class scales an input sequence, which is a tensor
    of one time step, containing all variables at this position,
    to the model dimension of the transformer (e.g. 512).
    This scaling is done with learnable parameters of a
    1-dimensional feed forward neural network.
    '''
    def  __init__(self,
                  num_variables: int,
                  emb_size: int):
        super(VarSeriesEmbedding, self).__init__()
        self.linear = nn.Linear(num_variables, emb_size, bias=True)
       
    def forward(self, varSeries: Tensor):
        return self.linear(varSeries)
   

class TransformerModel(nn.Module):
    def __init__(self,
                 emb_size: int,
                 num_src_variables: int,
                 num_tgt_variables: int,
                 num_encoder_layers: int = 6,
                 num_decoder_layers: int = 6,
                 nhead: int = 8,
                 dim_feedforward: int = 2048,
                 dropout: float = 0.1):
        super(TransformerModel, self).__init__()

        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout,
                                       batch_first=True)
       
        self.src_embedding = VarSeriesEmbedding(num_src_variables, emb_size)
        self.tgt_embedding = VarSeriesEmbedding(num_tgt_variables, emb_size)

    def init_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self,
                src_pad: Tensor,
                tgt_pad: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
       
        src_emb = torch.stack([self.src_embedding(i) for i in src_pad])
        tgt_emb = torch.stack([self.tgt_embedding(i) for i in tgt_pad])
               
        #src_emb = src_emb[np.newaxis,:,:]
        #tgt_emb = tgt_emb[np.newaxis,:,:]

        outs = self.transformer(src_emb,
                                tgt_emb,
                                src_mask,
                                tgt_mask,
                                None, # memory_mask
                                src_padding_mask,
                                tgt_padding_mask,
                                memory_key_padding_mask)
        return outs
   
## Create an example multivariate timeseries sample
# (with 3 variables and length 200)
t = np.arange(0, 200, 1)
T_1 = 50
T_2 = 30
T_3 = 10
x_1 = np.sin(t/T_1)
x_2 = np.sin(t/T_2)
x_3 = np.sin(t/T_3)
src = np.stack((x_1, x_2, x_3)).transpose()
src = torch.from_numpy(src)
src = src.type(torch.FloatTensor)

# Plot the timeseries
fig, (ax1, ax2, ax3) = plt.subplots(3,1)
fig.subplots_adjust(hspace=0.5)
ax1.plot(src[:,0])
ax2.plot(src[:,1])
ax3.plot(src[:,2])
plt.show()

# Add padding to the example timeseries
maxlen = 250
src_pad = padding(src, maxlen)

# Create the target sequence, shifted to the right
tgt = src[:-1,:]
tgt_pad = padding(tgt, maxlen)

# Define hyperparameters
emb_size = 512
num_src_variables = 3
num_tgt_variables = 3
nhead = 8
ffn_hidden_dim = 2048
batch_size = 128
num_encoder_layers = 6
num_decoder_layers = 6

# Initialize the model
transformer = TransformerModel(emb_size,
                               num_src_variables,
                               num_tgt_variables,
                               num_encoder_layers,
                               num_decoder_layers,
                               nhead,
                               ffn_hidden_dim)
transformer.init_parameters()

# Create the masks
src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask = create_masks(src_pad, tgt_pad, maxlen)

#src_padding_mask = src_padding_mask[np.newaxis,:]
#tgt_padding_mask = tgt_padding_mask[np.newaxis,:]
#memory_key_padding_mask = memory_key_padding_mask[np.newaxis,:]

# Feed example sequence into the transformer
output = transformer(src_pad, tgt_pad,
                     src_mask, tgt_mask,
                     src_padding_mask, tgt_padding_mask,
                     memory_key_padding_mask)

print("Done.")

Any help and clarification appreciated. Thank you in advance!

1 Like