Batching custom cuda function with vmap

Hi everyone,

I’m trying to use torch.vmap to speed up my project.
I have a custom cuda function, and I’m trying to wrap it into vmap to work with batches.

I’ll write here a toy code to better understand what’s going on:

class MyModule(nn.module):
    ....
    def forward(self, x):
        data = my_custom_cuda_func(x)
        ....

my_module1 = MyModule()
my_module2 = MyModule()

params, buffers = stack_module_state([my_module1, my_module2])

base_model = copy.deepcopy(my_module1)
base_model = base_model.to('meta')

def fmodel(params, buffers, x):
    return functional_call(base_model, (params, buffers), (x))

one_shot_model = vmap(fmodel)

test_data = ..... # a tensor of the shape [2, B, C]

results = one_shot_model(params, buffers, test_data)

I’m not sure that I’m properly understanding how the vmap works, but I would expect that when forward is called it receives a tensor of shape [B,C]. However, when I run the code, I get the error:

RuntimeError: Batching rule not implemented for my_custom_cuda_func. We could not generate a fallback.

I tried also to do wrap my_custom_cuda_func in vmap as follow:

data = vmap(my_custom_cuda_func)(x)

but I get the same error. Do you have any suggestions?
Thanks!

Hi @viciopoli,

When using torch.func.vmap it does not automatically generate a batching rule for custom functions (as noted at the bottom of the torch.func.vmap docs: torch.func.vmap — PyTorch 2.3 documentation).

In the case of a custom function, you need to manually specify the vmap rule (so torch.func.vmap can figure out how to vectorize over your function). This can be implemented via a torch.autograd.Function object, where you manually specify the vmap attribute. You can read more in the documentation: Automatic differentiation package - torch.autograd — PyTorch 2.3 documentation

In this case, you must also manually specify the backward formula for the custom torch.autograd.Function object in order to compute gradients.

Hi @AlphaBetaGamma96 ,

Thanks for the answer, I tried to implement the batched version of my function by simply iterate over the non-btached function, I followed the writing_batching_rules guide:

std::vector<torch::Tensor> my_custom_func(
        torch::Tensor x) {
  ...
  return some_other_func(x);
}

std::pair<std::vector<std::vector<torch::Tensor>>, std::optional<int64_t>> my_custom_func_batch_rule(
        const std::vector<torch::Tensor>& x
        const std::optional<int64_t>& batch_dim) {
    std::vector<std::vector<torch::Tensor>> results;
    for (size_t i = 0; i < rays_o_batch.size(); ++i) {
        auto sample_result = sample_pts_on_rays(x[i]);
        results.push_back(sample_result);
    }
    return std::make_tuple(std::move(results), 0);
}

Is this the way of proceeding?
After doing this I tried to register the batched function using VMAP_SUPPORT:

TORCH_LIBRARY(utils, m) {
  m.def("my_custom_func", &my_custom_func);
  VMAP_SUPPORT(my_custom_func, my_custom_func_batch_rule)
  ....
}

However, it seems I’m missing something since when compiling I get the error:

error: ‘VMAP_SUPPORT’ was not declared in this scope

Do you know if there is a more in-depth guide for implementing these kinds of operations?

Thanks

UPDATE

#include <ATen/functorch/BatchRulesHelper.h>

solves the VMAP_SUPPORT error.

I don’t think simply looping over the data within the torch.func.vmap will work, but I’m not 100% sure.

From my understanding of torch.func.vmap, it removes the batch dim of your input Tensor so ‘looping’ over the batch dim no longer makes sense (as it no longer exists).

I think it’d be best to get a dev’s opinion on this! Apologies for the tag, @ptrblck

hi @AlphaBetaGamma96,

Yes, I think that’s right, what I see when debugging the Python code is a BatchedTensor, that has the same shape of the tensor that my function works with.

I also noticed that in the batch_rule declarations in BatchRulesViews.cpp there is a moveBatchDimToFront call on the input tensor almost for in every batch_rule. Probably that’s the way of getting a normal tensor out of the BatchedTensor.

But, yes looking forward to hearing other suggestions!

Thanks.

Your batching rule should allow the usage of a batched input, but as @AlphaBetaGamma96 pointed out, using a for loop would still launch multiple kernels (one for each sample if I understand your code correctly) and might thus not yield a significant speedup (moving the loop from Python to C++ could still help). You could check if the actual operator (sample_pts_on_rays) could be rewritten in a way to accept batched inputs.

1 Like

Hi @ptrblck,

Thanks for your answer.
Sorry my bad, sample_pts_on_rays should be there my_custom_func. I was trying to call several times the function that does not work with batches. So, ideally, I should rewrite everything such that batched inputs are supported, but what exactly is a batched input?
When I debug the code in the Python script that calls the function, I see a BatchedTensor, but how do you work with this Tensor type?

Also, there is probably something wrong with the function declaration, since I get my_custom_func_generated_plumbing error when trying to VMAP_SUPPORT(my_custom_func, my_custom_func_batch_rule). Do you have any suggestions on why this could happen?

Thanks for your time!
Vincenzo

I don’t think you should directly interact with the BatchedTensor type, as I think torch.func handles that explicitly under the hood. I think the best approach would be to look for the look at writing_batching_rules doc you mentioned above: pytorch/functorch/writing_batching_rules.md at main · pytorch/pytorch · GitHub

I see, ok I’ll try to figure this out and I’ll update this post!

Thanks