Suppose I have a (fake) module that looks like this:
class MyLinearModule(nn.Module):
def __init__(self):
super().__init__()
self.mask = torch.ones(1, 100)
self.mask[:, :50] = 0
self.kernel = nn.Parameter(torch.rand(1, 100))
def forward(self, x):
return x @ (self.kernel * self.mask)
This code is fine for CPU but doesn’t work on GPU because self.mask
is a CPU tensor and .cuda()
won’t move it since it’s not a buffer or parameter. Fine. So we add the following method to MyLinearModule
:
[...]
def cuda(self, device=None):
self.mask = self.mask.cuda(device)
return super().cuda(device)
Which gets the job done. Then we decide to train on multiple GPUs by wrapping our module with DataParallel
and we get this failure:
File "test.py", line 13, in forward
return x @ (self.kernel * self.mask)
RuntimeError: expected device cuda:1 but got device cuda:0
I could write return x @ (self.kernel * self.mask.to(self.kernel.device))
but it seems silly to copy a constant from CPU to each GPU on every forward call. I feel like I’m missing something fundamental here.
What’s the right way to make sure self.mask
is on the appropriate device?