I am trying to simulate large numbers of 2d grid environments using GPU (a la Snake at a Million FPS).
This involves frequently looking up a precomputed mask tensor at each step for each environment in the batch.
The environment is of shape (b, H, W) and the mask tensor is of shape (b, H, W, H, W) since we store a mask of the whole board for each square in each batch. Also this mask can be made to fit into memory by using a sparse tensor. since the masks are ‘small’.
So my first try was to perform a vmap on a tensor of indices
[[0, i_0, j_0], [1, i_1, j_1]...]
using this function:
def get(t): x=torch.select(masks, 0, t) x=torch.select(x, 0, t) x=torch.select(x, 0, t) x=x.to_dense() return x
which is essentially the sparse tensor version of
masks[t, t, t, :, :]
However this doesn’t work because vmap doesn’t support the use of
This is disappointing because this tweet by pytorch has a similar use case with batch indexing
Thanks for this Ryan
May I ask for some minimal reproducible example? It’ll be easier to see if we can come up with something (I’m pretty sure we can!)
- Is it just the indexing of the mask that you want to vmap?
- Also how big are b, H and W? Do you really need a sparse tensor for this, as values will be Boolean it should not take much space if b H and W are not crazy big.
In the meantime I’ll try a few things given my understanding of the problem.
This was my attempt and verification that its the indexing that’s the problem
from functorch import vmap
idx = torch.zeros(size=(5, 0))
values = torch.zeros(size=(0,))
masks = torch.sparse_coo_tensor(idx,values,(10^4,20,20,20,20))
x=torch.select(masks, 0, t)
x=torch.select(x, 0, t)
x=torch.select(x, 0, t)
x=torch.select(masks, 0, 0)
x=torch.select(x, 0, 1)
x=torch.select(x, 0, 2)
i = torch.tensor([[, 0, 0] for _ in range(10^4)])
There are other operations that must be batched, but they are more straightforward. I believe all the other rules of the environments can be implemented efficiently using convolutions and multiplication. These masks I’m retrieving are somewhat expensive to compute, so rather than doing that at every step I chose to memoize it. It should be quite large, if I’m to streamline this otherwise expensive RL algorithm (e.g. 10^5). Say H, W = 20, B = 10^4 for starters
Open an issue on their github page and they might be able to fix it. Their github is here.
Did you manage to come up with anything? In the meantime, I’m exploring how to possibly circumvent this limitation. However it requires that the first n steps of my batch of states to all be the same, which is not ideal.
I’m not sure what you’re trying to do can be done, but here are a few things we could explore:
in your example you’re not using Boolean tensors. The shape of a Boolean tensor or 104 x 204 will be approx. 1.5Gb. It’s big, but not impossible to handle. Of course this is integrated is some more complex pipeline over which I have little grasp.
In TorchRL we have written a MemmapTensor class that you might be interested in. The idea is that you’re using a memory mapped array that sits on disk and that you can access easily in a row-by-row fashion (i.e. not loading the full data in memory to read it).
To create it, just call
Otherwise, we could perhaps work with the indices of the sparse tensor (match the indices from i with those from the sparse tensor)?
I’ll look in some other solutions too…