UNet implementation

I’ve been trying to implement the network described in U-Net: Convolutional Networks for Biomedical Image Segmentation using pytorch.

I’m still in the process of learning, so I’m not sure my implementation is right. Right now it seems the loss becomes nan quickly, while the network output “pixels” become 0 or 1 seemingly randomly. I’m not sure it is because of my implementation or because of my lack of understanding of the loss (I pass the last layer through a LogSoftmax, and then use NLLLoss2d).

Anyway, here is the tentative implementation, feel free to comment :wink:

class UNetConvBlock(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
        super(UNetConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_size, out_size, kernel_size)
        self.conv2 = nn.Conv2d(out_size, out_size, kernel_size)
        self.activation = activation

    def forward(self, x):
        out = self.activation(self.conv(x))
        out = self.activation(self.conv2(out))

        return out


class UNetUpBlock(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu, space_dropout=False):
        super(UNetUpBlock, self).__init__()
        self.up = nn.ConvTranspose2d(in_size, out_size, 2, stride=2)
        self.conv = nn.Conv2d(in_size, out_size, kernel_size)
        self.conv2 = nn.Conv2d(out_size, out_size, kernel_size)
        self.activation = activation

    def center_crop(self, layer, target_size):
        batch_size, n_channels, layer_width, layer_height = layer.size()
        xy1 = (layer_width - target_size) // 2
        return layer[:, :, xy1:(xy1 + target_size), xy1:(xy1 + target_size)]

    def forward(self, x, bridge):
        up = self.up(x)
        crop1 = self.center_crop(bridge, up.size()[2])
        out = torch.cat([up, crop1], 1)
        out = self.activation(self.conv(out))
        out = self.activation(self.conv2(out))

        return out


class UNet(nn.Module):
    def __init__(self, imsize):
        super(UNet, self).__init__()
        self.imsize = imsize

        self.activation = F.relu
        
        self.pool1 = nn.MaxPool2d(2)
        self.pool2 = nn.MaxPool2d(2)
        self.pool3 = nn.MaxPool2d(2)
        self.pool4 = nn.MaxPool2d(2)

        self.conv_block1_64 = UNetConvBlock(1, 64)
        self.conv_block64_128 = UNetConvBlock(64, 128)
        self.conv_block128_256 = UNetConvBlock(128, 256)
        self.conv_block256_512 = UNetConvBlock(256, 512)
        self.conv_block512_1024 = UNetConvBlock(512, 1024)

        self.up_block1024_512 = UNetUpBlock(1024, 512)
        self.up_block512_256 = UNetUpBlock(512, 256)
        self.up_block256_128 = UNetUpBlock(256, 128)
        self.up_block128_64 = UNetUpBlock(128, 64)

        self.last = nn.Conv2d(64, 2, 1)


    def forward(self, x):
        block1 = self.conv_block1_64(x)
        pool1 = self.pool1(block1)

        block2 = self.conv_block64_128(pool1)
        pool2 = self.pool2(block2)

        block3 = self.conv_block128_256(pool2)
        pool3 = self.pool3(block3)

        block4 = self.conv_block256_512(pool3)
        pool4 = self.pool4(block4)

        block5 = self.conv_block512_1024(pool4)

        up1 = self.up_block1024_512(block5, block4)

        up2 = self.up_block512_256(up1, block3)

        up3 = self.up_block256_128(up2, block2)

        up4 = self.up_block128_64(up3, block1)

        return F.log_softmax(self.last(up4))
4 Likes

I think you have a typo in the network definition (F.losoftmax). Apart from that it looks ok I think. I haven’t compared that to the original architecture, but it looks good to me. Maybe you have too large lr? Did you investigate why the loss becomes NaN?

Indeed, there’s a typo.

A simpler fully convolutional network will also eventually “converge” to nan, albeit more slowly (a few epochs).

I was using a similar setup with theano with decent results, so I think my data is ok, but I’m not sure how to troubleshoot.

I was using the Adam optimizer, but I’ll try with SGD and a small learning rate and go from there if the results are promising.

That’s weird, are you sure the networks are the same and that data is preprocessed correctly? It’s important to find the source of NaNs, that should not happen.

It appears there was a size mismatch I overlooked between the network output and the target. (8x2x68x68 vs 8x70x70)

I think there might be a bug there as it gives an error when the tensors are on cpu, but is silent and outputs something when they are on gpu. Do you want me to open an issue ?

That might be where the nans were coming from. I’ll let it train overnight and see.

To reproduce the error:

import torch
import torch.nn as nn
from torch.autograd import Variable

loss = nn.NLLLoss2d()

target = Variable(torch.Tensor(8, 70, 70).random_(0, 1)).long()
output = Variable(torch.randn(8, 2, 68, 68))

loss(output, target)
RuntimeError: Assertion `input0 == target0 && input2 == target1 && input3 == target2' failed. size mismatch (got input: 8x2x68x68, target: 8x70x70)
loss(output.cuda(), target.cuda())
Variable containing:
1.00000e-03 *
 -2.4779

Yes please! It seems that this loss is lacking some shape checks on the GPU indeed.

It’s done. It seems I still have some problems though, and they might also be related to cuda.

I can’t reproduce the error with trivial tensors, but basically the log_softmax function produces nans and -inf.
I’m not even sure those are “reasonable” values to obtain for a last layer, but in any case the cuda function doesn’t seem happy with them.

last_layer

Variable containing:
(0 ,0 ,.,.) = 
    3.5879    3.6678    3.8380  ...     3.1548    3.0576    2.9584
    3.4753    3.7363    3.8944  ...     2.9736    3.0051    2.9889
    3.3298    3.3160    3.5382  ...     2.8276    2.8111    2.8584
              ...                ⋱                ...             
    3.2416    3.1960    3.2502  ...    90.0304   98.9006   98.5473
    3.2719    3.2843    3.2724  ...    39.1980   67.7482   73.4172
    3.2535    3.3880    3.3061  ...    13.5371   25.7164   37.9838

(0 ,1 ,.,.) = 
   -3.7768   -3.8683   -4.0629  ...    -3.2815   -3.1703   -3.0568
   -3.6481   -3.9466   -4.1274  ...    -3.0742   -3.1102   -3.0917
   -3.4816   -3.4659   -3.7200  ...    -2.9072   -2.8883   -2.9424
              ...                ⋱                ...             
   -3.3808   -3.3285   -3.3906  ...  -122.8841 -137.5700 -122.2672
   -3.4154   -3.4296   -3.4160  ...   -48.6680  -82.5093  -91.6747
   -3.3943   -3.5482   -3.4545  ...   -17.6337  -33.1317  -46.8320
[torch.cuda.FloatTensor of size 1x2x68x68 (GPU 0)]

F.log_softmax(last_layer)

Variable containing:
(0 ,0 ,.,.) = 
   -0.0006   -0.0005   -0.0004  ...    -0.0016   -0.0020   -0.0024
   -0.0008   -0.0005   -0.0003  ...    -0.0024   -0.0022   -0.0023
   -0.0011   -0.0011   -0.0007  ...    -0.0032   -0.0033   -0.0030
              ...                ⋱                ...             
   -0.0013   -0.0015   -0.0013  ...        nan       nan       nan
   -0.0012   -0.0012   -0.0012  ...     0.0000    0.0000    0.0000
   -0.0013   -0.0010   -0.0012  ...     0.0000    0.0000    0.0000

(0 ,1 ,.,.) = 
   -7.3653   -7.5366   -7.9013  ...    -6.4379   -6.2299   -6.0176
   -7.1242   -7.6833   -8.0221  ...    -6.0501   -6.1175   -6.0829
   -6.8125   -6.7830   -7.2590  ...    -5.7381   -5.7027   -5.8038
              ...                ⋱                ...             
   -6.6237   -6.5260   -6.6420  ...       -inf      -inf      -inf
   -6.6886   -6.7152   -6.6896  ...   -87.8659      -inf      -inf
   -6.6491   -6.9371   -6.7618  ...   -31.1708  -58.8481  -84.8159
[torch.cuda.FloatTensor of size 1x2x68x68 (GPU 0)]

F.log_softmax(last_layer.cpu())

Variable containing:
(0 ,0 ,.,.) = 
   -0.0006   -0.0005   -0.0004  ...    -0.0016   -0.0020   -0.0024
   -0.0008   -0.0005   -0.0003  ...    -0.0024   -0.0022   -0.0023
   -0.0011   -0.0011   -0.0007  ...    -0.0032   -0.0033   -0.0030
              ...                ⋱                ...             
   -0.0013   -0.0015   -0.0013  ...     0.0000    0.0000    0.0000
   -0.0012   -0.0012   -0.0012  ...     0.0000    0.0000    0.0000
   -0.0013   -0.0010   -0.0012  ...    -0.0000    0.0000    0.0000

(0 ,1 ,.,.) = 
   -7.3653   -7.5366   -7.9013  ...    -6.4379   -6.2299   -6.0176
   -7.1242   -7.6833   -8.0221  ...    -6.0501   -6.1175   -6.0829
   -6.8125   -6.7830   -7.2590  ...    -5.7381   -5.7027   -5.8038
              ...                ⋱                ...             
   -6.6237   -6.5260   -6.6420  ...  -212.9145 -236.4706 -220.8145
   -6.6886   -6.7152   -6.6896  ...   -87.8659 -150.2575 -165.0919
   -6.6491   -6.9371   -6.7618  ...   -31.1708  -58.8481  -84.8159
[torch.FloatTensor of size 1x2x68x68]

It seems that it becomes numerically unstable if when the difference gets too large. I’ve opened an issue.

2 Likes

I also implemented an UNet variant in pytorch recently, and managed to train it more or less successfully for a Kaggle competition, here it is: https://github.com/lopuhin/kaggle-dstl/blob/292840bf4faf49ecf7c74bed9b6d91982a139090/models.py#L211 - but in my case the classes were not mutually exclusive, so I used sigmoid activations.

2 Likes

@lopuhin I also tried to implement UNet unfortunately I don’t have any convergence (neither for image-to-image transforms nor for multi-class pixel-wise segmentation).

Do you have any idea what the problem could be? Do you think it maybe because of missing Batch Norm layers?

What is your experience with the different UNet models you provide? How do e.g. SmallNet, OldNet, … compare to each other?

@bodokaiser the lack of any convergence even on train might be some bug (maybe even not in the network but in how inputs/outputs are prepared) or bad learning rate, and it can very much depend on the dataset - sorry, don’t have any more insights about this.

One thing that is different in my implementation is that I use upsampling instead of transposed convolutions, it worked significantly better in my case. Batch normalization speeds up convergence but is by no means essential, it worked fine without it too. Simpler models also gave okayish results, but UNet was consistently better - in this task the metric was intersection over union, and simple models were giving results in the 0.2-0.3 range (average over 10 classes), while UNet gave 0.4+ without too much tuning.

1 Like

@lopuhin after about ~200 iterations (batch size 1) my output images only have one color (= are classified to have one one homogenous segmentation label) which does not change with different images. I tried cuda and cpu mode (same problem) I also get this when only using 1 u-net layer (but not with only one standard convolutional layer). So not sure if this can be a bug. I guess I need to try out your implementation to find bugs in my code. So big thanks for sharing!

1 Like

I also spend a whole week on your code @bodokaiser and as you you mentioned, the blank output problem is so weird! I changed and analyzed all kinds of examinations to find the source of bad behavior, but nothing!!! Everything works as it should!!
I don’t know why the developers pf PyTorch do not pay attention to this weird problem, which I believe is a clear bug in PyTorch!!

maybe I should mention @apaszke explicitly to grab his attention :slight_smile:

I don’t think it’s a bug in pytorch, as I’ve commented in the other issue.
Using batch norm with such small batch sizes is probably not a good idea. If you need those (because of using a pre-trained network), I’d freeze the mean/std and parameters of the batch norm

1 Like

The point is that the problem persists even after removing batchnorm. Additionally, I’m using batch size of size 16 in my recent efforts . The same model is working under theano/lasagna implementation!

Hey,

I already commented on this issue on GitHub but got some new ideas today which might be worth to check out:

  1. Check if the loss is correct (correct sign +1, -1)
  2. Use Deconvolution instead of MaxPool

For 2. you could take a look at https://github.com/meetshah1995/pytorch-semseg/blob/master/ptsemseg/models/utils.py#L94-L108

Good luck!

If you share the lasagne and the pytorch code, I can have a look

HI @bodokaiser

for the 1. I think it is not the problem, because I’m normalizing the groundtuth to the range of {0,1} and the prediction after sigmoid is also in the range [0,1]. The BCELoss also is supposed to work with this range.

For 2, yeah maybe there is a problem with ConvTranspose2d. I will replace it with Upsampling and update you.

Here is the code to the THEANO version which works like a charm!!

Here’s the PyTorch code not working!
https://github.com/saeedizadi/UNET_POLYP

I had a quick look at your code, and one thing that you should note is that you need the Sigmoid for the BCELoss, and it seems that you commented it out?
Also, it would be great if you could explain what kind of problems you are having: is the network not converging? are you getting worse results that your lasagne implementation? the network doesn’t learn at all?