How to 'hook' into .cuda() method, or detect that .cuda() has been called?


(Hugh Perkins) #1

Maybe there’s another way of achieving my underlying goal, but in order to create tensors of the appropriate cudaness in my modules, I’ve been doing something like:

class Foo(nn.Module):
    def __init__(self):
        super().__init__()
        self.torch_constr = torch

    def cuda(self):
        super().cuda()
        self.torch_constr = torch.cuda

    def forward(self, x):
        state = self.torch_constr.FloatTensor(...).zero_()
        ...

The problem I get is that the cuda method seems to only get called on the top-level class? Not sure that it’s called on the children? What are standard way(s) to handle this?

(edit: I’ve gone back to doing torch_constr = torch.cuda if x.is_cuda else torch for now)


(Simon Wang) #2

can’t you just torch.zeros(...., device=x.device)?


(Hugh Perkins) #3

Thanks! That’s probably what I should do :slight_smile: