Hi there,
Note the quotes in the title: that’s not PyTorch which is to blame but the careless user that I am.
Here’s the guilty code:
total_loss = 0.0 # total loss of the epoch
for ...
loss = criterion(...)
loss.backward()
# ...
total_loss += loss
For reference, here is what’s happening:
total_loss
is initially a Python float variable, because we use the +=
operator on it, it is transformed into a Tensor
. Because loss
has requires_grad=True
, the new total_loss
tensor also does. The subsequent computations involving total_loss
are then recorded during the entire epoch, and eventually clutter the GPU memory.
The solution is simply to write
total_loss += loss.item()
The type casting of the total_loss
variable may have been harmless before the Variable
and Tensor
classes were merged, but not anymore.
As there is no mean to add in place a tensor to a float, Python falls back to using some __radd__
method of the Tensor
class. I think this should at least warn the user about the dangerous path he’s taking.
What’s your point of view on this matter ?
This is known behavior and is explained in the section Accumulating losses in the Migration Guide.
You would want this kind of behavior for example if you have different loss functions and would like to accumulate them:
total_loss = loss1 + loss2 + loss3
total_loss.backward()
Also you could accumulate the losses of several mini batches and backward them together:
for ...
loss = criterion(...)
total_loss += loss
total_loss.backward()
Therefore I’m not sure it would be a good idea to add a warning, since it’s used for the purpose of storing the computation graphs.
Hi,
Thanks for the link, which I should have read when migrating…
I think you misread me, I’m not complaining about the fact that we can aggregate losses when they are all tensors (I suppose that would be the case in your examples). What I’m pointing at is the implicit “type-cast” happening on Python numbers. Particularly I find the combination of these two things very error prone:
-
float += ScalarTensor
is valid and changes the type of the left value
- formatting a zero-dimensional tensor only shows the scalar value, without any indication that the variable is indeed a tensor
The second point is easy to fix. As for the first, __radd__
(and its friends) could print a warning which would be easily silenced by an explicit cast using torch.tensor
.