Hi !
I am currently trying to trace a custom RNN model to be able to execute it faster (I think it would make sense given how it uses for loops, but feel free to correct me if I’m wrong), for which a minimal example would be as follows :
import torch
from torch.nn import Module, Parameter
from torch.autograd import Variable
class minimal_ex(Module):
def __init__(self, input_size=1, recurrent_size=1):
super(minimal_ex, self).__init__()
n_h, n_i = recurrent_size, input_size
self.input_size = input_size
self.recurrent_size = recurrent_size
self.activation = lambda x: x
self.mask_rec = Parameter(torch.ones(n_h, n_h), requires_grad=False)
self.mask_in = Parameter(torch.ones(n_h, n_i), requires_grad=False)
self.neuron_signs = Parameter((torch.ones(n_h)).float(), requires_grad=False)
self.w_i = Parameter(torch.ones(recurrent_size, input_size))
self.w_h = Parameter(torch.zeros(recurrent_size, recurrent_size))
self.bias = Parameter(torch.zeros(recurrent_size))
self.cuda()
def forward(self, inp):
h_0 = Variable(torch.zeros(1, self.recurrent_size)).cuda()
batch_size, seq_len, dim = inp.shape
w_i = (self.w_i * self.mask_in).transpose(0,1)
w_h = (self.w_h * self.mask_rec).transpose(0,1)
w_h_with_signs = torch.mm(torch.diag(self.neuron_signs), w_h)
h = FloatTensor(batch_size, seq_len, self.recurrent_size)
h[:, -1, :] = h_0
for t in torch.arange(seq_len):
h[:, t, :] = self.activation(torch.mm(inp[:, t, :].clone(), w_i) +
torch.mm(h[:, t-1, :].clone(), w_h_with_signs) - self.bias)
return h
bs = 64
seq_len = 500
in_size = 2
inp = torch.FloatTensor(bs, seq_len, in_size).uniform_().cuda()
model = minimal_ex(input_size=in_size, recurrent_size=128)
traced_forward = torch.jit.trace(model, inp)
When I do so, the output contains several warnings :
- TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can’t record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
h = FloatTensor(batch_size, seq_len, self.recurrent_size) - RuntimeWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won’t change the number of iterations executed (and might lead to errors or silently give incorrect results).
‘incorrect results).’, category=RuntimeWarning) - TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can’t record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
torch.mm(h[:, t-1, :].clone(), w_h_with_signs) - self.bias) - TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the repeated trace. Detailed error:
Not within tolerance rtol=1e-05 atol=1e-05 at input[0, 26, 0] (0.48441100120544434 vs. 3.6893488147419103e+19) and 4095999 other locations (100.00%)
_check_trace([example_inputs], func, executor_options, module, check_tolerance, _force_outplace)
I was expecting some warnings (in particular, if I change the seq_len I would not expect it to work, but this is fine), but I don’t understand why the output of the traced model blows up while the Python one does not.
Is there something wrong with the way I use trace?
Thank you in advance for your help