Keras trains significantly faster than PyTorch for simple network

@kpedro88
why don’t you drop the model into lightning and try that (i added the few lines to make that work)?
lightning handles a lot of these performance optimizations such as what you ran into with syncing.

I converted it and ran some quick profiling (all of this took about 10 mins to do end to end).
On CPU was able to get it about 2.3x faster… try on the GPU?

Lightning version

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, random_split
import torch.optim as optim
from tqdm import tqdm
import numpy as np

# import PL
import pytorch_lightning as pl
pl.seed_everything(0)


class Classifier(pl.LightningModule):

    def __init__(self, model, criterion):
        super().__init__()
        self.model = model
        self.criterion = criterion
 
    def training_step(self, data, batch_idx):
        # train loop
        event_data, target_data = data
        output = self.model(event_data)
        batch_loss = self.criterion(output, target_data)
        self.log('train_loss', batch_loss)
        return batch_loss
    
    def validation_step(self, data, batch_idx):
        # val loop
        event_data, target_data = data
        output = self.model(event_data)
        batch_loss = self.criterion(output, target_data)
        self.log('val_loss', batch_loss)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.model.parameters(), lr=0.001)
        return optimizer

# -------------------
# generate random data
# -------------------
N_train_class0 = N_train_class1 = 1_250_000
event_dim = 8

rng_seed = 0
rng = np.random.Generator(np.random.PCG64(rng_seed))

class0_events = rng.uniform(100,500,size=(N_train_class0, event_dim))
class0_ytarget = np.zeros(shape=(N_train_class0, 1))

class1_events = rng.uniform(100,500,size=(N_train_class1, event_dim))
class1_ytarget = np.zeros(shape=(N_train_class1, 1))

permutation = rng.permutation(N_train_class0 + N_train_class1)
dataset = TensorDataset(
    torch.Tensor(np.concatenate([class0_events, class1_events])[permutation]),
    torch.Tensor(np.concatenate([class0_ytarget, class1_ytarget])[permutation])
)

# prepare for training
pct_valid = 0.2
num_valid = int(len(dataset)*pct_valid)
split_train, split_valid = random_split(dataset, [len(dataset)-num_valid, num_valid])
batch_size = 5000
loader_train = DataLoader(split_train, batch_size=batch_size, shuffle=True)
loader_valid = DataLoader(split_valid, batch_size=batch_size, shuffle=True)


# --------------
# model
# --------------
def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.0)

hidden_node_counts = [128, 64, 64, 64, 32]
layers = [
    nn.Linear(in_features = event_dim, out_features = hidden_node_counts[0]),
    nn.ReLU()
]
for counter in range(len(hidden_node_counts)-1):
    layers.extend([
        nn.Linear(in_features = hidden_node_counts[counter], out_features = hidden_node_counts[counter+1]),
        nn.ReLU()
    ])
layers.extend([
    nn.Linear(in_features = hidden_node_counts[-1], out_features = 1),
    nn.Sigmoid()
])
model = nn.Sequential(*layers)

print(model)
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Trainable params: {}".format(total_params))


criterion = nn.BCELoss()

# -------------
# TRAIN with Lightning
# -------------
classifier = Classifier(model, criterion)
trainer = pl.Trainer(max_epochs=10, gpus=1)
trainer.fit(classifier, loader_train, loader_valid)

I ran it on CPUs (but you can run of gpus by setting gpus=1.

I think your bottleneck also comes from the dataloader… lightning gives you useful warnings when it detects issues.

/Users/williamfalcon/miniconda3/envs/flash/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:49: UserWarning: Your val_dataloader has `shuffle=True`, it is best practice to turn this off for validation and test dataloaders.
  warnings.warn(*args, **kwargs)
/Users/williamfalcon/miniconda3/envs/flash/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:49: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  warnings.warn(*args, **kwargs)
/Users/williamfalcon/miniconda3/envs/flash/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:49: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  warnings.warn(*args, **kwargs)
Epoch 1:  61%|███████████████████████████████████████████████████▋                                 | 304/500 [00:23<00:15, 12.85it/s, loss=8.94e-12, v_num=4

Profiling

You can also run with profiling to see where the speed issues are coming from.

(profiling might slow things down though…)

Trainer(profiler='pytorch')
---------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
---------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                          AddmmBackward        30.23%     763.912ms        30.23%     763.912ms       1.900ms           402  
                               aten::mm        29.64%     749.083ms        29.64%     749.083ms       1.016ms           737  
                            aten::addmm        16.95%     428.369ms        16.95%     428.369ms       1.066ms           402  
                             aten::relu         4.23%     106.824ms         4.23%     106.824ms     318.878us           335  
                        aten::threshold         4.02%     101.632ms         4.02%     101.632ms     303.378us           335  
                          ReluBackward0         3.97%     100.330ms         3.97%     100.330ms     299.493us           335  
               aten::threshold_backward         3.89%      98.345ms         3.89%      98.345ms     293.566us           335  
                              aten::sum         1.61%      40.591ms         1.61%      40.591ms      86.548us           469  
                            aten::copy_         1.47%      37.197ms         1.47%      37.197ms      46.265us           804  
                                aten::t         0.78%      19.632ms         0.78%      19.632ms      10.104us          1943  
             aten::binary_cross_entropy         0.36%       9.013ms         0.36%       9.013ms     134.518us            67  
        torch::autograd::AccumulateGrad         0.32%       8.046ms         0.32%       8.046ms      10.008us           804  
                        aten::transpose         0.30%       7.688ms         0.30%       7.688ms       3.957us          1943  
                            aten::empty         0.26%       6.545ms         0.26%       6.545ms       2.505us          2613  
                          aten::sigmoid         0.22%       5.542ms         0.22%       5.542ms      41.361us           134  
                             aten::add_         0.20%       5.042ms         0.20%       5.042ms       6.367us           792  
                       aten::as_strided         0.17%       4.377ms         0.17%       4.377ms       1.519us          2881  
                              TBackward         0.14%       3.533ms         0.14%       3.533ms       8.788us           402  
                             aten::view         0.14%       3.463ms         0.14%       3.463ms       8.613us           402  
                             aten::mean         0.11%       2.833ms         0.11%       2.833ms      42.280us            67  
---------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 2.527s

Or try any of the other profilers… here we use the “simple” profiler

Trainer(profiler='simple')
Profiler Report

Action                                  |  Mean duration (s)    |Num calls              |  Total time (s)       |  Percentage %         |
------------------------------------------------------------------------------------------------------------------------------------
Total                                   |  -                    |_                      |  37.739               |  100 %                |
------------------------------------------------------------------------------------------------------------------------------------
run_training_epoch                      |  37.493               |1                      |  37.493               |  99.348               |
get_train_batch                         |  0.053623             |400                    |  21.449               |  56.836               |
run_training_batch                      |  0.024157             |400                    |  9.6627               |  25.604               |
optimizer_step_and_closure_0            |  0.023673             |400                    |  9.4692               |  25.091               |
training_step_and_backward              |  0.022778             |400                    |  9.1112               |  24.143               |
model_backward                          |  0.013783             |400                    |  5.5132               |  14.609               |
model_forward                           |  0.0088057            |400                    |  3.5223               |  9.3333               |
training_step                           |  0.0085555            |400                    |  3.4222               |  9.068                |
evaluation_step_and_end                 |  0.0089042            |102                    |  0.90823              |  2.4066               |
validation_step                         |  0.0087426            |102                    |  0.89174              |  2.3629               |
on_train_batch_end                      |  0.00036939           |400                    |  0.14776              |  0.39152              |
cache_result                            |  1.859e-05            |2323                   |  0.043184             |  0.11443              |
on_validation_batch_end                 |  0.0002122            |102                    |  0.021644             |  0.057353             |
on_batch_start                          |  3.5443e-05           |400                    |  0.014177             |  0.037566             |
on_after_backward                       |  1.7219e-05           |400                    |  0.0068876            |  0.018251             |
on_batch_end                            |  1.3641e-05           |400                    |  0.0054563            |  0.014458             |
on_before_zero_grad                     |  1.2449e-05           |400                    |  0.0049796            |  0.013195             |
on_train_batch_start                    |  9.6932e-06           |400                    |  0.0038773            |  0.010274             |
training_step_end                       |  8.5934e-06           |400                    |  0.0034374            |  0.0091083            |
on_validation_end                       |  0.0016228            |2                      |  0.0032456            |  0.0086001            |
on_validation_batch_start               |  2.1877e-05           |102                    |  0.0022315            |  0.005913             |
validation_step_end                     |  8.1364e-06           |102                    |  0.00082992           |  0.0021991            |
on_train_end                            |  0.00024225           |1                      |  0.00024225           |  0.00064191           |
on_epoch_start                          |  0.00021004           |1                      |  0.00021004           |  0.00055657           |
on_validation_start                     |  7.725e-05            |2                      |  0.0001545            |  0.00040939           |
on_train_start                          |  0.00014138           |1                      |  0.00014138           |  0.00037461           |
on_validation_epoch_end                 |  1.4126e-05           |2                      |  2.8251e-05           |  7.4859e-05           |
on_epoch_end                            |  1.7333e-05           |1                      |  1.7333e-05           |  4.5929e-05           |
on_validation_epoch_start               |  7.854e-06            |2                      |  1.5708e-05           |  4.1623e-05           |
on_before_accelerator_backend_setup     |  1.2458e-05           |1                      |  1.2458e-05           |  3.3011e-05           |
on_train_epoch_end                      |  9.25e-06             |1                      |  9.25e-06             |  2.4511e-05           |
on_fit_start                            |  8.709e-06            |1                      |  8.709e-06            |  2.3077e-05           |
on_train_epoch_start                    |  8e-06                |1                      |  8e-06                |  2.1198e-05           |

logging

And finally, if you want more speed ups, turn off the logger (lightning automatically generates tensorboard files).

Trainer(logger=False)

Bonus

Notice that Lightning is much simpler (but very different than keras).

If you want more of the “Keras” feel in pytorch, use our other library Flash (built on top of lightning)

Just add these last lines to your code

import Flash

# task
classifier = flash.Task(model, loss_fn=criterion, optimizer=optimizer)

# train
flash.Trainer(max_epochs=10, gpus=1).fit(classifier, loader_train, loader_valid)

Example (with num_workers)

notice on the profiler that you spent a non trivial amount of time on data loading and lightning
gave you a warning about num_workers?

Dataloader(num_workers=8)

Now if I run and profile again… we eliminate the data loading issue and you are left with the model optimization speed ups (which lightning handles most of the bottlenecks there)

Profiler Report

Action                                  |  Mean duration (s)    |Num calls              |  Total time (s)       |  Percentage %         |
------------------------------------------------------------------------------------------------------------------------------------
Total                                   |  -                    |_                      |  16.307               |  100 %                |
------------------------------------------------------------------------------------------------------------------------------------
run_training_epoch                      |  15.476               |1                      |  15.476               |  94.909               |
run_training_batch                      |  0.029962             |400                    |  11.985               |  73.496               |
optimizer_step_and_closure_0            |  0.029476             |400                    |  11.791               |  72.305               |
training_step_and_backward              |  0.028529             |400                    |  11.412               |  69.983               |
model_backward                          |  0.015577             |400                    |  6.2306               |  38.209               |
model_forward                           |  0.012756             |400                    |  5.1025               |  31.291               |
training_step                           |  0.012475             |400                    |  4.9902               |  30.602               |
evaluation_step_and_end                 |  0.017031             |102                    |  1.7372               |  10.653               |
validation_step                         |  0.016785             |102                    |  1.712                |  10.499               |
get_train_batch                         |  0.0024146            |400                    |  0.96584              |  5.923                |
on_train_batch_end                      |  0.00035222           |400                    |  0.14089              |  0.86398              |
cache_result                            |  2.2822e-05           |2323                   |  0.053016             |  0.32512              |
on_validation_batch_end                 |  0.00017669           |102                    |  0.018023             |  0.11052              |
on_batch_start                          |  2.3768e-05           |400                    |  0.0095072            |  0.058303             |
on_after_backward                       |  1.6999e-05           |400                    |  0.0067996            |  0.041698             |
on_batch_end                            |  1.5546e-05           |400                    |  0.0062184            |  0.038134             |
on_before_zero_grad                     |  1.4941e-05           |400                    |  0.0059764            |  0.03665              |
on_validation_end                       |  0.0021403            |2                      |  0.0042806            |  0.026251             |
training_step_end                       |  1.0234e-05           |400                    |  0.0040936            |  0.025104             |
on_train_batch_start                    |  9.4705e-06           |400                    |  0.0037882            |  0.023231             |
on_validation_batch_start               |  1.853e-05            |102                    |  0.00189              |  0.011591             |
validation_step_end                     |  1.3889e-05           |102                    |  0.0014167            |  0.0086879            |
on_validation_start                     |  0.00023427           |2                      |  0.00046854           |  0.0028733            |
on_validation_epoch_end                 |  0.00011319           |2                      |  0.00022638           |  0.0013882            |
on_epoch_start                          |  0.00022617           |1                      |  0.00022617           |  0.001387             |
on_train_end                            |  0.00022525           |1                      |  0.00022525           |  0.0013813            |
on_train_start                          |  0.000178             |1                      |  0.000178             |  0.0010916            |
on_epoch_end                            |  1.8708e-05           |1                      |  1.8708e-05           |  0.00011473           |
on_validation_epoch_start               |  9.2705e-06           |2                      |  1.8541e-05           |  0.0001137            |
on_train_epoch_start                    |  1.0667e-05           |1                      |  1.0667e-05           |  6.5415e-05           |
on_train_epoch_end                      |  9.209e-06            |1                      |  9.209e-06            |  5.6474e-05           |
on_fit_start                            |  8.209e-06            |1                      |  8.209e-06            |  5.0341e-05           |
on_before_accelerator_backend_setup     |  6.417e-06            |1                      |  6.417e-06            |  3.9352e-05           |

summary

so, in summary… it just looks like keras does more optimizations under the hood to help the user avoid mistakes. This is also something that lightning does.

Bonus 2

Also notice lightning gives this warning…

/Users/williamfalcon/miniconda3/envs/flash/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:49: UserWarning: Your val_dataloader has `shuffle=True`, it is best practice to turn this off for validation and test dataloaders.

it caught a mistake you made… don’t shuffle your validation data (it’s bad practice) :wink:

1 Like