DataParallel splits up input parameters

I ran into an issue related to the DataParallel class I’m not sure how to solve. Here’s a minimal example:

import torch
import torch.nn as nn
from   torch.nn import functional as F
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class Model(nn.Module):
    def __init__(self, input_size, output_size):
        super(Model, self).__init__()
        self.fc = nn.Linear(input_size, output_size)
    '''
    def forward(self, input, param):
        x = self.fc(input)
        return x
    '''
    def forward(self, input, param):
        w, b = param
        print("Shape of w: ", w.shape)
        output = F.linear(input, w, b)
        return output

input = torch.randn(10, 3).cuda()
w, b = torch.randn(4, 3).cuda(), torch.randn(4).cuda()
model = Model(3, 4)
model.to(device)

# Single GPU
output = model(input, [w, b])
print("Outside: input size", input.size(), "output_size", output.size())

# DataParallel
model = nn.DataParallel(model)
output = model(input, [w, b])
print("Outside: input size", input.size(), "output_size", output.size())

Let’s say somehow I need to redefine the forward method with nn.functional calls, so every forward propagation comes with newly-defined input parameters. The problem is when I try to parallelize the model, the DataParallel class seems to split up not only the data but also my input parameters, causing above output sizes to be diffferent.

This is what I get with 2 GPUs available:

Shape of w: torch.Size([4, 3])
Outside: input size torch.Size([10, 3]) output_size torch.Size([10, 4])
Shape of w: torch.Size([2, 3])
Shape of w: torch.Size([2, 3])
Outside: input size torch.Size([10, 3]) output_size torch.Size([10, 2])

How can I fix this?

I’m doing this for a MAML implementation and this seems the only way to do it, as load_state_dict() aren’t able to perserve previous computational graphs.

dataparallel will scatter/split input and args[w, b] to different devices, but replicate the model.parameters().

Why do you need to put [w,b] as args? model.parameters() will include them, right?

Thank you for your response. In the MAML setting, we need two sets of parameters to be preserved, so I use one of them as model.parameters() and the other loaded with nn.functional operations.

Because eventually the loss comes from forward propagations through both parameters, the computational graphs need to be preserved as well, so it can’t be done with load_state_dict operations, or storing them in separate models.

This is related to the discussion here: Gradient computation in meta-learning algorithms