This will be an error in PyTorch 0.5?


(Rio ) #1

So, I ran this:

for epoch in range(epochs):
epoch += 1
inputs = Variable(torch.from_numpy(x_train))
labels = Variable(torch.from_numpy(y_train))
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs,labels)
loss.backward()
optimizer.step()
print('epoch {}, loss {} '. format(epoch, loss.data[0]))

and got this error:
UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number
Remove the CWD from sys.path while we load stuff.

Can someone be kind enough to explain the statement above? If possible, indicating the problem in my code.
Thanks!


#2
for epoch in range(epochs):
    epoch += 1      # <-- you dont need to increse epoch by one, this is done automatically by the for statement above
    inputs = Variable(torch.from_numpy(x_train))
    labels = Variable(torch.from_numpy(y_train))
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs,labels)
    loss.backward()
    optimizer.step()
    print('epoch {}, loss {} '.format(epoch, loss.data[0]))
    #                                             ^
    #                                             | 

This is the problem I guess, replace loss.data[0] with loss.data.item()


(Rio ) #3

Thanks a lot! It worked but I didn’t get the underlying problem. What did changing it to .item() do?
And regarding the epoch, I wanted the epoch to start from 1 instead of the usual 0. So thus the epoch += 1


(Thomas V) #4

loss.item() even. :slight_smile:


#5

Your loss.data is a “tensor” which is not a vector/matrix anymore but only contains a single value.
Think of it like trying to get an index of a an int in python, which makes no sense right?
E.g.:

a = 3
b = a[0] # <-- getting an index of a single ITEM makes no sense (will crash your program)

To stick as close to the python feel as possible, they implemented this Warning you were getting.


#6

Regarding the for loop: If you want to start at a specific index, you can just use for epoch in range(1, epochs).