VMAP over GRU: Batching rule not implemented for aten::gru.input

Hello!

I need to implement a VMAP over a complex function that at some point calls a standard torch.GRU.

However, when the VMAP function is called, a “RuntimeError: Batching rule not implemented” type error is raised (associated with the internal _VF.gru call) as shown in the snippet below . The error seems very similar to the one raised and solved in #1089, but for some reason is not working for me.

Can you help me understand if I’m doing something wrong or if there is indeed still something that needs to be fixed?

This is critical for a project under development, I would really appreciate your help.

Thank you in advance!


VERSIONS:
python --version => Python 3.10.12
torch.version => 2.4.0+cu121

CODE TO REPRODUCE

import torch

# Set dimensions
input_size = 10
hidden_size = 2
num_layers=1
sequence_length = 5
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set GRU
rnn = torch.nn.GRU(input_size, hidden_size, num_layers)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set unbatched input
input = torch.randn(sequence_length, input_size)
h0 = torch.zeros(num_layers, hidden_size)
# Call GRU
output, _ = rnn(input, h0)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set batched input for VMAP (along first dimension)
batched_input = input.unsqueeze(0).repeat(3, 1, 1)
# Set GRU function
def my_function(input):
  h0 = torch.zeros(num_layers, hidden_size)
  output, _ = rnn(input, h0)
  return output
# Set VMAP GRU
vmap_gru = torch.vmap(my_function)
# Call VMAP GRU
vmap_output = vmap_gru(batched_input)

ERROR

RuntimeError                              Traceback (most recent call last)
[<ipython-input-9-dfebfb7b7a06>](https://localhost:8080/#) in <cell line: 26>()
     24 vmap_gru = torch.vmap(my_function)
     25 # Call VMAP GRU
---> 26 vmap_output = vmap_gru(batched_input)

7 frames
[/usr/local/lib/python3.10/dist-packages/torch/_functorch/apis.py](https://localhost:8080/#) in wrapped(*args, **kwargs)
    199 
    200     def wrapped(*args, **kwargs):
--> 201         return vmap_impl(
    202             func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
    203         )

[/usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py](https://localhost:8080/#) in vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
    329 
    330     # If chunk_size is not specified.
--> 331     return _flat_vmap(
    332         func,
    333         batch_size,

[/usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py](https://localhost:8080/#) in fn(*args, **kwargs)
     46     def fn(*args, **kwargs):
     47         with torch.autograd.graph.disable_saved_tensors_hooks(message):
---> 48             return f(*args, **kwargs)
     49 
     50     return fn

[/usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py](https://localhost:8080/#) in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
    478             flat_in_dims, flat_args, vmap_level, args_spec
    479         )
--> 480         batched_outputs = func(*batched_inputs, **kwargs)
    481         return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
    482 

[<ipython-input-9-dfebfb7b7a06>](https://localhost:8080/#) in my_function(input)
     19 def my_function(input):
     20   h0 = torch.zeros(num_layers, hidden_size)
---> 21   output, _ = rnn(input, h0)
     22   return output
     23 # Set VMAP GRU

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1551             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552         else:
-> 1553             return self._call_impl(*args, **kwargs)
   1554 
   1555     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1560                 or _global_backward_pre_hooks or _global_backward_hooks
   1561                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562             return forward_call(*args, **kwargs)
   1563 
   1564         try:

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/rnn.py](https://localhost:8080/#) in forward(self, input, hx)
   1137         self.check_forward_args(input, hx, batch_sizes)
   1138         if batch_sizes is None:
-> 1139             result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,
   1140                              self.dropout, self.training, self.bidirectional, self.batch_first)
   1141         else:

RuntimeError: Batching rule not implemented for aten::gru.input. We could not generate a fallback.