Hello guys,
I am rewriting tensorflow model to pytorch. I am stuck for 2 days on trying to rewrite this layer
class MultiScaleFeatureFusion(tf.keras.layers.Layer):
def __init__(self, filters, **kwargs):
super().__init__(**kwargs)
self.filters = filters
self.global_avg_pool = layers.GlobalAveragePooling2D()
self.multiply = layers.Multiply()
self.dense1 = layers.Dense(filters, activation="relu")
self.dense2 = layers.Dense(filters * 3, activation="sigmoid")
self.conv = layers.Conv2D(filters=filters, kernel_size=(1, 1), padding="same")
def get_config(self):
config = super().get_config()
config.update({"filters": self.filters})
return config
def call(self, inputs):
_x = self.global_avg_pool(inputs)
_x = self.dense1(_x)
_x = self.dense2(_x)
_x = self.multiply([inputs, _x])
_x = self.conv(_x)
return _x
apparently it multiplies input with _x that goes through average pooling and 2 dense layers with activations.
I have rewriten it like this in pytorch
class MultiScaleFeatureFusion(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
self.dense1 = nn.Linear(in_channels, out_channels)
self.relu = nn.ReLU()
self.dense2 = nn.Linear(out_channels, out_channels)
self.sigmoid = nn.Sigmoid()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding="same")
def forward(self, x):
_x = self.global_avg_pool(x)
_x = self.dense1(_x)
_x = self.relu(_x)
_x = self.dense2(_x)
_x = self.sigmoid(_x)
_x = self.conv(_x)
_x = torch.mul(x, _x)
return _x
It throws error at first dense1 layer and it is
113 def forward(self, input: Tensor) -> Tensor:
--> 114 return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x1 and 128x128)
I would appreciate help to unstuck me.