Complex modelling help?


I am very new to Pytorch and am undergoing my first major model construction project… As such, I am a bit lost and could use some guidance.

The model is as follows: We will have an LSTM network for each feature, and each of these networks will predict two things: The next feature at time t+1 and the target at time t+1. However, the target at time t+1 will be a shared prediction between all networks, whereas the feature prediction will be specific to that network. I have included a drawing to illustrate this. . Naturally following, the loss function will be a function of both predicted values, but I am hoping autograd can take care of this as if I just sum up all the appropriate parts it will change what is relevant through the computation graph.

This may be a really simple problem and a trivial question, but I haven’t seen an example of what I’m trying to do nor have I seen an easy way to connect multiple modules together. My intuition tells me that the x_n,t+1 predictions should be linear layers tacked on to the end of each specific module, and that the target prediction node should also be a linear layer but somehow connected to each module.

Thanks for any and all help, and apologies if this is question is too simple or over asked.

Do you really want separate LSTM networks for each feature? Why not just have a single LSTM that takes a vector containing (X_1,t, X_2,t, …X_n,t) as input at time t and predicts a vector containing (X_1,t+1, X_2,t+1, …X_n,t+1, target_t+1)?

One advantage of this method is that the LSTM will be able to use the values all of the features when predicting the next value of each one, rather than only being able to use X_1,t<=t when predicting X_1,t+1.

If you want to do it your way, then you can make an LSTM model per feature that predicts the next feature value (and I would suggest that each model should also try to predict target_t+1), then the outputs of all the LSTM models could be concatenated and fed into a Linear layer or a simple feed-forward network in order to produce the combined prediction for target_t+1.

For the loss you just add them up as you suggest and it works as expected.

global_loss = loss1 + loss2 + ...