Hello, I want to vectorize the function foo
import torch
import torch.nn.functional as F
def foo(logprobs, zero_mask, action_space_size):
ndim = zero_mask.shape[0]
zero_mask = zero_mask.flip(0)
reshape_sizes = torch.tensor([action_space_size], dtype=torch.int64).repeat_interleave(ndim)
logprobs = logprobs.reshape(*reshape_sizes.tolist())
sorted_indices = torch.argsort(zero_mask)
logprobs = logprobs.permute(*sorted_indices)
logprobs = logprobs.reshape(-1, ndim ** zero_mask.sum())
logprobs = torch.logsumexp(logprobs, dim=-1, keepdim=True)
logprobs = F.pad(logprobs, [0, ndim ** zero_mask.sum() - 1], value=-1e7)
logprobs = logprobs.reshape(*reshape_sizes.tolist())
invert_indices = torch.argsort(sorted_indices)
logprobs = logprobs.permute(*invert_indices)
logprobs = logprobs.reshape(-1)
return logprobs
def bar(logprobs, zero_mask):
# use vmap
batched_foo = torch.func.vmap(foo, in_dims=(0, 0, None))
return batched_foo(logprobs, zero_mask, 3)
logits = torch.ones((4, 1, 27), dtype=torch.float32)
logprob = F.log_softmax(logits, dim=-1)
zero_mask = torch.tensor([[False, True, True], [True, False, True], [False, False, False], [True, True, True]])
logprob = bar(logprob, zero_mask)
The permute
in this function is really tricky. After I use vmap, it throws
TypeError: permute(): argument 'dims' (position 1) must be tuple of ints, but found element of type Tensor at pos 0.
If I change the code into
logprobs = logprobs.permute(*sorted_indices.tolist())
The error msg became
RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
I have no idea how to solve it.