Thanks for the code.
The suggestion in the repo won’t work as the model is actually called from bottom to top.
Each submodule is passed as submodule
to the next layer, so that you actually just have to call unet_block_4
.
As the code is quite complicated, I think the easiest way would be to use hooks.
Here is an example how to get the intermediate activations:
class UnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
norm_layer=nn.BatchNorm2d, use_dropout=False):
super(UnetGenerator, self).__init__()
# construct unet structure
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
for i in range(num_downs - 5):
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
self.unet_block_1 = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
self.unet_block_2 = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=self.unet_block_1, norm_layer=norm_layer)
self.unet_block_3 = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=self.unet_block_2, norm_layer=norm_layer)
self.unet_block_4 = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=self.unet_block_3, outermost=True, norm_layer=norm_layer)
self.model = self.unet_block_4
def forward(self, input):
out = self.model(input)
return out
activation = {}
def get_activation(name):
def hook(model, input, output):
activation[name] = output.detach()
return hook
model = UnetGenerator(1, 1, 5)
model.unet_block_1.register_forward_hook(get_activation('block1'))
model.unet_block_2.register_forward_hook(get_activation('block2'))
model.unet_block_3.register_forward_hook(get_activation('block3'))
x = torch.randn(1, 1, 224, 224)
output = model(x)
print(activation['block1'].shape)
print(activation['block2'].shape)
print(activation['block3'].shape)
If you want to backpropagate through these activations, you would have to remove the .detach()
call in the hook.
Could you check, if this yields valid results? I’m still not sure how the code is exactly working.