Get encoder from trained UNet

Hi,

I have a trained a UNet model on some images but now, I want to extract the encoder part of the model. My UNet has the following architecture:

UNet(
  (conv_final): Conv2d(8, 1, kernel_size=(1, 1), stride=(1, 1))
  (down_convs): ModuleList(
    (0): DownConv(
      (conv1): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (1): DownConv(
      (conv1): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (2): DownConv(
      (conv1): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (3): DownConv(
      (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (4): DownConv(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (up_convs): ModuleList(
    (0): UpConv(
      (upconv): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
      (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (1): UpConv(
      (upconv): ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(2, 2))
      (conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (2): UpConv(
      (upconv): ConvTranspose2d(32, 16, kernel_size=(2, 2), stride=(2, 2))
      (conv1): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (3): UpConv(
      (upconv): ConvTranspose2d(16, 8, kernel_size=(2, 2), stride=(2, 2))
      (conv1): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
)

I have tried to load the encoder layers through model.down_convs but I get the following error:

TypeError Traceback (most recent call last)
in
----> 1 res = encoder(train_img)

~/anaconda3/envs/work/lib/python3.8/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
548 result = self._slow_forward(*input, **kwargs)
549 else:
–> 550 result = self.forward(*input, **kwargs)
551 for hook in self._forward_hooks.values():
552 hook_result = hook(self, input, result)

TypeError: forward() takes 1 positional argument but 2 were given

What should I do?

Can you add your code for loading the encoder also?

If you are referring to a separate model, then I don’t have any encoder model to load. I have only trained a UNet from which I want to get the encoder weights.

Are you doing this?

encoder = model.down_convs
result = encoder(train_img)

This works for me. It’ll be easier to debug if you post your code so we can reproduce the error

Yup, that’s what I am doing. I am loading a .npz file and encoding it.This is the code:

img = np.load('/HMI20181101_0000_bz.npz', allow_pickle=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_img = img['x']
train_img = train_img.reshape(1, 1, 512, 512)
train_img = np.array(train_img, dtype=np.float)
train_img = torch.from_numpy(train_img)
train_img = train_img.to(device, dtype=torch.float)

res = encoder(train_img)

This is my model for the encoder:

ModuleList(
  (0): DownConv(
    (conv1): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (1): DownConv(
    (conv1): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (2): DownConv(
    (conv1): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (3): DownConv(
    (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (4): DownConv(
    (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
)

And the error I get after running the code is:


TypeError Traceback (most recent call last)
in
----> 1 res = encoder(train_img)

~/anaconda3/envs/work/lib/python3.8/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
548 result = self._slow_forward(*input, **kwargs)
549 else:
–> 550 result = self.forward(*input, **kwargs)
551 for hook in self._forward_hooks.values():
552 hook_result = hook(self, input, result)

TypeError: forward() takes 1 positional argument but 2 were given

I have attached the model so you can try it out. And the weights from here

@ptrblck, any suggestions?

You cannot call nn.ModuleList directly, as it’s used as a container to store other modules.
If you want to call encoder with an input, you could either try to wrap all modules into nn.Sequential or apply the loop over all modules manually.

Are you talking about something like this?

test_block = torch.nn.Sequential(*list(test_model.children()))[:2]

Because I get the following model:

Sequential(
  (0): Conv2d(8, 1, kernel_size=(1, 1), stride=(1, 1))
  (1): ModuleList(
    (0): DownConv(
      (conv1): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (1): DownConv(
      (conv1): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (2): DownConv(
      (conv1): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (3): DownConv(
      (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (4): DownConv(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
)

When I run it res = test_block(train_img), I get the following error:

RuntimeError: Given groups=1, weight of size [1, 8, 1, 1], expected input[1, 1, 512, 512] to have 8 channels, but got 1 channels instead

This input layer requires 8 band input. as
Conv2d(in_channels, out_channels, kernel_size) which seems to me the bug so modify the original network to have input layer for 1 band and then extract the encoder.

Thanks. I will change the layer. But, I wanted to ask, if is it a problem here, so how come it didn’t create any problems while training since I Have trained on the same type of images?

I would be a bit careful using this code:

as it will create an nn.Sequential module using the child modules in a sequential order.
In your initial code it seems that

(conv_final): Conv2d(8, 1, kernel_size=(1, 1), stride=(1, 1))

is initialized as the first layer, but it supposed to be used as the last layer, which would create this issue.

Pass the nn.ModuleList to nn.Sequential as seen here:

mlist = nn.ModuleList([
    nn.Linear(1, 2),
    nn.Linear(2, 3),
    nn.Linear(3, 4)])

model = nn.Sequential(*mlist)
model(torch.randn(1, 1))

or iterate it manually.

So, using this,

will it be possible to assign UNet encoder parameters to a completely different model? Let’s say if I have weights and biases of the encoder part of the UNet, can I assign those values to a model that I have created using the way you mentioned? Only, in the tuple, I will have the filter size as mentioned.

Yes, that would be possible but you would usually have to load these weights manually or manipulate the state_dict keys so that the parameter names match again.

Thanks. I have tried that but I don’t know why, even though I have changed nothing and loading the same model, I get size mismatch error on every layer:

RuntimeError: Error(s) in loading state_dict for UNet:
	size mismatch for conv_final.weight: copying a param with shape torch.Size([1, 8, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 64, 1, 1]).
.
.
.

I don’t know how the number of channels went to 8 from 64.

That’s why I was trying to load the whole model and edit it from there.

For reference, this is the model definition and these are the weights

How was conv_final defined in the original model and how is it defined now?
Could you check the definition of this layer and how you’ve manipulated the state_dict?

I guess I was missing an argument value while defining the UNet object, I had to change it from 64 to 8. Like this
test_model = UNet(2, start_filts=8)
Now I can try your method to define an encoder.

Him

I have been able to extract the encoder and mostly been able to get rid of errors. I followed these steps:

test_modules = test_model.down_convs
test_modules = nn.Sequential(*test_modules)
test_modules.eval()
print((test_modules._modules))

And was able to get the model:

OrderedDict([('0', DownConv(
  (conv1): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)), ('1', DownConv(
  (conv1): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)), ('2', DownConv(
  (conv1): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)), ('3', DownConv(
  (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)), ('4', DownConv(
  (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
))])

I loaded my image like this (since is .npz file):

tmp = np.load(file, allow_pickle=True)
img = tmp['x']
img = np.reshape(img,(1,1,512,512))
img = torch.from_numpy(img)

But I got the error after running test_modules(img):

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-12-621fab09e884> in <module>
      1 with torch.no_grad():
----> 2     test_modules(img)

~/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

~/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/container.py in forward(self, input)
    115     def forward(self, input):
    116         for module in self:
--> 117             input = module(input)
    118         return input
    119 

~/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

~/Desktop/solar/solar/model.py in forward(self, x)
     65 
     66     def forward(self, x):
---> 67         x = F.relu(self.conv1(x))
     68         x = F.relu(self.conv2(x))
     69         before_pool = x

~/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

~/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/conv.py in forward(self, input)
    417 
    418     def forward(self, input: Tensor) -> Tensor:
--> 419         return self._conv_forward(input, self.weight)
    420 
    421 class Conv3d(_ConvNd):

~/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight)
    413                             weight, self.bias, self.stride,
    414                             _pair(0), self.dilation, self.groups)
--> 415         return F.conv2d(input, weight, self.bias, self.stride,
    416                         self.padding, self.dilation, self.groups)
    417 

TypeError: conv2d(): argument 'input' (position 1) must be Tensor, not tuple

The input image to the network is of type <class 'torch.Tensor'>. So what’s happening now?

@albanD, you can also let me know.