Hi, there,
I face a problem when I use vmap to calculate the batch of per-sample gradient of a RNN model. It always raise the dimension mismatch error and batching rule not implemented error. Is there any one can help me solve this?
Here is my codes:
import torch
import torch.nn as nn
from torch.func import functional_call, vmap, grad
class SimpleRNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.hidden_size = hidden_size
# batch_first=True means input/output is (batch, sequence, features)
self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
# x: (sequence_length, input_size) for a single sequence
# x: (batch_size, sequence_length, input_size) for a batch
h0 = torch.zeros(1, x.size(0), self.hidden_size).to(x.device)
out, _ = self.rnn(x, h0)
output = self.fc(out[:, -1, :])
return output
def compute_loss_for_single_sequence(model, single_input_sequence):
input_with_batch = single_input_sequence.unsqueeze(0)
loss = functional_call(model, dict(model.named_parameters()), input_with_batch)
return loss
if __name__ == '__main__':
# 1. Prepare batched data and model
input_size = 10
hidden_size = 20
output_size = 1
batch_size = 1000
sequence_length = 5
feature_dim = input_size
model = SimpleRNN(input_size, hidden_size, output_size)
params = {name: p for name, p in model.named_parameters()}
# Input data: (batch_size, sequence_length, feature_dim)
batch_input = torch.randn(batch_size, sequence_length, feature_dim, requires_grad=True)
# 2. test batch evaluate the forward
batch_output = vmap(compute_loss_for_single_sequence, in_dims=(None, 0))(model, batch_input)
# 3. test batch per-sample gradients
per_example_grad_fn = vmap(grad(compute_loss_for_single_sequence), in_dims=(None, 0))(model, batch_input)
- when test batch evaluate the forward, it raises the error:
RuntimeError: Batching rule not implemented for aten::rnn_tanh.input. We could not generate a fallback. - when test batch per-sample gradients, it raises the error:
I cannot correctly use vmap to evaluate function “compute_loss_for_single_sequence”, it obviously raises the error when test batch per-sample gradients.
ValueError: Thing passed to transform API must be Tensor, got <class ‘main.SimpleRNN’>