You’re right, it eventually converges, but I don’t understand how the gradients for a single sample can push it in the wrong direction for multiple steps.
My basic understanding of SGD says that with a single sample, the gradients should be guaranteed to be in the right direction. So how is this happening?
Here is my model:
class Net(torch.nn.Module):
def __init__(self, num_node_features):
super().__init__()
self.num_node_features = num_node_features
self.alpha = torch.nn.Parameter(torch.randn(1))
self.linear_1 = torch.nn.Linear(num_node_features, num_node_features * 2)
self.linear_2 = torch.nn.Linear(self.linear_1.out_features, self.linear_1.out_features)
self.linear_3 = torch.nn.Linear(self.linear_2.out_features, 1)
self.linear_4 = torch.nn.Linear(self.linear_2.out_features * 2, self.linear_2.out_features)
self.linear_5 = torch.nn.Linear(self.linear_2.out_features, self.linear_1.in_features)
self.linear_6 = torch.nn.Linear(self.linear_1.in_features, 1)
def forward(self, protein_ligand_pair):
graph_embeddings = []
for molecule in protein_ligand_pair:
#These are actually batches of matrices.
node_matrix, adjacency_matrix = molecule
propagations = 5
for step in range(propagations):
smoothed_node_matrix = torch.matmul(adjacency_matrix, node_matrix)
node_matrix = self.alpha*node_matrix + (1-self.alpha)*smoothed_node_matrix
batch_size, num_nodes, num_input_features = node_matrix.shape
new_node_matrix = torch.empty(batch_size, num_nodes, self.linear_2.in_features, device = self.alpha.device)
for node_i in range(num_nodes):
linear_layer = self.linear_1 if step == 0 else self.linear_2
node_features = node_matrix[:, node_i]
new_node_matrix[:, node_i] = torch.nn.ReLU()(linear_layer(node_features))
node_matrix = new_node_matrix
num_nodes, num_features = node_matrix.shape[1:]
aggregation_weights = torch.empty(batch_size, num_nodes, device = self.alpha.device)
for node_i in range(num_nodes):
node_features = node_matrix[:, node_i]
aggregation_weights[:, node_i] = torch.nn.ReLU()(self.linear_3(node_features))
aggregate = torch.matmul(aggregation_weights, node_matrix)
graph_embeddings.append(aggregate)
protein_ligand_concatenated = torch.cat(graph_embeddings, axis = 2)
last_hidden_output = torch.nn.Sequential(self.linear_4, torch.nn.ReLU(), self.linear_5, torch.nn.ReLU())(protein_ligand_concatenated)
ligand_protein_affinity = torch.nn.Sigmoid()(self.linear_6(last_hidden_output)).flatten()
return ligand_protein_affinity
Some context for the “for step in range(propagations)
” loop: How does applying the same convolutional layer to its own output affect learning?