The nn.Linear
layer is created on the CPU, so you would need to transfer it to the GPU as well:
layer = nn.Linear(784, 10).to(dev)
lin = layer(flatten)
The nn.Linear
layer is created on the CPU, so you would need to transfer it to the GPU as well:
layer = nn.Linear(784, 10).to(dev)
lin = layer(flatten)