I have a deep neural network which consists of a U-Net and two CNN layers afterwards. The U-net is just to learn point-wise features for a graph. Then feed the learned features to new CNN layers, but this seems not working as there is a huge bottleneck. I wonder how I can learn the U-net first then after converging learn the second part for fine-tuning. This is called stage-wise learning but I have not found a good resource for Pytorch.