Let’s say I make a custom module as follows
import torch
import torch.nn as nn
class MyModule(nn.Module):
def __init__(self)
super(MyModule, self).__init__()
self.t = torch.rand(10)
def forward(self, x):
x = x + self.t
print self.t.device # For demoing my question
return x
Now, this module is included in some other module, say my network:
class MyNetwork(nn.Module):
def __init__(self)
super(MyOtherModule, self).__init__()
self.mymod = MyModule()
def forward(self, x):
x = self.mymod(x)
return x
When I create an instance of MyNetwork
and put it on the GPU, why don’t the fields of MyModule
get put on the same GPU?
For example:
device = torch.device(0)
net = MyNetwork().to(device)
net(torch.rand(10).to(device))
# Prints out
>> cpu
This is particularly frustrating when I want to use nn.DataParallel
, as I always have to make sure to copy the data onto the right GPU (based on the input’s device).
Is there a better way to do this?