Weight type is torch.FloatTensor despite being on GPU

I have a simple implementation of a VAE, code to train is as follows;

model = torch_VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=vae_params.lr)
for epoch in range(vae_params.no_epochs):
      result = train_VAE(epoch, train_loader, model, optimizer, device)

train_VAE is constructed in the standard way, plus some code to check if the model weights and inputs are on the GPU;

for i, img in enumerate(loader):
    model.zero_grad()
    img = img.to(device)
    for device_thing in model.parameters():
        print('model param is on device: ', device_thing.device)
    print('data is on device', img.device)
    recon_image, mu, logvar = model(img)

Running this code returns exactly as expected, since we have all run code like this hundreds of times;

model param is on device:  cuda:0
model param is on device:  cuda:0
model param is on device:  cuda:0
model param is on device:  cuda:0
model param is on device:  cuda:0
model param is on device:  cuda:0
data is on device cuda:0

However, the line recon_image, mu, logvar = model(img) throws an error that the input is (torch.cuda.FloatTensor) and the model weights are (torch.FloatTensor). This would possibly imply that i made a simple mistake of forgetting a .to(device), but my checks seem to confirm that I did not do that. Another candidate issue is that I have a .cpu() somewhere, or maybe i define a new tensor that defaults to the cpu somewhere in the model, but I have also checked this and it all seems to be done properly. The part of the model throwing the error is at the beginning, and is as follows;

class torch_VAE(nn.Module):
    def __init__(self, n=32, z_dim=100):
        super().__init__()
        self.encoder = nn.Sequential(
            ConvRelu(1, n),
            ConvRelu(n, n, downsample=True),
            ConvRelu(n, n),
            ConvRelu(n, n, downsample=True),
            ConvRelu(n, n // 2),
            nn.Flatten())

where ConvRelu is

class ConvRelu(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, downsample=False, upsample=False):
        super().__init__()
        self.convrelu = [
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2),
            nn.ReLU()]
        if downsample:
            self.convrelu.append(nn.MaxPool2d(2))

        if upsample:
            self.convrelu.append(nn.ConvTranspose2d(out_channels, out_channels, kernel_size=2, stride=2))

    def forward(self, x):
        return nn.Sequential(*self.convrelu)(x)

I am completely and utterly stumped, as every debugging tactic I have tried tells me that I am doing this correctly. I am extremely interested in why this happens, and what mechanism of pytorch I am overlooking (if any). Any help would be greatly appreciated!

Exact error:

Traceback (most recent call last):
  File "C:\Users\turbo\AppData\Roaming\Python\Python37\site-packages\IPython\core\interactiveshell.py", line 3331, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-7f748147c5a3>", line 1, in <module>
    runfile('C:/Users/turbo/Python projects/Physics/Calo-ML/Calo-ML/experiment.py', wdir='C:/Users/turbo/Python projects/Physics/Calo-ML/Calo-ML')
  File "C:\Program Files\JetBrains\PyCharm 2019.3.4\plugins\python\helpers\pydev\_pydev_bundle\pydev_umd.py", line 197, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "C:\Program Files\JetBrains\PyCharm 2019.3.4\plugins\python\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "C:/Users/turbo/Python projects/Physics/Calo-ML/Calo-ML/experiment.py", line 107, in <module>
    main()
  File "C:/Users/turbo/Python projects/Physics/Calo-ML/Calo-ML/experiment.py", line 77, in main
    for out in result:
  File "C:\Users\turbo\Python projects\Physics\Calo-ML\Calo-ML\helpers\training.py", line 201, in train_VAE
    recon_image, mu, logvar = model(img)
  File "C:\Users\turbo\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "C:\Users\turbo\Python projects\Physics\Calo-ML\Calo-ML\models\VAE.py", line 150, in forward
    encoded_image = self.encoder(x)
  File "C:\Users\turbo\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "C:\Users\turbo\Anaconda3\lib\site-packages\torch\nn\modules\container.py", line 100, in forward
    input = module(input)
  File "C:\Users\turbo\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "C:\Users\turbo\Python projects\Physics\Calo-ML\Calo-ML\models\VAE.py", line 113, in forward
    return nn.Sequential(*self.convrelu)(x)
  File "C:\Users\turbo\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "C:\Users\turbo\Anaconda3\lib\site-packages\torch\nn\modules\container.py", line 100, in forward
    input = module(input)
  File "C:\Users\turbo\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "C:\Users\turbo\Anaconda3\lib\site-packages\torch\nn\modules\conv.py", line 345, in forward
    return self.conv2d_forward(input, self.weight)
  File "C:\Users\turbo\Anaconda3\lib\site-packages\torch\nn\modules\conv.py", line 342, in conv2d_forward
    self.padding, self.dilation, self.groups)
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

Hi,

I think that you are storing you self.convrelu as a plain list. Which prevents functions like .parameters or .cuda() to work properly.
You must use nn.ModuleList() to replace this list (it will behave the same way as a regular python list).

2 Likes

Huzzah! You are exactly correct. Thank you so much, sir!