Runtime Error: Expected scalar Long, but got Float for argument #2 'mat2'

I am currently trying to do binary classification, so I’ve set up a simple one input to 2 output neural network:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(1,2)
    def forward(self, x):
        x = self.fc1(x)
        return x

Using cross entropy loss, I attempt to classify inputs to a 0 or 1 labelling of data.
However, when attempting to train this model, there seems to be an issue when calculating a prediction:

for epoch in range(n_epochs):
    for x_batch, y_batch in train_loader:
        x_batch ='cpu')
        y_batch ='cpu')


        yhat = net(x_batch.long()).unsqueeze(dim=0)

        loss = loss_fn(yhat, y_batch)

x_batch, and y_batch are both long, but when calculating x = self.fc1(x) in forward, this error comes up:

Traceback (most recent call last):
  File "", line 65, in <module>
    yhat = net(x_batch).unsqueeze(dim=0)
  File "/home/george/.local/lib/python3.6/site-packages/torch/nn/modules/", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "", line 16, in forward
    x = self.fc1(x)
  File "/home/george/.local/lib/python3.6/site-packages/torch/nn/modules/", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/george/.local/lib/python3.6/site-packages/torch/nn/modules/", line 87, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/george/.local/lib/python3.6/site-packages/torch/nn/", line 1372, in linear
    output = input.matmul(weight.t())
RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'mat2' in call to _th_mm

I’ve searched around and found that I have to turn the weight of a Linear to scalar, but I’m not sure if this is the correct solution, nor do I know how to do it. Can someone advise?

nn.Linear uses float32 parameters by default, so your input tensor should also have the same data type.
Cast your x_batch to float() and run the code again.