Training time gets slower and slower on CPU

Hi all,

I am facing an issue when training an autoencoder on CPU (I am designing a lab for students to be made on a platform with no GPU, but the problem I will describe does not happen on GPU).

After some point, the time of an epoch starts increasing a lot. I provide a minimal working example code to reproduce the issue. At the beginning, the loop “get the batch, forward, grad, and optim step” takes 0.25 seconds, and after some time, it can be 2 or 3 times higher. This is what the graph below shows:

I tried many things including:

The issue arises on “PyTorch 1.10.2 + CUDA 11.3”, on “PyTorch 1.10.2 cpu only” and on “PyTorch with CUDA 11.1 on Google colab”.

Here is a minimal working example to reproduce the issue:

# coding=utf-8

import argparse
import time
import gc
import torch
import torch.nn as nn

class AutoEncoder(nn.Module):
    def __init__(self, h=128, e=128):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, h, 5, 1, 0), nn.ReLU(),
            nn.Conv2d(h, h, 5, 1, 0), nn.ReLU(),
            nn.Conv2d(h, h, 4, 2, 0), nn.ReLU(),
            nn.Conv2d(h, h, 3, 2, 0), nn.ReLU(),
            nn.Conv2d(h, e, 5, 1, 0),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(e, h, 5, 1, 0), nn.ReLU(),
            nn.ConvTranspose2d(h, h, 3, 2, 0), nn.ReLU(),
            nn.ConvTranspose2d(h, h, 4, 2, 0), nn.ReLU(),
            nn.ConvTranspose2d(h, h, 5, 1, 0), nn.ReLU(),
            nn.ConvTranspose2d(h, 1, 5, 1, 0),
        )

    def forward(self, x):
        x = self.decoder(self.encoder(x))
        return x

model = AutoEncoder(32, 32)
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.002, weight_decay=1e-4)
train_images = torch.randn(16384, 1, 32, 32)
for epoch in range(40):
    for i, batch in enumerate(train_images.split(64)):
        tic = time.perf_counter()
        output = model(batch)
        loss = 0.5 * (output - batch).pow(2).sum() / batch.size(0)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        toc = time.perf_counter()
        gc.collect()
        del loss, output, batch

        print(f"epoch {epoch} {i} took {toc-tic:.2f} seconds")

Again, I have cleaned the code to show the issue. Do you have any idea? As I am designing a lab for students, I would like to avoid hacks like “reload the model every epoch to keep the training fast”.

Thanks for your help.

1 Like

Could you check if your workstation is reducing its clocks due to overheating etc.?
Based on your description I would probably start by profiling the system and make sure it can run at a high load for an extended period of time.

Thanks a lot for your reply. I have tried to monitor the frequency of the CPU. I don’t know if I did that properly, I used psutil.cpu_freq().current.

However, I have ran the same experiment on many different platforms and the behaviour is consistent across them. The reported frequency does not seem to change (but I don’t know if this how we can check if a processor is down-clocking). But since the starting point of increasing time is always 1500, it would be strange if it did.

On the graph below,

  • the x-axis is the number of minibatch used for the forward/backward/step,
  • the left y-axis is the time is take for one forward/backward/step
  • the right y-axis is the CPU frequency returned by psutil.cpu_freq().current.

This is strangely very consistent across plateforms, when 1500 minibatches have been forwarded, the time increases. This is not the case on the AMD processor. On the JupyterHub/Kubernetes, the trend is less obvious, but still, the time increases a bit as of 1500.

I also tried (not shown here) PyTorch 1.1, 1.6, 1.7, 1.8, and 1.9 all showing the same trend.

Does it rind any bell to you? I am less sure about the AMD thing, as I am on a server and don’t master what is going on. I will try to find a laptop with an AMD proc to master what is running on the machine.

Thanks for your help.

Hi,

Could you try adding torch.set_flush_denormal(True) at the top of your script to see if that fixes the issue?

3 Likes

Thanks a lot for your suggestion: with torch.set_flush_denormal(True), this has solved the issue. I ran each experiment multiple times and the results are consistent: no more increase in time after step 1500.

Thanks a lot.

Great to hear!

For future times, this happens when numbers gets very close to 0. Your CPU, to try and keep precision, has to do some fancy arithmetic and thus gets slowed down.
This happened to me before on “toy” examples where your model converges very well and thus a lot of numbers are very close to 0.

3 Likes