Submodule buffers not being reassigned when moving module to cuda()


I have a module that have multiple submodules as attributes. I have wrapped the submodules in nn.ModuleList and nn.ModuleDict. When I call .cuda() on the parent module, the buffers would be copied over to the GPU, but the submodules attributes that originally was the same object as the buffers, is not reassigned to the buffers on the GPU. How do I best fix this? Below is a snippet of a toy code

import torch

class InnerToyModule(torch.nn.Module):
    def __init__(self, x):
        super(InnerToyModule, self).__init__()
        self.ten = torch.tensor(x)
        self.register_buffer("ten_" + str(x), self.ten)

class ToyModule(torch.nn.Module):

    def __init__(self):
        super(ToyModule, self).__init__()
        self.inner = torch.nn.ModuleList([InnerToyModule(1), InnerToyModule(2)])

if __name__ == '__main__':
    toy = ToyModule()
    print(toy.inner[0].ten is dict(toy.named_buffers())["inner.0.ten_1"])
    print(toy.inner[0].ten is dict(toy.named_buffers())["inner.0.ten_1"])

{'inner.0.ten_1': tensor(1), 'inner.1.ten_2': tensor(2)}
tensor(1, device='cuda:0')

When you call .cuda() on a parent module, it recursively applies .cuda() to all of its submodules, but this does not update any references to the original objects. To fix this issue, you can reassign the attributes of the parent module with the CUDA versions of the submodules after calling .cuda() . Here is an example modification of your toy code that demonstrates this:


import torch.nn as nn

class MyModule(nn.Module):
    def __init__(self):
        self.buffer1 = nn.Parameter(torch.randn(10, 10))
        self.buffer2 = nn.Parameter(torch.randn(10, 10))
        self.module_list = nn.ModuleList([nn.Linear(10, 10) for _ in range(2)])
        self.module_dict = nn.ModuleDict({"linear": nn.Linear(10, 10)})
    def forward(self, x):
        x = self.module_list[0](x)
        x = self.module_dict["linear"](x)
        return x
model = MyModule()

# Reassign the attributes with the CUDA versions of the submodules
model.module_list = nn.ModuleList([module.cuda() for module in model.module_list])
model.module_dict["linear"] = model.module_dict["linear"].cuda()