How to 'hook' into .cuda() method, or detect that .cuda() has been called?

Maybe there’s another way of achieving my underlying goal, but in order to create tensors of the appropriate cudaness in my modules, I’ve been doing something like:

class Foo(nn.Module):
    def __init__(self):
        super().__init__()
        self.torch_constr = torch

    def cuda(self):
        super().cuda()
        self.torch_constr = torch.cuda

    def forward(self, x):
        state = self.torch_constr.FloatTensor(...).zero_()
        ...

The problem I get is that the cuda method seems to only get called on the top-level class? Not sure that it’s called on the children? What are standard way(s) to handle this?

(edit: I’ve gone back to doing torch_constr = torch.cuda if x.is_cuda else torch for now)

can’t you just torch.zeros(...., device=x.device)?

Thanks! That’s probably what I should do :slight_smile: