Backwards with only a subset of the output losses?

Hi everyone,

I see that the feedforward of my neural network takes only half the amount of time as the backward (pretty common stat with neural networks).I would like to reduce the amount of backwards calculations by only doing them for outputs that achieve above average loss. How do I do this correctly? Here’s the code:

preds, truths, loss = self.train_step(data) # each are [4096,1] 
# Get only hard samples
masked_loss = loss[0:5].mean() # taking explicitly only 5 samples
self.scaler.scale(masked_loss).backward() 
self.scaler.step(self.optimizer) 
self.scaler.update()

But I noticed that the backward() operation is still taking the same time as just doing it for the whole loss. What am I missing?

The increased time is likely due to the number of calculation steps involved. But calculations across the batch dim are performed asynchronuously.

Let me explain it this way. Suppose you’re using an H100 with 18,432 CUDA cores. And your batch size is 100. Each of those cores behave like an independent worker, capable of making quick work of tensor operations with multiplication of hundreds of elements in a matmul per batch sample.

Decreasing the batch size won’t necessarily decrease the speed to perform a given set of calculations needing to be performed in order, such as a backward pass.

But it would be interesting to see if training a model on only outliers has any benefit to accuracy. My initial inclination is it might result in overfitting.

@J_Johnson Thanks for the answer. That makes sense. Actually this Online Hard Example Mining (OHEM) paper seems to suggest that focusing on the hard examples might lead to better test accuracy. They leave the reasons as to be investigated, and thats what I’m trying to do.

I actually found that in my particular problem the cost of feedforward is negligible compared to the backward calculations. So I devised this kind of plan that I run massive amount of samples in feedforward and then choose only a subset for both the feedforward and backward:

            start_time_forward = time.time()
            if self.opt.OHEM: # Online Hard Example Mining, data has 2*K samples
                # Massive forward pass (with 2*K samples), do not calculate gradients
                with torch.no_grad():
                    start_time_forward1 = time.time()
                    _,_,loss = self.train_step(data) # _,_, [2*K,1]
                    end_time_forward1 = time.time() # takes ~ 0.0025, negligible
                    # choose only K hard samples
                    start_time_masking = time.time()
                    data = self.mask_data(data, torch.topk(loss.detach(), K, sorted=False)[1])
                    end_time_masking = time.time() # takes ~ 5e-5, negligible
                # Mini forward pass (with K samples), calculate the gradients
                start_time_forward2 = time.time()
                preds, truths, loss = self.train_step(data) # data has now only K samples, so SHOULD TAKE THE SAME TIME AS REGULAR BELOW!!!
                end_time_forward2 = time.time() # takes ~ 0.06
                end_time_forward = time.time() # takes ~ 0.06, so the mini forward pass TAKES UP ALL THE TIME!!! TODO: solve why?
            else: # Regular forward pass with all samples, data has K samples
                preds, truths, loss = self.train_step(data)
                end_time_forward = time.time() # takes ~ 0.002, negligible

But as you can see I have something weird happening. That second feed forward in the OHEM case is taking 0.06 seconds, where as the regular feedforward is taking 0.002. They should be the same as they both run K samples. Am I doing some step wrong there?

CUDA operations are executed asynchronously so you would need to synchronize your code before starting or stopping your host timers.

I suppose it might work if the outliers are in fact accurately labeled and you maintain a loss cutoff that allows a substantial amount of the samples through. But most large public datasets have mislabeled data. Something to watch out for.

This looks like a decent tutorial for code synchronization in Pytorch, as @ptrblck raised:

Thanks @J_Johnson and @ptrblck for advising me using the synchronization. It revealed that the running times I calculated in my second answer were wrong and indeed the backwards takes 2x the time of feedforward also in my case (the standard rule of thumb).

Going back to the original question, I see that I can decrease the time it takes to do the backward if I do the forward pass two times (first with all samples and then with only hard samples):

            if self.opt.OHEM: # Online Hard Example Mining
                # 1st forward pass, do not calculate gradients
                with torch.no_grad():
                    _,_,loss = self.train_step(data) # _,_, [K,1]
                    torch.cuda.synchronize()
                    # choose only K hard samples
                    start_time_masking = time.time()
                    hard_fraction = 0.25 # reduction in computation time should be roughly: hard_fraction * 2/3
                    data = self.mask_data(data, torch.topk(loss.detach(), int(self.opt.num_rays * hard_fraction), sorted=False)[1])
                    end_time_masking = time.time() # takes ~ 5e-5, negligible
                # 2nd forward pass, calculate the gradients. Data now only has a <<hard_fraction>> of the samples
                preds, truths, loss = self.train_step(data)
                torch.cuda.synchronize()
                end_time_forward = time.time() # takes ~ 0.06
            else: # Regular forward pass with all samples, data has K samples
                preds, truths, loss = self.train_step(data)
                torch.cuda.synchronize()
                end_time_forward = time.time() # takes ~ 0.03

            loss = loss.mean()
            start_time_backward = time.time()
            # Backward pass
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()
            torch.cuda.synchronize()
            end_time_backward = time.time() # takes ~ 0.02 with OHEM, and ~ 0.06 otherwise.

BUT I notice that if I do it this simpler way (with only one forward pass and selecting only a subset loss for backward):

            start_time_forward = time.time()
            preds, truths, loss = self.train_step(data) # data has K samples
            torch.cuda.synchronize()
            end_time_forward = time.time() # takes ~ 0.03 with K samples
            # OHEM: pick only the hard samples for the costy backward pass (backward pass cost = 2 x forward pass cost)
            if self.opt.OHEM:
                hard_fraction = 0.25 # reduction in computation time SHOULD be roughly: hard_fraction * 2/3
                loss = loss[torch.topk(loss.detach(), int(self.opt.num_rays * hard_fraction), sorted=False)[1]] # a subset loss

            loss = loss.mean() # turn loss to scalar
            start_time_backward = time.time()
            # Backward pass
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()
            torch.cuda.synchronize()
            end_time_backward = time.time() # still takes ~ 0.06 even with the OHEM.

The time it takes to calculate the backward stays roughly the same as I wouldn’t do the OHEM picking of hard samples. This makes me think that its still doing the gradient calculations for all the samples. I guess it should be possible to gain performance, since we see that by doing it with two forward passes works on reducing the time it takes to calculate the backwards.

Hi Juuso!

This is correct. Pytorch operations, including those in the backward pass,
operate, for efficiency reasons, on entire tensors (even if they only use or
modify a subset of elements).

Let me reword your code slightly:

loss_subset = loss[torch.topk(loss.detach(), int(self.opt.num_rays * hard_fraction), sorted=False)[1]] # a subset loss

loss_scalar = loss_subset.mean() # turn loss to scalar

(This is just to make clear that in your code, loss is a python name that refers
to different tensors at different points in the code.)

When you call loss_scalar.backward(), you backpropagate through the scalar
back to loss_subset – a tensor that does consist of only a subset of batch items.
This very trivial backpropagation is ever so slightly cheaper because loss_subset
is a smaller tensor (but any benefit is so small as to be irrelevant).

Autograd then backpropagates from loss_subset back to loss. (This is also a
trivial backpropagation.) However, loss now contains all of the batch elements,
even though the gradients being backpropagated for many of those batch elements
are zero.

All of the rest of the backpropagation is then carried out with all of the batch
elements (because pytorch is processing entire tensors). Even where the gradient
being backpropagated has zeros for the masked-out rows, pytorch still performs
the entire tensor operation, multiplying (or whatever) by those zeros. So there is
no savings in time (nor in memory).

Because of how pytorch operates on entire tensors, there is no way to get your
hoped-for performance gain if you only perform the single forward pass.

Your scheme of performing two forward passes it the way to achieve your
performance gain. You need to perform the first forward pass in order to
compute the per-batch-element losses that you use to perform your “hard
mining.” But if you want to backpropagate for only the “hard” subset of your
batch elements, you also need to perform the second forward pass to
construct the computation graph that contains the smaller hard-subset
intermediate tensors.

If hard_fraction is close to zero, you will likely get significant savings, while
if hard_fraction is close to one, the cost of the second forward pass will
likely exceed any savings from the subset backward pass.

As an aside, wrapping your first forward pass in a with torch.no_grad():
block will get you some additional savings, mostly in memory, because doing
so avoids constructing the full, non-subset computation graph that you won’t
be using anyway. (These memory savings are likely to be significant.)

Last, let me second J’s comment that your hard mining may or may not improve
training speed and / or final performance. Trying it both ways, of course, is the
best way to find out.

Best.

K. Frank

Thank you @KFrank. Alright the (only) way to do it with PyTorch is with 2 forward passes. Is this specific to PyTorch and to the way they do the backpropagation with the graph? Do you know if others, like tensorflow or JAX, would differ in this regard?