Error: Tensor for 'out' is on CPU, Tensor for argument #1 'self' is on CPU, but expected them to be on GPU (while checking arguments for addmm)

Hi all,
I’m trying to solve simple problems based on MNIST.

import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
dev = 'cuda'
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

batch_size = 1
test_batch_size = 1
train_loader = torch.utils.data.DataLoader (datasets.MNIST('dataset/', train=True, download=True,
                                                           transform= transforms.Compose([transforms.ToTensor(),
                                                            ])),
                                                            batch_size = batch_size, shuffle=True)

image, label = next(iter(train_loader))

with torch.no_grad():

    flatten = image.view(1, 28 * 28).to(dev) 
    print(flatten.shape)

    lin = nn.Linear(784, 10)(flatten) 
    print(lin)

Even though I designated flatten to ‘cuda’, error message still comes out.

Traceback (most recent call last):
  File "C:/Users/yoonh/Works/main.py", line 59, in <module>
    lin = nn.Linear(784, 10)(flatten) 
  File "C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\linear.py", line 94, in forward
    return F.linear(input, self.weight, self.bias)
  File "C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\functional.py", line 1753, in linear
    return torch._C._nn.linear(input, weight, bias)
RuntimeError: Tensor for 'out' is on CPU, Tensor for argument #1 'self' is on CPU, but expected them to be on GPU (while checking arguments for addmm)

I am looking forward to see any help. Thanks in advance.
Kind regards,
Yoonho

The nn.Linear layer is created on the CPU, so you would need to transfer it to the GPU as well:

layer = nn.Linear(784, 10).to(dev)
lin = layer(flatten)