Maybe there’s another way of achieving my underlying goal, but in order to create tensors of the appropriate cudaness in my modules, I’ve been doing something like:
class Foo(nn.Module): def __init__(self): super().__init__() self.torch_constr = torch def cuda(self): super().cuda() self.torch_constr = torch.cuda def forward(self, x): state = self.torch_constr.FloatTensor(...).zero_() ...
The problem I get is that the cuda method seems to only get called on the top-level class? Not sure that it’s called on the children? What are standard way(s) to handle this?
(edit: I’ve gone back to doing
torch_constr = torch.cuda if x.is_cuda else torch for now)