Grayscale images in u-net


I was looking to the following example of U-net for multiclass segmentation Unet.
Maybe is a silly question but is it possible to use u-net with grayscale images?
In principle I would like to “re-use” the example in the link (if possible). Otherwise if somebody could address me to some examples it would be also greatly appreciated!


I am far from an expert but I am using UNetMini from this example. This is a post with some of my code where I got some help not to long ago. Might find it somewhat useful.

Hi Ryan,

Thanks for your reply! I was trying to look to your code and the example you linked. I’ve been looking to your post here (I quote the original)

So I tried to modify also the first layer in the Unet I’m using by writing:
self.dconv_down1 = double_conv(1, 64)

Before there was 3 being the three channels RGB I guess.

I’ve tried to run the training and I get this error now

invalid argument 0: Sizes of tensors must match except in dimension 1. Got 263 and 272 in dimension 2 at C:/w/1/s/tmp_conda_3.7_055457/conda/conda-bld/pytorch_1565416617654/work/aten/src\THC/generic/

I’ve checked all the images they are 1024x1024 png grayscale images but I’m not sure how to solve this issue.


By seeing the image sizes 263, a guess would be to check the padding values that you have set.

Also, could you share some more information about your log message?

Hi ! thanks for you reply!

Currently i’m getting this:

RuntimeError Traceback (most recent call last)
17 exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=75, gamma=0.1)
—> 19 model = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=100)

in train_model(model, optimizer, scheduler, num_epochs)
62 # track history if only in train
63 with torch.set_grad_enabled(phase == ‘train’):
—> 64 outputs = model(inputs)
65 loss = calc_loss(outputs, labels, metrics)

F:\aaa\Anaconda3\lib\site-packages\torch\nn\modules\ in call(self, *input, **kwargs)
545 result = self._slow_forward(*input, **kwargs)
546 else:
–> 547 result = self.forward(*input, **kwargs)
548 for hook in self._forward_hooks.values():
549 hook_result = hook(self, input, result)

~\Desktop\graph_rec_png\pytorch-unet\ in forward(self, x)
45 x = self.upsample(x)
—> 46 x =[x, conv3], dim=1)
48 x = self.dconv_up3(x)

RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 1. Got 23 and 32 in dimension 2 at C:/w/1/s/tmp_conda_3.7_055457/conda/conda-bld/pytorch_1565416617654/work/aten/src\THC/generic/

I’m using as the example to load my images 1024x1024 gay scale images in png format. The idea is to feed them into the Unet here in the example Unet

from import Dataset, DataLoader
from torchvision import transforms, datasets, models

class SimDataset(Dataset):
def init(self, image_paths, mask_paths , transform=None):
self.image_paths = image_paths
self.mask_paths = mask_paths

def transforms(self, image, mask):
    #img = img.resize((wsize, baseheight), PIL.Image.ANTIALIAS)

    image = image.resize((64, 64), PIL.Image.NEAREST)
    mask = mask.resize((64, 64), PIL.Image.NEAREST)        
    image = TF.to_tensor(image)
    mask = TF.to_tensor(mask)
    #mask = np.array(mask)
    #mask = np.moveaxis(mask,2,0)
    #mask = torch.from_numpy(np.array(mask))
    #mask = mask.unsqueeze(0)
    #mask = mask.unsqueeze(0)
    return [image, mask]

def __getitem__(self, index):
    image =[index])
    mask =[index])
    x, y = self.transforms(image, mask)
    return x, y

def __len__(self):
    return len(self.image_paths)

The training loop would be as the one in that example (I used this model to train successfully a multi-class segmentation problem)

Thanks for any suggestion !

Thanks for the error logs.

From line no 45-48, I believe, it must be the invalid padding setting that is consuming one pixel at each layer.
Could you check the padding values once, I guess your issue is something similar to one addressed in this answer.

Currently i’m using padding = 1 and before I did not have any issue. Not sure why now is giving me this problem.

I’ve checked the link but I could not solve the issue. Do you have any idea of how I could solve this problem or things which I could look at?