Using DataParallel with functional API

Let’s say I have the following example (modified from [Data parallel tutorial])(https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html#simple-model )

class Model(nn.Module):
    # Our model

    def __init__(self, input_size, output_size):
        super(Model, self).__init__()
        self.fc = nn.Linear(input_size, output_size)
        self.f = torch.ones(1)

    def forward(self, input):
        output = self.fc(input) * self.f
        print(self.f)
        print("\tIn Model: input size", input.size(),
              "output size", output.size())

        return output

wrapped in a nn.DataParallel.
This results in this error when doing a forward pass, because of

RuntimeError: Expected object of type torch.cuda.FloatTensor but found type torch.FloatTensor for argument #2 'other'

If I try to call .cuda() on the f field, it goes on the first cuda device, and then the forward pass does not work because they are on different devices:

class Model(nn.Module):
    # Our model

    def __init__(self, input_size, output_size):
        super(Model, self).__init__()
        self.fc = nn.Linear(input_size, output_size)
        self.f = torch.ones(1).cuda()
        

    def forward(self, input):
        output = self.fc(input) * self.f
        print(self.f)
        print("\tIn Model: input size", input.size(),
              "output size", output.size())

        return output

Log:

RuntimeError: arguments are located on different GPUs at 

So how do I use data parallel with the functional API?

Most likely self.f is not pushed to the right device, since it is neither registered as an nn.Parameter nor as a buffer (using self.register_buffer).
Use the former case, if self.f should require gradients and the latter if not.

The manual cuda() call inside your __init__ method won’t work, as the model will be sent to each GPU which was passed as device_ids.
If you want to manually push the tensor (not necessary in your use case), you could use:

output = self.fc(input) * self.f.to(input.device)

in your forward.

Thanks! That is exactly what the problem was. I am new to PyTorch and I had no idea I had to register it, but it definitely makes sense though, how would PyTorch know otherwise.