I’m making a module and I expected to get 1 input (shape (2,2,3,3)) at a time. I just realized nn.Linear
expects a batch dimension, so I need to expect batches not individual inputs. How does nn.Linear
process batches and how can I process batches in my forward()? Would I just put everything in a loop over the batch elements?
This is what my forward() method looks like:
def forward(self, pair_of_graphs):
embeddings = []
for graph in pair_of_graphs:
node_matrix, adjacency_matrix = graph
steps = 5
for step in range(steps):
message_passed_node_matrix = torch.matmul(adjacency_matrix, node_matrix)
alpha = torch.nn.Parameter(torch.zeros(1))
node_matrix = alpha*node_matrix + (1-alpha)*messaged_passed_node_matrix
new_node_matrix = torch.zeros(len(node_matrix), self.linear_2.in_features)
for node_i in range(len(node_matrix)):
linear_layer = self.linear_1 if step == 0 else self.linear_2
new_node_matrix[node_i] = linear_layer(node_matrix[node_i])
node_matrix = new_node_matrix
weights_for_average = torch.zeros(len(node_matrix))
for node_i in range(len(node_matrix)):
weights_for_average[node_i] = self.linear_3(node_matrix[node_i])
weighted_sum_node = torch.matmul(weights_for_average, node_matrix)
embeddings.append(weighted_sum_node)
concat = torch.cat(embeddings)
out = self.linear_4(concat)
return out
Any other suggestions about how I am handling my input (e.g. looping through each node of each graph) would be appreciated too.