Help me rewrite this from tensorflow (Dense layers)

Hello guys,

I am rewriting tensorflow model to pytorch. I am stuck for 2 days on trying to rewrite this layer

class MultiScaleFeatureFusion(tf.keras.layers.Layer):
    def __init__(self, filters, **kwargs):
        super().__init__(**kwargs)
        self.filters = filters
        self.global_avg_pool = layers.GlobalAveragePooling2D()
        self.multiply = layers.Multiply()
        self.dense1 = layers.Dense(filters, activation="relu")
        self.dense2 = layers.Dense(filters * 3, activation="sigmoid")
        self.conv = layers.Conv2D(filters=filters, kernel_size=(1, 1), padding="same")

    def get_config(self):
        config = super().get_config()
        config.update({"filters": self.filters})
        return config

    def call(self, inputs):
        _x = self.global_avg_pool(inputs)

        _x = self.dense1(_x)
        _x = self.dense2(_x)

        _x = self.multiply([inputs, _x])
        _x = self.conv(_x)
        return _x

apparently it multiplies input with _x that goes through average pooling and 2 dense layers with activations.

I have rewriten it like this in pytorch

class MultiScaleFeatureFusion(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.dense1 = nn.Linear(in_channels, out_channels)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(out_channels, out_channels)
        self.sigmoid = nn.Sigmoid()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding="same")

    def forward(self, x):
        _x = self.global_avg_pool(x)
        _x = self.dense1(_x)
        _x = self.relu(_x)
        _x = self.dense2(_x)
        _x = self.sigmoid(_x)
        _x = self.conv(_x)
        _x = torch.mul(x, _x)
        return _x

It throws error at first dense1 layer and it is

    113 def forward(self, input: Tensor) -> Tensor:
--> 114     return F.linear(input, self.weight, self.bias)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x1 and 128x128)

I would appreciate help to unstuck me.

The activation shape in your PyTorch model seems to be [batch_size=128, features=1] so self.dense1 should use in_features=1. Could you check if this shape matches the TF model?

Welcome to the forums!

There is a lot to unpack here.

  1. When you are calling nn.AdaptiveAvgPool2d(1), you are telling PyTorch that whatever size image comes into that, make it 1x1. So the output of that layer would then be (batch_size, channels, 1, 1). Let’s assume you want that. Then you should call a flatten() afterward.
_x = self.global_avg_pool(inputs).flatten(1) 
  1. When you get down to your convolutional layer, you will need to add those two dims back.
_x = self.conv(_x.unsqueeze(2).unsqueeze(3))

Note: The above operation, in your case, is the mathematical equivalent of a Linear layer under the guise of a convolution layer.

Let’s demonstrate that to be the case with the following code:

import torch
import torch.nn as nn


model1 = nn.Linear(3,4)

model2 = nn.Conv2d(3,4, kernel_size=1)

model2.weight.data = model1.weight.data.unsqueeze(2).unsqueeze(3)
model2.bias.data = model1.bias.data

dummy_data = torch.rand((2, 3))

outputs1 = model1(dummy_data)
outputs2 = model2(dummy_data.unsqueeze(2).unsqueeze(3)).squeeze(2).squeeze(2)
print(outputs1)
print(outputs2)
print(outputs1.allclose(outputs2))

TLDR you could just use a Linear layer and do the same thing. But just make sure to unsqueeze your dim 2 and 3 before the final matrix multiplication with the original inputs.

thanks solved it using views at the end (i could use flattern and unsqueeze as well - i believe its same)
here is the final code

class MultiScaleFeatureFusion(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.dense1 = nn.Linear(in_channels, in_channels)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(in_channels, in_channels)
        self.sigmoid = nn.Sigmoid()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding="same")

    def forward(self, x):
        _x = self.global_avg_pool(x)
        _x = _x.view(_x.shape[0], -1)
        _x = nn.Sequential(self.dense1, self.relu, self.dense2, self.sigmoid)(_x)
        _x = _x.view(_x.size(0), _x.size(1), 1, 1)
        _x = torch.mul(x, _x)
        _x = self.conv(_x)
        return _x

thanks for help!