Confused about loss.data[0]


(Karl Tum) #1

Hi,

When I saw some demo codes:

            outputs = model(inputs)
            _, preds = torch.max(outputs.data, 1)
            loss = criterion(outputs, labels)
            # backward + optimize only if in training phase
            if phase == 'train':
                loss.backward()
                optimizer.step()
            # statistics
            running_loss += loss.data[0]

If we would like to extract the loss tensor from loss variable, why not use loss.data?

What does loss.data[0] mean here?


(Thomas V) #2

Hi,

There is:

  • loss the Variable,
  • loss.data the (presumably size 1) Tensor,
  • loss.data[0] the (python) float at position 0 in the tensor.

As such, by just using loss.data you would not run into the “keeping track over everything” problem (which would happen if you use loss and something is not volatile), but you would add torch tensors instead of just python numbers.

Best regards

Thomas


(Thomas V) #3

As the above still get’s likes:

Note that the above post is outdated.

Nowadays, with PyTorch >= 0.4 you have

  • loss the Tensor (which previously was the variable),
  • loss.data (shouldn’t be needed much anymore) which is roughly equivalent to loss.detach(), a Tensor which does not do tracing for derivatives anymore (you can use this to keep around but e.g. don’t want to move things off the GPU yet)
  • loss.item() the Python number contained in a 1-element tensor.

Best regards

Thomas


(Gabriel Chu) #4

So .item() can be used only for the condition that the tensor has only 1 element right? If the tensor has multiple elements, then .item() is not applicable?

Just to be clear, for the following case, which is a line copied from some outdated implementation, are we using detach() instead?

batch.trg.data[i, 0]

(Thomas V) #5

You can change that to batch.trg.detach()[i, 0] unless you do funny things with it afterwards. (But note that keeping it around will keep the storage of batch.trg around, not just the first column, same with .data.)
There also is .tolist() if you want to convert a tensor into a list (of lists).

Best regards

Thomas