PyTorch Equivalent to Keras and Tensorflow of Deep Ranking

Hi,
I’m trying to implement the given network (Deep Ranking) in PyTorch.


Here is the Keras and Tensorflow equivalent to it.

def convnet_model_():

    vgg_model = VGG16(weights=None, include_top=False)
    x = vgg_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(4096, activation='relu')(x)
    x = Dropout(0.6)(x)
    x = Dense(4096, activation='relu')(x)
    x = Dropout(0.6)(x)
    x = Lambda(lambda  x_: K.l2_normalize(x,axis=1))(x)
    convnet_model = Model(inputs=vgg_model.input, outputs=x)
    return convnet_model

def deep_rank_model():
 
    convnet_model = convnet_model_()
    first_input = Input(shape=(224,224,3))
    first_conv = Conv2D(96, kernel_size=(8, 8),strides=(16,16), padding='same')(first_input)
    first_max = MaxPool2D(pool_size=(3,3),strides = (4,4),padding='same')(first_conv)
    first_max = Flatten()(first_max)
    first_max = Lambda(lambda  x: K.l2_normalize(x,axis=1))(first_max)

    second_input = Input(shape=(224,224,3))
    second_conv = Conv2D(96, kernel_size=(8, 8),strides=(32,32), padding='same')(second_input)
    second_max = MaxPool2D(pool_size=(7,7),strides = (2,2),padding='same')(second_conv)
    second_max = Flatten()(second_max)
    second_max = Lambda(lambda  x: K.l2_normalize(x,axis=1))(second_max)

    merge_one = concatenate([first_max, second_max])

    merge_two = concatenate([merge_one, convnet_model.output])
    emb = Dense(4096)(merge_two)
    l2_norm_final = Lambda(lambda  x: K.l2_normalize(x,axis=1))(emb)

    final_model = Model(inputs=[first_input, second_input, convnet_model.input], outputs=l2_norm_final)

    return final_model

I’m confused about multiple things.

  1. How to merge layers in Pytorch like they did in Keras.
  2. How can I provide the same input image to 3 different networks
  3. Any other mistakes I’m making while implementing this network in PyTorch
    Also, can you please walk me through pytorch implementation of same. I’ve written few lines of code. I’ve used resnet instead of VGG
model_conv = torchvision.models.resnet50(pretrained=True)
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs,2000)
class Network(nn.Module):

    def __init__(self):

        super().__init__()
        self.conv1 = nn.Conv2d(3,96,8, stride=16,padding=1)
        self.maxpool1 = nn.MaxPool2d(3,4,padding=1)
        
    def forward(self,x,y,z):
        x = model_conv(x)
        y = self.conv1(y)
        y = self.maxpool1(y)
        y = F.normalize(y,dim=1,p=2)
        z = self.conv1(z)
        z = self.maxpool1(z)
        z = F.normalize(z,p=2, dim=1)
        
        return x
test = Network()

I hope this helps

class ShallowConv(nn.Module):
    def __init__(self, size, mp_k_size):
        super(ShallowConv, self).__init__()
        self.size = size
        self.conv = nn.Conv2d(3, 96, 8, stride=(4,4))
        self.maxp = nn.MaxPool2d(mp_k_size)
        self.flat = Flatten()
    
    def forward(self, x):
        x = nn.functional.interpolate(x, self.size)
        x = self.conv(x)
        x = self.maxp(x)
        return self.flat(x)

class DeepRank(nn.Module):
    def __init__(self, deep_model, shallow_nets):
        super(DeepRank, self).__init__()
        self.deep_model = deep_model
        self.shallow_nets = nn.ModuleList(shallow_nets)
        
    def forward(self, x):
        feats = []
        deep_feat = self.deep_model(x)
        for net in self.shallow_nets:
            feats.append(net(x))

        # Not sure if a new list is needed
        # or can we just overwrite the old-list
        normed = [] 
        # L2-Norm of each extracted feature
        for feat in feats:
            normed.append(F.normalize(feat, dim=1, p=2))


        return torch.cat(normed, dim=1)

DeepRank(models.resnet34, [...Shallow Models...])

1 Like