I ran into an issue related to the DataParallel class I’m not sure how to solve. Here’s a minimal example:

```
import torch
import torch.nn as nn
from torch.nn import functional as F
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class Model(nn.Module):
def __init__(self, input_size, output_size):
super(Model, self).__init__()
self.fc = nn.Linear(input_size, output_size)
'''
def forward(self, input, param):
x = self.fc(input)
return x
'''
def forward(self, input, param):
w, b = param
print("Shape of w: ", w.shape)
output = F.linear(input, w, b)
return output
input = torch.randn(10, 3).cuda()
w, b = torch.randn(4, 3).cuda(), torch.randn(4).cuda()
model = Model(3, 4)
model.to(device)
# Single GPU
output = model(input, [w, b])
print("Outside: input size", input.size(), "output_size", output.size())
# DataParallel
model = nn.DataParallel(model)
output = model(input, [w, b])
print("Outside: input size", input.size(), "output_size", output.size())
```

Let’s say somehow I need to redefine the forward method with nn.functional calls, so every forward propagation comes with newly-defined input parameters. The problem is when I try to parallelize the model, the DataParallel class seems to split up not only the data but also my input parameters, causing above output sizes to be diffferent.

This is what I get with 2 GPUs available:

Shape of w: torch.Size([4, 3])

Outside: input size torch.Size([10, 3]) output_size torch.Size([10, 4])

Shape of w: torch.Size([2, 3])

Shape of w: torch.Size([2, 3])

Outside: input size torch.Size([10, 3]) output_size torch.Size([10, 2])

How can I fix this?

I’m doing this for a MAML implementation and this seems the only way to do it, as load_state_dict() aren’t able to perserve previous computational graphs.