Hello!
I need to implement a VMAP over a complex function that at some point calls a standard torch.GRU.
However, when the VMAP function is called, a “RuntimeError: Batching rule not implemented” type error is raised (associated with the internal _VF.gru call) as shown in the snippet below . The error seems very similar to the one raised and solved in #1089, but for some reason is not working for me.
Can you help me understand if I’m doing something wrong or if there is indeed still something that needs to be fixed?
This is critical for a project under development, I would really appreciate your help.
Thank you in advance!
VERSIONS:
python --version => Python 3.10.12
torch.version => 2.4.0+cu121
CODE TO REPRODUCE
import torch
# Set dimensions
input_size = 10
hidden_size = 2
num_layers=1
sequence_length = 5
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set GRU
rnn = torch.nn.GRU(input_size, hidden_size, num_layers)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set unbatched input
input = torch.randn(sequence_length, input_size)
h0 = torch.zeros(num_layers, hidden_size)
# Call GRU
output, _ = rnn(input, h0)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set batched input for VMAP (along first dimension)
batched_input = input.unsqueeze(0).repeat(3, 1, 1)
# Set GRU function
def my_function(input):
h0 = torch.zeros(num_layers, hidden_size)
output, _ = rnn(input, h0)
return output
# Set VMAP GRU
vmap_gru = torch.vmap(my_function)
# Call VMAP GRU
vmap_output = vmap_gru(batched_input)
ERROR
RuntimeError Traceback (most recent call last)
[<ipython-input-9-dfebfb7b7a06>](https://localhost:8080/#) in <cell line: 26>()
24 vmap_gru = torch.vmap(my_function)
25 # Call VMAP GRU
---> 26 vmap_output = vmap_gru(batched_input)
7 frames
[/usr/local/lib/python3.10/dist-packages/torch/_functorch/apis.py](https://localhost:8080/#) in wrapped(*args, **kwargs)
199
200 def wrapped(*args, **kwargs):
--> 201 return vmap_impl(
202 func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
203 )
[/usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py](https://localhost:8080/#) in vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
329
330 # If chunk_size is not specified.
--> 331 return _flat_vmap(
332 func,
333 batch_size,
[/usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py](https://localhost:8080/#) in fn(*args, **kwargs)
46 def fn(*args, **kwargs):
47 with torch.autograd.graph.disable_saved_tensors_hooks(message):
---> 48 return f(*args, **kwargs)
49
50 return fn
[/usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py](https://localhost:8080/#) in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
478 flat_in_dims, flat_args, vmap_level, args_spec
479 )
--> 480 batched_outputs = func(*batched_inputs, **kwargs)
481 return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
482
[<ipython-input-9-dfebfb7b7a06>](https://localhost:8080/#) in my_function(input)
19 def my_function(input):
20 h0 = torch.zeros(num_layers, hidden_size)
---> 21 output, _ = rnn(input, h0)
22 return output
23 # Set VMAP GRU
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
1554
1555 def _call_impl(self, *args, **kwargs):
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1563
1564 try:
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/rnn.py](https://localhost:8080/#) in forward(self, input, hx)
1137 self.check_forward_args(input, hx, batch_sizes)
1138 if batch_sizes is None:
-> 1139 result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,
1140 self.dropout, self.training, self.bidirectional, self.batch_first)
1141 else:
RuntimeError: Batching rule not implemented for aten::gru.input. We could not generate a fallback.