I’m making a module and I expected to get 1 input (shape (2,2,3,3)) at a time. I just realized
nn.Linear expects a batch dimension, so I need to expect batches not individual inputs. How does
nn.Linear process batches and how can I process batches in my forward()? Would I just put everything in a loop over the batch elements?
This is what my forward() method looks like:
def forward(self, pair_of_graphs): embeddings =  for graph in pair_of_graphs: node_matrix, adjacency_matrix = graph steps = 5 for step in range(steps): message_passed_node_matrix = torch.matmul(adjacency_matrix, node_matrix) alpha = torch.nn.Parameter(torch.zeros(1)) node_matrix = alpha*node_matrix + (1-alpha)*messaged_passed_node_matrix new_node_matrix = torch.zeros(len(node_matrix), self.linear_2.in_features) for node_i in range(len(node_matrix)): linear_layer = self.linear_1 if step == 0 else self.linear_2 new_node_matrix[node_i] = linear_layer(node_matrix[node_i]) node_matrix = new_node_matrix weights_for_average = torch.zeros(len(node_matrix)) for node_i in range(len(node_matrix)): weights_for_average[node_i] = self.linear_3(node_matrix[node_i]) weighted_sum_node = torch.matmul(weights_for_average, node_matrix) embeddings.append(weighted_sum_node) concat = torch.cat(embeddings) out = self.linear_4(concat) return out
Any other suggestions about how I am handling my input (e.g. looping through each node of each graph) would be appreciated too.