Concatenate tensors of different shape for LSTM training

I am struggling with the following situation: I have to train a LSTM to generate series of bank transactions, and to do that I would also like to insert in the LSTM some information about the subject performing the operations. My ultimate goal, after the training, would be to feed the LSTM with a vector containing the info about a subject, possibly a first operation, and then generate a sequence of operations.

Now, my doubt is the following: since the information about the subject is a 1-row tensor, while the sequence of operations (of variable length) is of multiple rows and different features, therefore the two tensors have different shapes. How can they be concatenated together and how should they be fed into the network?

Let’s say I have:


subj_info = torch.tensor([26., 0., 1., 0.])   # tensor containing a bunch of info about the user

operation_series = tensor([[0.0000e+00, 3.3305e-04, 1.6129e-02, 0.0000e+00],
                           [0.0000e+00, 3.3305e-04, 1.6129e-02, 0.0000e+00],
                           [0.0000e+00, 4.6638e-04, 2.2581e-02, 0.0000e+00],
                           [0.0000e+00, 3.3305e-04, 1.6129e-02, 0.0000e+00],
                           [0.0000e+00, 3.3305e-04, 1.6129e-02, 0.0000e+00],
                           [0.0000e+00, 3.3305e-04, 1.6129e-02, 0.0000e+00],
                           [1.0000e+00, 2.6664e-03, 0.0000e+00, 1.0000e+00],
                           [0.0000e+00, 3.3305e-04, 1.6129e-02, 0.0000e+00],
                           [1.0000e+00, 1.9997e-03, 0.0000e+00, 1.0000e+00],
                           [1.0000e+00, 3.4416e-04, 0.0000e+00, 1.0000e+00],
                           [1.0000e+00, 6.6638e-04, 0.0000e+00, 1.0000e+00],
                           [1.0000e+00, 9.9972e-04, 0.0000e+00, 1.0000e+00],
                           [0.0000e+00, 3.3305e-04, 1.6129e-02, 0.0000e+00],
                           [0.0000e+00, 3.3305e-04, 1.6129e-02, 0.0000e+00])
             # 2D tensor with l operations, each with n features

I would like to concatenate the two tensors to feed them into the LSTM, so that the network learns the sequences but also the info associated to the subject.

I already tried:

torch.cat([subj_info.unsqueeze(0), operation_series.unsqueeze(0)], dim=0) 

but it doesn’t work because they have different shapes, not even creating a new dimension and concatenating along that, and neither torch.stack did the trick for me. Am I doing something wrong with the dimensions of the tensors?

For the moment I am concatenating the subj_info to the whole list of operations so my input data is:



input = tensor([[26., 0., 1., 0., 0.0000e+00, 3.3305e-04, 1.6129e-02, 0.0000e+00], 
               [26., 0., 1., 0., 0.0000e+00, 3.3305e-04, 1.6129e-02, 0.0000e+00],
               [26., 0., 1., 0., 0.0000e+00, 4.6638e-04, 2.2581e-02, 0.0000e+00],
               [26., 0., 1., 0., 0.0000e+00, 3.3305e-04, 1.6129e-02, 0.0000e+00],
               [26., 0., 1., 0., 0.0000e+00, 3.3305e-04, 1.6129e-02, 0.0000e+00],
               [26., 0., 1., 0., 0.0000e+00, 3.3305e-04, 1.6129e-02, 0.0000e+00],
               [26., 0., 1., 0., 1.0000e+00, 2.6664e-03, 0.0000e+00, 1.0000e+00],
               [26., 0., 1., 0., 0.0000e+00, 3.3305e-04, 1.6129e-02, 0.0000e+00],
               [26., 0., 1., 0., 1.0000e+00, 1.9997e-03, 0.0000e+00, 1.0000e+00],
               [26., 0., 1., 0., 1.0000e+00, 3.4416e-04, 0.0000e+00, 1.0000e+00],
               [26., 0., 1., 0., 1.0000e+00, 6.6638e-04, 0.0000e+00, 1.0000e+00],
               [26., 0., 1., 0., 1.0000e+00, 9.9972e-04, 0.0000e+00, 1.0000e+00],
               [26., 0., 1., 0., 0.0000e+00, 3.3305e-04, 1.6129e-02, 0.0000e+00],
               [26., 0., 1., 0., 0.0000e+00, 3.3305e-04, 1.6129e-02, 0.0000e+00])
          

but I don’t think this is optimal, because the LSTM won’t learn correctly the features about the operations.

Moreover, let’s say I manage to concatenate them so that I have something like:

tensor[ [subj_info] ,[ [....]
                       [....]
                       [....]
                         ..
                       [....] ] ]

How should I use such input for the LSTM if I want it to “focus” only on the operation sequence?

Thanks to everyone who can help.

Can anyone help? I can’t find a solution for this situation.

Your subj_info information is not a sequence and not specific to any time step. So why would you want to given the LSTM layer.

The in my opinion more intuitive choice would be to give only operation_series to the LSTM and the merge the output of the LSTM (i.e., the last hidden state) subj_info. So you forward() method could crudely look as follows:

def forward(self, operation_series, subj_info)
    ...
    # Probably good to have at least one linear layer before merging
    sub_info = self.fc(subj_info)
    ...
    outputs, (h, c) = self.lstm(operation_series)
    # Get last hidden state (assuming bidirectional=False for simplicity)
    h = h[-1]
    # Merge hidden state and subj_info
    X = torch.cat([h, subj_info], dim=...)
    # Push X to some more linear layers (optional with Dropout)
    ...


You are indeed right. My idea of putting the two things together was because I wanted the LSTM to learn the operation patterns based also on the subj_info because, after the training, my goal would be to use the network to generate new unseen sequences given the subj_info as input. In that case, I would have to loop for how many operations I want to generate and infer the network each time with the previous output, is that correct?

Oh, wait…now I understand your task: Given the information about a subject, you want to predict a sequence in operations. So subj_info is the input and operation_series is your target.

In terms of sequence-to-sequence tasks, this is a one-to-many task: one input subj_info, sequence of outputs (operation_series). Basically similar to image caption generation. Or like speech recognition; see the slide below:

So you basically need some encoder that converts subj_info into some internal representation, and use this representation as the initial hidden state for the LSTM decoder.

The Seq2Seq tutorial gives a good idea how to set this up. The only difference is that in your case the input is not a sequence but just your subj_info, so the encoder will look even simpler.

1 Like

Thank you, I very much appreciate your help!!

Hi @vdw , I take advantage of your kindness since you already went once through this thread. I am now advancing a bit with my project based on the input you gave me, which makes total sence since including the anagraphical information of the subject is essential for my purposes.
I have now some doubts which I was hoping you could help me clarify. In particular:

  • since operation_series is a variable-length sequence and its features are of various nature, some of them are numerical (amount of the operation, amount in cash etc.) while others are categorical (causal of the operation, country) and possibly also boolean ones (represented by 0/1 flags) is it possible for the LSTM to learn all of these different features? I was thinking of applying different activation functions to the output sequence considering the position of the features which I know, for example
def forward(): 

    ...
 
    output[:,:,:1] = self.sigmoid(output[:,:,:1])   # binary feature

    output[:,:,1:] = self.relu(output[:,:,1:3])       # 2 numerical features

    output[:,:,3:7] = self.softmax(output[:,:,3:7] ) # categorical feature
                                                     # one-hot-encoded in 4 columns

etc. Does it make sense to apply different activation functions to each column of the predicted sequence? Because my idea is that the sequences that I’m trying to generate are a mixture of regression and classification tasks, based on the nature of the feature.

  • The second doubt is a consequence of the first one, hence in the training phase should I use different loss functions on the input based on the column that contains such function? For example
for batch in dataloader:

        subj, op = batch
        subj = subj.to(device)
        op = op.to(device)
        
        output, (_,_) = net(subj, op[:, :-1, :])
        
        optimizer.zero_grad()

        loss1 = loss_bin_ce(output[:, :, :1], op[:, 1:, :1])   
        loss2 = loss_mse(output[:, :, 1:3], op[:, 1:, 1:3])
        loss3 = loss_cat_ce(output[:, :, 3:7], op[:, 1:, 3:7])
        loss = loss1 + loss2 + loss3
        
        loss.backward()
        optimizer.step()
  • To give you a wider idea, for the moment the structure that I am using is the following one:
Subj2Seq(
  (encoder): Encoder(
    (rnn): LSTM(num_features_subject, hidden_size, batch_first=True)     
                                       # LSTM that runs only once the subj
                                       # to get the initial hidden and cell state
  )
  (decoder): DecoderRNN(
    (lstm): LSTM(num_features_operations, 
                 hidden_size, 
                 num_layers=3, 
                 batch_first=True,  
                 dropout=0.2)
    (out): Linear(in_features=hidden_size, 
                  out_features=num_features_operations, 
                  bias=True)
    (sigmoid): Sigmoid()
    (softmax): Softmax(dim=None)
    (relu): ReLU()
  )
)

One last doubt I have regards the activation function after the LSTM and before the linear layer, should there be any? I was thinking of Tanh but I don’t understand if theoretically it makes sense to use it, or if it should go just after the output.

Sorry for the long post, I hope I won’t abuse of your patience (: I hope I was able to explain myself enough

While I never had this use case myself, I cannot immediately see why this wouldn’t work. The alternative would be to have 4 output layers (e.g., out_bin, out_num1, out_num2, out_cat), one for each predicted operation. However, this should yield to same result as having one out layer and consider the different predictions using slicing. I assume your model is training?

Well, I’m not sure why you use relu for the numerical features. Numerical features mean regression so you should not use any activation function here. Granted, if you only have positive values, it wouldn’t matter. Still a bad practice. Remember that the activation functions after the last layer are not there to introduce non-linearities (like for the hidden layer) but to yield suitable output values for the respective loss functions. In short, remove

output[:,:,1:] = self.relu(output[:,:,1:3])

There’s no need for an activation function after the LSTM layer as there is an activation function within the LSTM cell.

1 Like