Torch.cat() and nn.Linear()

In an nn.module, with:

        self.fc1a = nn.Linear(512+2*num_attrs, 512)
        self.fc1b = nn.Linear(512, 512)
        self.fc1c = nn.Linear(2*num_attrs, 512, bias=False)

I would expect both

        xa = F.relu(self.fc1a(torch.cat((x, 0.0*xb), 1)))

and

        xa = F.relu(self.fc1b(x)+self.fc1c(0.0*xb))

to give the same results as

        xa = F.relu(self.fc1b(x))

but only the latter gives the same result while the former consistently leads to a performance degradation. The two statements involves (currently) unused parameters in the same way. Is there something wrong with the first statement? (x and xb are respectively of size 512 and 2*num_attrs in the non-batch dimension). Pytorch version is 1.5.0 and the optimizer is adam.

if your num_attrs is big, fc1a default weight initialization may be the issue, as its output variance is reduced comparing to fc1b

Good point, thanks. I missed that adding channels decreased the fan-in. However, num_attrs is 312 so math.sqrt((512+2*num_attrs)/512) is just a bit less than 1.5. Trying:

        self.fc1a = nn.Linear(512+2*num_attrs, 512)
        with torch.no_grad():
            self.fc1a.weight *= math.sqrt((512+2*num_attrs)/512)
            self.fc1a.bias *= math.sqrt((512+2*num_attrs)/512)

still did not give the same result as in the two other cases. I eventually figured out where the difference was coming from while trying this though (fc1b was used elsewhere and combined gradients helped so not using it here had a negative impact). Thanks.