Help with PyTorch equivalent for a Keras code

Hey there, I am trying to reproduce a Keras code in PyTorch can anyone help me out.

base_inception = InceptionV3(weights='imagenet', include_top=False, 
                             input_shape=(299, 299, 3))
# Add a global spatial average pooling layer
out = base_inception.output
out = GlobalAveragePooling2D()(out)
out = Dense(512, activation='relu')(out)
out = Dense(512, activation='relu')(out)
total_classes = y_train_ohe.shape[1]
predictions = Dense(total_classes, activation='softmax')(out)

model = Model(inputs=base_inception.input, outputs=predictions)

# only if we want to freeze layers
for layer in base_inception.layers:
    layer.trainable = False
# Compile 
model.compile(Adam(lr=.0001), loss='categorical_crossentropy', metrics=['accuracy']) 

I tried it out but I couldn’t work out the Global Average / Average Pooling part

model_ft = models.inception_v3(pretrained=True)
for param in model_ft.parameters():
    param.requires_grad = False
#num_ftrs = model_ft.classifier[6].in_features
#model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)

model_ft.aux_logits = False
num_ftrs = model_ft.fc.in_features

class inception_v3_see_smart(nn.Module):
    def __init__(self, originalModel):
        super(inception_v3_see_smart, self).__init__()
        self.Model = originalModel
        self.adaptive_pool = nn.AvgPool2d(2)
        #self.conv1 = nn.Conv2d(1000, 2000, 3)
        self.dense1 = nn.Linear(512,512)
        self.dense2 = nn.Linear(512,62)
    def forward(self, x):
        x = self.Model(x)
        x = self.adaptive_pool(x)
        x = F.relu(self.dense1(x))
        x = F.relu(self.dense2(x))

        return x

model_ft = inception_v3_see_smart(model_ft)


model_ft =
criterion = nn.CrossEntropyLoss()

Any help is really appreciated!

1 Like

Isn’t global average pooling just average pooling where the kernel is the whole image?

1 Like

Try nn.AdaptiveAvgPool2d(1) for your self.adaptive_pool layer. Also, reshape the tensor from 4D to 2D after pooling, either by x = x.view(x.shape[:2]) or an nn.Flatten() layer.

1 Like

Sure thanks for the help will look into it and let you know