Trying to understand how gradients when permuting / unpermuting tensors are calculated.
Here is a minimal repro:
import torch
def permute(X, gather_indices, topk):
return X[gather_indices // topk]
def unpermute(X_permuted, gather_indices):
X = torch.empty_like(X_permuted)
X.index_copy_(0, gather_indices, X_permuted)
return X
def check_permute(X, gather_indices, topk):
X_permuted = permute(X, gather_indices, topk)
X_unpermuted = unpermute(X_permuted, gather_indices)
for i in range(topk):
x = X_unpermuted[i::topk]
assert x.shape == X.shape
assert torch.equal(x, X)
def calculate_permute_grad(grad_output, gather_indices, seq_len, topk):
assert grad_output.shape[0] == seq_len * topk
dX = unpermute(grad_output, gather_indices)
dX = dX.reshape(seq_len, topk, -1)
dX = dX.sum(dim=1)
return dX
if __name__ == "__main__":
num_experts = 4
seq_len = 32
hidden_size = 64
device = "cuda"
torch.manual_seed(0)
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
print(f"dtype: {dtype}")
for topk in [1, 4, 8]:
X = torch.randn(seq_len, hidden_size, dtype=dtype, device=device, requires_grad=True)
X.retain_grad()
selected_experts = torch.randint(0, num_experts, (seq_len, topk), device=device).view(-1)
gather_indices = selected_experts.argsort(dim=0)
assert gather_indices.shape[0] == seq_len * topk
# Check permute
check_permute(X, gather_indices, topk)
X_permuted = permute(X, gather_indices, topk)
grad_out = torch.randn_like(X_permuted)
X_permuted.backward(grad_out)
assert X.grad is not None
# Calculate grad manually
dX = calculate_permute_grad(grad_out, gather_indices, seq_len, topk)
assert dX.shape == X.grad.shape
diff = (X.grad - dX).abs().max()
print(f"topk: {topk}, diff: {diff.item():.6f}")
Here are the outputs for the following dtypes
dtype: torch.float32
topk: 1, diff: 0.000000
topk: 4, diff: 0.000000
topk: 8, diff: 0.000001
dtype: torch.float16
topk: 1, diff: 0.000000
topk: 4, diff: 0.003906
topk: 8, diff: 0.007812
dtype: torch.bfloat16
topk: 1, diff: 0.000000
topk: 4, diff: 0.062500
topk: 8, diff: 0.062500
The small diffs for float32
leads me to think that my manual gradient calculation is mathematically correct but the larger diffs for {b}float16
implies that the way I’m calculating the gradient is not numerically equivalent to the autograd implementation.
The reason I need to calculate the permutation grad manually is for a custom autograd.Function
.
Any recommendations for how to better align the gradients when using lower precision data types (bfloat16
)?
Thanks!