Clone a variable on multiple GPUs

Hi,

In my code, I want to keep the latest input feature map to the layer and subtract it with the input feature map and then update the latest value.
The code is working on a single GPU, but I get the following error when running on multiple GPUs. Any thoughts?

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!

class Subtraction(nn.Module):
    def __init__(self, channel,width,height):
        super(Subtraction, self).__init__()
        
        self.latest_input = Variable(torch.zeros(channel,width,height), requires_grad=False).cuda()

    def forward(self, input):
        #difference of feature maps
        output = input - self.latest_input
        self.latest_input = input.clone()
        return output

If you are using DataParallel or DistributedDataParallel, remove the cuda() call in self.latest_input.
These wrappers will push the model replica to the corresponding device for you, while you are explicitly creating this tensor on cuda:0, which could yield the device mismatch error.

PS: Variables are deprecated since PyTorch 0.4 so you can use tensors directly in newer versions. :wink:

Thanks for your help. I’m using DataParallel. I removed “.cuda()” in “self.latest_input” statement, but It seems that “input.clone()” copy the tensor on CPU as I get the following error.

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Sorry for forgetting the second part of the right approach. :wink:

The new error is not raised by input.clone(), but should still be raised by self.latest_input, which is now on the CPU as it wasn’t properly registered.

To make sure model.to() pushes all internal parameters and buffers to the desired device you have to register self.latest_input as an nn.Parameter, if it should be trained, or as a buffer if it shouldn’t:

self.latest_input = nn.Parameter(torch.zeros(channel, width, height)
# or
self.register_buffer('latest_input', torch.zeros(channel, width, height))
1 Like

Thanks @ptrblck. It worked :slightly_smiling_face:.

1 Like