Why no nn.BufferList like function for registered buffer tensor?

For running mean/variance like tensor, we can use self.register_buffer(name, tensor) to manger it.
However, self.register_buffer must use name to refer to the tensor. In some case, I have a list of buffer tensors and I dnt want to use name to manager these buffers, for instance:

class MyModule(nn.Module):
  def __init__(n):
         self.params = [tensor(3, 3) for i  in range(n)]
         for i,p in enumerate(self.params):
             self.register_buffer('tensor'+str(i), p)

and then I want to visit self.params. However, once we call MyModule().to('cuda'), the tensor in self.params still point to initial cpu tensor, and we need to use self.tensor0/1/... to refer to cuda tensor. That’s not I wanted.
I wish to refer to these tensor by self.params still.
Could anyone help me.

1 Like
class MyModule(nn.Module):
  def __init__(n):
         self.params = [] 
         _params = [tensor(3, 3) for i  in range(n) ]
         for i,p in enumerate(_params):
             self.register_buffer('tensor'+str(i), p)
             self.params.append(getattr(self, 'tensor' + str(I))) 
         

Should work. I haven’t tested it with modifying the params but since lists usually only store references it should work this way

2 Likes

This would not work, you will obtain something like ‘Overflow when unpacking long’ (because params is pointing to un-initialized values that are too long to be translated into int for the print).

I would suggere something like this for a simple access from index:

class MyModule(nn.Module):
    def __init__(self,n):
        super(MyModule, self).__init__()
        _params = [torch.Tensor(3, 3) for i  in range(n) ]
        for i,p in enumerate(_params):
            self.register_buffer('tensor'+str(i), p)
 
    def params(self, i):
        return self.__getattr__('tensor'+str(i))

And if you want the table structure in order to loop, then you can return an iterator with yield:

class MyModule(nn.Module):
    def __init__(self,n):
        super(MyModule, self).__init__()
        self.n = n
        _params = [torch.Tensor(3, 3) for i  in range(n) ]
        for i,p in enumerate(_params):
            self.register_buffer('tensor'+str(i), p)
 
    def params(self):
        for i in range(self.n):
            yield self.__getattr__('tensor'+str(i))

The code @justusschock posted works for the purpose of registering buffers. The printing error you saw is already fixed.

2 Likes

Ah cool ! I should update then

When I move the module to cuda the registered tensors becomes cuda tensors but the tensors in the list of self.params stays in cpu. This is not expected, I guess.

1 Like

It does not work in Pytorch 1.6

Your second implementation with the generator worked for me when creating “list buffer tensors” that work with pytorch lightning trainers. :raised_hands:

I required a while loop to fully replicate the functionality I needed from a list, otherwise, StopIteration errors occur on repeated access.

class MyModule(nn.Module):
    def __init__(self,n):
        super(MyModule, self).__init__()
        self.n = n
        _params = [torch.Tensor(3, 3) for i  in range(n) ]
        for i,p in enumerate(_params):
            self.register_buffer('tensor'+str(i), p)
 
    def params(self):
        while True:
            for i in range(self.n):
                yield self.__getattr__('tensor'+str(i))

For anyone reading this in the future, just use this code (because I’ve tested it) and understand that one of the benefits of using buffer is to automatically update dtype and device of your float attributes when .to() is called on the module.
Copy it and try it in your notebook.

import torch
from torch import nn

class MyModule(nn.Module):
    def __init__(self, n):
        super().__init__()
        self.n = n
        params = [torch.randn(3, 3) for i  in range(n) ]
        for i,p in enumerate(params):
            self.register_buffer(f"param_{i}", p)

    @property
    def params(self):
        return [getattr(self, f"param_{i}") for i in range(self.n)]


m = MyModule(3)
print(m.params[0].dtype, m.params[0].device)  # torch.float32 cpu
m.to("cuda")
print(m.params[0].dtype, m.params[0].device)  # torch.float32 cuda:0
m.to(torch.float64)
print(m.params[0].dtype, m.params[0].device)  # torch.float64 cuda:0