I’m working on some training code that computes the total log probabilities of prediction sequences (i.e. outputs from a language model).
I had previously implemented this using F.log_softmax
followed by torch.gather
, but realized this can also be done with F.cross_entropy
with reduction='none'
followed by a sum operation over the last dimension. I rewrote my code using F.cross_entropy
thinking it would speed up my code since it’s a single builtin function instead of two but to my surprise, training is now significantly slower.
Here’s a hacky microbenchmark script capturing what I’m talking about:
import torch
import torch.nn.functional as F
import time
torch.manual_seed(42)
batch_size, seq_len, vocab_size = 2, 512, 50257
logits = torch.randn(batch_size, seq_len, vocab_size)
labels = torch.randint(0, vocab_size, (batch_size, seq_len))
labels[labels % 3 == 0] = -100
labels_safe = labels.detach()
labels_safe[labels_safe == -100] = 0
logits = logits.to("cuda")
labels = labels.to("cuda")
labels_safe = labels_safe.to("cuda")
def benchmark(func, num_iters=20):
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(num_iters):
value = func()
torch.cuda.synchronize()
end = time.perf_counter()
return (end - start) / num_iters, value
def original_method():
token_logprobs = F.log_softmax(logits, dim=-1)
gathered_logprobs = torch.gather(
token_logprobs, 2, labels_safe.unsqueeze(-1)
).squeeze(-1)
mask = (labels_safe != -100).float()
loss = (gathered_logprobs * mask).sum(dim=-1)
return -loss
def new_method():
token_loss = F.cross_entropy(
logits.permute(0, 2, 1),
labels,
reduction="none",
)
loss = token_loss.sum(dim=-1)
return loss
num_iters = 20
original_time, original_value = benchmark(original_method, num_iters)
new_time, new_value = benchmark(new_method, num_iters)
print(f"Original Average Time: {original_time * 1e3:.3f} ms")
print(f"Original Value: {original_value.tolist()}")
print(f"New Average Time: {new_time * 1e3:.3f} ms")
print(f"New Value: {new_value.tolist()}")
Outputs on 1x A100 GPU:
Original Average Time: 2.653 ms
Original Value: [5802.1328125, 5825.033203125]
New Average Time: 76.821 ms
New Value: [5802.1328125, 5825.0341796875]
I thought it might be due to the reduction='none'
, but commenting it out doesn’t lead to any noticeable speedup.
Does anyone know what might be happening here?