Proper way to write modules portable with .to(device)

I would like to know the correct way to write a module that will avoid RuntimeError: Expected all tensors to be on the same device. (I am using torch.jit.script if this is important)

I recently asked a related question about the .to(device) function and was helpfully told that “valid members for this operation is other nn.Modules, Parameters and Buffers”. What should you do with other kinds of data?

Here is an illustration of what I would like to do:

class TestModule(nn.Module):

    def __init__(self, net, param_float, param_tensor):
        super().__init__()
        self.net = net
        self.tensor_ = nn.Parameter(param_tensor)
        self.float_ = nn.Parameter(torch.tensor(param_float))
        self.tuple_of_tuples_ = ((1, 4), (1, 5))
        self.tensor_list_ = []
        
    def forward(self, x):
        # something that uses all module attributes

Many thanks!

As already described, parameters, buffers, and other modules will be registered to the parent module and thus the to() calls will also be applied on these.
Using plain tensors as attributes in your module (or lists etc.) won’t have such an effect.
The main question would be why you would use these tensors, which are neither parameters nor buffers nor are they inputs to the forward method?

In my forward function, the first step is a featurization which looks like this:

x = torch.stack([
       torch.sqrt(torch.sum(torch.square(r[j] - r[i])))
       for (i, j) in self.pairs
       ]).reshape(1, -1)

where self.pairs is a Tuple[Tuple[int]] (I’m not sure how to write that properly for type hints.)

Another example that has me stumped is a growing list of tensors. Occasionally, during some forward steps, this list will have a tensor appended to it. This list is used during calls to forward.

Lastly, I have constants. For these, it appears I want nn.Parameter(param, requires_grad=False)

Is it true that these kinds of features weren’t intended to be used in a nn.Module? I might be missing something bigger. Thanks for your help!

It depends a bit on your use case and where these tensors are coming from.
Generally you could use:

  • trainable parameters registering them to the model via nn.Parameter(List)
  • Non-trainable tensors via self.regiter_buffer (or set the requires_grad attribute of a parameter to False): this would properly register the tensors into the model and they would also show up in the state_dict, so you could save/load them
  • Non-trainable tensors via self.register_buffer(tensor, persistent=False) which will be registered into the model, but won’t show up in the state_dict
  • Pass tensors to the forward method (with the right dtype and on the right device

However, if you need to create new tensors in the forward, you could use the .device attribute of any parameter (assuming they are all pushed to the device already). In your use case I don’t know how self.pairs is treated, but as already described a plain list won’t be registered, so you would need to push these tensors to the device manually.

1 Like

Thank you again. So for the case of the integer tuple, I can imagine a slightly hacky way to substitute a tensor with dtype int and register it as a buffer.

I’m still mystified about the growing list of tensors. Could I self.register_buffer an empty tensor and safely reassign it during code execution during forward so that it can grow longer? Like this

    self.register_buffer('a', a)

def forward(self, x: torch.Tensor, extend: bool):
    if extend:
        new_tensor = ...
        self.a = torch.stack(self.a, new_tensor)

I think this use case should be possible if you make sure the shapes are as expected.
Here is a small example, which concatenates a new tensor to the internal buffer and checks that it’s also added to the state_dict:

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer('a', torch.ones(1, 1))

    def forward(self, x: torch.Tensor, extend: bool):
        if extend:
            new_tensor = torch.randn(1, 1)
            self.a = torch.cat([self.a, new_tensor], dim=0)
        return x

model = MyModel()
x = torch.randn(1, 1)

for _ in range(10):
    out = model(x, extend=True)
    print(model.state_dict())
1 Like