Significant improvement in time with GPU as compared to CPU not being achieved

Hi,

I am training a 7 layers linear fully connected network with 48 neurons in each hidden layer (which gives 14353 learnable parameters). My data has 3 input features and 1 output. Data size is around 51230. I am using Dataloader with 20 batches. However the time improvement from CPU to GPU is only of 30-40% reduction in training time. After experimentation, i have noticed that GPU would only give significant time improvement if the total number of learnable parameters in significantly increased, say to the order of millions (then we can have training time reduced around 7 times). Can we not achieve significant benefit from GPU for NN model with 14353 parameters?

Overall if i train it for 200 epochs, these are time comparisons:
For 14k model parameters:
CPU: 5.7 min
GPU: 4.1 min

For 3.54 millions model parameters:
CPU: 31.0 min
GPU: 4.27 min

Is there any other way i can reduce my training time for around 14k model parameters?

If you are working with a small dataset, you could preload the whole dataset and push it to the GPU before training, to avoid loading times. Also, make sure not to create any unnecessary synchronization points in your training loop, e.g. printing the loss often.

Since the whole data has been transferred to GPU before creating dataset and train loader, i think it’s not having any loading time issues. pin_memory wouldn’t work for data already on GPU, if that’s what you’re suggesting. Also i have tried to minimize loss printing. Here are major parts of the code for clarity.

class Net(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(Net, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, H)
        self.linear3 = torch.nn.Linear(H, H)
        self.linear4 = torch.nn.Linear(H, H)
        self.linear5 = torch.nn.Linear(H, H)
        self.linear6 = torch.nn.Linear(H, H)
        self.linear7 = torch.nn.Linear(H, H)
        self.linear8 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        out = F.relu(self.linear1(x))
        out = F.relu(self.linear2(out))
        out = F.relu(self.linear3(out))
        out = F.relu(self.linear4(out))
        out = F.relu(self.linear5(out))
        out = F.relu(self.linear6(out))
        out = F.relu(self.linear7(out))
        y_pred = self.linear8(out)
        return y_pred

D_in, H, D_out = 3, 768, 1
model = Net(D_in, H, D_out)
criterion = nn.MSELoss(reduction='sum')
optimizer = Adam(model.parameters(), lr=5e-4)

device = torch.device('cuda')
model.to(device)

dataX = dataX.to(device)
dataY = dataY.to(device)
dataset = TensorDataset(dataX,dataY)
training_batches = 20
batch_size_train = int(len(dataX)/training_batches) +1
train_loader = DataLoader(dataset, batch_size=batch_size_train, shuffle=True)

start_time = time.time()
for epoch in range(201):
    running_loss = 0.0
    for i, data in enumerate(train_loader):
        features, target = data
        optimizer.zero_grad()
        forward = model(features)
        loss = criterion(forward, target)
        if epoch % 100 == 0:
            running_loss += loss.item()
        loss.backward()
        optimizer.step()
    if epoch % 100 == 0:
        print('Epoch: {},    Training Loss: {:.2e}'.format(epoch, running_loss/training_batches))

elapsed = time.time() - start_time
print('GPU Time: {:.2f} min'.format(elapsed/60))

I’ve added torch.cuda.synchronize() before starting and stopping the timer and used some dummy data:

dataX = torch.randn(1000, 3).to(device)
dataY = torch.randn(1000, 1).to(device)

CPU: 0.44854 seconds/epoch
GPU (TitanV): 0.0877 seconds/epoch

Based on your estimate (4.1 minutes for 200 epochs), it seems that each epoch takes (4.1*60)/200 = 1.22seconds on your GPU.

Which GPU, CUDA, cudnn, PyTorch versions are you using?

I have tried the same dummy data with 1000 samples and the GPU i am using would do it in 0.117 seconds/epoch.

GPU: Tesla P100-SXM2-16GB
CUDA driver version: 10010
cudnn: 7.5.1_10.1
PyTorch: 1.1.0

Can the issue be with data i am loading from files? I am converting it to FloatTensor from pandas dataframe using command

dataX = torch.tensor(dataX.values).to(dtype=torch.float)

Data is available in .csv text file. Number are mentioned in this format.

2.489500679193377281e-03,0.000000000000000000e+00,3.944573057764559962e+02,1.833216381585000068e-02

Thanks.

This shouldn’t make a speed difference, but could you try to use torch.from_numpy instead of torch.tensor to create dataX?

torch.from_numpy didn’t make any difference. Exactly same time as before.

If I understand the issue correctly at the moment, you are seeing a time of 0.117 s/epoch using random data on the GPU and 1.22 s/epoch if you use your real data?

We are aware of denormal values, which might slow down the execution on the CPU, but this shouldn’t be the case on the GPU.
Could you nevertheless set torch.set_flush_denormal(True) and time it again?

0.117 s/epoch was time for 1000 samples of random dummy data, while my real data size is 51320 and 1.22 s/epoch corresponds to that.

My concern is with the comparison of CPU & GPU time for this real data of size 51320, in which i am not getting significant reduction in training time with GPU. (Also if i increase the size for random dummy data to 51320, times are almost the same)

For 14k model parameters: CPU: 5.7min GPU: 4.1min

I am not sure if i am doing everything i can, to get the maximum benefit from GPU.
Isn’t GPU benefit significant for model with only 14k parameters?

Unfortunately torch.set_flush_denormal(True) didn’t make any effect.

If you change the data size to 51320, the batch size will increase to 2566, as you are calculating it dynamically using:

training_batches = 20
batch_size_train = int(len(dataX)/training_batches)

Running the script with this batch size gives:
GPU: 0.44s/epoch
CPU: 3.27s/epoch

Thanks a lot @ptrblck_de for your help throughout the discussion.

Here’s an enigma. Although GPU Tesla P100 provide better times for data size of 1000, the training time increases rapidly by just increasing the data size, as corresponding times are mentioned below:

1k  :  0.117 s/epoch
11k :  0.327 s/epoch
21k :  0.546 s/epoch
31k :  0.762 s/epoch
41k :  0.981 s/epoch
51k :  1.206 s/epoch

Do you understand why there’s so much difference in time increase for two different GPUs?

Also i tried a code from a research paper written with tensorflow. Having made no changes to it, training time on Tesla P100 was 6.3 min while it’s mentioned in the paper that NVIDIA Titan X did it in 1 min. I can’t figure any reason since i think P100 is more powerful.

I can rerun these tests tomorrow on a few GPUs and report some numbers.
In the meantime, could you update PyTorch to the latest stable release (1.2.0) so that we get comparable numbers?

I am unable to update it to version 1.2.0. I have tried updating it using these commands:

conda update pytorch
conda install pytorch=1.2.0 -c pytorch

I am running it on university’s computing resources, where i access GPU remotely and i think i am not allowed to update packages freely.

If you can run docker, you could try to run the code inside it.

Here are my results:
P100, 16GB

N = 1000
GPU Time: 0.076056s/epoch
N = 11000
GPU Time: 0.148172s/epoch
N = 21000
GPU Time: 0.220223s/epoch
N = 31000
GPU Time: 0.299151s/epoch
N = 41000
GPU Time: 0.368989s/epoch
N = 51000
GPU Time: 0.448879s/epoch

V100, 16GB

N = 1000
GPU Time: 0.081930s/epoch
N = 11000
GPU Time: 0.147144s/epoch
N = 21000
GPU Time: 0.254728s/epoch
N = 31000
GPU Time: 0.282772s/epoch
N = 41000
GPU Time: 0.382091s/epoch
N = 51000
GPU Time: 0.440113s/epoch

Note that I just executed the code once for each config.

So that’s a concern. Your results are as expected. I am not sure why they aren’t like that for me. What’s the clock speed on your GPUs?

On a unrelated note, I think it worth mentioning that these many linear layers in sequence is not advisable in most cases. You will be probably better of using two linear layers (in total) and playing with the number of neurons in the first linear layer.

If you want more sensible comparisons I’d fix the batch size at the same size regardless of the dataset size . I’m not sure why you’d want a fixed number of batches per epoch. Not doign this will result in changing GPU utilization with different dataset sizes .

I wouldnt’ be suprised if the performance, even on GPU is CPU bound due to dataset/dataloader python overhead. The call __getitem__ per batch element of many many thousands will chew up a lot of CPU time when done in Python. I’d try building a random index and building batches from X, Y tensors manually.

So i have worked with the administrators of computing resources and turns out that P100 GPU i was using wasn’t performing optimally for some reasons. Same GPU on another system gave comparable performance as you mentioned.

Many thanks for your time and support.

1 Like

It doesn’t make much difference. Except increasing the number of batches for same data size increases the computing time. Also whole dataset is already on GPU before training starts.