Maybe there’s another way of achieving my underlying goal, but in order to create tensors of the appropriate cudaness in my modules, I’ve been doing something like:
class Foo(nn.Module):
def __init__(self):
super().__init__()
self.torch_constr = torch
def cuda(self):
super().cuda()
self.torch_constr = torch.cuda
def forward(self, x):
state = self.torch_constr.FloatTensor(...).zero_()
...
The problem I get is that the cuda method seems to only get called on the top-level class? Not sure that it’s called on the children? What are standard way(s) to handle this?
(edit: I’ve gone back to doing torch_constr = torch.cuda if x.is_cuda else torch
for now)