Hi everyone,
I’m trying to use torch.vmap
to speed up my project.
I have a custom cuda function, and I’m trying to wrap it into vmap to work with batches.
I’ll write here a toy code to better understand what’s going on:
class MyModule(nn.module):
....
def forward(self, x):
data = my_custom_cuda_func(x)
....
my_module1 = MyModule()
my_module2 = MyModule()
params, buffers = stack_module_state([my_module1, my_module2])
base_model = copy.deepcopy(my_module1)
base_model = base_model.to('meta')
def fmodel(params, buffers, x):
return functional_call(base_model, (params, buffers), (x))
one_shot_model = vmap(fmodel)
test_data = ..... # a tensor of the shape [2, B, C]
results = one_shot_model(params, buffers, test_data)
I’m not sure that I’m properly understanding how the vmap works, but I would expect that when forward is called it receives a tensor of shape [B,C]. However, when I run the code, I get the error:
RuntimeError: Batching rule not implemented for my_custom_cuda_func. We could not generate a fallback.
I tried also to do wrap my_custom_cuda_func
in vmap as follow:
data = vmap(my_custom_cuda_func)(x)
but I get the same error. Do you have any suggestions?
Thanks!
Hi @viciopoli,
When using torch.func.vmap
it does not automatically generate a batching rule for custom functions (as noted at the bottom of the torch.func.vmap
docs: torch.func.vmap — PyTorch 2.3 documentation).
In the case of a custom function, you need to manually specify the vmap
rule (so torch.func.vmap
can figure out how to vectorize over your function). This can be implemented via a torch.autograd.Function
object, where you manually specify the vmap
attribute. You can read more in the documentation: Automatic differentiation package - torch.autograd — PyTorch 2.3 documentation
In this case, you must also manually specify the backward
formula for the custom torch.autograd.Function
object in order to compute gradients.
Hi @AlphaBetaGamma96 ,
Thanks for the answer, I tried to implement the batched version of my function by simply iterate over the non-btached function, I followed the writing_batching_rules guide:
std::vector<torch::Tensor> my_custom_func(
torch::Tensor x) {
...
return some_other_func(x);
}
std::pair<std::vector<std::vector<torch::Tensor>>, std::optional<int64_t>> my_custom_func_batch_rule(
const std::vector<torch::Tensor>& x
const std::optional<int64_t>& batch_dim) {
std::vector<std::vector<torch::Tensor>> results;
for (size_t i = 0; i < rays_o_batch.size(); ++i) {
auto sample_result = sample_pts_on_rays(x[i]);
results.push_back(sample_result);
}
return std::make_tuple(std::move(results), 0);
}
Is this the way of proceeding?
After doing this I tried to register the batched function using VMAP_SUPPORT:
TORCH_LIBRARY(utils, m) {
m.def("my_custom_func", &my_custom_func);
VMAP_SUPPORT(my_custom_func, my_custom_func_batch_rule)
....
}
However, it seems I’m missing something since when compiling I get the error:
error: ‘VMAP_SUPPORT’ was not declared in this scope
Do you know if there is a more in-depth guide for implementing these kinds of operations?
Thanks
UPDATE
#include <ATen/functorch/BatchRulesHelper.h>
solves the VMAP_SUPPORT
error.
I don’t think simply looping over the data within the torch.func.vmap
will work, but I’m not 100% sure.
From my understanding of torch.func.vmap
, it removes the batch dim of your input Tensor so ‘looping’ over the batch dim no longer makes sense (as it no longer exists).
I think it’d be best to get a dev’s opinion on this! Apologies for the tag, @ptrblck
hi @AlphaBetaGamma96,
Yes, I think that’s right, what I see when debugging the Python code is a BatchedTensor
, that has the same shape of the tensor that my function works with.
I also noticed that in the batch_rule declarations in BatchRulesViews.cpp there is a moveBatchDimToFront
call on the input tensor almost for in every batch_rule. Probably that’s the way of getting a normal tensor out of the BatchedTensor
.
But, yes looking forward to hearing other suggestions!
Thanks.
Your batching rule should allow the usage of a batched input, but as @AlphaBetaGamma96 pointed out, using a for loop would still launch multiple kernels (one for each sample if I understand your code correctly) and might thus not yield a significant speedup (moving the loop from Python to C++ could still help). You could check if the actual operator (sample_pts_on_rays
) could be rewritten in a way to accept batched inputs.
1 Like
Hi @ptrblck,
Thanks for your answer.
Sorry my bad, sample_pts_on_rays
should be there my_custom_func
. I was trying to call several times the function that does not work with batches. So, ideally, I should rewrite everything such that batched inputs are supported, but what exactly is a batched input?
When I debug the code in the Python script that calls the function, I see a BatchedTensor
, but how do you work with this Tensor type?
Also, there is probably something wrong with the function declaration, since I get my_custom_func_generated_plumbing
error when trying to VMAP_SUPPORT(my_custom_func, my_custom_func_batch_rule). Do you have any suggestions on why this could happen?
Thanks for your time!
Vincenzo
I don’t think you should directly interact with the BatchedTensor
type, as I think torch.func
handles that explicitly under the hood. I think the best approach would be to look for the look at writing_batching_rules
doc you mentioned above: pytorch/functorch/writing_batching_rules.md at main · pytorch/pytorch · GitHub
I see, ok I’ll try to figure this out and I’ll update this post!
Thanks