Module.cuda() not moving Module tensor?

class ToyModule(torch.nn.Module):
    def __init__(self) -> None:
        super(ToyModule, self).__init__()
        self.layer = torch.nn.Linear(2, 2)
        self.expected_moved_cuda_tensor = torch.tensor([0, 2, 3])

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return self.layer(input)
toy_module = ToyModule()
toy_module.cuda()

When we call .cuda() all the parameters and buffers of the module are moved to the GPU:

next(toy_module.layer.parameters()).device
>>> device(type='cuda', index=0)

But when we inspect the tensor attribute of toy_module, we see device(type='cpu')?

toy_module.expected_moved_cuda_tensor.device
>>> device(type='cpu')

Is this expected or am I missing anything? Thank you.

self.expected_moved_cuda_tensor is neither a parameter nor a buffer, that’s why it’s device is unchanged. If you want to create a parameter and use it then you can do it as follows-

import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = nn.Linear(2, 1)
        self.linear1.weight = torch.nn.Parameter(torch.ones(2, 1))
        self.linear1.bias = torch.nn.Parameter(torch.zeros(1))

    def forward(self, x):
        x = self.linear1(x)
        return x

You can even use those parameters in forward method like-

import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.weight = torch.nn.Parameter(torch.ones(2, 1))
        self.bias = torch.nn.Parameter(torch.zeros(1))

    def forward(self, x):
        # linear regression completely from scratch,
        # using parameters created in __init__
        x = torch.mm(x, self.weight) + self.bias
        return x

And moving above model to .cuda() does move model parameters to cuda-

model = Model()
model.cuda()
print(model.weight.device) # prints device(type='cuda', index=0)
print(model.bias.device) # prints device(type='cuda', index=0)
1 Like

Thanks a lot!

But isn’t it defeat the intuition of .cuda() if the Module tensor device stays unchanged?

Though .cuda() “should” do as you said, but I don’t think changing devices for all the torch.tensor attributes of a class inherited from nn.Module by default is good idea. In your use case it might be helpful, but in some case user may don’t that, so I think that’s the reason why it ain’t do that by default.

One more thing if you want to create just a constant tensor (not a parameter) then you can do that as

self.a_constant_tensor = nn.Parameter(torch.ones(2, 1), requires_grad = False)

and then use it in forward method.

OR you can use buffers, "which is recommended"

import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.weight = torch.nn.Parameter(torch.zeros(2, 1))
        self.bias = torch.nn.Parameter(torch.zeros(1))
        self.register_buffer('a_constant_tensor', torch.tensor([0.5]))

    def forward(self, x):
        # linear regression completely from scratch,
        # using parameters created in __init__
        x = torch.mm(x, self.weight) + self.bias + self.a_constant_tensor
        return x


model = Model().cuda()

Doing this wouldn’t consider self.a_constant_tensor as a parameter, so printing parameters wouldn’t return self.a_constant_tensor -

for param in model.parameters():
    print(param)
# Only prints about self.weight and self.bias
'''
Parameter containing:
tensor([[0.],
        [0.]], device='cuda:0', requires_grad=True)
Parameter containing:
tensor([0.], device='cuda:0', requires_grad=True)
'''
1 Like