Moving tensors to devices

I’m trying to do some manual model pipelining simply by using .to(device) approach. I have a network that is built from some nested nn.Module classes. My initial idea was to do the movement of individual layers in the class constructors, and then everything should automatically be handled (except for input/output data movement), i.e.:

class Encoder(nn.Module):
    def __init__(self, device="cpu", chs=(3,16,32,64,128)):
        super().__init__()
        self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]).to(device) for i in range(len(chs)-1)])
        self.pool       = nn.MaxPool2d(2).to(device)
        self.device     = device

(Block in the above is another class Block(nn.Module))

However, on experimenting with this, it didn’t actually seem to move those layers on to the device. I had to move the functionality into the forward function instead, i.e.:

def forward(self, x):
    return_values = []
    for block in self.enc_blocks:
        x = block.to(self.device)(x)
        return_values.append(x.to(self.device))
        x = self.pool.to(self.device)(x)
    return return_values

Which now puts things where I expect them. Is there a reason the tensor movement functions in the constructor wouldn’t work, but the ones in the forward function would work? Or is my thinking about what to() is doing incorrect?

thanks

It works for me using:

import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, device="cpu"):
        super().__init__()
        self.enc_blocks = nn.ModuleList([nn.Linear(1, 1).to(device) for i in range(3)])
        self.device     = device

    def forward(self, x):
        for module in self.enc_blocks:
            x = module(x)
        return x


model = Encoder(device="cuda:5")
print(model.enc_blocks[0].weight.device)
x = torch.randn(1, 1, device="cuda:5")

out = model(x)
print(out.device)

Output:

cuda:5
cuda:5
1 Like

:thinking:Ok, looks like I just needed to upgrade Torch to get the correct behaviour.

Thanks (I’m jealous of your cuda:5 :grin: )