Currently, I have to pass a device parameter into my custom layer and then manually put tensors onto the specified device manually using .to(device) or device=device.
Is this behavior expected? It looks kind of ugly to me.
Shouldn’t model.to(device) put all the layers, including my custom layer, to device for me?
Ok I found out the problem. It is not with weight that has been declared as a Parameter. It is with a map that I define in __init__ function. I need this map to also be put into the same cuda device as weight but I do not want it to be a Parameter.
Sorry if new to this, how does register_buffer handle the .to(device) problem here, such that the custom module goes on the same device as the model it is part of. Thank you so much!
The .to() method will be applied on all internal _parameters and _buffers as can be seen here and here.
The _parameters and _buffers attributes will be registered, if you set an nn.Parameter as an attribute or use self.register_buffer. Have a look at this example:
Is there any way to make sure that tensors created in the forward method are also on the appropriate device without passing in the device explicitly? Do you need to register buffers in the forward pass in this case? Seems kinda weird…maybe you could say something like myForwardTensor.to(self.device)?
One thing to be careful with here - make sure that the you don’t assign the tensor to a variable, and then register the variable as a buffer, like this:
# Incorrect way
self.my_tensor = torch.randn(1)
self.register_buffer('my_buffer', my_tensor)
If you do this, then try to access the original variable instead of the buffer, you will find that the device has not propagated to the variable.
def forward(self, inputs):
inputs = inputs / self.my_tensor # ERROR: self.my_tensor is on cpu
This is the right way to do it:
def __init__(self):
self.register_buffer('my_buffer', torch.randn(1))
...
def forward(self, inputs):
inputs = inputs / self.my_buffer # This works. self.my_buffer is on gpu
That’s a good point and thanks for sharing.
Since you’ve tagged me I assume one of my code snippets shows this behavior? (I can’t find it here, so could you send me a link to it so that I could add a comment or correct it?)
Nope, your code is correct. I was following your code and ran into this problem because I implemented register_buffer slightly differently than you did. So I thought I’d share to help others avoid my mistake