Suppose I have a (fake) module that looks like this:
class MyLinearModule(nn.Module): def __init__(self): super().__init__() self.mask = torch.ones(1, 100) self.mask[:, :50] = 0 self.kernel = nn.Parameter(torch.rand(1, 100)) def forward(self, x): return x @ (self.kernel * self.mask)
This code is fine for CPU but doesn’t work on GPU because
self.mask is a CPU tensor and
.cuda() won’t move it since it’s not a buffer or parameter. Fine. So we add the following method to
[...] def cuda(self, device=None): self.mask = self.mask.cuda(device) return super().cuda(device)
Which gets the job done. Then we decide to train on multiple GPUs by wrapping our module with
DataParallel and we get this failure:
File "test.py", line 13, in forward return x @ (self.kernel * self.mask) RuntimeError: expected device cuda:1 but got device cuda:0
I could write
return x @ (self.kernel * self.mask.to(self.kernel.device)) but it seems silly to copy a constant from CPU to each GPU on every forward call. I feel like I’m missing something fundamental here.
What’s the right way to make sure
self.mask is on the appropriate device?