Hi Juuso!
This is correct. Pytorch operations, including those in the backward pass,
operate, for efficiency reasons, on entire tensors (even if they only use or
modify a subset of elements).
Let me reword your code slightly:
loss_subset = loss[torch.topk(loss.detach(), int(self.opt.num_rays * hard_fraction), sorted=False)[1]] # a subset loss
loss_scalar = loss_subset.mean() # turn loss to scalar
(This is just to make clear that in your code, loss
is a python name that refers
to different tensors at different points in the code.)
When you call loss_scalar.backward()
, you backpropagate through the scalar
back to loss_subset
– a tensor that does consist of only a subset of batch items.
This very trivial backpropagation is ever so slightly cheaper because loss_subset
is a smaller tensor (but any benefit is so small as to be irrelevant).
Autograd then backpropagates from loss_subset
back to loss
. (This is also a
trivial backpropagation.) However, loss
now contains all of the batch elements,
even though the gradients being backpropagated for many of those batch elements
are zero.
All of the rest of the backpropagation is then carried out with all of the batch
elements (because pytorch is processing entire tensors). Even where the gradient
being backpropagated has zeros for the masked-out rows, pytorch still performs
the entire tensor operation, multiplying (or whatever) by those zeros. So there is
no savings in time (nor in memory).
Because of how pytorch operates on entire tensors, there is no way to get your
hoped-for performance gain if you only perform the single forward pass.
Your scheme of performing two forward passes it the way to achieve your
performance gain. You need to perform the first forward pass in order to
compute the per-batch-element losses that you use to perform your “hard
mining.” But if you want to backpropagate for only the “hard” subset of your
batch elements, you also need to perform the second forward pass to
construct the computation graph that contains the smaller hard-subset
intermediate tensors.
If hard_fraction
is close to zero, you will likely get significant savings, while
if hard_fraction
is close to one, the cost of the second forward pass will
likely exceed any savings from the subset backward pass.
As an aside, wrapping your first forward pass in a with torch.no_grad():
block will get you some additional savings, mostly in memory, because doing
so avoids constructing the full, non-subset computation graph that you won’t
be using anyway. (These memory savings are likely to be significant.)
Last, let me second J’s comment that your hard mining may or may not improve
training speed and / or final performance. Trying it both ways, of course, is the
best way to find out.
Best.
K. Frank