@kpedro88
why don’t you drop the model into lightning and try that (i added the few lines to make that work)?
lightning handles a lot of these performance optimizations such as what you ran into with syncing.
I converted it and ran some quick profiling (all of this took about 10 mins to do end to end).
On CPU was able to get it about 2.3x faster… try on the GPU?
Lightning version
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, random_split
import torch.optim as optim
from tqdm import tqdm
import numpy as np
# import PL
import pytorch_lightning as pl
pl.seed_everything(0)
class Classifier(pl.LightningModule):
def __init__(self, model, criterion):
super().__init__()
self.model = model
self.criterion = criterion
def training_step(self, data, batch_idx):
# train loop
event_data, target_data = data
output = self.model(event_data)
batch_loss = self.criterion(output, target_data)
self.log('train_loss', batch_loss)
return batch_loss
def validation_step(self, data, batch_idx):
# val loop
event_data, target_data = data
output = self.model(event_data)
batch_loss = self.criterion(output, target_data)
self.log('val_loss', batch_loss)
def configure_optimizers(self):
optimizer = optim.Adam(self.model.parameters(), lr=0.001)
return optimizer
# -------------------
# generate random data
# -------------------
N_train_class0 = N_train_class1 = 1_250_000
event_dim = 8
rng_seed = 0
rng = np.random.Generator(np.random.PCG64(rng_seed))
class0_events = rng.uniform(100,500,size=(N_train_class0, event_dim))
class0_ytarget = np.zeros(shape=(N_train_class0, 1))
class1_events = rng.uniform(100,500,size=(N_train_class1, event_dim))
class1_ytarget = np.zeros(shape=(N_train_class1, 1))
permutation = rng.permutation(N_train_class0 + N_train_class1)
dataset = TensorDataset(
torch.Tensor(np.concatenate([class0_events, class1_events])[permutation]),
torch.Tensor(np.concatenate([class0_ytarget, class1_ytarget])[permutation])
)
# prepare for training
pct_valid = 0.2
num_valid = int(len(dataset)*pct_valid)
split_train, split_valid = random_split(dataset, [len(dataset)-num_valid, num_valid])
batch_size = 5000
loader_train = DataLoader(split_train, batch_size=batch_size, shuffle=True)
loader_valid = DataLoader(split_valid, batch_size=batch_size, shuffle=True)
# --------------
# model
# --------------
def init_weights(m):
if type(m) == nn.Linear:
torch.nn.init.xavier_uniform(m.weight)
m.bias.data.fill_(0.0)
hidden_node_counts = [128, 64, 64, 64, 32]
layers = [
nn.Linear(in_features = event_dim, out_features = hidden_node_counts[0]),
nn.ReLU()
]
for counter in range(len(hidden_node_counts)-1):
layers.extend([
nn.Linear(in_features = hidden_node_counts[counter], out_features = hidden_node_counts[counter+1]),
nn.ReLU()
])
layers.extend([
nn.Linear(in_features = hidden_node_counts[-1], out_features = 1),
nn.Sigmoid()
])
model = nn.Sequential(*layers)
print(model)
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Trainable params: {}".format(total_params))
criterion = nn.BCELoss()
# -------------
# TRAIN with Lightning
# -------------
classifier = Classifier(model, criterion)
trainer = pl.Trainer(max_epochs=10, gpus=1)
trainer.fit(classifier, loader_train, loader_valid)
I ran it on CPUs (but you can run of gpus by setting gpus=1.
I think your bottleneck also comes from the dataloader… lightning gives you useful warnings when it detects issues.
/Users/williamfalcon/miniconda3/envs/flash/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:49: UserWarning: Your val_dataloader has `shuffle=True`, it is best practice to turn this off for validation and test dataloaders.
warnings.warn(*args, **kwargs)
/Users/williamfalcon/miniconda3/envs/flash/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:49: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
warnings.warn(*args, **kwargs)
/Users/williamfalcon/miniconda3/envs/flash/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:49: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
warnings.warn(*args, **kwargs)
Epoch 1: 61%|███████████████████████████████████████████████████▋ | 304/500 [00:23<00:15, 12.85it/s, loss=8.94e-12, v_num=4
Profiling
You can also run with profiling to see where the speed issues are coming from.
(profiling might slow things down though…)
Trainer(profiler='pytorch')
--------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
--------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
AddmmBackward 30.23% 763.912ms 30.23% 763.912ms 1.900ms 402
aten::mm 29.64% 749.083ms 29.64% 749.083ms 1.016ms 737
aten::addmm 16.95% 428.369ms 16.95% 428.369ms 1.066ms 402
aten::relu 4.23% 106.824ms 4.23% 106.824ms 318.878us 335
aten::threshold 4.02% 101.632ms 4.02% 101.632ms 303.378us 335
ReluBackward0 3.97% 100.330ms 3.97% 100.330ms 299.493us 335
aten::threshold_backward 3.89% 98.345ms 3.89% 98.345ms 293.566us 335
aten::sum 1.61% 40.591ms 1.61% 40.591ms 86.548us 469
aten::copy_ 1.47% 37.197ms 1.47% 37.197ms 46.265us 804
aten::t 0.78% 19.632ms 0.78% 19.632ms 10.104us 1943
aten::binary_cross_entropy 0.36% 9.013ms 0.36% 9.013ms 134.518us 67
torch::autograd::AccumulateGrad 0.32% 8.046ms 0.32% 8.046ms 10.008us 804
aten::transpose 0.30% 7.688ms 0.30% 7.688ms 3.957us 1943
aten::empty 0.26% 6.545ms 0.26% 6.545ms 2.505us 2613
aten::sigmoid 0.22% 5.542ms 0.22% 5.542ms 41.361us 134
aten::add_ 0.20% 5.042ms 0.20% 5.042ms 6.367us 792
aten::as_strided 0.17% 4.377ms 0.17% 4.377ms 1.519us 2881
TBackward 0.14% 3.533ms 0.14% 3.533ms 8.788us 402
aten::view 0.14% 3.463ms 0.14% 3.463ms 8.613us 402
aten::mean 0.11% 2.833ms 0.11% 2.833ms 42.280us 67
--------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 2.527s
Or try any of the other profilers… here we use the “simple” profiler
Trainer(profiler='simple')
Profiler Report
Action | Mean duration (s) |Num calls | Total time (s) | Percentage % |
------------------------------------------------------------------------------------------------------------------------------------
Total | - |_ | 37.739 | 100 % |
------------------------------------------------------------------------------------------------------------------------------------
run_training_epoch | 37.493 |1 | 37.493 | 99.348 |
get_train_batch | 0.053623 |400 | 21.449 | 56.836 |
run_training_batch | 0.024157 |400 | 9.6627 | 25.604 |
optimizer_step_and_closure_0 | 0.023673 |400 | 9.4692 | 25.091 |
training_step_and_backward | 0.022778 |400 | 9.1112 | 24.143 |
model_backward | 0.013783 |400 | 5.5132 | 14.609 |
model_forward | 0.0088057 |400 | 3.5223 | 9.3333 |
training_step | 0.0085555 |400 | 3.4222 | 9.068 |
evaluation_step_and_end | 0.0089042 |102 | 0.90823 | 2.4066 |
validation_step | 0.0087426 |102 | 0.89174 | 2.3629 |
on_train_batch_end | 0.00036939 |400 | 0.14776 | 0.39152 |
cache_result | 1.859e-05 |2323 | 0.043184 | 0.11443 |
on_validation_batch_end | 0.0002122 |102 | 0.021644 | 0.057353 |
on_batch_start | 3.5443e-05 |400 | 0.014177 | 0.037566 |
on_after_backward | 1.7219e-05 |400 | 0.0068876 | 0.018251 |
on_batch_end | 1.3641e-05 |400 | 0.0054563 | 0.014458 |
on_before_zero_grad | 1.2449e-05 |400 | 0.0049796 | 0.013195 |
on_train_batch_start | 9.6932e-06 |400 | 0.0038773 | 0.010274 |
training_step_end | 8.5934e-06 |400 | 0.0034374 | 0.0091083 |
on_validation_end | 0.0016228 |2 | 0.0032456 | 0.0086001 |
on_validation_batch_start | 2.1877e-05 |102 | 0.0022315 | 0.005913 |
validation_step_end | 8.1364e-06 |102 | 0.00082992 | 0.0021991 |
on_train_end | 0.00024225 |1 | 0.00024225 | 0.00064191 |
on_epoch_start | 0.00021004 |1 | 0.00021004 | 0.00055657 |
on_validation_start | 7.725e-05 |2 | 0.0001545 | 0.00040939 |
on_train_start | 0.00014138 |1 | 0.00014138 | 0.00037461 |
on_validation_epoch_end | 1.4126e-05 |2 | 2.8251e-05 | 7.4859e-05 |
on_epoch_end | 1.7333e-05 |1 | 1.7333e-05 | 4.5929e-05 |
on_validation_epoch_start | 7.854e-06 |2 | 1.5708e-05 | 4.1623e-05 |
on_before_accelerator_backend_setup | 1.2458e-05 |1 | 1.2458e-05 | 3.3011e-05 |
on_train_epoch_end | 9.25e-06 |1 | 9.25e-06 | 2.4511e-05 |
on_fit_start | 8.709e-06 |1 | 8.709e-06 | 2.3077e-05 |
on_train_epoch_start | 8e-06 |1 | 8e-06 | 2.1198e-05 |
logging
And finally, if you want more speed ups, turn off the logger (lightning automatically generates tensorboard files).
Trainer(logger=False)
Bonus
Notice that Lightning is much simpler (but very different than keras).
If you want more of the “Keras” feel in pytorch, use our other library Flash (built on top of lightning)
Just add these last lines to your code
import Flash
# task
classifier = flash.Task(model, loss_fn=criterion, optimizer=optimizer)
# train
flash.Trainer(max_epochs=10, gpus=1).fit(classifier, loader_train, loader_valid)
Example (with num_workers)
notice on the profiler that you spent a non trivial amount of time on data loading and lightning
gave you a warning about num_workers?
Dataloader(num_workers=8)
Now if I run and profile again… we eliminate the data loading issue and you are left with the model optimization speed ups (which lightning handles most of the bottlenecks there)
Profiler Report
Action | Mean duration (s) |Num calls | Total time (s) | Percentage % |
------------------------------------------------------------------------------------------------------------------------------------
Total | - |_ | 16.307 | 100 % |
------------------------------------------------------------------------------------------------------------------------------------
run_training_epoch | 15.476 |1 | 15.476 | 94.909 |
run_training_batch | 0.029962 |400 | 11.985 | 73.496 |
optimizer_step_and_closure_0 | 0.029476 |400 | 11.791 | 72.305 |
training_step_and_backward | 0.028529 |400 | 11.412 | 69.983 |
model_backward | 0.015577 |400 | 6.2306 | 38.209 |
model_forward | 0.012756 |400 | 5.1025 | 31.291 |
training_step | 0.012475 |400 | 4.9902 | 30.602 |
evaluation_step_and_end | 0.017031 |102 | 1.7372 | 10.653 |
validation_step | 0.016785 |102 | 1.712 | 10.499 |
get_train_batch | 0.0024146 |400 | 0.96584 | 5.923 |
on_train_batch_end | 0.00035222 |400 | 0.14089 | 0.86398 |
cache_result | 2.2822e-05 |2323 | 0.053016 | 0.32512 |
on_validation_batch_end | 0.00017669 |102 | 0.018023 | 0.11052 |
on_batch_start | 2.3768e-05 |400 | 0.0095072 | 0.058303 |
on_after_backward | 1.6999e-05 |400 | 0.0067996 | 0.041698 |
on_batch_end | 1.5546e-05 |400 | 0.0062184 | 0.038134 |
on_before_zero_grad | 1.4941e-05 |400 | 0.0059764 | 0.03665 |
on_validation_end | 0.0021403 |2 | 0.0042806 | 0.026251 |
training_step_end | 1.0234e-05 |400 | 0.0040936 | 0.025104 |
on_train_batch_start | 9.4705e-06 |400 | 0.0037882 | 0.023231 |
on_validation_batch_start | 1.853e-05 |102 | 0.00189 | 0.011591 |
validation_step_end | 1.3889e-05 |102 | 0.0014167 | 0.0086879 |
on_validation_start | 0.00023427 |2 | 0.00046854 | 0.0028733 |
on_validation_epoch_end | 0.00011319 |2 | 0.00022638 | 0.0013882 |
on_epoch_start | 0.00022617 |1 | 0.00022617 | 0.001387 |
on_train_end | 0.00022525 |1 | 0.00022525 | 0.0013813 |
on_train_start | 0.000178 |1 | 0.000178 | 0.0010916 |
on_epoch_end | 1.8708e-05 |1 | 1.8708e-05 | 0.00011473 |
on_validation_epoch_start | 9.2705e-06 |2 | 1.8541e-05 | 0.0001137 |
on_train_epoch_start | 1.0667e-05 |1 | 1.0667e-05 | 6.5415e-05 |
on_train_epoch_end | 9.209e-06 |1 | 9.209e-06 | 5.6474e-05 |
on_fit_start | 8.209e-06 |1 | 8.209e-06 | 5.0341e-05 |
on_before_accelerator_backend_setup | 6.417e-06 |1 | 6.417e-06 | 3.9352e-05 |
summary
so, in summary… it just looks like keras does more optimizations under the hood to help the user avoid mistakes. This is also something that lightning does.
Bonus 2
Also notice lightning gives this warning…
/Users/williamfalcon/miniconda3/envs/flash/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:49: UserWarning: Your val_dataloader has `shuffle=True`, it is best practice to turn this off for validation and test dataloaders.
it caught a mistake you made… don’t shuffle your validation data (it’s bad practice)