Execution time on CPU

Hi,

I’m struggling with the CPU version of pytorch.
I did learn my network on GPU, and when I apply it with GPU I have no problem.

I would like to execute it on CPU. My problems is that some time I get a respons in like 30s, and for the same network but further in the learning, it increase up to 3 min and more.
I checked the different layers, and it seems that it come from the full connected layers, which go from 10s to 2min to apply.

My problem is not really the time of execution, but the difference of time beetween to checkpoints of the same model.

For information:
pytorch 0.3.1 (install with pip)
python 3.5.2

The part of the code that create/apply the network is this one, extracted from https://github.com/ZijunDeng/pytorch-semantic-segmentation :


class FCN8s(nn.Module):
    def __init__(self, num_classes):
        super(FCN8s, self).__init__()
        vgg = models.vgg16()
        features, classifier = list(vgg.features.children()), list(vgg.classifier.children())
        features[0].padding = (100, 100)

        for f in features:
            if 'MaxPool' in f.__class__.__name__:
                f.ceil_mode = True
            elif 'ReLU' in f.__class__.__name__:
                f.inplace = True

        self.features3 = nn.Sequential(*features[: 17])
        self.features4 = nn.Sequential(*features[17: 24])
        self.features5 = nn.Sequential(*features[24:])

        self.score_pool3 = nn.Conv2d(256, num_classes, kernel_size=1)
        self.score_pool4 = nn.Conv2d(512, num_classes, kernel_size=1)
        self.score_pool3.weight.data.zero_()
        self.score_pool3.bias.data.zero_()
        self.score_pool4.weight.data.zero_()
        self.score_pool4.bias.data.zero_()

        fc6 = nn.Conv2d(512, 4096, kernel_size=7)
        fc6.weight.data.copy_(classifier[0].weight.data.view(4096, 512, 7, 7))
        fc6.bias.data.copy_(classifier[0].bias.data)
        fc7 = nn.Conv2d(4096, 4096, kernel_size=1)
        fc7.weight.data.copy_(classifier[3].weight.data.view(4096, 4096, 1, 1))
        fc7.bias.data.copy_(classifier[3].bias.data)
        score_fr = nn.Conv2d(4096, num_classes, kernel_size=1)
        score_fr.weight.data.zero_()
        score_fr.bias.data.zero_()
        self.score_fr = nn.Sequential(
            fc6, nn.ReLU(inplace=True), nn.Dropout(), fc7, nn.ReLU(inplace=True), nn.Dropout(), score_fr
        )

        self.upscore2 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, bias=False)
        self.upscore_pool4 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, bias=False)
        self.upscore8 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=16, stride=8, bias=False)
        self.upscore2.weight.data.copy_(get_upsampling_weight(num_classes, num_classes, 4))
        self.upscore_pool4.weight.data.copy_(get_upsampling_weight(num_classes, num_classes, 4))
        self.upscore8.weight.data.copy_(get_upsampling_weight(num_classes, num_classes, 16))

    def forward(self, x):
        x_size = x.size()
        pool3 = self.features3(x)
        pool4 = self.features4(pool3)
        pool5 = self.features5(pool4)

        score_fr = self.score_fr(pool5)
        upscore2 = self.upscore2(score_fr)

        score_pool4 = self.score_pool4(0.01 * pool4)
        upscore_pool4 = self.upscore_pool4(
            score_pool4[:, :, 5: (5 + upscore2.size()[2]), 5: (5 + upscore2.size()[3])]
            + upscore2)

        score_pool3 = self.score_pool3(0.0001 * pool3)
        upscore8 = self.upscore8(
        score_pool3[:, :, 9: (9 + upscore_pool4.size()[2]), 9: (9 + upscore_pool4.size()[3])]
            + upscore_pool4)
        return upscore8[:, :, 31: (31 + x_size[2]), 31: (31 + x_size[3])].contiguous()

Is there sommething with the way I apply the network ?
Thanks

I am not sure at all but slicing your data a bit everywhere may slow down your execution, especially if the data is spread on your GPU.

Thanks for your reply,

If you’re talking about the slices in the init, I think it’s juste a way to define my operators and doesn’t affect the calculus (but I may be wrong).
In the forward pass, it may slow down but I didn’t get problems with this part of the network.
I realy got problems with self.features3 and self.score_fr. They go, on CPU, from a couple of seconds to severals minutes.

I don’t know then, they seem fine to me. Maybe someone else will find the issue :confused: are you using eval and train mode for your model ?

Yes, while testing it’s in eval mode.

I think it’s something related to my installation. I’ll try different modification since I’m working on docker too.

All the methods using copy_ take some time to execute, especially on CPU this might be slow. Also the nn.Sequential() modules are slower than just executing them on the forward pass. I think this is due to some overhead that needs to be created when executing the Sequential module.
The extra time might come from accessing memory. Training on your CPU is slow because you are storing a lot of things in RAM because of the copy_ s and fetching data from RAM is very slow. On GPU you do not notice how slow it is because VRAM is much faster than RAM.

This are just some quick conclusions I drew by looking at your code, but I might be wrong. This is obviously added to the fact that training on CPU is much slower than on GPU due to the data parallelism that can only be achieved using GPU.

Thnaks, I’ll look for what you said.

I just want to do the inference on CPU, not the training. But what you said seems coherent for inference too.

Oh I didn’t read that, sorry. Then what you really need to do is is to set the input Variable or any parameter in the model as volatile which makes sure to use the least amount of memory possible. This is used for inference mode since you don’t even need to call backward() for inference. You can use it like this:

input = Variable(torch.randn(1, 3, 227, 227), volatile=True) 
 model(input)

Something like this might improve your speed considerably. Although what I said before still applies.

Hi,
My variables were already on volatile.

I’m gonna try to squeeze the sequential, which is used only to simplify the code.

The time spend is a problem too, but it’s related to this issue on github : https://github.com/pytorch/pytorch/issues/4703

My first problem is still the same. Whene using twice the same data, the same network architecture but two different checkpoints, my runtime change a lot.

I’m trying different changes related to the limits of my CPU, like using cropped images or smaller network. The problem is still the same for now

I had the same problem, the model after 4 iterations took 10 times longer to inference in CPU compare to the model after 2 iterations, and I found that in the convolution layer of all residual blocks, the weights got extremely small, in my case ~ 10^-27 and as a result it took much longer time to calculate. It got too small because I trained it very long time. So I just need to revert to the previous model and train with a smaller number of epochs and it worked.

If you are running into performance issues with these small numbers, you might try to use torch.set_flush_denormal(True) to disable denormal floating point numbers on the CPU.