Pytorch is Much Slower Than Keras

Hi I am trying to train a U-Net model with PyTorch. Before that I used keras with U-Net and same dataset and training through one epoch took me 10-12 secs but for PyTorch it takes 6 minutes. There should not be this sort of difference right?
Here is my dataloader code:

num_classes = np.unique(labels).shape[0]  # replace with number of classes in your dataset

from torch.utils.data import Dataset
from torchvision.transforms import ToTensor


class CustomImageDataset(Dataset):
    def __init__(self, mask, img, transform=None):
        self.mask = mask
        self.img = img
        self.transform = transform
        self.mask = torch.from_numpy(self.mask)
        self.mask = self.mask.type(torch.LongTensor)
        self.mask = nn.functional.one_hot(self.mask, num_classes=2)
        self.mask = torch.movedim(self.mask,3,4)
        self.mask = torch.movedim(self.mask,(1,2,3),(2,3,1))
        self.mask = self.mask[:,:,:,:,0]
    def __len__(self):
        return self.img.shape[0]

    def __getitem__(self, idx):
        #print(idx)
        image = self.img[idx]
        mask = self.mask[idx]
        if self.transform:
            image = self.transform()(image)
 
        return image, mask


train_data = CustomImageDataset(labels, image_dataset, ToTensor)
train, valid = torch.utils.data.random_split(train_data, [2434, 270], generator=torch.Generator().manual_seed(42))


model = UNetWithResnet50Encoder()
model.double()
model.to(device)
loss_fn = nn.BCEwithLogitsLoss

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

train_loader = DataLoader(train, batch_size=16, shuffle=True, pin_memory = True, num_workers=12)
validation_loader = DataLoader(valid, batch_size=1, shuffle=False,pin_memory = True, num_workers=12)

scaler = torch.cuda.amp.GradScaler()

And here is my training function:

def train_fn(loader, model, optimizer, loss_fn, scaler):
  loop = tqdm(loader)
  for batch_idx, (data,targets) in enumerate(loop):
    data = data.to(device=DEVICE)
    targets = targets.to(device=DEVICE)

    with torch.cuda.amp.autocast():
      predictions = model(data)
      loss = loss_fn(predictions.float(), targets.float())
    
    optimizer.zero_grad()
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    loop.set_postfix(loss=loss.item())

And here is the U-Net implementation that I use:

import torch
import torch.nn as nn
import torchvision
resnet = torchvision.models.resnet.resnet50(pretrained=True)


class ConvBlock(nn.Module):
    """
    Helper module that consists of a Conv -> BN -> ReLU
    """

    def __init__(self, in_channels, out_channels, padding=1, kernel_size=3, stride=1, with_nonlinearity=True):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, padding=padding, kernel_size=kernel_size, stride=stride)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.with_nonlinearity = with_nonlinearity

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        if self.with_nonlinearity:
            x = self.relu(x)
        return x


class Bridge(nn.Module):
    """
    This is the middle layer of the UNet which just consists of some
    """

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.bridge = nn.Sequential(
            ConvBlock(in_channels, out_channels),
            ConvBlock(out_channels, out_channels)
        )

    def forward(self, x):
        return self.bridge(x)


class UpBlockForUNetWithResNet50(nn.Module):
    """
    Up block that encapsulates one up-sampling step which consists of Upsample -> ConvBlock -> ConvBlock
    """

    def __init__(self, in_channels, out_channels, up_conv_in_channels=None, up_conv_out_channels=None,
                 upsampling_method="conv_transpose"):
        super().__init__()

        if up_conv_in_channels == None:
            up_conv_in_channels = in_channels
        if up_conv_out_channels == None:
            up_conv_out_channels = out_channels

        if upsampling_method == "conv_transpose":
            self.upsample = nn.ConvTranspose2d(up_conv_in_channels, up_conv_out_channels, kernel_size=2, stride=2)
        elif upsampling_method == "bilinear":
            self.upsample = nn.Sequential(
                nn.Upsample(mode='bilinear', scale_factor=2),
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)
            )
        self.conv_block_1 = ConvBlock(in_channels, out_channels)
        self.conv_block_2 = ConvBlock(out_channels, out_channels)

    def forward(self, up_x, down_x):
        """

        :param up_x: this is the output from the previous up block
        :param down_x: this is the output from the down block
        :return: upsampled feature map
        """
        x = self.upsample(up_x)
        x = torch.cat([x, down_x], 1)
        x = self.conv_block_1(x)
        x = self.conv_block_2(x)
        return x


class UNetWithResnet50Encoder(nn.Module):
    DEPTH = 6

    def __init__(self, n_classes=2):
        super().__init__()
        resnet = torchvision.models.resnet.resnet50(pretrained=True)
        down_blocks = []
        up_blocks = []
        self.input_block = nn.Sequential(*list(resnet.children()))[:3]
        self.input_pool = list(resnet.children())[3]
        for bottleneck in list(resnet.children()):
            if isinstance(bottleneck, nn.Sequential):
                down_blocks.append(bottleneck)
        self.down_blocks = nn.ModuleList(down_blocks)
        self.bridge = Bridge(2048, 2048)
        up_blocks.append(UpBlockForUNetWithResNet50(2048, 1024))
        up_blocks.append(UpBlockForUNetWithResNet50(1024, 512))
        up_blocks.append(UpBlockForUNetWithResNet50(512, 256))
        up_blocks.append(UpBlockForUNetWithResNet50(in_channels=128 + 64, out_channels=128,
                                                    up_conv_in_channels=256, up_conv_out_channels=128))
        up_blocks.append(UpBlockForUNetWithResNet50(in_channels=64 + 3, out_channels=64,
                                                    up_conv_in_channels=128, up_conv_out_channels=64))

        self.up_blocks = nn.ModuleList(up_blocks)

        self.out = nn.Conv2d(64, n_classes, kernel_size=1, stride=1)

    def forward(self, x, with_output_feature_map=False):
        pre_pools = dict()
        pre_pools[f"layer_0"] = x
        x = self.input_block(x)
        pre_pools[f"layer_1"] = x
        x = self.input_pool(x)

        for i, block in enumerate(self.down_blocks, 2):
            x = block(x)
            if i == (UNetWithResnet50Encoder.DEPTH - 1):
                continue
            pre_pools[f"layer_{i}"] = x

        x = self.bridge(x)

        for i, block in enumerate(self.up_blocks, 1):
            key = f"layer_{UNetWithResnet50Encoder.DEPTH - 1 - i}"
            x = block(x, pre_pools[key])
        output_feature_map = x
        x = self.out(x)
        del pre_pools
        if with_output_feature_map:
            return x, output_feature_map
        else:
            return x

Did you check the model architectures and the number of parameters in both implementations and made sure these are equal?
We have seen similar observations in the past which were sometimes narrowed down to quite different model architectures as the author wasn’t familiar with both frameworks.

To check the number of parameters in Pytorch, you can use:

print(sum(p.numel() for p in model.parameters))

Also, are you using Windows or Linux?

Hi ptrblck,
I tried the dataset with fcn_resnet50 from torch.vision.models and even that takes 3 min per epoch. This make me belive that my dataloader or training fn does something wrong.

I don’t know which dataset you were using but are you seeing an unexpectedly long loading time per sample? I.e. how many sampler are you loading in an epoch and what kind of processing is used?
Also, where is the data stored as often any other medium besides a fast SSD might cause a bottleneck.

The dataset should not be the problem since I used same dataset with keras ad it was almost 50 times faster. I am using google colab pro and I shared the dataLoader code. Since I am new to pyTorch I believe I am making mistakes at writing my dataloader. In the code first I preprocess all my data and upload to cpu ram as numpy arrays. Then using my dataloader I upload them to gpu as batches.
Here is my dataloader code.

num_classes = np.unique(labels).shape[0]  # replace with number of classes in your dataset

from torch.utils.data import Dataset
from torchvision.transforms import ToTensor


class CustomImageDataset(Dataset):
    def __init__(self, mask, img, transform=None):
        self.mask = mask
        self.img = img
        self.transform = transform
        self.mask = torch.from_numpy(self.mask)
        self.mask = self.mask.type(torch.LongTensor)
        self.mask = nn.functional.one_hot(self.mask, num_classes=2)
        self.mask = torch.movedim(self.mask,3,4)
        self.mask = torch.movedim(self.mask,(1,2,3),(2,3,1))
        self.mask = self.mask[:,:,:,:,0]
    def __len__(self):
        return self.img.shape[0]

    def __getitem__(self, idx):
        #print(idx)
        image = self.img[idx]
        mask = self.mask[idx]
        if self.transform:
            image = self.transform()(image)
 
        return image, mask


train_data = CustomImageDataset(labels, image_dataset, ToTensor)
train, valid = torch.utils.data.random_split(train_data, [2434, 270], generator=torch.Generator().manual_seed(42))


model = UNetWithResnet50Encoder()
model.double()
model.to(device)
loss_fn = nn.BCEwithLogitsLoss

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

train_loader = DataLoader(train, batch_size=16, shuffle=True, pin_memory = True, num_workers=12)
validation_loader = DataLoader(valid, batch_size=1, shuffle=False,pin_memory = True, num_workers=12)

scaler = torch.cuda.amp.GradScaler()

And here is my training function:

def train_fn(loader, model, optimizer, loss_fn, scaler):
  loop = tqdm(loader)
  for batch_idx, (data,targets) in enumerate(loop):
    data = data.to(device=DEVICE)
    targets = targets.to(device=DEVICE)

    with torch.cuda.amp.autocast():
      prediction

I don’t see any obvious issues in your dataset, but you should also check if using the large number of workers is causing a slowdown. Since you are preloading the entire dataset in the __init__ method, the actual sample loading might be cheap as it’s only an indexing operation in __getiteim__.
Multiple workers would be beneficial if the loading and transformation takes some time but I’ve seen a negative effect when multiple workers just try to index an array due to the added overhead.