Converting Keras code to pytorch

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import keras
import keras.layers as layers 
from keras.models import Model
from keras.layers.core import Lambda
import encoder_models as EM
import cv2
import numpy as np


    
def GlobalAveragePooling2D_r(f):
    def func(x):
    
        repc =  int(x.shape[4])
        m    =  keras.backend.repeat_elements(f, repc, axis = 4)
        x    =  layers.multiply([x, m])
        repx =  int(x.shape[2])
        repy =  int(x.shape[3])
        x    = (keras.backend.sum(x, axis=[2, 3], keepdims=True) / (keras.backend.sum(m, axis=[2, 3], keepdims=True)))
        x    =  keras.backend.repeat_elements(x, repx, axis = 2)
        x    =  keras.backend.repeat_elements(x, repy, axis = 3)    
        return x
    return Lambda(func)

def Rep_mask(f):
    def func(x):
        x    =  keras.backend.repeat_elements(x, f, axis = 1)   
        return x
    return Lambda(func)
    
def common_representation(x1, x2): 
    repc =  int(x1.shape[1])
    x2   =  keras.layers.Reshape(target_shape=(1, np.int32(x2.shape[1]), np.int32(x2.shape[2]), np.int32(x2.shape[3]))) (x2) 
    x2   =  Rep_mask(repc)(x2)
    x    =  layers.concatenate([x1, x2], axis=4) 
    x    =  layers.TimeDistributed(layers.Conv2D(128, 3, padding = 'same', kernel_initializer = 'he_normal'))(x)
    x    =  layers.TimeDistributed(layers.BatchNormalization(axis=3))(x) 
    x    =  layers.TimeDistributed(layers.Activation('relu'))(x) 
    return x

I’m new to Pytorch and I’m having some problems converting this code written in Keras to PyTorch. I have converted half of the code but I’m stuck in this part. Any help would be greatly appreciated.

Are you referring to the TimeDistributed layer? afaik there have been a few previous threads on how to implement this e.g., Timedistributed CNN - #2 by ilyes

I got the intuition for the Time Distributed layer but I’m actually stuck in layers.multiply(), keras.backend.sum and all

Is layers.multiply() an elementwise mul? that corresponds to torch.mul. Similarly if keras.backend.sum corresponds to a reduction across some axes that corresponds to torch.sum.

And keras.backend.repeat_elements corresponds to torch.repeat() or torch.repeat_interleave() ?

***I believe it should correspond to torch.repeat_interleave, as the docs say it is like numpy’s repeat.

I have tried torch.repeat_interleave() for a rank 4 tensor like [batch,channel,height,width] along axis=1 (channel) but the ouput tensor is 1D.

I think the default behavior is to use the flattened input unless a dimension is passed.
Can you share a code snippet which demonstrates your issue?

Okay this worked similar to keras.backend.repeat_elements()

import torch

A = torch.randn([5,3,224,224])

print(A.shape)

B = torch.repeat_interleave(A,4,dim=1)

print(B.shape)  ## (5,12,224,224)

Probably I was missing something
Thanks !!

Okay one last Question how to write the below function in Pytorch’s way

def GlobalAveragePooling2D_r(f):
    def func(x):
    
        repc =  int(x.shape[4])
        m    =  keras.backend.repeat_elements(f, repc, axis = 4)
        x    =  layers.multiply([x, m])
        repx =  int(x.shape[2])
        repy =  int(x.shape[3])
        x    = (keras.backend.sum(x, axis=[2, 3], keepdims=True) / (keras.backend.sum(m, axis=[2, 3], keepdims=True)))
        x    =  keras.backend.repeat_elements(x, repx, axis = 2)
        x    =  keras.backend.repeat_elements(x, repy, axis = 3)    
        return x
    return Lambda(func)

If you mean the organization of the operator invocations into Modules, I think a good reference is some of the implementations in torchvision e.g., the resnet implementation here:
https://pytorch.org/vision/stable/_modules/torchvision/models/resnet.html#resnet18