Slow training time for Unet model-go to 15 Jun 2019 post


model = Unet(in_channels =1, out_channels =1)
criterion = nn.BCEWithLogitsLoss()                                                           # Loss and optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate,momentum=0.9)
for epoch in tqdm_notebook(range(num_epochs)):
         
#         print(epoch)
        total_train = 0
        correct_train = 0
        epoch_loss = 0
        DICE = 0
      
        for k,train_img in enumerate(trainLoader):
            
            inputs,labels = train_img
#             print(inputs.shape)
            inputs = inputs.unsqueeze_(0)
            inputs = inputs.reshape(batch_size_train,1,img_width,img_height)
            labels = labels.unsqueeze_(0)
            labels = labels.reshape(batch_size_train,1,img_width,img_height)

            # Forward pass 
            optimizer.zero_grad()                                                 # zeroes the gradient buffers of all parameters
            outputs = model(inputs)                                               # outputs.shape =(batch_size, n_classes, img_cols, img_rows)

            outputs = outputs.permute(0, 2, 3, 1)                                 # outputs.shape =(batch_size, img_cols, img_rows, n_classes)

            m = outputs.shape[0]                                                  # m = batch size
            width_out  = outputs.shape[2] 
            height_out = outputs.shape[1] 
            outputs_new = outputs.resize(m*width_out*height_out)                   # Resizing the outputs and label to calculate pixel wise softmax loss
            labels_new = labels.resize(m*width_out*height_out)         
            loss = criterion(outputs_new,labels_new)       
            epoch_loss += loss.item()
         
            loss.backward()                                                      # Backward and optimizeuns
            optimizer.step()                                                     # update gradients
            print(k,'done')    
            
            DICE += dice_coeff(outputs[0,:,:,0], labels[0,0,:,:]).item()
            print(DICE)
            
            _, predicted = torch.max(outputs_new.data, 1)
            print(predicted.shape)
            total_train += labels_new.nelement()                                    
            correct_train += predicted.eq(labels_new.data).sum().item()
            train_accuracy = 100 * correct_train / total_train```
type or paste code here
c**an someone explain what does this line do  _, predicted = torch.max(outputs_new.data, 1)**

I am getting an error as IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

my predicted binary segmented output is 540*800

_, predicted = torch.max(outputs_new.data, 1) stores the indices of the max values in dim1.
Alternatively you could also write predicted = torch.argmax(outputs_new, 1).
The error you get points to a missing dimension, i.e. it seems outputs_new has only a single dimension.

Thanks for the reply. I have one more doubt

I am using google collab and while running its GPU,the GPU usage suddenly increases while compiling this line
correct_train += predicted.eq(labels_new.data).sum().item()
the shape of predicted size is torch.Size([291600])
cannot figure out why it does, is there any other method to write the same line of code which uses less memory

Maybe I am making some other mistake.

Could you post the shape of labels_new?
You could see an increased memory usage, if a broadcasting happens in predicted.eq(labels?new).
This would be the case, if predicted has a shape of [291600], while labels_new has [291600, 1].
The results of this operations would thus be [291600, 291600].

you were right, broadcasting was the reason for sudden increase in GPU memory usage.
Thanks.

The above training code is for U-net, which I wrote but the problem is it takes around 15 seconds for a forward +backward prop when a single image of size 540*540 is passed.
Is this the normal time taken for Unet in pytorch. or do I need to make some changes in code to increase the speed.
I feel it is taking more time compared to Unet implemented on Keras.

below is the code I have written.

class Unet(nn.Module):
    
    def contracting_block(self, in_channels, out_channels, kernel_size=3):    
        block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(out_channels),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=out_channels, out_channels=out_channels),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(out_channels),
                    )
        return block
    
    def expansive_block(self, in_channels, mid_channels, out_channels, kernel_size=3):
            block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channels),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(mid_channels),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channels, out_channels=mid_channels),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(mid_channels),
                    torch.nn.ConvTranspose2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=0, output_padding=0,dilation=1)
                    )
            return  block
    
    def final_block(self, in_channels, mid_channels, out_channels, kernel_size=3):
            block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channels ),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(mid_channels),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channels, out_channels=mid_channels),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(mid_channels),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channels, out_channels=out_channels,padding=(86,86)),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(out_channels)
                    )
            return  block
    
    def __init__(self, in_channels, out_channels):
        super(Unet, self).__init__()
        
        self.layer1 = nn.Conv2d(1,1 , kernel_size=1, stride=(1,1))
        
        #Encode
        self.conv_encode1 = self.contracting_block(in_channels =1 , out_channels =64)
        self.conv_maxpool1 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv_encode2 = self.contracting_block(in_channels =64, out_channels =128)
        self.conv_maxpool2 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv_encode3 = self.contracting_block(in_channels =128, out_channels =256)
        self.conv_maxpool3 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv_encode4 = self.contracting_block(in_channels =256, out_channels =512)
        self.conv_maxpool4 = torch.nn.MaxPool2d(kernel_size=2)
        
        # Bottleneck
        self.bottleneck = torch.nn.Sequential(
                            torch.nn.Conv2d(kernel_size=3, in_channels=512, out_channels=1024),
                            torch.nn.ReLU(),
                            torch.nn.BatchNorm2d(1024),
                            torch.nn.Conv2d(kernel_size=3, in_channels=1024, out_channels=1024),
                            torch.nn.ReLU(),
                            torch.nn.BatchNorm2d(1024),
                            torch.nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2, padding=0, output_padding=0)
                            )
        # Decode
        self.conv_decode4 = self.expansive_block(1024, 512, 256)
        self.conv_decode3 = self.expansive_block(512, 256, 128)
        self.conv_decode2 = self.expansive_block(256, 128, 64)
        self.final_layer =  self.final_block(128, 64, 1)
        
        
        self.pad1 = nn.ConstantPad2d(padding =(1,0,1,0),value=0)
        
    def crop_and_concat(self, upsampled, bypass, crop=False):
        if crop:
            c = (bypass.size()[2] - upsampled.size()[2]) // 2
            bypass = F.pad(bypass, (-c, -c, -c, -c))
        return torch.cat((upsampled, bypass),1)
    
    def forward(self, x):
        
        # Encode

        pad_x = self.layer1(x)        
        encode_block1 = self.conv_encode1(pad_x)
        encode_pool1 = self.conv_maxpool1(encode_block1)        
        encode_block2 = self.conv_encode2(encode_pool1)
        encode_pool2 = self.conv_maxpool2(encode_block2)
        encode_block3 = self.conv_encode3(encode_pool2)
        encode_pool3 = self.conv_maxpool3(encode_block3)        
        encode_block4 = self.conv_encode4(encode_pool3)
        encode_pool4 = self.conv_maxpool4(encode_block4)
        # Bottleneck
        bottleneck1 = self.bottleneck(encode_pool4)
        
        # Decode
        decode_block4 = self.crop_and_concat(bottleneck1, encode_block4, crop=True)        
        cat_layer3 = self.conv_decode4(decode_block4)
        cat_layer3 = self.pad1(cat_layer3)
        
        decode_block3 = self.crop_and_concat(cat_layer3, encode_block3, crop=True)
        cat_layer2 = self.conv_decode3(decode_block3)
        cat_layer2 = self.pad1(cat_layer2)
        decode_block2 = self.crop_and_concat(cat_layer2, encode_block2, crop=True)        
        cat_layer1 = self.conv_decode2(decode_block2)
        
        pad1_cat_layer1 = self.pad1(cat_layer1)
             
        decode_block1 = self.crop_and_concat(pad1_cat_layer1, encode_block1, crop=True)
        
        final_layer = self.final_layer(decode_block1)

        return  final_layer

t1 = time.time()
batch_size_train = 1
batch_sizetest = 1
num_epochs = 50
learning_rate = 0.1
img_width = 540
img_height = 540

trainLoader  = DataLoader(traindata, batch_size = batch_size_train , shuffle=True , num_workers=0)
model = Unet(in_channels =1, out_channels =1)
criterion = nn.BCEWithLogitsLoss()                                                           # Loss and optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate,momentum=0.99)

below is short snippet of result

forward time taken:  5.263527870178223
loss:  0.6772812604904175
Backward prop time taken:  11.189908027648926
1  done
DICE:  0.6863318905234337
Train_accuracy:  61.27932098765432

The timing seems to be strange.
I get approx. 31ms and 55ms for the forward and backward pass on a TITAN V, respectively:

device = 'cuda'
model = Unet(in_channels=1, out_channels=1)
model.to(device)
x = torch.randn(1, 1, 540, 540, device=device)

torch.cuda.synchronize()
t0 = time.time()
output = model(x)
torch.cuda.synchronize()
t1 = time.time()
print('fwd {}s'.format(t1 - t0))

loss = output.mean()

torch.cuda.synchronize()
t0 = time.time()
loss.backward()
torch.cuda.synchronize()
t1 = time.time()
print('bwd {}s'.format(t1 - t0))

(1.25s and 1.77s for CPU)

Could your bottleneck be the data loading?
E.g. if you don’t use multiple workers and have a lot of data preprocessing, your GPU might just be starving.

I am using google Colab GPU which uses Tesla K80
I tried changing num_worker to 2,4,8,16 but it didn’t speed up the training

But I tried your way
I tried running my model without any data loading like your above code just randomly initializing a 540*540 image and passing it to the model in separate collab file to check if there is any problem in data loading or my model itself

But still I get forward time as

contracting block time:  2.4187822341918945
bottleneck block time:  0.31655025482177734
expanding block time:  2.5195484161376953
fwd 5.256795644760132s
bwd 10.607162237167358s

Below is the code I have written for your reference

Unet model:

class Unet(nn.Module):
    
    def contracting_block(self, in_channels, out_channels, kernel_size=3):    
        block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(out_channels),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=out_channels, out_channels=out_channels),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(out_channels),
                    )
        return block
    
    def expansive_block(self, in_channels, mid_channels, out_channels, kernel_size=3):
            block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channels),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(mid_channels),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channels, out_channels=mid_channels),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(mid_channels),
                    torch.nn.ConvTranspose2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=0, output_padding=0,dilation=1)
                    )
            return  block
    
    def final_block(self, in_channels, mid_channels, out_channels, kernel_size=3):
            block = torch.nn.Sequential(
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channels ),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(mid_channels),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channels, out_channels=mid_channels),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(mid_channels),
                    torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channels, out_channels=out_channels,padding=(86,86)),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm2d(out_channels)
                    )
            return  block
    
    def __init__(self, in_channels, out_channels):
        super(Unet, self).__init__()
        
        self.layer1 = nn.Conv2d(1,1 , kernel_size=1, stride=(1,1))
        
        #Encode
        self.conv_encode1 = self.contracting_block(in_channels =1 , out_channels =64)
        self.conv_maxpool1 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv_encode2 = self.contracting_block(in_channels =64, out_channels =128)
        self.conv_maxpool2 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv_encode3 = self.contracting_block(in_channels =128, out_channels =256)
        self.conv_maxpool3 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv_encode4 = self.contracting_block(in_channels =256, out_channels =512)
        self.conv_maxpool4 = torch.nn.MaxPool2d(kernel_size=2)
        
        # Bottleneck
        self.bottleneck = torch.nn.Sequential(
                            torch.nn.Conv2d(kernel_size=3, in_channels=512, out_channels=1024),
                            torch.nn.ReLU(),
                            torch.nn.BatchNorm2d(1024),
                            torch.nn.Conv2d(kernel_size=3, in_channels=1024, out_channels=1024),
                            torch.nn.ReLU(),
                            torch.nn.BatchNorm2d(1024),
                            torch.nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2, padding=0, output_padding=0)
                            )
        # Decode
        self.conv_decode4 = self.expansive_block(1024, 512, 256)
        self.conv_decode3 = self.expansive_block(512, 256, 128)
        self.conv_decode2 = self.expansive_block(256, 128, 64)
        self.final_layer =  self.final_block(128, 64, 1)
        
        
        self.pad1 = nn.ConstantPad2d(padding =(1,0,1,0),value=0)
        
    def crop_and_concat(self, upsampled, bypass, crop=False):
        if crop:
            c = (bypass.size()[2] - upsampled.size()[2]) // 2
            bypass = F.pad(bypass, (-c, -c, -c, -c))
        return torch.cat((upsampled, bypass),1)
    
    def forward(self, x):
        
        t1 = time.time()
        
        # Encode
        pad_x = self.layer1(x)        
        encode_block1 = self.conv_encode1(pad_x)
        encode_pool1 = self.conv_maxpool1(encode_block1)
        
        encode_block2 = self.conv_encode2(encode_pool1)
        encode_pool2 = self.conv_maxpool2(encode_block2)

        encode_block3 = self.conv_encode3(encode_pool2)
        encode_pool3 = self.conv_maxpool3(encode_block3)
        
        encode_block4 = self.conv_encode4(encode_pool3)
        encode_pool4 = self.conv_maxpool4(encode_block4)

        t2 = time.time()
        print("contracting block time: ", t2-t1)
        
        t3 = time.time()
    
        # Bottleneck
        bottleneck1 = self.bottleneck(encode_pool4)
        t4 = time.time()
        print("bottleneck block time: ", t4-t3)

        t5 = time.time()

        # Decode
        decode_block4 = self.crop_and_concat(bottleneck1, encode_block4, crop=True)
        
        cat_layer3 = self.conv_decode4(decode_block4)
        cat_layer3 = self.pad1(cat_layer3)
        
        decode_block3 = self.crop_and_concat(cat_layer3, encode_block3, crop=True)

        cat_layer2 = self.conv_decode3(decode_block3)
        cat_layer2 = self.pad1(cat_layer2)

        decode_block2 = self.crop_and_concat(cat_layer2, encode_block2, crop=True)
        
        cat_layer1 = self.conv_decode2(decode_block2)
        pad1_cat_layer1 = self.pad1(cat_layer1)        
        
        decode_block1 = self.crop_and_concat(pad1_cat_layer1, encode_block1, crop=True)
        
        final_layer = self.final_layer(decode_block1)
        
        t5 = time.time()
        print("expanding block time: ", t5-t4)
        
        return  final_layer
batch_size_train = 1
batch_sizetest = 1
num_epochs = 50
learning_rate = 0.1
img_width = 540
img_height = 540

Training code: Same as yours

x = torch.randn(1, 1, 540, 540)


# torch.cuda.synchronize()
t0 = time.time()
output = model(x)
# torch.cuda.synchronize()
t1 = time.time()
print('fwd {}s'.format(t1 - t0))

loss = output.mean()

# torch.cuda.synchronize()
t0 = time.time()
loss.backward()
# torch.cuda.synchronize()
t1 = time.time()
print('bwd {}s'.format(t1 - t0))

Please help me I just cant figure out why, or give me any idea on how to approach to solve this problem