DataParallel output differs from its module

Hi,

I am trying to do multi-task learning with two classification layers from single shared representation.
I believe, having two nn.Linear layers would not differ from having one nn.Linear layer then split,
but it occurs to me they work differently when used with nn.DataParallel.

Please see the code at the bottom.
<class ‘Split’> has one nn.Linear then splits the output.
<class ‘TwoHeads’> has two nn.Linear layers.

What I do in the main code is to compare the outputs of

nn.DataParallel(net)

to

nn.DataParallel(net).module.

For Split the output coincides.
For TwoHeads the output differs.

And the code here.

import torch
import torch.nn as nn

class Split(nn.Module):
    def __init__(self):
        super(Split, self).__init__()
        self.linear = nn.Linear(10, 5)

    def forward(self, x):
        out = self.linear(x)
        out1, out2 = torch.split(out, 3, dim=1)
        return out1, out2

class TwoHeads(nn.Module):
    def __init__(self):
        super(TwoHeads, self).__init__()
        self.linear1 = nn.Linear(10, 3)
        self.linear2 = nn.Linear(10, 2)

    def forward(self, x):
        out1 = self.linear1(x)
        out2 = self.linear2(x)
        return out1, out2

if __name__ == '__main__':
    net1 = Split()
    net2 = TwoHeads()

    for net in [net1, net2]:
        net.cuda()
        net = nn.DataParallel(net, list(range(4)))

        with torch.no_grad():
            data = torch.randn(500, 10).cuda()

            out1, out2 = net(data)
            mod1, mod2 = net.module(data)

            print(int(not torch.equal(out1, mod1)), end=' ')
            print(int(not torch.equal(out2, mod2)), end=' ')

            print((out1 - mod1).abs().max(), (out2 - mod2).abs().max())

For me, the output looks like

0 0 tensor(0., device=‘cuda:0’) tensor(0., device=‘cuda:0’)
0 1 tensor(0., device=‘cuda:0’) tensor(2.3842e-07, device=‘cuda:0’)

My questions are:

  • Is it intended?
  • If it is because the computation graph uses the same variable twice, shall I always avoid ‘branching’ the computation graph?
  • What would be the correct way to do multi-task learning with DataParallel?

The error of 1e-7 is most likely due to floating point precision (usually you expect the error for FP32 to be in ~1e-6), so it seems your code is working fine.

x = torch.randn(10, 10, 10)
sum1 = x.sum()
sum2 = x.sum(0).sum(0).sum(0)
print(sum1 - sum2)
> tensor(-3.8147e-06)

Thank you so much for your answer and the code.

I firstly thought of it too, and I repeated the same experiment many times and on another machine but
the error still occurs for <class ‘TwoHeads’> at the second output only.

Would it be still most likely the floating point error?

Since the error is that low, I would still assume it’s still due to floating point precision.