Hi,
I am doing a small study on the importance of image resolution perseveration in semantic segmentation. I have a basic unet that I train from scratch using different image sizes. The orginal image is roughly 2500x2500. So very large. I have the same basic code where I resize to fit on the my HW infrastructure. Iterations of the code worked fine for 64x64 to 1024x1024. At 2056x2056 I get the following
INFO: underlay of /usr/bin/nvidia-smi required more than 50 (343) bind mounts
Available: True, Count: 2, Name: Tesla V100-SXM2-16GB
Data Loaded Successfully!
Number of Training Samples: 320
Number of Testing Samples: 142
Epoch: 0
^M 0%| | 0/320 [00:00<?, ?it/s]^M 0%| | 0/320 [00:04<?, ?it/s]
Traceback (most recent call last):
File "/mnt/mp_unet_noencoder_2056x1.py", line 405, in <module>
main()
File "/mnt/mp_unet_noencoder_2056x1.py", line 385, in main
loss_val = train_function(train_loader, model, optimizer, loss_function, DEVICE)
File "/mnt/mp_unet_noencoder_2056x1.py", line 310, in train_function
preds = model(X)
File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/mnt/mp_unet_noencoder_2056x1.py", line 234, in forward
up4 = self.up_concat4(center.to(DEVICE_1), conv4.to(DEVICE_1)).to(DEVICE_1) # 128*64*128
File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/mnt/mp_unet_noencoder_2056x1.py", line 142, in forward
outputs0 = torch.cat([outputs0, input[i]], 1)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 256 but got size 257 for tensor number 1 in the list.
What is weird is I am off by only 1 on code that worked for smaller image sizes. Why would I just get an OOM error from cuda ?
here is my code with model parallel implemented based on code I saw on a stackoverflow post.
import os
import random
import time
import numpy as np
import pandas as pd
from PIL import Image
import torch
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.data import Dataset
from torchvision import transforms
from weights import init_weights
import segmentation_models_pytorch as smp
import torchvision.transforms.functional as TF
import torch.nn.functional as F
DEVICE_0 = "cuda:0"
DEVICE_1 = "cuda:1"
IMG_SIZE = 2056
SPLIT_SIZE = 1
BATCH_SIZE = 1
EPOCHS = 30
NUM_WORKERS = 1
LEARNING_RATE = .001
PIN_MEMORY = True
DEVICE = 'cuda'
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = False
class TrainingDataset(Dataset):
def __init__(self, image_dir, mask_dir, image_transform=None, mask_transform=None ):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.image_transform = image_transform
self.mask_transform = mask_transform
self.images = os.listdir(image_dir)
def __len__(self):
return len(self.images)
def augment(self, image, mask):
if self.image_transform is not None:
image = self.image_transform(image)
if self.mask_transform is not None:
mask = self.mask_transform(mask)
mask = mask.unsqueeze(0)
# Random horizontal flipping
if random.random() > 0.5:
image = TF.hflip(image)
mask = TF.hflip(mask)
# Random rotation
if random.random() > 0.5:
rotate_angle = random.randint(-5, 5)
hshift = round(random.uniform(0,0.1), 2)
vshift = round(random.uniform(0,0.1), 2)
shear_angle = random.randint(-5, 5)
image = TF.affine(image,
angle =rotate_angle,
translate = (hshift,vshift),
scale = 1 ,
shear = shear_angle,
fill = 0
)
mask = TF.affine(mask,
angle =rotate_angle,
translate = (hshift,vshift),
scale = 1,
shear = shear_angle,
fill = 0
)
return image, mask
def __getitem__(self, index):
img_path = os.path.join(self.image_dir, self.images[index]).replace("\\","/")
#mask_path = os.path.join(self.mask_dir, self.images[index].replace(".jpg", "_mask.gif"))
mask_path = os.path.join(self.mask_dir, self.images[index]).replace("\\","/")
image = Image.open(img_path).convert('RGB')
mask = Image.open(mask_path)#.convert('L')
x, y = self.augment(image, mask)
#y = y.squeeze()
return x, y
class unetConv2(nn.Module):
def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1):
super(unetConv2, self).__init__()
self.n = n
self.ks = ks
self.stride = stride
self.padding = padding
s = stride
p = padding
if is_batchnorm:
for i in range(1, n + 1):
conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p),
nn.BatchNorm2d(out_size), nn.ReLU(inplace=True),)
setattr(self, 'conv%d' % i, conv)
in_size = out_size
else:
for i in range(1, n + 1):
conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), nn.ReLU(inplace=True), )
setattr(self, 'conv%d' % i, conv)
in_size = out_size
# initialise the blocks
for m in self.children():
init_weights(m, init_type='kaiming')
def forward(self, inputs):
x = inputs
for i in range(1, self.n + 1):
conv = getattr(self, 'conv%d' % i)
x = conv(x)
return x
class unetUp(nn.Module):
def __init__(self, in_size, out_size, is_deconv, n_concat=2):
super(unetUp, self).__init__()
self.conv = unetConv2(out_size * 2, out_size, False)
if is_deconv:
self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=4, stride=2, padding=1)
else:
self.up = nn.UpsamplingBilinear2d(scale_factor=2)
# initialise the blocks
for m in self.children():
if m.__class__.__name__.find('unetConv2') != -1: continue
init_weights(m, init_type='kaiming')
def forward(self, inputs0, *input):
outputs0 = self.up(inputs0)
for i in range(len(input)):
outputs0 = torch.cat([outputs0, input[i]], 1)
return self.conv(outputs0)
class unetUp_origin(nn.Module):
def __init__(self, in_size, out_size, is_deconv, n_concat=2):
super(unetUp_origin, self).__init__()
# self.conv = unetConv2(out_size*2, out_size, False)
if is_deconv:
self.conv = unetConv2(in_size + (n_concat - 2) * out_size, out_size, False)
self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=4, stride=2, padding=1)
else:
self.conv = unetConv2(in_size + (n_concat - 2) * out_size, out_size, False)
self.up = nn.UpsamplingBilinear2d(scale_factor=2)
# initialise the blocks
for m in self.children():
if m.__class__.__name__.find('unetConv2') != -1: continue
init_weights(m, init_type='kaiming')
def forward(self, inputs0, *input):
# print(self.n_concat)
# print(input)
outputs0 = self.up(inputs0)
for i in range(len(input)):
outputs0 = torch.cat([outputs0, input[i]], 1)
return self.conv(outputs0)
class ModelParallelUnet(nn.Module):
def __init__(self, in_channels=3, in_classes=3, bilinear=True, feature_scale=4,
is_deconv=True, is_batchnorm=True):
super(ModelParallelUnet, self).__init__()
self.n_channels = in_channels
self.n_classes = in_classes
self.bilinear = bilinear
self.feature_scale = feature_scale
self.is_deconv = is_deconv
self.is_batchnorm = is_batchnorm
filters = [64, 128, 256, 512, 1024]
# downsampling
self.conv1 = unetConv2(self.n_channels, filters[0], self.is_batchnorm).to(DEVICE_0)
self.maxpool1 = nn.MaxPool2d(kernel_size=2).to(DEVICE_0)
self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm).to(DEVICE_0)
self.maxpool2 = nn.MaxPool2d(kernel_size=2).to(DEVICE_0)
self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm).to(DEVICE_0)
self.maxpool3 = nn.MaxPool2d(kernel_size=2).to(DEVICE_0)
self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm).to(DEVICE_1)
self.maxpool4 = nn.MaxPool2d(kernel_size=2).to(DEVICE_1)
self.center = unetConv2(filters[3], filters[4], self.is_batchnorm).to(DEVICE_1)
# upsampling
self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv).to(DEVICE_1)
self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv).to(DEVICE_1)
self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv).to(DEVICE_1)
self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv).to(DEVICE_0)
self.outconv1 = nn.Conv2d(filters[0], in_classes, 3, padding=1).to(DEVICE_0)
# initialise weights
for m in self.modules():
if isinstance(m, nn.Conv2d):
init_weights(m, init_type='kaiming')
elif isinstance(m, nn.BatchNorm2d):
init_weights(m, init_type='kaiming')
def dotProduct(self,seg,cls):
B, N, H, W = seg.size()
seg = seg.view(B, N, H * W)
final = torch.einsum("ijk,ij->ijk", [seg, cls])
final = final.view(B, N, H, W)
return final
def forward(self, inputs):
conv1 = self.conv1(inputs.to(DEVICE_0)) # 16*512*1024
maxpool1 = self.maxpool1(conv1.to(DEVICE_0)).to(DEVICE_0) # 16*256*512
conv2 = self.conv2(maxpool1.to(DEVICE_0)).to(DEVICE_0) # 32*256*512
maxpool2 = self.maxpool2(conv2.to(DEVICE_0)).to(DEVICE_0) # 32*128*256
conv3 = self.conv3(maxpool2.to(DEVICE_0)).to(DEVICE_0) # 64*128*256
maxpool3 = self.maxpool3(conv3.to(DEVICE_0)).to(DEVICE_0) # 64*64*128
conv4 = self.conv4(maxpool3.to(DEVICE_1)).to(DEVICE_1) # 128*64*128
maxpool4 = self.maxpool4(conv4.to(DEVICE_1)).to(DEVICE_1) # 128*32*64
center = self.center(maxpool4.to(DEVICE_1)).to(DEVICE_1) # 256*32*64
up4 = self.up_concat4(center.to(DEVICE_1), conv4.to(DEVICE_1)).to(DEVICE_1) # 128*64*128
up3 = self.up_concat3(up4.to(DEVICE_1), conv3.to(DEVICE_1)).to(DEVICE_1) # 64*128*256
up2 = self.up_concat2(up3.to(DEVICE_1), conv2.to(DEVICE_1)).to(DEVICE_1) # 32*256*512
up1 = self.up_concat1(up2.to(DEVICE_0), conv1.to(DEVICE_0)).to(DEVICE_0) # 16*512*1024
d1 = self.outconv1(up1.to(DEVICE_0)) # 256
return torch.sigmoid(d1.to(DEVICE_0)).to(DEVICE_0)
def get_loader(train_dir,train_maskdir,val_dir,val_maskdir,batch_size,train_transform_image,
train_transform_mask,val_transform,num_workers=1,pin_memory=True,):
train_ds = TrainingDataset(image_dir=train_dir,mask_dir=train_maskdir,
image_transform=train_transform_image,mask_transform=train_transform_mask)
train_loader = DataLoader(train_ds,batch_size=batch_size,num_workers=num_workers,
pin_memory=pin_memory,shuffle=True,)
test_ds = TrainingDataset(image_dir=val_dir,mask_dir=val_maskdir,
image_transform=train_transform_image,mask_transform=train_transform_mask)
test_loader = DataLoader(test_ds,batch_size=batch_size,num_workers=num_workers,
pin_memory=pin_memory,shuffle=True,)
return train_loader, test_loader
def tensor_to_numpy_preserve_scale(x):
#print(f' input of preserve scale {x.size}')
temp = transforms.ToTensor()(np.array(x,dtype='int64'))
#temp = torch.squeeze(temp)
#temp = torch.transpose(temp, 0, 1)
#print(f'output of preserve scale {temp.shape}')
return temp
def pad_image_with_aspect(x):
w,h = x.size
h_new = int(IMG_SIZE*h/w)
pad_amount = int((IMG_SIZE-h_new)//2)
square_image = transforms.Compose([
transforms.Resize(size=(h_new,IMG_SIZE)),
transforms.Pad((0,pad_amount),fill=0, padding_mode='constant'),
transforms.Resize(size=(IMG_SIZE,IMG_SIZE))
])
return square_image(x)
def pad_mask_with_aspect(x):
c,h,w = x.shape
h_new = int(IMG_SIZE*h/w)
if (h_new % 2 == 0):
pad_amount = int((abs(IMG_SIZE-h_new))//2)
square_image = transforms.Compose([
transforms.Resize(size=(h_new,IMG_SIZE)),
transforms.Pad((0,pad_amount),fill=0, padding_mode='constant'),
#transforms.Resize(size=(IMG_SIZE,IMG_SIZE))
])
else:
pad_amount = int((abs(IMG_SIZE-h_new))//2)
square_image = transforms.Compose([
transforms.Resize(size=(h_new,IMG_SIZE)),
transforms.Pad((0,pad_amount,0,pad_amount+1),fill=0, padding_mode='constant'),
#transforms.Resize(size=(IMG_SIZE,IMG_SIZE))
])
temp = torch.squeeze(square_image(x))
if BATCH_SIZE == 1:
temp = temp[None,:,:]
return temp
def train_function(data, model, optimizer, loss_fn, device):
loss_values = []
data = tqdm(data)
for index, batch in enumerate(data):
X, y = batch
y=y.squeeze(dim=0)
y=y.to(dtype=torch.long)
X, y = X.to(device), y.to(device)
preds = model(X)
#print(f'the prediction size is {preds.shape} and label size is {y.shape}')
loss = loss_fn(preds, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()
def test(model, data_loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(data_loader):
# We set this just for the example to run quickly.
if batch_idx * len(data) > BATCH_SIZE:
break
data, target = data.to(DEVICE_0), target.to(DEVICE_0)
outputs = model(data)
_, predicted = torch.max(outputs.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
return correct / total
def main():
MODEL_PATH = '/mnt/models/'
TRAIN_IMG_DIR = "/mnt/training_images/images/training"
TRAIN_MASK_DIR = "/mnt/training_images/masks/training"
VAL_IMG_DIR = "/mnt/training_images/images/testing"
VAL_MASK_DIR = "/mnt/training_images/masks/testing"
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = False
#Check if GPU is available ===================================
avail = torch.cuda.is_available()
devCnt = torch.cuda.device_count()
devName = torch.cuda.get_device_name(0)
print("Available: " + str(avail) + ", Count: " + str(devCnt) + ", Name: " + str(devName))
epoch = 0 # epoch is initially assigned to 0. If LOAD_MODEL is true then
image_transform = transforms.Compose([
transforms.Lambda(lambda x : pad_image_with_aspect(x)),
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
])
mask_transform = transforms.Compose([
transforms.Lambda(lambda x : tensor_to_numpy_preserve_scale(x)),
transforms.Lambda(lambda x : pad_mask_with_aspect(x)),
])
train_loader, test_loader = get_loader(TRAIN_IMG_DIR,TRAIN_MASK_DIR,VAL_IMG_DIR,
VAL_MASK_DIR,BATCH_SIZE,image_transform,mask_transform,image_transform,
NUM_WORKERS,PIN_MEMORY,)
# Check Tensor shapes ======================================================
#batch = next(iter(train_loader))
#images, labels = batch
print('Data Loaded Successfully!')
print(f'Number of Training Samples: {len(train_loader)}')
print(f'Number of Testing Samples: {len(test_loader)}')
model = ModelParallelUnet(in_channels=3, in_classes=3)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
#0 Backgroun 1 Sample 2 Top Platen
loss_function = smp.losses.DiceLoss(mode='multiclass')
LOSS_VALS = [] # Defining a list to store loss values after every epoch
EPOCH_LIST = []
EPOCH_RUNTIME = []
TEST_ACCURACY = []
#Training the model for every epoch.
for e in range(epoch, EPOCHS):
print(f'Epoch: {e}')
epoch_start = time.time()
loss_val = train_function(train_loader, model, optimizer, loss_function, DEVICE)
acc = test(model, test_loader)
print(f'loss value = {loss_val}')
LOSS_VALS.append(loss_val)
EPOCH_LIST.append(e)
EPOCH_RUNTIME.append(time.time() - epoch_start)
TEST_ACCURACY.append(acc)
torch.save({
'model_state_dict': model.state_dict(),
'optim_state_dict': optimizer.state_dict(),
'epoch': e,
'loss_values': LOSS_VALS,
'accuracy': TEST_ACCURACY,
'epochs_run': EPOCH_LIST,
'epoch_time': EPOCH_RUNTIME
}, f'{MODEL_PATH}/MP_unet_NONE_backbone_{IMG_SIZE}x{BATCH_SIZE}_epoch_{e}.pth')
print("Epoch completed and model successfully saved!")
if __name__ == '__main__':
main()