F.cross_entropy unexpectedly slower than F.log_softmax + torch.gather

I’m working on some training code that computes the total log probabilities of prediction sequences (i.e. outputs from a language model).

I had previously implemented this using F.log_softmax followed by torch.gather, but realized this can also be done with F.cross_entropy with reduction='none' followed by a sum operation over the last dimension. I rewrote my code using F.cross_entropy thinking it would speed up my code since it’s a single builtin function instead of two but to my surprise, training is now significantly slower.

Here’s a hacky microbenchmark script capturing what I’m talking about:

import torch
import torch.nn.functional as F
import time

torch.manual_seed(42)

batch_size, seq_len, vocab_size = 2, 512, 50257
logits = torch.randn(batch_size, seq_len, vocab_size)
labels = torch.randint(0, vocab_size, (batch_size, seq_len))
labels[labels % 3 == 0] = -100

labels_safe = labels.detach()
labels_safe[labels_safe == -100] = 0

logits = logits.to("cuda")
labels = labels.to("cuda")
labels_safe = labels_safe.to("cuda")


def benchmark(func, num_iters=20):
    torch.cuda.synchronize()
    start = time.perf_counter()
    for _ in range(num_iters):
        value = func()
        torch.cuda.synchronize()
    end = time.perf_counter()
    return (end - start) / num_iters, value


def original_method():
    token_logprobs = F.log_softmax(logits, dim=-1)
    gathered_logprobs = torch.gather(
        token_logprobs, 2, labels_safe.unsqueeze(-1)
    ).squeeze(-1)
    mask = (labels_safe != -100).float()
    loss = (gathered_logprobs * mask).sum(dim=-1)
    return -loss


def new_method():
    token_loss = F.cross_entropy(
        logits.permute(0, 2, 1),
        labels,
        reduction="none",
    )
    loss = token_loss.sum(dim=-1)
    return loss


num_iters = 20
original_time, original_value = benchmark(original_method, num_iters)
new_time, new_value = benchmark(new_method, num_iters)

print(f"Original Average Time: {original_time * 1e3:.3f} ms")
print(f"Original Value: {original_value.tolist()}")
print(f"New Average Time: {new_time * 1e3:.3f} ms")
print(f"New Value: {new_value.tolist()}")

Outputs on 1x A100 GPU:

Original Average Time: 2.653 ms
Original Value: [5802.1328125, 5825.033203125]
New Average Time: 76.821 ms
New Value: [5802.1328125, 5825.0341796875]

I thought it might be due to the reduction='none', but commenting it out doesn’t lead to any noticeable speedup.

Does anyone know what might be happening here?

Profiling your code shows the main difference comes from the softmax implementation:

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                      aten::log_softmax         0.16%      95.089us        19.85%      12.051ms     602.532us       0.000us         0.00%      11.780ms     588.995us            20  
                                     aten::_log_softmax         0.47%     286.834us        19.70%      11.956ms     597.777us      11.780ms        95.37%      11.780ms     588.995us            20  
void at::native::(anonymous namespace)::cunn_SoftMax...         0.00%       0.000us         0.00%       0.000us       0.000us      11.780ms        95.37%      11.780ms     588.995us            20  
                                           aten::gather         0.65%     392.715us        21.77%      13.214ms     660.682us     208.353us         1.69%     208.353us      10.418us            20  
void at::native::_scatter_gather_elementwise_kernel<...         0.00%       0.000us         0.00%       0.000us       0.000us     208.353us         1.69%     208.353us      10.418us            20  
                                              aten::sum         0.61%     372.280us         8.25%       5.007ms     250.331us     108.608us         0.88%     108.608us       5.430us            20  
void at::native::reduce_kernel<512, 1, at::native::R...         0.00%       0.000us         0.00%       0.000us       0.000us     108.608us         0.88%     108.608us       5.430us            20  
                                               aten::to         0.06%      38.452us         8.64%       5.244ms     262.199us       0.000us         0.00%      95.902us       4.795us            20  
                                         aten::_to_copy         0.21%     125.527us         8.58%       5.206ms     260.277us       0.000us         0.00%      95.902us       4.795us            20  
                                            aten::copy_         0.26%     157.669us         8.16%       4.954ms     247.697us      95.902us         0.78%      95.902us       4.795us            20  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 60.702ms
Self CUDA time total: 12.352ms

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                               aten::cross_entropy_loss         0.01%     127.490us         6.00%      94.683ms       4.734ms       0.000us         0.00%        1.575s      78.755ms            20  
                                      aten::log_softmax         0.00%      78.236us         0.12%       1.895ms      94.743us       0.000us         0.00%        1.575s      78.749ms            20  
                                     aten::_log_softmax         0.03%     421.520us         0.11%       1.811ms      90.529us        1.549s        98.32%        1.575s      78.749ms            20  
void at::native::(anonymous namespace)::cunn_Spatial...         0.00%       0.000us         0.00%       0.000us       0.000us        1.549s        98.32%        1.549s      77.434ms            20  
                                       aten::contiguous         0.00%      57.140us         0.08%       1.224ms      61.219us       0.000us         0.00%      26.312ms       1.316ms            20  
                                            aten::clone         0.01%      87.524us         0.07%       1.167ms      58.362us       0.000us         0.00%      26.312ms       1.316ms            20  
                                            aten::copy_         0.01%     220.568us         0.03%     495.111us      24.756us      26.312ms         1.67%      26.312ms       1.316ms            20  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      26.312ms         1.67%      26.312ms       1.316ms            20  
                                      aten::nll_loss_nd         0.01%     170.559us         5.88%      92.660ms       4.633ms       0.000us         0.00%     116.958us       5.848us            20  
                                       aten::nll_loss2d         0.00%      58.176us         5.86%      92.382ms       4.619ms       0.000us         0.00%     116.958us       5.848us            20  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  

which should dispatch to this spatial implementation. CC @eqy in case you are familiar with the dispatching logic.

1 Like