Multiple corresponding batch-inputs to forward-pass

Hello,

I would like to pass two batches of corresponding/related data to the forward-function of my network.

The idea is to have some state-representation (batch x channels x width x height) = (32x4x84x84) and a set of some previously executed actions in the format (batch x nr_of_prev_actions) = (32x3), where state-representation X = (i,:,:,: ) corresponds to action-sequence Y = (i,: ) for all i.

Whether it is some sequence of actions or any other kind of sequence of numbers doesn’t really matter at this point. I just thought giving it a bit of context makes understanding the problem easier.

Yet, my function just takes batches of 1 element at a time and a corresponding sequence of 3 actions. So, I can conveniently match the single input batch element X against the single sequence of actions Y.
My current code looks as follows:

def forward(self, observation, previous_actions):
        
        observation = funct.relu(self.conv1(observation))
        observation = funct.relu(self.conv2(observation))
        
        observation = observation.view(1, self.flattened_size)  
        observation = t.cat((previous_actions, observation), 1) # concatenates actions and conv-output
        observation = funct.relu(self.fc1(observation))
        output = self.fc2(observation)
        
        return output

The idea is to attach a sequence Y to the flattened output of the convolutional layers.

To increase efficiency, I’d like to generalize the approach above and pass real batches of data to the forward pass (not just batches of 1 element at a time anymore).
Any thoughts on how I could realize that? I just don’t really get how the forward-pass could automatically match X = (i,:,:,: ) against Y = (i,: ) for all i.

Thanks in advance!

Daniel