self.expected_moved_cuda_tensor is neither a parameter nor a buffer, that’s why it’s device is unchanged. If you want to create a parameter and use it then you can do it as follows-
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear1 = nn.Linear(2, 1)
self.linear1.weight = torch.nn.Parameter(torch.ones(2, 1))
self.linear1.bias = torch.nn.Parameter(torch.zeros(1))
def forward(self, x):
x = self.linear1(x)
return x
You can even use those parameters in forward method like-
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.weight = torch.nn.Parameter(torch.ones(2, 1))
self.bias = torch.nn.Parameter(torch.zeros(1))
def forward(self, x):
# linear regression completely from scratch,
# using parameters created in __init__
x = torch.mm(x, self.weight) + self.bias
return x
And moving above model to .cuda() does move model parameters to cuda-
model = Model()
model.cuda()
print(model.weight.device) # prints device(type='cuda', index=0)
print(model.bias.device) # prints device(type='cuda', index=0)
Though .cuda() “should” do as you said, but I don’t think changing devices for all the torch.tensor attributes of a class inherited from nn.Module by default is good idea. In your use case it might be helpful, but in some case user may don’t that, so I think that’s the reason why it ain’t do that by default.
One more thing if you want to create just a constant tensor (not a parameter) then you can do that as
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.weight = torch.nn.Parameter(torch.zeros(2, 1))
self.bias = torch.nn.Parameter(torch.zeros(1))
self.register_buffer('a_constant_tensor', torch.tensor([0.5]))
def forward(self, x):
# linear regression completely from scratch,
# using parameters created in __init__
x = torch.mm(x, self.weight) + self.bias + self.a_constant_tensor
return x
model = Model().cuda()
Doing this wouldn’t consider self.a_constant_tensor as a parameter, so printing parameters wouldn’t return self.a_constant_tensor -
for param in model.parameters():
print(param)
# Only prints about self.weight and self.bias
'''
Parameter containing:
tensor([[0.],
[0.]], device='cuda:0', requires_grad=True)
Parameter containing:
tensor([0.], device='cuda:0', requires_grad=True)
'''