How does Module.cuda() work and how to extend new modules that can use .cuda() to put its data to gpu?

I want to write a loss function which will create new variables according to the shape of input data. The matter is when I use my loss with loss.cuda(0), the data in it won’t move to gpu device, and if the input data is on gpu it will occur error.

I referred to the code of loss functions already exists in PyTorch and noticed there is a class named _WeightedLoss and is used in NLLLoss. But because the weight NLL use is created in its init and needn’t adjust according to the input data, I don’t think this way is suitable for my problem.

The function I want to write can approximately described like below:

class exampleLoss(nn.Module):
    def __init__(self):
        super(exampleLoss,self).__init__()
    def forward(self,input, groundTruth):
        data1=torch.zeros_like(input)
        #some computation
        return result

I also noticed that the net I wrote before which also inherit from nn.Module can use net.cuda() to put all the layers and parameters to gpu even I don’t use functions like nn.ModuleList or nn.register_buffer. What’s the mechanism PyTorch used for collecting members of one class and make them available for cuda() operation?

Briefly, I want to use .cuda() to put my loss function which will create variables during forward() to gpu and is curious about the mechanism behind Module.cuda() operation.

Any help is appreciated. Thanks!

1 Like

Hi,

When you call .cuda() all the parameters and buffers of the module are moved to the gpu.
Parameters are everything that you saved as self.foo = nn.Parameter(args).
Buffers are any Tensor that you saved on self as self.bar = torch.rand(10).
All this is done in the __setattr__ functions of the Module if you want to check the implementation details.

If you create new tensors during the forward, it’s your responsability to create them with the right type. The set of functions torch.*_like(input) and input.new_*(new_tensor_size) can be used to easily create new tensors with the same type/device(/size) as other tensors.

5 Likes

Thanks for your reply. I haven’t noticed the device of torch.*_like(input) is also the same as that of input until you remind me. I will check __setattr__ for details. And that’s the second time you answered me. Thanks again!:grin:

Hi,

Does model_cuda = model_cpu.cuda() also copy the model_cpu.state_dict() to model_cuda?

Thanks,
Bandhav

Hi,

The .cuda() of is inplace for Modules. So you can simply do model.cuda() to transform model from cpu to cuda.

1 Like

Oh right! I just got to know that! Interesting that tensor.cuda() returns a copy of the tensor but model.cuda() is in-place.