Having troubles to vmap a custom function with permute

Hello, I want to vectorize the function foo

import torch
import torch.nn.functional as F

def foo(logprobs, zero_mask, action_space_size):
    ndim = zero_mask.shape[0]
    zero_mask = zero_mask.flip(0)
    reshape_sizes = torch.tensor([action_space_size], dtype=torch.int64).repeat_interleave(ndim)
    logprobs = logprobs.reshape(*reshape_sizes.tolist())

    sorted_indices = torch.argsort(zero_mask)
    logprobs = logprobs.permute(*sorted_indices)

    logprobs = logprobs.reshape(-1, ndim ** zero_mask.sum())
    logprobs = torch.logsumexp(logprobs, dim=-1, keepdim=True)
    logprobs = F.pad(logprobs, [0, ndim ** zero_mask.sum() - 1], value=-1e7)

    logprobs = logprobs.reshape(*reshape_sizes.tolist())
    invert_indices = torch.argsort(sorted_indices)
    logprobs = logprobs.permute(*invert_indices)
    logprobs = logprobs.reshape(-1)

    return logprobs

def bar(logprobs, zero_mask):
    # use vmap
    batched_foo = torch.func.vmap(foo, in_dims=(0, 0, None))
    return batched_foo(logprobs, zero_mask, 3)

logits = torch.ones((4, 1, 27), dtype=torch.float32)
logprob = F.log_softmax(logits, dim=-1)
zero_mask = torch.tensor([[False, True, True], [True, False, True], [False, False, False], [True, True, True]])

logprob = bar(logprob, zero_mask)

The permute in this function is really tricky. After I use vmap, it throws

TypeError: permute(): argument 'dims' (position 1) must be tuple of ints, but found element of type Tensor at pos 0.

If I change the code into

logprobs = logprobs.permute(*sorted_indices.tolist())

The error msg became

RuntimeError: Cannot access data pointer of Tensor that doesn't have storage

I have no idea how to solve it.