Keras trains significantly faster than PyTorch for simple network

While converting a colleague’s Keras network into PyTorch, I noticed that the training speed became significantly slower. The actual conversion is validated (gets the same results with actual data).

Below, I’ve provided some minimal examples that demonstrate the behavior using random data and a simple fully-connected network.

Summary: using an RTX 2080 Super GPU (driver version 460.80, CUDA version 11.2) with an Ubuntu 18.04.5 LTS container, I get ~2 seconds/epoch from Keras and ~15 seconds/epoch from PyTorch.

While generic suggestions to make PyTorch faster are always appreciated, I particularly want to understand what Keras is doing that PyTorch isn’t (or vice versa) in such a simple setup.

Keras code:

# imports and basic setup
import numpy as np
import tensorflow as tf

# print versions
print_versions = True
if print_versions:
    import platform
    print("Software versions")
    print(f"  * Python: {platform.python_version()}")
    print(f"    * numpy: {np.__version__}")
    print(f"    * tensorflow: {tf.__version__}")

# generate random data
N_train_class0 = N_train_class1 = 1_250_000
event_dim = 8

rng_seed = 0
rng = np.random.Generator(np.random.PCG64(rng_seed))

class0_events = rng.uniform(100,500,size=(N_train_class0, event_dim))
class0_ytarget = np.zeros(shape=(N_train_class0, 1))

class1_events = rng.uniform(100,500,size=(N_train_class1, event_dim))
class1_ytarget = np.zeros(shape=(N_train_class1, 1))

permutation = rng.permutation(N_train_class0 + N_train_class1)
events_train = np.concatenate([class0_events, class1_events])[permutation]
ytarget_train = np.concatenate([class0_ytarget, class1_ytarget])[permutation]

# setup model
tf.random.set_seed(0)

network = tf.keras.Sequential(name="event_variable")
network.add(tf.keras.layers.InputLayer(input_shape=(event_dim,)))

hidden_node_counts = [128, 64, 64, 64, 32]
for node_count in hidden_node_counts:
    network.add(tf.keras.layers.Dense(node_count, activation='relu'))

network.add(tf.keras.layers.Dense(1, activation='sigmoid'))

network.summary()

event_input_tensor = tf.keras.Input(shape=(event_dim,), name='event_input')
output_tensor = network(event_input_tensor)

model = tf.keras.Model(
    inputs = event_input_tensor,
    outputs = output_tensor
)
model.summary()

# prepare for training
model.compile(optimizer='adam', loss='binary_crossentropy')

# do training
model.fit(x=events_train, y=ytarget_train, batch_size=5000, epochs=10, validation_split=0.2)

Keras output:

Software versions
  * Python: 3.6.9
    * numpy: 1.19.5
    * tensorflow: 2.5.0
Model: "event_variable"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 128)               1152      
_________________________________________________________________
dense_1 (Dense)              (None, 64)                8256      
_________________________________________________________________
dense_2 (Dense)              (None, 64)                4160      
_________________________________________________________________
dense_3 (Dense)              (None, 64)                4160      
_________________________________________________________________
dense_4 (Dense)              (None, 32)                2080      
_________________________________________________________________
dense_5 (Dense)              (None, 1)                 33        
=================================================================
Total params: 19,841
Trainable params: 19,841
Non-trainable params: 0
_________________________________________________________________
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
event_input (InputLayer)     [(None, 8)]               0         
_________________________________________________________________
event_variable (Sequential)  (None, 1)                 19841     
=================================================================
Total params: 19,841
Trainable params: 19,841
Non-trainable params: 0
_________________________________________________________________
Epoch 1/10
400/400 [==============================] - 2s 4ms/step - loss: 0.0222 - val_loss: 3.5520e-16
Epoch 2/10
400/400 [==============================] - 2s 4ms/step - loss: 3.2005e-16 - val_loss: 3.5520e-16
Epoch 3/10
400/400 [==============================] - 2s 4ms/step - loss: 3.2005e-16 - val_loss: 3.5520e-16
Epoch 4/10
400/400 [==============================] - 2s 4ms/step - loss: 3.2005e-16 - val_loss: 3.5520e-16
Epoch 5/10
400/400 [==============================] - 2s 4ms/step - loss: 3.2005e-16 - val_loss: 3.5520e-16
Epoch 6/10
400/400 [==============================] - 2s 4ms/step - loss: 3.2005e-16 - val_loss: 3.5520e-16
Epoch 7/10
400/400 [==============================] - 2s 4ms/step - loss: 3.2005e-16 - val_loss: 3.5520e-16
Epoch 8/10
400/400 [==============================] - 2s 4ms/step - loss: 3.2005e-16 - val_loss: 3.5520e-16
Epoch 9/10
400/400 [==============================] - 2s 4ms/step - loss: 3.2005e-16 - val_loss: 3.5520e-16
Epoch 10/10
400/400 [==============================] - 2s 4ms/step - loss: 3.2005e-16 - val_loss: 3.5520e-16

PyTorch code:

# imports and basic setup
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, random_split
import torch.optim as optim
from tqdm import tqdm
import numpy as np

# print versions
print_versions = True
if print_versions:
    import platform
    print("Software versions")
    print(f"  * Python: {platform.python_version()}")
    print(f"    * numpy: {np.__version__}")
    print(f"    * torch: {torch.__version__}")
    
# choose cpu or gpu 
if torch.cuda.is_available():
    device = torch.device('cuda')
    print("Using GPU")
else:
    device = torch.device('cpu')
    print("Using CPU")

# generate random data
N_train_class0 = N_train_class1 = 1_250_000
event_dim = 8

rng_seed = 0
rng = np.random.Generator(np.random.PCG64(rng_seed))

class0_events = rng.uniform(100,500,size=(N_train_class0, event_dim))
class0_ytarget = np.zeros(shape=(N_train_class0, 1))

class1_events = rng.uniform(100,500,size=(N_train_class1, event_dim))
class1_ytarget = np.zeros(shape=(N_train_class1, 1))

permutation = rng.permutation(N_train_class0 + N_train_class1)
dataset = TensorDataset(
    torch.Tensor(np.concatenate([class0_events, class1_events])[permutation]).to(device),
    torch.Tensor(np.concatenate([class0_ytarget, class1_ytarget])[permutation]).to(device)
)

# setup model
torch.manual_seed(0)

def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.0)

hidden_node_counts = [128, 64, 64, 64, 32]
layers = [
    nn.Linear(in_features = event_dim, out_features = hidden_node_counts[0]),
    nn.ReLU()
]
for counter in range(len(hidden_node_counts)-1):
    layers.extend([
        nn.Linear(in_features = hidden_node_counts[counter], out_features = hidden_node_counts[counter+1]),
        nn.ReLU()
    ])
layers.extend([
    nn.Linear(in_features = hidden_node_counts[-1], out_features = 1),
    nn.Sigmoid()
])
model = nn.Sequential(*layers).to(device)

print(model)
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Trainable params: {}".format(total_params))

# prepare for training
pct_valid = 0.2
num_valid = int(len(dataset)*pct_valid)
split_train, split_valid = random_split(dataset, [len(dataset)-num_valid, num_valid])
batch_size = 5000
loader_train = DataLoader(split_train, batch_size=batch_size, shuffle=True)
loader_valid = DataLoader(split_valid, batch_size=batch_size, shuffle=True)

criterion = nn.BCELoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# do training
epochs = 10
for epoch in range(epochs):
    print("Epoch {}/{}".format(epoch+1,epochs))
    train_loss = 0
    for i, data in tqdm(enumerate(loader_train), unit="batch", total=len(loader_train)):
        model.train()
        model.zero_grad()
        optimizer.zero_grad()
        event_data, target_data = data
        output = model(event_data)
        batch_loss = criterion(output, target_data)
        batch_loss.backward()
        optimizer.step()
        model.eval()
        train_loss += batch_loss.item()
    train_loss /= len(loader_train)
    tqdm.write("loss: {}".format(train_loss))
        
    # validation
    valid_loss = 0
    with torch.no_grad():
        for i, data in enumerate(loader_valid):
            event_data, target_data = data
            output = model(event_data)
            batch_loss = criterion(output, target_data)
            valid_loss += batch_loss.item()
    valid_loss /= len(loader_valid)
    tqdm.write("val_loss: {}".format(valid_loss))

PyTorch output:

Software versions
  * Python: 3.8.8
    * numpy: 1.19.2
    * torch: 1.8.1
Using GPU
Sequential(
  (0): Linear(in_features=8, out_features=128, bias=True)
  (1): ReLU()
  (2): Linear(in_features=128, out_features=64, bias=True)
  (3): ReLU()
  (4): Linear(in_features=64, out_features=64, bias=True)
  (5): ReLU()
  (6): Linear(in_features=64, out_features=64, bias=True)
  (7): ReLU()
  (8): Linear(in_features=64, out_features=32, bias=True)
  (9): ReLU()
  (10): Linear(in_features=32, out_features=1, bias=True)
  (11): Sigmoid()
)
Trainable params: 19841
Epoch 1/10

100%|██████████| 400/400 [00:15<00:00, 25.90batch/s]

loss: 0.0007586651195278834
val_loss: 9.65595376399564e-11
Epoch 2/10

100%|██████████| 400/400 [00:15<00:00, 26.06batch/s]

loss: 8.720163338594642e-11
val_loss: 6.139279044338475e-11
Epoch 3/10

100%|██████████| 400/400 [00:15<00:00, 26.20batch/s]

loss: 5.653502557594059e-11
val_loss: 3.945827847622041e-11
Epoch 4/10

100%|██████████| 400/400 [00:15<00:00, 26.17batch/s]

loss: 3.585220385111595e-11
val_loss: 2.4437905700794295e-11
Epoch 5/10

100%|██████████| 400/400 [00:15<00:00, 25.85batch/s]

loss: 2.306700272415238e-11
val_loss: 1.6093254558841032e-11
Epoch 6/10

100%|██████████| 400/400 [00:15<00:00, 26.12batch/s]

loss: 1.5616419936706484e-11
val_loss: 9.655952790815769e-12
Epoch 7/10

100%|██████████| 400/400 [00:15<00:00, 26.15batch/s]

loss: 1.0102988931559586e-11
val_loss: 5.602836770923769e-12
Epoch 8/10

100%|██████████| 400/400 [00:15<00:00, 26.10batch/s]

loss: 7.271767608636043e-12
val_loss: 3.4570694570912332e-12
Epoch 9/10

100%|██████████| 400/400 [00:15<00:00, 25.95batch/s]

loss: 5.5432326155450965e-12
val_loss: 2.6226043559063328e-12
Epoch 10/10

100%|██████████| 400/400 [00:15<00:00, 26.10batch/s]

loss: 4.321337200834802e-12
val_loss: 1.907348680038612e-12

For general performance improvements you could take a look at the Performance Guide. While you could potentially improve the performance I doubt it could explain the large difference.
Based on the model architecture, I would guess that the general overhead might be high compared to the actual GPU workload, which should be visible as whitespaces in a profiling timeline (i.e. no CUDA kernels being executed). To verify this, you could create an Nsight Systems profile and check the timeline.

Thanks; I had tried a few of these recommendations and observed no improvement:

  • model = torch.jit.script(model) before training
  • pin_memory=True in loaders

However, I had not realized that tensor.item() caused a synchronization. Is there a recommended way to sum the loss over batches in an epoch without using item()?

(I’ll try the profiler separately…)

You could detach() the loss tensors and accumulate the losses without using item().
Sure, I think the profile could help as it might show the actual kernel execution times and when they are launched.

@kpedro88
why don’t you drop the model into lightning and try that (i added the few lines to make that work)?
lightning handles a lot of these performance optimizations such as what you ran into with syncing.

I converted it and ran some quick profiling (all of this took about 10 mins to do end to end).
On CPU was able to get it about 2.3x faster… try on the GPU?

Lightning version

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, random_split
import torch.optim as optim
from tqdm import tqdm
import numpy as np

# import PL
import pytorch_lightning as pl
pl.seed_everything(0)


class Classifier(pl.LightningModule):

    def __init__(self, model, criterion):
        super().__init__()
        self.model = model
        self.criterion = criterion
 
    def training_step(self, data, batch_idx):
        # train loop
        event_data, target_data = data
        output = self.model(event_data)
        batch_loss = self.criterion(output, target_data)
        self.log('train_loss', batch_loss)
        return batch_loss
    
    def validation_step(self, data, batch_idx):
        # val loop
        event_data, target_data = data
        output = self.model(event_data)
        batch_loss = self.criterion(output, target_data)
        self.log('val_loss', batch_loss)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.model.parameters(), lr=0.001)
        return optimizer

# -------------------
# generate random data
# -------------------
N_train_class0 = N_train_class1 = 1_250_000
event_dim = 8

rng_seed = 0
rng = np.random.Generator(np.random.PCG64(rng_seed))

class0_events = rng.uniform(100,500,size=(N_train_class0, event_dim))
class0_ytarget = np.zeros(shape=(N_train_class0, 1))

class1_events = rng.uniform(100,500,size=(N_train_class1, event_dim))
class1_ytarget = np.zeros(shape=(N_train_class1, 1))

permutation = rng.permutation(N_train_class0 + N_train_class1)
dataset = TensorDataset(
    torch.Tensor(np.concatenate([class0_events, class1_events])[permutation]),
    torch.Tensor(np.concatenate([class0_ytarget, class1_ytarget])[permutation])
)

# prepare for training
pct_valid = 0.2
num_valid = int(len(dataset)*pct_valid)
split_train, split_valid = random_split(dataset, [len(dataset)-num_valid, num_valid])
batch_size = 5000
loader_train = DataLoader(split_train, batch_size=batch_size, shuffle=True)
loader_valid = DataLoader(split_valid, batch_size=batch_size, shuffle=True)


# --------------
# model
# --------------
def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.0)

hidden_node_counts = [128, 64, 64, 64, 32]
layers = [
    nn.Linear(in_features = event_dim, out_features = hidden_node_counts[0]),
    nn.ReLU()
]
for counter in range(len(hidden_node_counts)-1):
    layers.extend([
        nn.Linear(in_features = hidden_node_counts[counter], out_features = hidden_node_counts[counter+1]),
        nn.ReLU()
    ])
layers.extend([
    nn.Linear(in_features = hidden_node_counts[-1], out_features = 1),
    nn.Sigmoid()
])
model = nn.Sequential(*layers)

print(model)
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Trainable params: {}".format(total_params))


criterion = nn.BCELoss()

# -------------
# TRAIN with Lightning
# -------------
classifier = Classifier(model, criterion)
trainer = pl.Trainer(max_epochs=10, gpus=1)
trainer.fit(classifier, loader_train, loader_valid)

I ran it on CPUs (but you can run of gpus by setting gpus=1.

I think your bottleneck also comes from the dataloader… lightning gives you useful warnings when it detects issues.

/Users/williamfalcon/miniconda3/envs/flash/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:49: UserWarning: Your val_dataloader has `shuffle=True`, it is best practice to turn this off for validation and test dataloaders.
  warnings.warn(*args, **kwargs)
/Users/williamfalcon/miniconda3/envs/flash/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:49: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  warnings.warn(*args, **kwargs)
/Users/williamfalcon/miniconda3/envs/flash/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:49: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  warnings.warn(*args, **kwargs)
Epoch 1:  61%|███████████████████████████████████████████████████▋                                 | 304/500 [00:23<00:15, 12.85it/s, loss=8.94e-12, v_num=4

Profiling

You can also run with profiling to see where the speed issues are coming from.

(profiling might slow things down though…)

Trainer(profiler='pytorch')
---------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
---------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                          AddmmBackward        30.23%     763.912ms        30.23%     763.912ms       1.900ms           402  
                               aten::mm        29.64%     749.083ms        29.64%     749.083ms       1.016ms           737  
                            aten::addmm        16.95%     428.369ms        16.95%     428.369ms       1.066ms           402  
                             aten::relu         4.23%     106.824ms         4.23%     106.824ms     318.878us           335  
                        aten::threshold         4.02%     101.632ms         4.02%     101.632ms     303.378us           335  
                          ReluBackward0         3.97%     100.330ms         3.97%     100.330ms     299.493us           335  
               aten::threshold_backward         3.89%      98.345ms         3.89%      98.345ms     293.566us           335  
                              aten::sum         1.61%      40.591ms         1.61%      40.591ms      86.548us           469  
                            aten::copy_         1.47%      37.197ms         1.47%      37.197ms      46.265us           804  
                                aten::t         0.78%      19.632ms         0.78%      19.632ms      10.104us          1943  
             aten::binary_cross_entropy         0.36%       9.013ms         0.36%       9.013ms     134.518us            67  
        torch::autograd::AccumulateGrad         0.32%       8.046ms         0.32%       8.046ms      10.008us           804  
                        aten::transpose         0.30%       7.688ms         0.30%       7.688ms       3.957us          1943  
                            aten::empty         0.26%       6.545ms         0.26%       6.545ms       2.505us          2613  
                          aten::sigmoid         0.22%       5.542ms         0.22%       5.542ms      41.361us           134  
                             aten::add_         0.20%       5.042ms         0.20%       5.042ms       6.367us           792  
                       aten::as_strided         0.17%       4.377ms         0.17%       4.377ms       1.519us          2881  
                              TBackward         0.14%       3.533ms         0.14%       3.533ms       8.788us           402  
                             aten::view         0.14%       3.463ms         0.14%       3.463ms       8.613us           402  
                             aten::mean         0.11%       2.833ms         0.11%       2.833ms      42.280us            67  
---------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 2.527s

Or try any of the other profilers… here we use the “simple” profiler

Trainer(profiler='simple')
Profiler Report

Action                                  |  Mean duration (s)    |Num calls              |  Total time (s)       |  Percentage %         |
------------------------------------------------------------------------------------------------------------------------------------
Total                                   |  -                    |_                      |  37.739               |  100 %                |
------------------------------------------------------------------------------------------------------------------------------------
run_training_epoch                      |  37.493               |1                      |  37.493               |  99.348               |
get_train_batch                         |  0.053623             |400                    |  21.449               |  56.836               |
run_training_batch                      |  0.024157             |400                    |  9.6627               |  25.604               |
optimizer_step_and_closure_0            |  0.023673             |400                    |  9.4692               |  25.091               |
training_step_and_backward              |  0.022778             |400                    |  9.1112               |  24.143               |
model_backward                          |  0.013783             |400                    |  5.5132               |  14.609               |
model_forward                           |  0.0088057            |400                    |  3.5223               |  9.3333               |
training_step                           |  0.0085555            |400                    |  3.4222               |  9.068                |
evaluation_step_and_end                 |  0.0089042            |102                    |  0.90823              |  2.4066               |
validation_step                         |  0.0087426            |102                    |  0.89174              |  2.3629               |
on_train_batch_end                      |  0.00036939           |400                    |  0.14776              |  0.39152              |
cache_result                            |  1.859e-05            |2323                   |  0.043184             |  0.11443              |
on_validation_batch_end                 |  0.0002122            |102                    |  0.021644             |  0.057353             |
on_batch_start                          |  3.5443e-05           |400                    |  0.014177             |  0.037566             |
on_after_backward                       |  1.7219e-05           |400                    |  0.0068876            |  0.018251             |
on_batch_end                            |  1.3641e-05           |400                    |  0.0054563            |  0.014458             |
on_before_zero_grad                     |  1.2449e-05           |400                    |  0.0049796            |  0.013195             |
on_train_batch_start                    |  9.6932e-06           |400                    |  0.0038773            |  0.010274             |
training_step_end                       |  8.5934e-06           |400                    |  0.0034374            |  0.0091083            |
on_validation_end                       |  0.0016228            |2                      |  0.0032456            |  0.0086001            |
on_validation_batch_start               |  2.1877e-05           |102                    |  0.0022315            |  0.005913             |
validation_step_end                     |  8.1364e-06           |102                    |  0.00082992           |  0.0021991            |
on_train_end                            |  0.00024225           |1                      |  0.00024225           |  0.00064191           |
on_epoch_start                          |  0.00021004           |1                      |  0.00021004           |  0.00055657           |
on_validation_start                     |  7.725e-05            |2                      |  0.0001545            |  0.00040939           |
on_train_start                          |  0.00014138           |1                      |  0.00014138           |  0.00037461           |
on_validation_epoch_end                 |  1.4126e-05           |2                      |  2.8251e-05           |  7.4859e-05           |
on_epoch_end                            |  1.7333e-05           |1                      |  1.7333e-05           |  4.5929e-05           |
on_validation_epoch_start               |  7.854e-06            |2                      |  1.5708e-05           |  4.1623e-05           |
on_before_accelerator_backend_setup     |  1.2458e-05           |1                      |  1.2458e-05           |  3.3011e-05           |
on_train_epoch_end                      |  9.25e-06             |1                      |  9.25e-06             |  2.4511e-05           |
on_fit_start                            |  8.709e-06            |1                      |  8.709e-06            |  2.3077e-05           |
on_train_epoch_start                    |  8e-06                |1                      |  8e-06                |  2.1198e-05           |

logging

And finally, if you want more speed ups, turn off the logger (lightning automatically generates tensorboard files).

Trainer(logger=False)

Bonus

Notice that Lightning is much simpler (but very different than keras).

If you want more of the “Keras” feel in pytorch, use our other library Flash (built on top of lightning)

Just add these last lines to your code

import Flash

# task
classifier = flash.Task(model, loss_fn=criterion, optimizer=optimizer)

# train
flash.Trainer(max_epochs=10, gpus=1).fit(classifier, loader_train, loader_valid)

Example (with num_workers)

notice on the profiler that you spent a non trivial amount of time on data loading and lightning
gave you a warning about num_workers?

Dataloader(num_workers=8)

Now if I run and profile again… we eliminate the data loading issue and you are left with the model optimization speed ups (which lightning handles most of the bottlenecks there)

Profiler Report

Action                                  |  Mean duration (s)    |Num calls              |  Total time (s)       |  Percentage %         |
------------------------------------------------------------------------------------------------------------------------------------
Total                                   |  -                    |_                      |  16.307               |  100 %                |
------------------------------------------------------------------------------------------------------------------------------------
run_training_epoch                      |  15.476               |1                      |  15.476               |  94.909               |
run_training_batch                      |  0.029962             |400                    |  11.985               |  73.496               |
optimizer_step_and_closure_0            |  0.029476             |400                    |  11.791               |  72.305               |
training_step_and_backward              |  0.028529             |400                    |  11.412               |  69.983               |
model_backward                          |  0.015577             |400                    |  6.2306               |  38.209               |
model_forward                           |  0.012756             |400                    |  5.1025               |  31.291               |
training_step                           |  0.012475             |400                    |  4.9902               |  30.602               |
evaluation_step_and_end                 |  0.017031             |102                    |  1.7372               |  10.653               |
validation_step                         |  0.016785             |102                    |  1.712                |  10.499               |
get_train_batch                         |  0.0024146            |400                    |  0.96584              |  5.923                |
on_train_batch_end                      |  0.00035222           |400                    |  0.14089              |  0.86398              |
cache_result                            |  2.2822e-05           |2323                   |  0.053016             |  0.32512              |
on_validation_batch_end                 |  0.00017669           |102                    |  0.018023             |  0.11052              |
on_batch_start                          |  2.3768e-05           |400                    |  0.0095072            |  0.058303             |
on_after_backward                       |  1.6999e-05           |400                    |  0.0067996            |  0.041698             |
on_batch_end                            |  1.5546e-05           |400                    |  0.0062184            |  0.038134             |
on_before_zero_grad                     |  1.4941e-05           |400                    |  0.0059764            |  0.03665              |
on_validation_end                       |  0.0021403            |2                      |  0.0042806            |  0.026251             |
training_step_end                       |  1.0234e-05           |400                    |  0.0040936            |  0.025104             |
on_train_batch_start                    |  9.4705e-06           |400                    |  0.0037882            |  0.023231             |
on_validation_batch_start               |  1.853e-05            |102                    |  0.00189              |  0.011591             |
validation_step_end                     |  1.3889e-05           |102                    |  0.0014167            |  0.0086879            |
on_validation_start                     |  0.00023427           |2                      |  0.00046854           |  0.0028733            |
on_validation_epoch_end                 |  0.00011319           |2                      |  0.00022638           |  0.0013882            |
on_epoch_start                          |  0.00022617           |1                      |  0.00022617           |  0.001387             |
on_train_end                            |  0.00022525           |1                      |  0.00022525           |  0.0013813            |
on_train_start                          |  0.000178             |1                      |  0.000178             |  0.0010916            |
on_epoch_end                            |  1.8708e-05           |1                      |  1.8708e-05           |  0.00011473           |
on_validation_epoch_start               |  9.2705e-06           |2                      |  1.8541e-05           |  0.0001137            |
on_train_epoch_start                    |  1.0667e-05           |1                      |  1.0667e-05           |  6.5415e-05           |
on_train_epoch_end                      |  9.209e-06            |1                      |  9.209e-06            |  5.6474e-05           |
on_fit_start                            |  8.209e-06            |1                      |  8.209e-06            |  5.0341e-05           |
on_before_accelerator_backend_setup     |  6.417e-06            |1                      |  6.417e-06            |  3.9352e-05           |

summary

so, in summary… it just looks like keras does more optimizations under the hood to help the user avoid mistakes. This is also something that lightning does.

Bonus 2

Also notice lightning gives this warning…

/Users/williamfalcon/miniconda3/envs/flash/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:49: UserWarning: Your val_dataloader has `shuffle=True`, it is best practice to turn this off for validation and test dataloaders.

it caught a mistake you made… don’t shuffle your validation data (it’s bad practice) :wink:

1 Like

Thanks, this is great! I’ve only recently learned about Lightning and hadn’t tried it yet, but this is very good motivation. (And also thanks for pointing out the validation shuffling - I was trying to replicate Keras’ settings, but it turns out that shuffle=True in Keras also only shuffles the training data.)

It turns out that the solution is very simple:

I initially thought that num_workers was unnecessary; because the dataset is small, I just transferred the whole thing to the GPU at the beginning (torch.Tensor(np.concatenate([class0_events, class1_events])[permutation]).to(device)). A loader with num_workers won’t even run in this case, because all the data is already loaded on the GPU.

However, it turns out to be faster not to do this, and to let the dataloader use multiple workers to load the relevant data to the GPU for each batch. I suspect this is because of shuffle=True; it’s better to use all the CPU cores to send just one batch’s data as contiguous register entries on the GPU. (As usual, it make sense in hindsight…) Lightning also takes care of the to(device) calls for the data automatically in the training loop, which is nice.

1 Like