Hi!
I am implementing a custom model for spatio-temporal graph data and I want it to be based on RNNs (of any kind).
My data is of shape (b, n_nodes, n_timesteps, n_features).
For example, if I have a classic RNN in my model, I would want it to do something among the lines of:
for t in T:
h_t = spatial_operator(tanh(w_ih x_t + w_hh h_t-1))
instead of:
for t in T:
h_t = tanh(w_ih x_t + w_hh h_t-1)
Is there a way to “insert” the spatial operator in the computation of the hidden state of the existing RNNs? If yes, what do you think is the easiest way to do so?
Edit: A more detailed explanation and example.
So I want to do the regular hidden state computation and then take that same value through the spatial operator, and then go to the next step, and repeat.
Say I have 2 previous timesteps (x_0 and x_1) and I want to predict 1 future time steps (x_2)
encoder - I want to generate an aggregate representation over the previous time steps
_, h_0 = rnn(0, x_0)
h_0_sp = spatial_operator(h_0) # so basically this is applied after the calculation of the hidden state for the first time step
_, h_1 = rnn(h_0_sp, x_1)
h_1_sp = spatial_operator(h_1) # same as above -- this is the latent representation of the input interval
decoder - I want to decore the next output states from the latent representational
_, h_2 = rnn(h_1_sp, ...)
h_2_sp = spatial_operator(h_2)
y_hat_2 = sigmoid(mlp(h_2_sp))
_, h_3 = rnn(h_2_sp, ...)
h_3_sp = spatial_operator(h_3)
y_hat_3 = sigmoid(mlp(h_3_sp))
Problem is that torch.nn.RNN takes an input of shape (N, L, Hin), so when I call
_, h = rnn(x)
I can’t add the spatial operator after each hidden state computation because the data is batched along the temporal dimension and it works like a black box.
Anything I can do in this case before re-implementing my own RNN? Is there something I can reuse? Would it be a good idea to try to re-create torch.nn.RNN to work with an additional operator, or would it be too much of a hassle and a waste of time?