Model.cuda() does not convert all variables to cuda

Hi, so i am trying to write an architecture where i have to convert entire models to cuda using model.cuda(). However, some of the elements are variables initialised in the init() loop of nn.Module() class. How do i convert them to cuda ? For example,

class Net(nn.Module):
    def __init__(self):
        self.xyz=torch.tensor([1,2,3,4...])   
        # Convert this to cuda without using .cuda() on tensor xyz, but by using model.cuda()
        .
        .
    def forward(self,x):
        .
        .

Is it possible to do this?

torch.tensor([1,2,3]).cuda() :slight_smile:

EDIT: Wait, I just read your comment! At what point in your script you want to push those tensors in cuda? Why don’t you want to use .cuda()?

I want to do this:

model=Net()
model.cuda() ##Convert everything to cuda()

I dont want to use .cuda() as later on i will be migrating the code to a CPU only device where i would need to test on CPU.

EDIT: I will train on GPU and test on CPU, so the weights will need to be .cpu()

I see, I also run stuff both on the CPU and GPU. My solution generally is something like

def make_cuda(fun):
    return fun.cuda() if torch.cuda.is_available() else fun

class Net(nn.Module):
    def __init__(self):
        self.xyz = make_cuda(torch.tensor([1,2,3,4...]))
        # Convert this to cuda without using .cuda() on tensor xyz, but by using model.cuda()
        .
        .
    def forward(self,x):
        .
        .

This will automatically push stuff to cuda only if you are ona machine with a GPU.

1 Like

Wont work, as i would be packaging the model and converting it to tf-lite and stuff…I figured out something that would work though:

model=model.cpu()
model.xyz=model.xyz.cpu()

While saving the weights

This will push tensors to cuda when you call model.cuda():


import torch
import torch.nn as nn
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.xyz = torch.tensor([1,2,3,4])

    def forward(self,x):
        pass

    def cuda(self):
        super().cuda()
        for k, v in self.__dict__.items():
            if isinstance(v, torch.Tensor):
                v.cuda()
       

a = Net()
a.cuda()

edit: I mistakenly wrote self.cuda() instead of super().cuda. Fixed.

1 Like

While this approach would work, the proper way to register tensors inside an nn.Module would be to either use nn.Parameter (if this tensor requires gradients and should be trained) or via self.register_buffer. Both approaches will make sure that this tensor will be pushed to the specified device (in model.to(device)) and will also be added to the state_dict (which would be important if you want to save and load this model).

CC @Hmrishav_Bandyopadhy

2 Likes

Can you add an example of self.register_buffer for this use-case ?

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        xyz = torch.tensor([1, 2, 3, 4], dtype=torch.float)
        self.xyz = nn.Parameter(xyz)

    def forward(self, x):
        return x

model = Net()
model.xyz
"""
Parameter containing:
tensor([1., 2., 3., 4.], requires_grad=True)
"""

model.cuda()
model.xyz
"""
Parameter containing:
tensor([1., 2., 3., 4.], device='cuda:0', requires_grad=True)
"""
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        xyz = torch.tensor([1, 2, 3, 4], dtype=torch.float)
        self.register_buffer("xyz", xyz)

    def forward(self, x):
        return x

model = Net()
model.xyz
"""
tensor([1., 2., 3., 4.])
"""

model.cuda()
model.xyz
"""
tensor([1., 2., 3., 4.], device='cuda:0')
"""

1 Like