Problem in adding new layer between a model

I wanted to add a new conv layer in middle of resnet50 I was able to add it and also I was able to get inference from the new model but the problem is that I initialized the weights of the new layer to ones so that there is no difference between the output of the new and the original model. Kernel size, weight every thing is 1 and bias is also 0.
I checked the output of both the new and the original models before the layer I have added in the new model and both outputs are same.
But when I passed that output to the new layer they became different. I am not sure why and not able to rectify the error.

I also created a new layer similar to the one that I added in the model with all same attributes and initializations and passed that output to this layer and then it remained same.
I am not able to understand why this is not happening in the model. Outside the model it is working fine.

Could anyone please help me with this

Could you post the definition of the new conv layer, please?
Note that even if you are setting the kernel size to 1 and fill all weights with 1s, the output might still differ, since the filter kernel will sum over the in_channel dimension by default (or are you using a depthwise conv?).

nn.Conv2d(2048, 2048, kernel_size=[1, 1], stride=(1, 1), bias=False)
I added this to the model with

with torch.no_grad():
  self.layer.weight.fill_(1.0)

in the init function after defining it

but the thing is that when I got output till the previous layer in both the original and the new network they were same
but when I passed them through this layer it changed

I used this

for k,v in model_dict.items():
  if(k in trained_original_state_dict):
    model_dict[k] = trained_original_state_dict[k]

to copy the weights from one model to another
i.e from pretrained model to the new model

http://www.robots.ox.ac.uk/~vgg/data/vgg_face2/models/pytorch/resnet50_128_pytorch.tar.gz

I got the model from here

I tried to recreate the same thing with a smaller model so I took up LeNet5

import torch
import torch.nn as nn
import torch.nn.functional as F
class LeNet5(nn.Module):

    def __init__(self, n_classes):
        super(LeNet5, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1)
        self.ac1 = nn.Tanh()
        self.pool1 = nn.AvgPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1)
        self.ac2 = nn.Tanh()
        self.pool2 = nn.AvgPool2d(kernel_size=2)
        self.conv3 = nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5, stride=1)
        self.ac3 = nn.Tanh()
        self.fc1 = nn.Linear(in_features=120, out_features=84)
        self.ac4 = nn.Tanh()
        self.fc2 = nn.Linear(in_features=84, out_features=n_classes)


    def forward(self, x):
        x = self.conv1(x)
        x = self.ac1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.ac2(x)
        x = self.pool2(x)
        x = self.conv3(x)
        x = self.ac3(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self.ac4(x)
        x = self.fc2(x)
        return x

my_model = LeNet5(5)
ran = torch.randn(1,1,32,32)
my_model(ran)
my_model_state_dict = my_model.state_dict()

In the second model I added conv4 layer

class LeNet5_new(nn.Module):

    def __init__(self, n_classes):
        super(LeNet5_new, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1)
        self.ac1 = nn.Tanh()
        self.pool1 = nn.AvgPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1)
        self.ac2 = nn.Tanh()
        self.pool2 = nn.AvgPool2d(kernel_size=2)
        self.conv3 = nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5, stride=1)
        self.ac3 = nn.Tanh()
        self.conv4 = nn.Conv2d(in_channels=120, out_channels=120, kernel_size=1, stride=1,bias=0)
        self.fc1 = nn.Linear(in_features=120, out_features=84)
        self.ac4 = nn.Tanh()
        self.fc2 = nn.Linear(in_features=84, out_features=n_classes)

        with torch.no_grad():
          self.conv4.weight.fill_(1.0)


    def forward(self, x):
        x = self.conv1(x)
        x = self.ac1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.ac2(x)
        x = self.pool2(x)
        x = self.conv3(x)
        x = self.ac3(x)
        x = self.conv4(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self.ac4(x)
        x = self.fc2(x)
        return x

my_model_new = LeNet5_new(5)

my_model_new_state_dict = my_model_new.state_dict()

# Below is the code to copy the weights from old state-dict to the new one
for k,v in my_model_new_state_dict.items():
  if(k in my_model_state_dict):
    my_model_new_state_dict[k] = my_model_state_dict[k]

# my_model_new_state_dict['conv4.weight']  # Checked weights are all 1s for that layer

# Loaded the State-dict
my_model_new.load_state_dict(my_model_new_state_dict)

# Passed through the same output as earlier model but found them to be different
my_model_new(ran)

I also added hooks to see if the output till ac3 is same or not for both the models and it turned out to be same.
but when we pass through that particular layer then the change happens.

I also tried to see if it happens when I create a layer seperately

md = nn.Conv2d(1, 1, kernel_size=[1, 1], stride=(1, 1), bias=False)

with torch.no_grad():
  md.weight.fill_(1.0)

created random input

inp = torch.randn(1, 1, 2, 2)
inp
tensor([[[[-0.5713,  0.0483],
          [ 0.3396,  0.9315]]]])

passed through that layer

ap = md(inp)
ap
tensor([[[[-0.5713,  0.0483],
          [ 0.3396,  0.9315]]]], grad_fn=<MkldnnConvolutionBackward>)

These were same.
There was no effect on passing through that layer.

How can I achieve the same thing when I add it into the model

Ok I get it now.
I think if we have the weights as like one hot encodings then we can get the output same as input as it will give only the specific channel and we can do this for all the channels

The problem with your small example is that you are using in_channels=out_channels=1, which would return the same output.
However, as explained in the previous post, this won’t be the case for multiple input channels, since the convolution will sum over the in_channels dimension:

md = nn.Conv2d(3, 3, kernel_size=[1, 1], stride=(1, 1), bias=False)
with torch.no_grad():
  md.weight.fill_(1.0)

x = torch.ones(1, 3, 4, 4)
out = md(x)
print(out)
> tensor([[[[3., 3., 3., 3.],
          [3., 3., 3., 3.],
          [3., 3., 3., 3.],
          [3., 3., 3., 3.]],

         [[3., 3., 3., 3.],
          [3., 3., 3., 3.],
          [3., 3., 3., 3.],
          [3., 3., 3., 3.]],

         [[3., 3., 3., 3.],
          [3., 3., 3., 3.],
          [3., 3., 3., 3.],
          [3., 3., 3., 3.]]]], grad_fn=<MkldnnConvolutionBackward>)

You could use a grouped convolution instead (but it might not fit your use case):

md = nn.Conv2d(3, 3, kernel_size=[1, 1], stride=(1, 1), bias=False, groups=3)