How to extract intermediate feature maps from U-Net?

I am using the pix2pix code https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix. Does anyone know how to extract intermediate feature maps from U-Net? The same problem is here https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/398.

Thank you in advance.

The author posted a solution in the git issue. Did you encounter any problems with returning the activation in your forward method?

I changed the code as shown in this figure but I got the following error:

File “/home/Special/h_106/anaconda3/lib/python3.6/site-packages/torch/nn/modules/conv.py”, line 301, in forward
self.padding, self.dilation, self.groups)
RuntimeError: Given groups=1, weight of size [512, 256, 4, 4], expected input[4, 6, 256, 256] to have 256 channels, but got 6 channels instead

I have set input_nc=6 , output_nc=3 and batch_size=4 .
Any suggestions?

Could you check that the right tensor is passed to the right unet_block?
It looks like you are mixing up the tensors passed to the next block.

I also think this is caused by passing the wrong tensors.
But I really do not know where to check?
And ‘input’ should be [4,6,256,256].

Any advice and suggestions will be greatly appreciated.

Could you post your model code so that I can have a look?

@ptrblck Sure, here is my UnetGenerator

class UnetGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, num_downs, ngf=64,
                 norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UnetGenerator, self).__init__()

        # construct unet structure
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
        for i in range(num_downs - 5):
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
        self.unet_block_1 = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        self.unet_block_2 = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=self.unet_block_1, norm_layer=norm_layer)
        self.unet_block_3 = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=self.unet_block_2, norm_layer=norm_layer)
        self.unet_block_4 = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=self.unet_block_3, outermost=True, norm_layer=norm_layer)

    def forward(self, input):
        out1 = self.unet_block_1(input)
        out2 = self.unet_block_2(out1)
        out3 = self.unet_block_3(out2)
        out_image = self.unet_block_4(out3)
        return out_image, out1, out2, out3

And the calling code is,

combine=torch.cat((self.real_A, self.real_B), 1)
self.fake_B, self.f1, self.f2, self.f3 = self.netG(combine)

The dimension of combine is [4,6,256,256].

Let me know if you want other information. Thank you in advance.

Thanks for the code.
The suggestion in the repo won’t work as the model is actually called from bottom to top.
Each submodule is passed as submodule to the next layer, so that you actually just have to call unet_block_4.
As the code is quite complicated, I think the easiest way would be to use hooks.
Here is an example how to get the intermediate activations:

class UnetGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, num_downs, ngf=64,
                 norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UnetGenerator, self).__init__()

        # construct unet structure
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
        for i in range(num_downs - 5):
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
        self.unet_block_1 = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        self.unet_block_2 = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=self.unet_block_1, norm_layer=norm_layer)
        self.unet_block_3 = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=self.unet_block_2, norm_layer=norm_layer)
        self.unet_block_4 = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=self.unet_block_3, outermost=True, norm_layer=norm_layer)

        self.model = self.unet_block_4

    def forward(self, input):
        out = self.model(input)
        return out

activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook
        
model = UnetGenerator(1, 1, 5)
model.unet_block_1.register_forward_hook(get_activation('block1'))
model.unet_block_2.register_forward_hook(get_activation('block2'))
model.unet_block_3.register_forward_hook(get_activation('block3'))


x = torch.randn(1, 1, 224, 224)
output = model(x)
print(activation['block1'].shape)
print(activation['block2'].shape)
print(activation['block3'].shape)

If you want to backpropagate through these activations, you would have to remove the .detach() call in the hook.

Could you check, if this yields valid results? I’m still not sure how the code is exactly working.

1 Like

Hi @ptrblck,
Thanks for your code but I got the following error,

Traceback (most recent call last):
  File "train.py", line 32, in <module>
    model.optimize_parameters()
  File "/home/csdept/projects/pytorch-CycleGAN-and-pix2pix/models/pix2pix_model.py", line 209, in optimize_parameters
    self.forward()
  File "/home/csdept/projects/pytorch-CycleGAN-and-pix2pix/models/pix2pix_model.py", line 112, in forward
    self.model_feature.unet_block_1.register_forward_hook(self.get_activation('block1'))
TypeError: get_activation() takes 1 positional argument but 2 were given

Are you trying to register the hooks in the forward function?
If so, move the code outside just after the model instantiation as shown in my example.