Dropout in eval mode of GoogLeNet: model(x) or layer(layer(x))

Hello!

I find a really tricky situation for the dropout function in the pretrained GoogLeNet. The output results between directly executing the whole model and executing the model layer by layer are different. It seems that when executing the model layer by layer, the dropout function does not work while it works if executing the whole model.

Details are shown below. Thank you!

Version info (google colab):
torch – 1.7.0+cu101
torchvision – 0.8.1+cu101

import torch
from torch import nn
from torchvision import models
x = torch.randn(10, 3, 224, 224)

glnet = models.googlenet(pretrained=True)
glnet.eval
glnet(x)

outputs

tensor([[ 0.4788,  0.8616,  1.1871,  ...,  0.3228, -0.7227, -0.7120],
        [-1.3281,  0.3395, -0.9889,  ..., -0.4708, -0.4659,  1.5156],
        [ 1.0499,  1.3531,  0.3119,  ..., -1.3807,  0.2523, -1.0316],
        ...,
        [-0.5240, -1.7480, -0.0047,  ..., -0.8630, -0.2038, -0.6803],
        [ 1.1060,  1.3443, -0.1791,  ..., -1.1149,  1.2490, -0.0174],
        [-1.9743, -0.2748, -0.5077,  ..., -1.3009,  0.3457,  0.0843]],
       grad_fn=<AddmmBackward>)

while

out = x
for name, layer in glnet.named_children():
    if name == 'fc':
        out = nn.Flatten()(out)
    out = layer(out)
out

outputs

tensor([[ 1.2480,  0.3792,  2.1190,  ...,  0.1834,  0.3580, -1.0001],
        [-0.7878, -0.0876, -0.8435,  ..., -0.8066,  0.2189,  1.0071],
        [ 1.3254, -0.1794, -0.4596,  ..., -0.1602, -0.1182, -0.5954],
        ...,
        [-0.8594, -0.8541,  0.2045,  ..., -0.7589,  0.4976, -0.4356],
        [-0.6665, -0.1939,  0.1854,  ..., -1.1892,  1.4693,  0.5318],
        [-2.6076, -0.0268, -1.8729,  ..., -1.2042,  0.0576,  0.2504]],
       grad_fn=<AddmmBackward>)

It seems that NOT ONLY THE DROPOUT.
Run this following code and you can find that features[0] is different from features[1] !

import torch
from torch import nn
from torchvision import models

features = []

def hook(module, inputs, outputs):
    features.append(outputs.clone())

x = torch.randn(10, 3, 224, 224).float()
glnet = models.googlenet(pretrained=True)
glnet.eval()

glnet.conv1.register_forward_hook(hook)

with torch.no_grad():
    model_output = glnet(x.clone())
    print(model_output)

    out = x.clone()
    for name, layer in glnet.named_children():
        if name == 'fc':
            out = nn.Flatten()(out)
        out = layer(out)
    print(out)

for feature in features:
    print("feature", feature[0, 0, 0, :10])

It’s so strange…I don’t know why…

1 Like

The forward method for this model is not only calling all submodules sequentially, but also transforming the input, if the pretrained model is used as seen here.
Generally, I would not recommend to wrap all submodules in e.g. an nn.Sequential container or to call them sequentially without checking the forward method first and make sure that no functional calls would be dropped.

CC @acoder_acoder

Thanks, @ptrblck!
In the forward() function, though _transform_input() is called, it does nothing for the input. Since a parameter transform_input is controlling whether to use the transform function.

In fact, not only for the GoogLeNet, but the situation also occurs for some other networks. So I wonder is there anything weird for the dropout function. Thank you very much!

@Eta_C , Thank you! According to ptrblck’s answer. The problem is the transform_input parameter. I thought the parameter transform_input was False, while it was True…
That is the reason…
This line

Thank you very much!

Since this attribute is set to True in your example, the input will be transformed.
Using it also yields the same results:

x = torch.randn(10, 3, 224, 224)
glnet = models.googlenet(pretrained=True)
glnet.eval()
res = glnet(x)

out = x
if glnet.transform_input:
    out = glnet._transform_input(out)
for name, layer in glnet.named_children():
    if name == 'fc':
        out = nn.Flatten()(out)
    out = layer(out)

print((out - res).abs().max())
> tensor(0., grad_fn=<MaxBackward1>)
1 Like