UNet implmentation in pytorch too slow

I had a tensorflow based simple UNet model used for optimization program. I want to implement it in pytorch and written the equivalent version of it. Given below are the two implementations :

Tensorflow code -

k_size = 3

def u_net():
    
    inputs = Input((cropsize,cropsize,frames))
    block1 = Conv2D(32, (k_size, k_size), padding="same", activation = 'relu')(inputs)
    
    block2 = Conv2D(32, (k_size, k_size), padding="same", activation = 'relu')(block1)
    block2 = Conv2D(32, (k_size, k_size), padding="same", activation = 'relu')(block2)
    down1 = MaxPooling2D(pool_size=(2,2))(block2)
    
    block3 = Conv2D(64, (k_size, k_size), padding="same", activation = 'relu')(down1)
    block3 = Conv2D(64, (k_size, k_size), padding="same", activation = 'relu')(block3)
    down2 = MaxPooling2D(pool_size=(2,2))(block3)
    
    block4 = Conv2D(128, (k_size, k_size), padding="same", activation = 'relu')(down2)
    block4 = Conv2D(128, (k_size, k_size), padding="same", activation = 'relu')(block4)
    down3 = MaxPooling2D(pool_size=(2,2))(block4)
    
    block5 = Conv2D(256, (k_size, k_size), padding="same", activation = 'relu')(down3)
    block5 = Conv2D(256, (k_size, k_size), padding="same", activation = 'relu')(block5)
    
    up1 = UpSampling2D(size=(2,2))(block5)
    cat1 = concatenate([block4,up1])
    block6 = Conv2D(128, (k_size, k_size), padding="same", activation = 'relu')(cat1)
    block6 = Conv2D(128, (k_size, k_size), padding="same", activation = 'relu')(block6)
    
    up2 = UpSampling2D(size=(2,2))(block6)
    cat2 = concatenate([block3,up2])
    block7 = Conv2D(64, (k_size, k_size), padding="same", activation = 'relu')(cat2)
    block7 = Conv2D(64, (k_size, k_size), padding="same", activation = 'relu')(block7)
    
    up3 = UpSampling2D(size=(2,2))(block7)
    cat3 = concatenate([block2,up3])
    block8 = Conv2D(32, (k_size, k_size), padding="same", activation = 'relu')(cat3)
    block8 = Conv2D(32, (k_size, k_size), padding="same", activation = 'relu')(block8)
    block9 = Conv2D(32, (k_size, k_size), padding="same", activation = 'relu')(block8)
    output = Conv2D(1, (1, 1), padding="same")(block9)
    output = tf.keras.layers.ReLU(max_value=1.0)(output)
    
    model = tf.keras.Model(inputs=[inputs], outputs=[output])
    
    return model

Pytorch code -

# full assembly of the sub-parts to form the complete net
from unet_parts import *


class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()
        self.inc = inconv(n_channels, 32)
        self.down1 = down(32, 64)
        self.down2 = down(64, 128)
        self.down3 = down(128, 256)
        # remember if you are using bilinear interpolation or convtranspose2d for upsampling
        # when using bilinear interpolation, in channels for up1 is 256+128
        # when using convtranspose2d, in channels for up1 is 256, up2 is 128
        self.up1 = up(384, 128)
        self.up2 = up(192, 64)
        self.up3 = up(96, 32)

        self.outc4 = outconv(32, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x = self.up1(x4, x3)
        x = self.up2(x, x2)
        x = self.up3(x, x1)
        x = self.outc4(x)
        return x

where its parts are defined as follows -

# sub-parts of the U-Net model

import torch
import torch.nn as nn


class double_conv(nn.Module):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class inconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(x)
        return x


class down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool2d(2),
            double_conv(in_ch, out_ch)
        )

    def forward(self, x):
        x = self.mpconv(x)
        return x


class up(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=True):
        super(up, self).__init__()

        #  if your machine do not have enough memory to handle all those weights
        #  bilinear interpolation could be used to do the upsampling.
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2)

        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class outconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)

    def forward(self, x):
        x = self.conv(x)
        return x

I plan to use custom loss function for my problem. But for basic ssim loss, tensorflow code runs way faster than pytorch (2 min vs 15 min approximately). Earlier I was using ConvTranspose2d, but i noticed that tensorflow is using bilinear upsampling, i changed that in my pytorch version. Also i had to modify input channels in up part of UNet. I’ve also tried not using Batch Norm as tensorflow code isn’t using it, but it didn’t solve the problem. I’ve read how to optimize pytorch guide and am setting grads to zero as told in the guide. I’ve implemented the model exactly same as tensorflow, then why am i noticing so much time difference ? Obviously I’m wrong somewhere but i fail to notice. Please help me. Any help would be appreciated. I’ve tried this like for over a month but to no avail.

I’m running two models as follows -
Tensorflow -

learning_rate = 0.001

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    learning_rate,
    decay_steps=50,
    decay_rate=0.90,
    staircase=True)

opt = tf.keras.optimizers.Adam(lr_schedule)
batch_size = 1
eps = 400
l = Physics_Loss

def get_lr_metric(optimizer):
    def lr(y_true, y_pred):
        return optimizer._decayed_lr(tf.float32)
    return lr

lr_metric=get_lr_metric(opt)

model.compile(loss=l, optimizer=opt,metrics=['accuracy',lr_metric],run_eagerly=True)
model_history = model.fit(x=single_input,y=single_input,epochs=eps,batch_size=batch_size)

Pytorch -

model = UNet(n_channels=9, n_classes=1)
print("{} parameters in total".format(sum(x.numel() for x in model.parameters())))
learning_rate = 0.001

cuda=torch.device('cuda:0')
model=model.to(cuda)

# Create the Adam optimizer with the initial learning rate
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Create an ExponentialLR scheduler to decay the learning rate
lr_schedule = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.90)

# Set batch size and number of epochs
batch_size = 1
epochs = 50

l = Physics_loss
train_losses=[]

for epoch in range(epochs):
    running_loss=0.0
    inputs=input_frames
     # zero the parameter gradients
    for param in model.parameters():
        param.grad=None
    outputs=model(inputs) # forward pass
    loss=l(inputs,outputs)
    loss.backward() #backpropagate the loss
    optimizer.step() # update the weights
    lr_schedule.step() # update the learning rate
    running_loss+=loss
    train_loss=running_loss/1
    train_losses.append(train_loss)
    #print average training loss for the epoch
    print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch+1, train_loss))

Also, in pytorch i’m running for 50 epochs only as they are taking so long to run.

Are you running Windows or Linux? The default vanilla dataloader for Windows is incredibly slow because it only uses 1 worker.

Hi, I’m currently on windows, but I don’t think that is the problem because I’m not using dataloader in each epoch as in conventional training algorithm.

My program is something like an “untrained” neural net in the sense that once I load input data (experimental low res images) in variable “input_frames”, I don’t have to load any other data as I’m passing the same frames as input in each epoch. Once UNet generates some image, that image is used to generate low res images using physics based model say convolution with known experimental point spread function and uses the loss between “input_frames” and low res images from UNet is optimized.

Overall, my point is that in conventional training based algorithms, one has to load say low res images and high res images and train the model, but in this case there is no such need as eacch epoch uses the same low res images and new learned state of model to generate new image.

Thank you for your reply !