Identity Convolution Weights for 3 Channel Image

Hi,
I wanted to initialize a Conv2D Kernel which returns the same image after running the image through it. This is the code I am currently using

wts = np.zeros((3,3,3,3))
nn.init.dirac_(torch.from_numpy(wts))
with torch.no_grad():
  conv_layer.weight = nn.Parameter(torch.tensor(wts,dtype=torch.float))
new_img = conv_layer(img_tensor).detach()

However, whenever I run the image through the Conv Layer, the new image always is disturbed in certain pixels. The means of the two images are also different. Could you guide me as to the best way to create a Convolution Filter with weights that return the same image back? I know for the single channel image case, the weights would simply be

[0, 0, 0
 0, 1, 0
 0, 0, 0]

but I’m unsure what it should be for a 3 channel image

You would need to use a grouped convolution where the posted kernel would be repeated in dim0 (the out_channel dimension) as seen here:

# setup
in_channels  = out_channels = 7

# init
wts = torch.zeros(1, 1, 3, 3)
nn.init.dirac_(wts)
wts = wts.repeat(out_channels, 1, 1, 1)

conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size=3, bias=False, padding=1, groups=out_channels)
with torch.no_grad():
  conv_layer.weight.copy_(wts)

# test
batch_size = 16
h, w = 224, 224
x = torch.randn(batch_size, in_channels, 224, 224)
new_img = conv_layer(x)

print((new_img - x).abs().max())
# tensor(0., grad_fn=<MaxBackward1>)