This will be an error in PyTorch 0.5?

So, I ran this:

for epoch in range(epochs):
    epoch += 1
    inputs = Variable(torch.from_numpy(x_train))
    labels = Variable(torch.from_numpy(y_train))
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs,labels)
    loss.backward()
    optimizer.step()
    print('epoch {}, loss {} '. format(epoch, loss.data[0]))

and got this error:

UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number
  Remove the CWD from sys.path while we load stuff.

Can someone be kind enough to explain the statement above? If possible, indicating the problem in my code.
Thanks!

1 Like
for epoch in range(epochs):
    epoch += 1      # <-- you dont need to increse epoch by one, this is done automatically by the for statement above
    inputs = Variable(torch.from_numpy(x_train))
    labels = Variable(torch.from_numpy(y_train))
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs,labels)
    loss.backward()
    optimizer.step()
    print('epoch {}, loss {} '.format(epoch, loss.data[0]))
    #                                             ^
    #                                             | 

This is the problem I guess, replace loss.data[0] with loss.data.item()

5 Likes

Thanks a lot! It worked but I didn’t get the underlying problem. What did changing it to .item() do?
And regarding the epoch, I wanted the epoch to start from 1 instead of the usual 0. So thus the epoch += 1

loss.item() even. :slight_smile:

3 Likes

Your loss.data is a “tensor” which is not a vector/matrix anymore but only contains a single value.
Think of it like trying to get an index of a an int in python, which makes no sense right?
E.g.:

a = 3
b = a[0] # <-- getting an index of a single ITEM makes no sense (will crash your program)

To stick as close to the python feel as possible, they implemented this Warning you were getting.

1 Like

Regarding the for loop: If you want to start at a specific index, you can just use for epoch in range(1, epochs).

1 Like

Thanks a lot for solving the problem. Even I was stuck with the similar error.

Thanks again for the explanation.