[DataParallel] Get updated module attributes

It seems that DataParallel.module returns initial version of the module passed to it, e.g in the following code:

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import DataParallel


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.tmp = 0

    def forward(self, x):
        self.tmp = x
        print("Actual:", self.tmp)


net = DataParallel(Model())

for i in range(3):
    input = Variable(torch.rand(2, 1))
    net(input)
    print("Returned:", net.module.tmp)
    print("============")

I get:

Actual: Actual: Variable containing:
 0.1303
[torch.cuda.FloatTensor of size 1x1 (GPU 0)]
Variable containing:
 0.9771
[torch.cuda.FloatTensor of size 1x1 (GPU 1)]


Returned: 0
============
Actual: Actual: Variable containing:
1.00000e-03 *
  5.5766
[torch.cuda.FloatTensor of size 1x1 (GPU 0)]

Variable containing:
 0.8851
[torch.cuda.FloatTensor of size 1x1 (GPU 1)]

Returned: 0
============
Actual: Actual: Variable containing:
 0.6126
[torch.cuda.FloatTensor of size 1x1 (GPU 1)]

Variable containing:
 0.4636
[torch.cuda.FloatTensor of size 1x1 (GPU 0)]

Returned: 0
============

How can I retrieve the updated values?

1 Like

Apparently, something like this is not possible since nn.Module is stateless.

So is there a workround for this?