Training with batch_size = 1, all outputs are the same and trains poorly

I am trying to train a network to output target values (between 0 and 1). I cannot batch my inputs, so I am using a batch size of 1. Since I don’t want the sum of the loss gradients of each example, but the gradient of the average loss, I am adding item_loss/num_items for each item to end up with an average epoch_loss and optimize that.

But it trains strangely and poorly. To illustrate, I will train with just 1 or 2 examples.

If I train with just 1 example (say y = .25) and the initial output is .15, for some reason the output values continue to increase after multiple epochs even after it passes the target, e.g. [.15, .18, .21, .225, .24, .256, .275, .31].

Here are training logs from a real example:

Outputs: 0.525374174118042       Targets: 0.7524919509887695     Item loss: 0.05158248543739319
Epoch: 1         Batch: 0        Batch loss: 0.0515825
Outputs: 0.5907765030860901      Targets: 0.7524919509887695     Item loss: 0.026151886209845543
Epoch: 2         Batch: 1        Batch loss: 0.0261519
Outputs: 0.6628296971321106      Targets: 0.7524919509887695     Item loss: 0.008039319887757301
Epoch: 3         Batch: 2        Batch loss: 0.0080393
Outputs: 0.735643744468689       Targets: 0.7524919509887695     Item loss: 0.0002838620566762984
Epoch: 4         Batch: 3        Batch loss: 0.0002839
Outputs: 0.7974085807800293      Targets: 0.7524919509887695     Item loss: 0.0020175036042928696
Epoch: 5         Batch: 4        Batch loss: 0.0020175
Outputs: 0.8372832536697388      Targets: 0.7524919509887695     Item loss: 0.007189564872533083
Epoch: 6         Batch: 5        Batch loss: 0.0071896
Outputs: 0.8560269474983215      Targets: 0.7524919509887695     Item loss: 0.010719495825469494
Epoch: 7         Batch: 6        Batch loss: 0.0107195
Outputs: 0.8599795699119568      Targets: 0.7524919509887695     Item loss: 0.011553588323295116
Epoch: 8         Batch: 7        Batch loss: 0.0115536
Outputs: 0.8537989258766174      Targets: 0.7524919509887695     Item loss: 0.010263103060424328
Epoch: 9         Batch: 8        Batch loss: 0.0102631
Outputs: 0.8402236700057983      Targets: 0.7524919509887695     Item loss: 0.007696854416280985
Epoch: 10        Batch: 9        Batch loss: 0.0076969
Outputs: 0.8212108612060547      Targets: 0.7524919509887695     Item loss: 0.0047222888097167015
Epoch: 11        Batch: 10       Batch loss: 0.0047223
Outputs: 0.7986907958984375      Targets: 0.7524919509887695     Item loss: 0.0021343333646655083
Epoch: 12        Batch: 11       Batch loss: 0.0021343
Outputs: 0.7748352289199829      Targets: 0.7524919509887695     Item loss: 0.0004992220783606172
Epoch: 13        Batch: 12       Batch loss: 0.0004992
Outputs: 0.7519184350967407      Targets: 0.7524919509887695     Item loss: 3.289204641987453e-07
Epoch: 14        Batch: 13       Batch loss: 0.0000003

If I train with just 2 examples (say y_1 = .25, y_2 = .7), for some reason after 1 or 2 epochs the outputs for both examples will always be the same, e.g. [(.354, .332), (.36, .36), (.38, .38), (.352, .352)], I haven no idea why. It doesn’t get close to either one of targets, but I suspect that it is optimizing for their average, (y_1 + y_2)/2.

Here are training logs from a real example:

Outputs: 0.6534302234649658      Targets: 0.12747164070606232    Item loss: 0.13831622898578644 
Outputs: 0.781857967376709       Targets: 0.6774895191192627     Item loss: 0.005446386523544788
Epoch: 1         Batch: 0        Batch loss: 0.1437626                                          
Outputs: 0.49351614713668823     Targets: 0.12747164070606232    Item loss: 0.06699429452419281 
Outputs: 0.49351614713668823     Targets: 0.6774895191192627     Item loss: 0.016923101618885994
Epoch: 2         Batch: 1        Batch loss: 0.0839174                                          
Outputs: 0.4287058711051941      Targets: 0.6774895191192627     Item loss: 0.030946651473641396
Outputs: 0.4287058711051941      Targets: 0.12747164070606232    Item loss: 0.04537103697657585 
Epoch: 3         Batch: 2        Batch loss: 0.0763177                                          
Outputs: 0.37630677223205566     Targets: 0.12747164070606232    Item loss: 0.030959460884332657
Outputs: 0.37630677223205566     Targets: 0.6774895191192627     Item loss: 0.045355524867773056
Epoch: 4         Batch: 3        Batch loss: 0.0763150                                          
Outputs: 0.3388264775276184      Targets: 0.6774895191192627     Item loss: 0.05734632909297943 
Outputs: 0.3388264775276184      Targets: 0.12747164070606232    Item loss: 0.022335434332489967
Epoch: 5         Batch: 4        Batch loss: 0.0796818                                          
Outputs: 0.31596532464027405     Targets: 0.6774895191192627     Item loss: 0.06534986943006516 
Outputs: 0.31596532464027405     Targets: 0.12747164070606232    Item loss: 0.01776493526995182 
Epoch: 6         Batch: 5        Batch loss: 0.0831148                                          
Outputs: 0.3056109547615051      Targets: 0.12747164070606232    Item loss: 0.015866806730628014
Outputs: 0.3056109547615051      Targets: 0.6774895191192627     Item loss: 0.06914683431386948 
Epoch: 7         Batch: 6        Batch loss: 0.0850136                                          
Outputs: 4.601355976774357e-05   Targets: 0.6774895191192627     Item loss: 0.2294648438692093  
Outputs: 0.013865623623132706    Targets: 0.12747164070606232    Item loss: 0.006453163921833038
Epoch: 8         Batch: 7        Batch loss: 0.2359180                                          
Outputs: 0.3054547607898712      Targets: 0.6774895191192627     Item loss: 0.06920493394136429 
Outputs: 0.3054547607898712      Targets: 0.12747164070606232    Item loss: 0.015838995575904846
Epoch: 9         Batch: 8        Batch loss: 0.0850439                                          
Outputs: 0.31345269083976746     Targets: 0.12747164070606232    Item loss: 0.017294475808739662
Outputs: 0.31345269083976746     Targets: 0.6774895191192627     Item loss: 0.0662614032626152  
Epoch: 10        Batch: 9        Batch loss: 0.0835559                                          
Outputs: 0.32778313755989075     Targets: 0.12747164070606232    Item loss: 0.020062347874045372
Outputs: 0.32778313755989075     Targets: 0.6774895191192627     Item loss: 0.06114727631211281 
Epoch: 11        Batch: 10       Batch loss: 0.0812096                                          
Outputs: 0.34695807099342346     Targets: 0.6774895191192627     Item loss: 0.05462551862001419 
Outputs: 0.34695807099342346     Targets: 0.12747164070606232    Item loss: 0.024087145924568176
Epoch: 12        Batch: 11       Batch loss: 0.0787127                                          

How can I fix these problems, what could I be doing wrong?

Here is my training code:

def train(training_dataloader, model, loss_func, optimizer, epochs, log_interval, save_as):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_items = len(training_dataloader.sampler.indices)

    for epoch in range(epochs):
        epoch_loss = 0.0
        for i, batch in enumerate(training_dataloader):
            a = torch.tensor(batch["a"]).float().to(device)
            b = torch.tensor(batch["b"]).float().to(device)
            inputs = [a,b]
            output_values = model(inputs)
            true_values = batch["c"].float().to(device)
            item_loss = loss_func(output_values, true_values) / num_items
            epoch_loss += item_loss

            if i % log_interval == 0:
                print(f"Outputs: {output_values.item()} \t Targets: {true_values.item()} \t Item loss: {item_loss}")

        optimizer.zero_grad()
        epoch_loss.backward()
        optimizer.step()

        if i % log_interval == 0:
            print(f"Epoch: { epoch + 1 } \t Epoch: {epoch} \t Epoch loss: {epoch_loss:.7f}")

I thought it might be this (Outputs from a simple DNN are always the same whatever the input is), but model.state_dict() suggests the weights and biases are all on a similar scale. @ptrblck could you lend a hand?

This doesn’t seem right, as a single sample should perfectly overfit, while your increasing output points towards a potential bug in the code.

However, your output shows some convergence in the end and the loss is 3.289204641987453e-07, which seems to be good.

Could you post your model definition, so that we can have a look?

1 Like

You’re right, it eventually converges, but I don’t understand how the gradients for a single sample can push it in the wrong direction for multiple steps.

My basic understanding of SGD says that with a single sample, the gradients should be guaranteed to be in the right direction. So how is this happening?

Here is my model:

class Net(torch.nn.Module):
    def __init__(self, num_node_features):
        super().__init__()
        self.num_node_features = num_node_features
        self.alpha = torch.nn.Parameter(torch.randn(1))
        self.linear_1 = torch.nn.Linear(num_node_features, num_node_features * 2)
        self.linear_2 = torch.nn.Linear(self.linear_1.out_features, self.linear_1.out_features)
        self.linear_3 = torch.nn.Linear(self.linear_2.out_features, 1)
        self.linear_4 = torch.nn.Linear(self.linear_2.out_features * 2, self.linear_2.out_features)
        self.linear_5 = torch.nn.Linear(self.linear_2.out_features, self.linear_1.in_features)
        self.linear_6 = torch.nn.Linear(self.linear_1.in_features, 1)

    def forward(self, protein_ligand_pair):
        graph_embeddings = []

        for molecule in protein_ligand_pair:
            #These are actually batches of matrices.
            node_matrix, adjacency_matrix = molecule
            propagations = 5
            for step in range(propagations):
                smoothed_node_matrix = torch.matmul(adjacency_matrix, node_matrix)
                node_matrix = self.alpha*node_matrix + (1-self.alpha)*smoothed_node_matrix

                batch_size, num_nodes, num_input_features = node_matrix.shape
                new_node_matrix = torch.empty(batch_size, num_nodes, self.linear_2.in_features, device = self.alpha.device)
                for node_i in range(num_nodes):
                    linear_layer = self.linear_1 if step == 0 else self.linear_2
                    node_features = node_matrix[:, node_i]
                    new_node_matrix[:, node_i] = torch.nn.ReLU()(linear_layer(node_features))
                node_matrix = new_node_matrix
            
            num_nodes, num_features = node_matrix.shape[1:]
            aggregation_weights = torch.empty(batch_size, num_nodes, device = self.alpha.device)
            for node_i in range(num_nodes):
                node_features = node_matrix[:, node_i]
                aggregation_weights[:, node_i] = torch.nn.ReLU()(self.linear_3(node_features))

            aggregate = torch.matmul(aggregation_weights, node_matrix)
            graph_embeddings.append(aggregate)
        
        protein_ligand_concatenated = torch.cat(graph_embeddings, axis = 2)
        last_hidden_output = torch.nn.Sequential(self.linear_4, torch.nn.ReLU(), self.linear_5, torch.nn.ReLU())(protein_ligand_concatenated)
        ligand_protein_affinity = torch.nn.Sigmoid()(self.linear_6(last_hidden_output)).flatten()

        return ligand_protein_affinity

Some context for the “for step in range(propagations)” loop: How does applying the same convolutional layer to its own output affect learning?

My opinion, the main reason of such behavior is you are doing one optimiztion step once per epoch. This way you are optimizing for the best “epoch-level” loss. So the loss will converge, but it is not guaranteed to have optimal weights on per-sample basis. Doing the one optimization step per epoch is like doing “batch gradient descent” with all your data as one batch. The opposite will be doing optimiztion step every sample, which is in your case will be “online” or “stochastic” gradient descent. The third way is to try “mini-batch” gradient descent - which is in your case to accumulate gradients for several samples (e.g. 10-20, you decide) and do optimization step after that. As you can see, it is somehow the way of regularizing you model - when optimizing once per epoch the model will have a hard time to learn a concept of single sample, but when optimizing after every sample it would be hard for the model to “see the forest behind the trees”. If none of that doesn’t seem to have some sense for you, I would suggest to familiarize yourself with a concepts of batch, stochastic and mini-batch gradient descent. The great explanations on the topic you can find from the popular Andrew Ng Machine Learning or Deep Learning Courses.

That doesn’t explain this situation where the entire dataset has exactly 1 sample. I’m more confident that something is wrong and it might be a bug in light of @ptrblck’s reply.

It could be bug, sure. But if I understand correctly, first log is where you train with only one sample and it converges perfectly. Did you try mini batch strategy, training for more epochs or maybe increasing the learning rate for batch strategy?

In the logs, the output starts below the target value, then the optimizer pushes the parameters to make the output increase and eventually it passes the target.

At this point, the optimizer should change the parameters so that the output will be in the opposite direction (decreasing it). Gradient descent should guarantee this decrease if the dataset has 1 sample only.

But instead, the output just keeps increasing and increasing for multiple steps before eventually coming back down and converging. I don’t understand how this could happen – optimizer makes it continue increasing for multiple steps after passing target and loss goes up and up – with only 1 sample.

Solved the case with only 1 sample! My expectations were based on SGD, but I was using Adam, which oscillates before converging.

But why with 2 samples in the dataset, does the network end up outputting the same strange value for both of them, instead of converging to each of their targets? I still don’t understand this.

I cleaned up and modularized my model and solved the case with multiple samples all having the same output. It had to do with me passing the original inputs into a linear layer instead of the convolved inputs.

Since all the inputs contain similar sets of elements, except the elements are related in different ways, passing the inputs to a linear layer and aggregating each of their output elements just results in the same/similar output for all of them.

When I apply a convolution to each input so that each element of the input is updated with information from the other elements it’s connected to, then those differences become apparent in the convolved inputs, the linear layers outputs diverge between the convolved inputs, and aggregating those outputs can lead to very different final outputs between the inputs depending on how the elements in the inputs were connected.

At least, that’s what I think is going on and it’s alright now.