Hi,
I have a kernel that works well on non-batch data, and I want to extend it to batchwise-format, and calculate gradient through it. The problem is our batch is not heterogeneous, it looks like:
what we want to do is use vmap instead of for loop:
result = []
for i in stride:
ans = kernel(x[stride: stride+1])
result.append(ans)
return torch.concat(result)
vmap(kernel, in_dims=(None, 0, 0))(x, stride[:-1], stride[1:])
apparently it’s not working. What should I do? is there any other solution except padding?
You might be interested in nested tensors which are special subclass of tensor designed to work on batches of data where the batches are jagged/have varying length.
(Note that some parts of it are out of date, any call to torch.nested.nested_tensor should be passed a layout=torch.jagged argument; The part about implementing MHA has been updated relatively recently though)
NestedTensors unfortunately don’t work with vmap yet, so you’ll need to explicitly write out the batch dim in your program. (We have plans to support this soon though)
Thank you so much! I’ve used nested tensors in my collate_fn before, but I haven’t tried applying it in this context yet. I’ll test it out and update this post once I have some results.
I have read the doc but am still slightly confused: when should I use jagged?
# Note: the torch.jagged layout is a nested tensor layout that supports a single ragged
# dimension and works with torch.compile. The batch items each have shape (B, S*, D)
# where B = batch size, S* = ragged sequence length, and D = embedding dimension.
I believe those tensors I don’t need a jagged layout. But if I have a bunch of tensors with shape [(n, ), (m, )...], then I need jagged? Is that jagged mean the dimension in the middle is irregular?
I believe those tensors I don’t need a jagged layout.
Yep
But if I have a bunch of tensors with shape [(n, ), (m, )...] , then I need jagged?
Jagged layout only supports a single jagged dimension, e.g. among the tensors that you have, only a single dimension can vary. So if you have a bunch of 2-D tensors where n/m can both vary, then jagged layout would not work.
Is that jagged mean the dimension in the middle is irregular?
Your jagged dimension cannot be the very first one because there needs to be a preceeding batch dimension.
But you can have a 2D nested tensor, e.g. with shape= [3, [1, 2, 3]] where the varying dimension is the very last one.