TensorFlow/Keras to PyTorch translation

Hello, I am trying to translate a TensorFlow/Keras model to PyTorch, but specifically facing an issue in writing the PyTorch equivalent of a Dense layer in the TensorFlow/Keras code.

original TensorFlow/Keras code

inputs = (256, 256, 1)

model = Sequential()
# encoder
model.add(Convolution2D(32, (3,3), input_shape=inputs, \
                        activation='relu', padding='same'))
model.add(MaxPooling2D((2,2), padding='same'))
model.add(Convolution2D(64, (3,3), activation='relu', padding='same'))
model.add(MaxPooling2D((2,2), padding='same'))
model.add(Convolution2D(128, (3,3), activation='relu', padding='same'))
model.add(MaxPooling2D((2,2), padding='same'))

model.add(Dense(128, activation='relu'))

# decoder
model.add(UpSampling2D((2,2)))
model.add(Convolution2D(128, (3,3), activation='sigmoid', padding='same'))
model.add(UpSampling2D((2,2)))
model.add(Convolution2D(64, (3,3), activation='sigmoid', padding='same'))
model.add(UpSampling2D((2,2)))
model.add(Convolution2D(1, (3,3), activation='sigmoid', padding='same'))

and the TensorFlow/Keras model summary is as follows:

Model: "sequential"
_________________________________________________________________
 Layer (type)                       Output Shape          Param #   
=================================================================
 conv2d (Conv2D)                  (None, 256, 256, 32)      320       
                                                                 
 max_pooling2d (MaxPooling2D)     (None, 128, 128, 32)      0                                                                 
                                                                 
 conv2d_1 (Conv2D)                (None, 128, 128, 64)      18496     
                                                                 
 max_pooling2d_1 (MaxPooling 2D)  (None, 64, 64, 64)        0                                                 
                                                                 
 conv2d_2 (Conv2D)                (None, 64, 64, 128)       73856     
                                                                 
 max_pooling2d_2 (MaxPooling 2D)  (None, 32, 32, 128)       0                          
                                                                 
 dense (Dense)                    (None, 32, 32, 128)       16512     
                                                                 
 up_sampling2d (UpSampling2D)     (None, 64, 64, 128)       0                                  
                                                                 
 conv2d_3 (Conv2D)                (None, 64, 64, 128)       147584    
                                                                 
 up_sampling2d_1 (UpSampling 2D)  (None, 128, 128, 128)     0                                
                                                                 
 conv2d_4 (Conv2D)                (None, 128, 128, 64)      73792     
                                                                 
 up_sampling2d_2 (UpSampling  2D) (None, 256, 256, 64)      0         
                                                                 
 conv2d_5 (Conv2D)                (None, 256, 256, 1)       577       
                                                                 
=================================================================
Total params: 331,137
Trainable params: 331,137
Non-trainable params: 0
_________________________________________________________________

My attempt at writing the PyTorch equivalent is here:

class Net(nn.Module):
    
    def __init__(self):
        super(Net, self).__init__()
        
        # encoder
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding='same')
        self.pool1 = nn.MaxPool2d(kernel_size=2, padding=0)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding='same')
        self.pool2 = nn.MaxPool2d(kernel_size=2, padding=0)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding='same')
        self.pool3 = nn.MaxPool2d(kernel_size=2, padding=0)
        
        # fully-connected
        self.fc1 = nn.Linear(128*16*16, 128)
        
        # decoder
        self.up1 = nn.Upsample(scale_factor=2)
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding='same')
        self.up2 = nn.Upsample(scale_factor=2)
        self.conv5 = nn.Conv2d(128, 64, kernel_size=3, padding='same')
        self.up3 = nn.Upsample(scale_factor=2)
        self.conv6 = nn.Conv2d(64, 1, kernel_size=3, padding='same')
        
        # calculate output sizes
        self.conv1_out = get_output_shape(self.conv1, (1, 1, 256, 256))
        self.pool1_out = get_output_shape(self.pool1, self.conv1_out)
        self.conv2_out = get_output_shape(self.conv2, self.pool1_out)
        self.pool2_out = get_output_shape(self.pool2, self.conv2_out)
        self.conv3_out = get_output_shape(self.conv3, self.pool2_out)
        self.pool3_out = get_output_shape(self.pool3, self.conv3_out)
        
        self.fc1_out = get_output_shape(self.fc1, self.pool3_out)
        
        # self.up1_out = get_output_shape(self.up1, self.fc1_out)
        # self.conv4_out = get_output_shape(self.conv4, self.up1_out)
        # self.up2_out = get_output_shape(self.up2, self.conv4_out)
        # self.conv5_out = get_output_shape(self.conv5, self.up2_out)
        # self.up3_out = get_output_shape(self.up3, self.conv5_out)
        # self.conv6_out = get_output_shape(self.conv6, self.up3_out)
        
    
    def forward(self, x):
        
        # encoder activations
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = F.relu(self.conv3(x))
        x = self.pool3(x)
        
        # x = x.view(x.size(0), -1)
        # fully-connected activation
        x = F.relu(self.fc1(x))
        
        # decoder activations
        x = self.up1(x)
        x = F.sigmoid(self.conv4(x))
        x = self.up2(x)
        x = F.sigmoid(self.conv5(x))
        x = self.up3(x)
        x = F.sigmoid(self.conv6(x))

I am getting a dimension mismatch error in the Linear layer of PyTorch code I’ve written. When I do a print(model) for an instance of the PyTorch class I’ve written, I get the following:

Net(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=32768, out_features=128, bias=True)
  (up1): Upsample(scale_factor=2.0, mode=nearest)
  (conv4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (up2): Upsample(scale_factor=2.0, mode=nearest)
  (conv5): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (up3): Upsample(scale_factor=2.0, mode=nearest)
  (conv6): Conv2d(64, 1, kernel_size=(3, 3), stride=(1, 1), padding=same)
)

Some side info: While debugging, I found a StackOverflow post which had a code snippet to print output size of each layer for the PyTorch code, and I’m mentioning it below, if required:

def get_output_shape(model, image_dim):
    return model(torch.rand(*(image_dim))).data.shape

Now, using this code snippet above, I get the output sizes of the first three conv-pool pairs correct for the PyTorch code I’ve written (it can be verified from the TensorFlow/Keras model summary above, and also keep in mind that in PyTorch the number of kerne;s comes first and then the spatial dimension, while it is the opposite in TensorFlow/Keras) :

torch.Size([1, 32, 256, 256])
torch.Size([1, 32, 128, 128])

torch.Size([1, 64, 128, 128])
torch.Size([1, 64, 64, 64])

torch.Size([1, 128, 64, 64])
torch.Size([1, 128, 32, 32])

However, it gives me an error of dimension mismatch while trying to find output size of the fc1 layer in PyTorch, as follows:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_10614/273569339.py in <module>
----> 1 model = Net()
      2 
      3 print(model)

/tmp/ipykernel_10614/1565241974.py in __init__(self)
     32         self.pool3_out = get_output_shape(self.pool3, self.conv3_out)
     33 
---> 34         self.fc1_out = get_output_shape(self.fc1, self.pool3_out)
     35 
     36         # self.up1_out = get_output_shape(self.up1, self.fc1_out)

/tmp/ipykernel_10614/1833495027.py in get_output_shape(model, image_dim)
      1 def get_output_shape(model, image_dim):
----> 2     return model(torch.rand(*(image_dim))).data.shape

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1188         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190             return forward_call(*input, **kwargs)
   1191         # Do not call functions when jit is used
   1192         full_backward_hooks, non_full_backward_hooks = [], []

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/linear.py in forward(self, input)
    112 
    113     def forward(self, input: Tensor) -> Tensor:
--> 114         return F.linear(input, self.weight, self.bias)
    115 
    116     def extra_repr(self) -> str:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (4096x32 and 32768x128)

Any help would be appreciated.


EDIT: I got to know that Keras does implicit flattening before sending it to Dense, and that is an error in my PyTorch code, but even after adding the flattening code as follows, I get a dimension mismatch error when trying to run torch-summary.

flattening code (goes before self.fc1 = nn.Linear(...))

x = x.view(x.size(0), -1)

dimension mismatch error when trying to get model summary using torch-summary

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/opt/conda/lib/python3.7/site-packages/torchsummary/torchsummary.py in summary(model, input_data, batch_dim, branching, col_names, col_width, depth, device, dtypes, verbose, *args, **kwargs)
    139             with torch.no_grad():
--> 140                 _ = model.to(device)(*x, *args, **kwargs)  # type: ignore[misc]
    141         except Exception as e:

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1189                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190             return forward_call(*input, **kwargs)
   1191         # Do not call functions when jit is used

/tmp/ipykernel_149/842565426.py in forward(self, x)
     57         # fully-connected activation
---> 58         x = F.relu(self.fc1(x))
     59 

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1207 
-> 1208         result = forward_call(*input, **kwargs)
   1209         if _global_forward_hooks or self._forward_hooks:

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/linear.py in forward(self, input)
    113     def forward(self, input: Tensor) -> Tensor:
--> 114         return F.linear(input, self.weight, self.bias)
    115 

RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x131072 and 32768x128)

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_149/566945719.py in <module>
----> 1 summary(model, (1,256,256))

/opt/conda/lib/python3.7/site-packages/torchsummary/torchsummary.py in summary(model, input_data, batch_dim, branching, col_names, col_width, depth, device, dtypes, verbose, *args, **kwargs)
    144                 "Failed to run torchsummary. See above stack traces for more details. "
    145                 "Executed layers up to: {}".format(executed_layers)
--> 146             ) from e
    147         finally:
    148             if hooks is not None:

RuntimeError: Failed to run torchsummary. See above stack traces for more details. Executed layers up to: [Conv2d: 1-1, MaxPool2d: 1-2, Conv2d: 1-3, MaxPool2d: 1-4, Conv2d: 1-5, MaxPool2d: 1-6]

I can’t speak as to what the TF model should or shouldn’t do. That said, your Pytorch model has 3x MaxPool2d layers that are cutting the size in half per dim. And your input size is 256x256. Also, you have 128 channels coming out of your CNN encoder.


So it’s giving the correct output size for the Pytorch version. You can either adjust your Linear layer size by 4x or run through one more Conv2d + MaxPool2d layer.

I had another glance at your TF model print out. They are using 128 in_features. So what’s likely happening is something like the following:

self.fc1 = nn.Linear(128, 128)
...
#forward pass 

x = self.pool3(x)

N, c, h, w = x.shape
x = x.reshape(N, c, h*w).permute(0,2,1)

x = F.relu(self.fc1(x))

x = x.permute(0,2,1).reshape(N, c, h, w)
1 Like

Thank you, this works and I get the following PyTorch model summary using torch-summary:

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
├─Conv2d: 1-1                            [-1, 32, 256, 256]        320
├─MaxPool2d: 1-2                         [-1, 32, 128, 128]        --
├─Conv2d: 1-3                            [-1, 64, 128, 128]        18,496
├─MaxPool2d: 1-4                         [-1, 64, 64, 64]          --
├─Conv2d: 1-5                            [-1, 128, 64, 64]         73,856
├─MaxPool2d: 1-6                         [-1, 128, 32, 32]         --
├─Linear: 1-7                            [-1, 1024, 128]           16,512
├─Upsample: 1-8                          [-1, 128, 64, 64]         --
├─Conv2d: 1-9                            [-1, 128, 64, 64]         147,584
├─Upsample: 1-10                         [-1, 128, 128, 128]       --
├─Conv2d: 1-11                           [-1, 64, 128, 128]        73,792
├─Upsample: 1-12                         [-1, 64, 256, 256]        --
├─Conv2d: 1-13                           [-1, 1, 256, 256]         577
==========================================================================================
Total params: 331,137
Trainable params: 331,137
Non-trainable params: 0
Total mult-adds (G): 2.47
==========================================================================================
Input size (MB): 0.25
Forward/backward pass size (MB): 41.50
Params size (MB): 1.26
Estimated Total Size (MB): 43.01
==========================================================================================

But a follow-up question: the output dimension for the TF model for the Dense layer is (None, 32, 32, 128), however for the PyTorch model’s Linear layer is [-1, 1024, 128]. I don’t understand why.

32 x 32 = 1024

After the Linear layer matmul and bias addition operations are complete, the code in my previous reply permutes the H x W dim back to the end, before reshaping them to separate dims to feed into the decoder CNN.

When you run a tensor through a Pytorch Linear layer, it will matmul the last dim with the weights, while all other dims are effectively treated as batches. This conforms with how I’ve seen Google’s research team(who primarily use TF) handle operations in their ViT paper for the initial patchify stem.