I am trying to simulate large numbers of 2d grid environments using GPU (a la Snake at a Million FPS).

This involves frequently looking up a precomputed mask tensor at each step for each environment in the batch.

The environment is of shape (b, H, W) and the mask tensor is of shape (b, H, W, H, W) since we store a mask of the whole board for each square in each batch. Also this mask can be made to fit into memory by using a sparse tensor. since the masks are ‘small’.

So my first try was to perform a vmap on a tensor of indices `[[0, i_0, j_0], [1, i_1, j_1]...]`

using this function:

`def get(t): x=torch.select(masks, 0, t[0]) x=torch.select(x, 0, t[1]) x=torch.select(x, 0, t[2]) x=x.to_dense() return x`

which is essentially the sparse tensor version of `masks[t[0], t[1], t[2], :, :]`

However this doesn’t work because vmap doesn’t support the use of `.item()`

.

This is disappointing because this tweet by pytorch has a similar use case with batch indexing