Permute / Unpermute Gradient

Trying to understand how gradients when permuting / unpermuting tensors are calculated.

Here is a minimal repro:

import torch


def permute(X, gather_indices, topk):
    return X[gather_indices // topk]


def unpermute(X_permuted, gather_indices):
    X = torch.empty_like(X_permuted)
    X.index_copy_(0, gather_indices, X_permuted)
    return X


def check_permute(X, gather_indices, topk):
    X_permuted = permute(X, gather_indices, topk)
    X_unpermuted = unpermute(X_permuted, gather_indices)
    for i in range(topk):
        x = X_unpermuted[i::topk]
        assert x.shape == X.shape
        assert torch.equal(x, X)


def calculate_permute_grad(grad_output, gather_indices, seq_len, topk):
    assert grad_output.shape[0] == seq_len * topk
    dX = unpermute(grad_output, gather_indices)
    dX = dX.reshape(seq_len, topk, -1)
    dX = dX.sum(dim=1)
    return dX


if __name__ == "__main__":
    num_experts = 4
    seq_len = 32
    hidden_size = 64
    device = "cuda"
    torch.manual_seed(0)
    for dtype in [torch.float32, torch.float16, torch.bfloat16]:
        print(f"dtype: {dtype}")
        for topk in [1, 4, 8]:
            X = torch.randn(seq_len, hidden_size, dtype=dtype, device=device, requires_grad=True)
            X.retain_grad()

            selected_experts = torch.randint(0, num_experts, (seq_len, topk), device=device).view(-1)
            gather_indices = selected_experts.argsort(dim=0)
            assert gather_indices.shape[0] == seq_len * topk
            
            # Check permute
            check_permute(X, gather_indices, topk)

            X_permuted = permute(X, gather_indices, topk)
            grad_out = torch.randn_like(X_permuted)
            X_permuted.backward(grad_out)
            assert X.grad is not None
        
            # Calculate grad manually
            dX = calculate_permute_grad(grad_out, gather_indices, seq_len, topk)
            assert dX.shape == X.grad.shape
            diff = (X.grad - dX).abs().max()
            print(f"topk: {topk}, diff: {diff.item():.6f}")

Here are the outputs for the following dtypes

dtype: torch.float32
topk: 1, diff: 0.000000
topk: 4, diff: 0.000000
topk: 8, diff: 0.000001

dtype: torch.float16
topk: 1, diff: 0.000000
topk: 4, diff: 0.003906
topk: 8, diff: 0.007812

dtype: torch.bfloat16
topk: 1, diff: 0.000000
topk: 4, diff: 0.062500
topk: 8, diff: 0.062500

The small diffs for float32 leads me to think that my manual gradient calculation is mathematically correct but the larger diffs for {b}float16 implies that the way I’m calculating the gradient is not numerically equivalent to the autograd implementation.

The reason I need to calculate the permutation grad manually is for a custom autograd.Function.

Any recommendations for how to better align the gradients when using lower precision data types (bfloat16)?

Thanks!