UNet implementation

No, I’ve commented in the unet.py file, but it exists in the main.py:

outputs = F.sigmoid(model(inputs))

The problem is that the network starts to converge and the loss goes from ~0.7 down to ~0.2 very naturally! So we have convergence! right? however, when I try to evaluate the learned model on even the training images, the output is not better than a blank image!
I was thinking that there might be some problem with loading the learned weight. So, I incorporated a evaluation phase in each batch update. It is evident that the prediction goes apparently goes to blank map in the early first epoch!.

I tested many thing! I also replaced my implemented arch with implementations of UNet by others, and the same problem! I changed the input range into [0,255] instead of [0,1]! the same problem! I changed the loss funtion from BCELoss, to MSELoss, and to CrossEntripyLoss2d; again the same problem!!! With and without batch normalization!! Different Gradient Descent algorithms!

That’s so weird! this is why I doubt it to be a bug in PyTorch!

Is there any updates on this?
@fmassa

I finally found the problem!!
For the last set of convolutions, that is 128-> 64 -> 64 -> 1, the activation function should not be used!
The activation function causes the values to vanish!

I just removed the nn.ReLU() modules on top of these convolution layers and now everything works like a charm!

Saeed

Dear all, I almost went through all your implementation about Unet. I can not find weight initialization function or syntax in your implement. I am wondering if I need to initialize weight by myself or there is any missing from the code.

These are only the model design codes. I do weigh initialization as:

            for m in self.modules():
                    if isinstance(m, nn.Conv2d):
                            n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                            m.weight.data.normal_(0, math.sqrt(2. / n))

Sorry that I am reviving this topic after one year :confused: I am also trying to implement UNet but I cannot understand the center_crop function that how does it work.
another question is what is imsize in UNet class? any replies would appreciated.

1 Like