Dictionary model inputs .to(device) issue

For context, my model consists of 2 separate NNs in which their outputs are added together to give the final output. The input (generated from the dataloader) is a dictionary with keys corresponding to the respective NN it needs to go to. The values of a corresponding key are tensors that are fed through the corresponding NN. My issue arises when trying to send the inputs to the device (cuda, specifically) as the .to method only works on tensors. Is there a way I can get around this? Can I simply feed the inputs to the model (output=model(inputs)) and have the model, once it reads the dictionary values, send the tensor values to device? Or does that not work? Thanks!

It’s not very orthodox but you can send them to the gpu inside the forward function. Note that this won’t be compatible with dataparallel.

Anyway it seems complicated to generate a dictionary, if you have 2 NN why don’t you code the dataloader to load 2 set of inputs?


You would have to define this inside your network. A really basic example would be the following:

class CustomModule(torch.nn.Module):
    def __init__(self):
        self.linear1 = torch.nn.Linear(3, 1)
        self.linear2 = torch.nn.Linear(3, 1)

    def forward(self, input_dict: dict):
        out1 = self.linear1(input_dict["input_1"].to(self.linear1.weights.device))
        out2 = self.linear2(input_dict["input_2"].to(self.linear2.weights.device))

        return out1 + out2.to(out1.device)

Note: You don’t want to hardcode your devices here, since this would brake if you change your network device.

Also as @JuanFMontesinos mentioned, this would work for simple NNs only (to get this working - which is possible in general - a little bit more effort would be required.

1 Like

Following up on this as I implemented the .to() in the forward function. If I hope to use DataParallel, does this method not work anymore? If so, why is that the case? Do I need to figure out a way to construct my data such that I send it to the device before being fed into the forward function?


DataParallel multiplexes the whole model across different GPUs. Saying that, the code would still work, but the .to operations would became a no-op, since the devices of self.linear1 and self.linear2 would be the same.

Hmm, I don’t think I quite follow. Does that mean I don’t have to worry about trying to send my data to the GPU since the model is there and the inputs will just be fed in? I’m struggling to see what I would need to modify as to ensure I can parallelize properly and efficiently.

The problem is, that DataParallel tries to split your batch (the pytorch tensor). This would cause the error, since it can’t handle dicts. You’d have to move the device mapping outside the model.

1 Like

Ahhh that makes sense now. Thank you!

Well it’s not the main problem @justusschock. If you set inside the model the gpu id, that overwrites dataparallel splits, thus, it will throw an error saying data is allocated in another gpu.

yes you are right, but this should work, since I did not hardcode any of the devices, but referenced the devices of the processing layers.

1 Like

Indeed, I did not check the code :slight_smile: