Is it possible to have torch.cat in the model definition so when I use Netron I can visualize the skip connections?
If I understood correctly, you mean something like Unet or resnet by skip-connections, right? If yes, the answer is yes too! I have not used Netron but when I tried visualizing model (Unet: skip connection between encoder and decoder layers) using Tensorboard or TensorboardX, it was ok and I could see the connections.
Here is the link to my implementation on GitHub
But the main part of code that do the
cat thing is here:
class Expand(nn.Module): def __init__(self, input_channel, output_channel, ks=4, s=2): """ This path consists of an up sampling of the feature map followed by a 4x4 convolution ("up-convolution" or Transformed Convolution) that halves the number of feature channels, a concatenation with the correspondingly cropped feature map from Contract phase :param input_channel: input channel size :param output_channel: output channel size """ super(Expand, self).__init__() self.layers = CE(input_channel * 2, output_channel, ks, s) def forward(self, x1, x2): delta_x = x1.size() - x2.size() delta_y = x1.size() - x2.size() x2 = F.pad(x2, pad=(delta_x // 2, delta_y // 2, delta_x // 2, delta_y // 2), mode='constant', value=0) x = torch.cat((x2, x1), dim=1) x = self.layers(x) return x
And for visualizing I did not specified any particular arg, etc and you do not need too, because the graph of the model constructed by the flow of data through network, so if you have a skip-connection, the data would flow that way too.
Yes exactly, this way of doing concatenation doesn’t appear on Netron visualization. I converted the model to OMNX first and it worked correctly. I think the problem with torch.save function.
The skip-connections does not contain any weights right, so
torch.save won’t save them I think. I am not sure about this or maybe eager execution fails the
save function to cover all aspects?
Yes, you are correct. To visualise a Pytorch model with Netron, you have to save it in ONNX as it also saves the graph. I use
torch.onnx.export() for this. With
torch.save() you can just visualise individual blocks in Netron and see the weights, but not the network graph.