As a general rule, computations on tensors in pytorch run faster if
you use built-in tensor operations, rather than looping over indicies.

Try running torch.flatten() on the last two dimensions of your tensor
to get a tensor of shape(BxCxL), where L = W*H.

Then do your computations with things like torch.max(), using its dim argument to specify that you want to take the max over your
tensor’s last dimension (W*H).

Give it a try, and if you have issues, post your code (working or
not) together with any errors.

(If you do get this working, it might be nice to post a follow-up with
before-and-after timings so we can see how much it helped.)