I am trying to break down a conv2d into two smaller convs along the out_channels dimension.
The output of the smaller convs are concatenated along the out_channels dimension to create the original output.
I copy the weights from the original conv and try to compare the outputs and gradients. I observe that while the outputs are the same, the gradients have a slight error (around 2e-7). Although this error seems to be small, during a long training (e.g. Resnet18 on CIFAR-10) it causes the network to diverge.
Here’s the sample code for you to try. The printed diff between grads is not zero.
import torch
import torch.nn as nn
device = 'cuda'
c_out = 513
c_temp = 1
conv = nn.Conv2d(2, c_out, kernel_size=2, bias=False).to(device)
conv1 = nn.Conv2d(2, c_temp, kernel_size=2, bias=False).to(device)
conv2 = nn.Conv2d(2, c_out - c_temp, kernel_size=2, bias=False).to(device)
with torch.no_grad():
conv1.weight[:] = conv.weight[:c_temp]
conv2.weight[:] = conv.weight[c_temp:]
inp = torch.rand(1,2,3,3).to(device)
out = conv(inp)
out1 = conv1(inp)
out2 = conv2(inp)
cat = torch.cat((out1, out2), dim=1)
grad = torch.ones(1,c_out,2,2).to(device)
cat.backward(grad)
out.backward(grad)
print(conv.weight.grad[:c_temp] - conv1.weight.grad)