Off by just 1 of a weird number doing image segmentation in model parallel

Hi,

I am doing a small study on the importance of image resolution perseveration in semantic segmentation. I have a basic unet that I train from scratch using different image sizes. The orginal image is roughly 2500x2500. So very large. I have the same basic code where I resize to fit on the my HW infrastructure. Iterations of the code worked fine for 64x64 to 1024x1024. At 2056x2056 I get the following

INFO:    underlay of /usr/bin/nvidia-smi required more than 50 (343) bind mounts
Available: True, Count: 2, Name: Tesla V100-SXM2-16GB
Data Loaded Successfully!
Number of Training Samples: 320
Number of Testing Samples: 142
Epoch: 0
^M  0%|          | 0/320 [00:00<?, ?it/s]^M  0%|          | 0/320 [00:04<?, ?it/s]
Traceback (most recent call last):
  File "/mnt/mp_unet_noencoder_2056x1.py", line 405, in <module>
    main()
  File "/mnt/mp_unet_noencoder_2056x1.py", line 385, in main
    loss_val = train_function(train_loader, model, optimizer, loss_function, DEVICE)
  File "/mnt/mp_unet_noencoder_2056x1.py", line 310, in train_function
    preds = model(X)
  File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/mp_unet_noencoder_2056x1.py", line 234, in forward
    up4 = self.up_concat4(center.to(DEVICE_1), conv4.to(DEVICE_1)).to(DEVICE_1)  # 128*64*128
  File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/mp_unet_noencoder_2056x1.py", line 142, in forward
    outputs0 = torch.cat([outputs0, input[i]], 1)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 256 but got size 257 for tensor number 1 in the list.

What is weird is I am off by only 1 on code that worked for smaller image sizes. Why would I just get an OOM error from cuda ?

here is my code with model parallel implemented based on code I saw on a stackoverflow post.

import os
import random
import time
import numpy as np
import pandas as pd
from PIL import Image 
import torch
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.data import Dataset
from torchvision import transforms
from weights import init_weights
import segmentation_models_pytorch as smp
import torchvision.transforms.functional as TF
import torch.nn.functional as F

DEVICE_0 = "cuda:0"
DEVICE_1 = "cuda:1"
IMG_SIZE = 2056
SPLIT_SIZE = 1
BATCH_SIZE = 1
EPOCHS = 30 
NUM_WORKERS = 1
LEARNING_RATE = .001
PIN_MEMORY = True
DEVICE = 'cuda'
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = False

class TrainingDataset(Dataset):
    def __init__(self, image_dir, mask_dir, image_transform=None, mask_transform=None ):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_transform = image_transform
        self.mask_transform = mask_transform
        self.images = os.listdir(image_dir)
        
    def __len__(self):
        return len(self.images)
    
    def augment(self, image, mask):
        
        if self.image_transform is not None:
            image = self.image_transform(image)
        if self.mask_transform is not None:
            mask = self.mask_transform(mask)
            mask = mask.unsqueeze(0)
            
        # Random horizontal flipping
        if random.random() > 0.5:
            image = TF.hflip(image)
            mask = TF.hflip(mask)
            
        # Random rotation
        if random.random() > 0.5:
            rotate_angle = random.randint(-5, 5)
            hshift = round(random.uniform(0,0.1), 2)
            vshift = round(random.uniform(0,0.1), 2)
            shear_angle = random.randint(-5, 5)
            image = TF.affine(image,
                              angle =rotate_angle,
                              translate = (hshift,vshift),
                              scale = 1 ,
                              shear = shear_angle,
                              fill = 0
                             )
            
            mask = TF.affine(mask,
                              angle =rotate_angle,
                              translate = (hshift,vshift),
                              scale = 1,
                              shear = shear_angle,
                              fill = 0
                             )
           
        return image, mask
    
    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index]).replace("\\","/")
        #mask_path = os.path.join(self.mask_dir, self.images[index].replace(".jpg", "_mask.gif"))
        mask_path = os.path.join(self.mask_dir, self.images[index]).replace("\\","/")
        image = Image.open(img_path).convert('RGB')    
        mask = Image.open(mask_path)#.convert('L')
        x, y = self.augment(image, mask)
        #y = y.squeeze() 
        return x, y

class unetConv2(nn.Module):
    def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1):
        super(unetConv2, self).__init__()
        self.n = n
        self.ks = ks
        self.stride = stride
        self.padding = padding
        s = stride
        p = padding
        
        if is_batchnorm:
            for i in range(1, n + 1):
                conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p),
                                     nn.BatchNorm2d(out_size), nn.ReLU(inplace=True),)
                setattr(self, 'conv%d' % i, conv)
                in_size = out_size
        else:
            for i in range(1, n + 1):
                conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), nn.ReLU(inplace=True), )
                setattr(self, 'conv%d' % i, conv)
                in_size = out_size

        # initialise the blocks
        for m in self.children():
            init_weights(m, init_type='kaiming')


    def forward(self, inputs):
        x = inputs
        for i in range(1, self.n + 1):
            conv = getattr(self, 'conv%d' % i)
            x = conv(x)
        return x
        
class unetUp(nn.Module):
    def __init__(self, in_size, out_size, is_deconv, n_concat=2):
        super(unetUp, self).__init__()
        self.conv = unetConv2(out_size * 2, out_size, False)
        
        if is_deconv:
            self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=4, stride=2, padding=1)
        else:
            self.up = nn.UpsamplingBilinear2d(scale_factor=2)

        # initialise the blocks
        for m in self.children():
            if m.__class__.__name__.find('unetConv2') != -1: continue
            init_weights(m, init_type='kaiming')


    def forward(self, inputs0, *input):
        outputs0 = self.up(inputs0)
        for i in range(len(input)):
            outputs0 = torch.cat([outputs0, input[i]], 1)
        return self.conv(outputs0)

class unetUp_origin(nn.Module):
    def __init__(self, in_size, out_size, is_deconv, n_concat=2):
        super(unetUp_origin, self).__init__()
        # self.conv = unetConv2(out_size*2, out_size, False)
        if is_deconv:
            self.conv = unetConv2(in_size + (n_concat - 2) * out_size, out_size, False)
            self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=4, stride=2, padding=1)
        else:
            self.conv = unetConv2(in_size + (n_concat - 2) * out_size, out_size, False)
            self.up = nn.UpsamplingBilinear2d(scale_factor=2)

        # initialise the blocks
        for m in self.children():
            if m.__class__.__name__.find('unetConv2') != -1: continue
            init_weights(m, init_type='kaiming')

    def forward(self, inputs0, *input):
        # print(self.n_concat)
        # print(input)
        outputs0 = self.up(inputs0)
        for i in range(len(input)):
            outputs0 = torch.cat([outputs0, input[i]], 1)
        return self.conv(outputs0)


class ModelParallelUnet(nn.Module):
    def __init__(self, in_channels=3, in_classes=3, bilinear=True, feature_scale=4, 
                 is_deconv=True, is_batchnorm=True):
        super(ModelParallelUnet, self).__init__()        
        self.n_channels = in_channels
        self.n_classes = in_classes
        self.bilinear = bilinear
        self.feature_scale = feature_scale
        self.is_deconv = is_deconv
        self.is_batchnorm = is_batchnorm
        filters = [64, 128, 256, 512, 1024]

        # downsampling
        self.conv1 = unetConv2(self.n_channels, filters[0], self.is_batchnorm).to(DEVICE_0)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2).to(DEVICE_0)

        self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm).to(DEVICE_0)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2).to(DEVICE_0)

        self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm).to(DEVICE_0)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2).to(DEVICE_0)

        self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm).to(DEVICE_1)
        self.maxpool4 = nn.MaxPool2d(kernel_size=2).to(DEVICE_1)

        self.center = unetConv2(filters[3], filters[4], self.is_batchnorm).to(DEVICE_1)

        # upsampling
        self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv).to(DEVICE_1)
        self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv).to(DEVICE_1)
        self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv).to(DEVICE_1)
        self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv).to(DEVICE_0)
        self.outconv1 = nn.Conv2d(filters[0], in_classes, 3, padding=1).to(DEVICE_0)

        # initialise weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init_weights(m, init_type='kaiming')
            elif isinstance(m, nn.BatchNorm2d):
                init_weights(m, init_type='kaiming')

    def dotProduct(self,seg,cls):
        B, N, H, W = seg.size()
        seg = seg.view(B, N, H * W)
        final = torch.einsum("ijk,ij->ijk", [seg, cls])
        final = final.view(B, N, H, W)
        return final


    def forward(self, inputs):
        conv1 = self.conv1(inputs.to(DEVICE_0))  # 16*512*1024
        maxpool1 = self.maxpool1(conv1.to(DEVICE_0)).to(DEVICE_0)  # 16*256*512

        conv2 = self.conv2(maxpool1.to(DEVICE_0)).to(DEVICE_0)  # 32*256*512
        maxpool2 = self.maxpool2(conv2.to(DEVICE_0)).to(DEVICE_0)  # 32*128*256

        conv3 = self.conv3(maxpool2.to(DEVICE_0)).to(DEVICE_0)  # 64*128*256
        maxpool3 = self.maxpool3(conv3.to(DEVICE_0)).to(DEVICE_0)  # 64*64*128

        conv4 = self.conv4(maxpool3.to(DEVICE_1)).to(DEVICE_1)  # 128*64*128
        maxpool4 = self.maxpool4(conv4.to(DEVICE_1)).to(DEVICE_1)  # 128*32*64

        center = self.center(maxpool4.to(DEVICE_1)).to(DEVICE_1)        # 256*32*64

        up4 = self.up_concat4(center.to(DEVICE_1), conv4.to(DEVICE_1)).to(DEVICE_1)  # 128*64*128
        up3 = self.up_concat3(up4.to(DEVICE_1), conv3.to(DEVICE_1)).to(DEVICE_1)  # 64*128*256
        up2 = self.up_concat2(up3.to(DEVICE_1), conv2.to(DEVICE_1)).to(DEVICE_1)  # 32*256*512
        up1 = self.up_concat1(up2.to(DEVICE_0), conv1.to(DEVICE_0)).to(DEVICE_0)  # 16*512*1024

        d1 = self.outconv1(up1.to(DEVICE_0))  # 256
        return torch.sigmoid(d1.to(DEVICE_0)).to(DEVICE_0)


def get_loader(train_dir,train_maskdir,val_dir,val_maskdir,batch_size,train_transform_image,
    train_transform_mask,val_transform,num_workers=1,pin_memory=True,):
    
    train_ds = TrainingDataset(image_dir=train_dir,mask_dir=train_maskdir,
        image_transform=train_transform_image,mask_transform=train_transform_mask)

    train_loader = DataLoader(train_ds,batch_size=batch_size,num_workers=num_workers,
        pin_memory=pin_memory,shuffle=True,)

    test_ds = TrainingDataset(image_dir=val_dir,mask_dir=val_maskdir,
        image_transform=train_transform_image,mask_transform=train_transform_mask)

    test_loader = DataLoader(test_ds,batch_size=batch_size,num_workers=num_workers,
        pin_memory=pin_memory,shuffle=True,)

    return train_loader, test_loader
def tensor_to_numpy_preserve_scale(x):
    #print(f' input of preserve scale {x.size}')
    temp = transforms.ToTensor()(np.array(x,dtype='int64'))
    #temp = torch.squeeze(temp)
    #temp = torch.transpose(temp, 0, 1)
    #print(f'output of preserve scale {temp.shape}')
    return temp

def pad_image_with_aspect(x):
    w,h = x.size
    h_new = int(IMG_SIZE*h/w)
    pad_amount = int((IMG_SIZE-h_new)//2)
    square_image = transforms.Compose([
                                        transforms.Resize(size=(h_new,IMG_SIZE)),
                                        transforms.Pad((0,pad_amount),fill=0, padding_mode='constant'),
                                        transforms.Resize(size=(IMG_SIZE,IMG_SIZE))
    ])
    return square_image(x)

def pad_mask_with_aspect(x):
    c,h,w = x.shape
    h_new = int(IMG_SIZE*h/w)
    if (h_new % 2 == 0):
        pad_amount = int((abs(IMG_SIZE-h_new))//2)
        square_image = transforms.Compose([
                                        transforms.Resize(size=(h_new,IMG_SIZE)),
                                        transforms.Pad((0,pad_amount),fill=0, padding_mode='constant'),
                                        #transforms.Resize(size=(IMG_SIZE,IMG_SIZE))
        ])
    else:
        pad_amount = int((abs(IMG_SIZE-h_new))//2)
        square_image = transforms.Compose([
                                        transforms.Resize(size=(h_new,IMG_SIZE)),
                                        transforms.Pad((0,pad_amount,0,pad_amount+1),fill=0, padding_mode='constant'),
                                        #transforms.Resize(size=(IMG_SIZE,IMG_SIZE))
        ])
    
    temp = torch.squeeze(square_image(x))
    if BATCH_SIZE == 1:
        temp = temp[None,:,:]
    return temp


def train_function(data, model, optimizer, loss_fn, device):
    loss_values = []
    data = tqdm(data)
    for index, batch in enumerate(data): 
        X, y = batch
        y=y.squeeze(dim=0)
        y=y.to(dtype=torch.long)
        X, y = X.to(device), y.to(device)
        preds = model(X)
        #print(f'the prediction size is {preds.shape} and label size is {y.shape}')
        loss = loss_fn(preds, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    return loss.item()

def test(model, data_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(data_loader):
            # We set this just for the example to run quickly.
            if batch_idx * len(data) > BATCH_SIZE:
                break
            data, target = data.to(DEVICE_0), target.to(DEVICE_0)
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    return correct / total

def main(): 
    MODEL_PATH = '/mnt/models/'
    TRAIN_IMG_DIR = "/mnt/training_images/images/training"
    TRAIN_MASK_DIR = "/mnt/training_images/masks/training"
    VAL_IMG_DIR = "/mnt/training_images/images/testing"
    VAL_MASK_DIR = "/mnt/training_images/masks/testing"
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = False
    
    #Check if GPU is available ===================================
    avail = torch.cuda.is_available()
    devCnt = torch.cuda.device_count()
    devName = torch.cuda.get_device_name(0)
    print("Available: " + str(avail) + ", Count: " + str(devCnt) + ", Name: " + str(devName))
    epoch = 0 # epoch is initially assigned to 0. If LOAD_MODEL is true then

    image_transform = transforms.Compose([
        transforms.Lambda(lambda x : pad_image_with_aspect(x)),
        transforms.PILToTensor(),
        transforms.ConvertImageDtype(torch.float),
    ]) 
    mask_transform = transforms.Compose([
        transforms.Lambda(lambda x : tensor_to_numpy_preserve_scale(x)),
        transforms.Lambda(lambda x : pad_mask_with_aspect(x)), 
    ])
    
    train_loader, test_loader = get_loader(TRAIN_IMG_DIR,TRAIN_MASK_DIR,VAL_IMG_DIR,
        VAL_MASK_DIR,BATCH_SIZE,image_transform,mask_transform,image_transform,
        NUM_WORKERS,PIN_MEMORY,)
    
    # Check Tensor shapes ======================================================
    #batch = next(iter(train_loader))
    #images, labels = batch

    print('Data Loaded Successfully!')
    print(f'Number of Training Samples: {len(train_loader)}')
    print(f'Number of Testing Samples: {len(test_loader)}')

    model = ModelParallelUnet(in_channels=3, in_classes=3)
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    #0 Backgroun 1 Sample  2 Top Platen
    loss_function = smp.losses.DiceLoss(mode='multiclass')

    LOSS_VALS = [] # Defining a list to store loss values after every epoch
    EPOCH_LIST = []
    EPOCH_RUNTIME = []
    TEST_ACCURACY = [] 
    #Training the model for every epoch. 
    for e in range(epoch, EPOCHS):
        print(f'Epoch: {e}')
        epoch_start = time.time()
        loss_val = train_function(train_loader, model, optimizer, loss_function, DEVICE)
        acc = test(model, test_loader)
        print(f'loss value = {loss_val}')
        LOSS_VALS.append(loss_val) 
        EPOCH_LIST.append(e) 
        EPOCH_RUNTIME.append(time.time() - epoch_start)
        TEST_ACCURACY.append(acc)
        torch.save({
            'model_state_dict': model.state_dict(),
            'optim_state_dict': optimizer.state_dict(),
            'epoch': e,
            'loss_values': LOSS_VALS,
            'accuracy': TEST_ACCURACY,
            'epochs_run': EPOCH_LIST,
            'epoch_time': EPOCH_RUNTIME
        }, f'{MODEL_PATH}/MP_unet_NONE_backbone_{IMG_SIZE}x{BATCH_SIZE}_epoch_{e}.pth')
        print("Epoch completed and model successfully saved!")
    

if __name__ == '__main__':
    main()