Hello,
After updating PyTorch 1.0.1 -> 1.4.0 I encountered some unwanted behavior. When creating a ScriptModule, the order of its Parameters changes. I’ll attach an example based on the https://github.com/pytorch/benchmark/blob/master/rnns/fastrnns/custom_lstms.py file.
Two classes are defined. The only difference between the two is that one is a ScriptModule and the other an nn.Module. By plotting their parameter sizes we see that they are now stored in a different order than declared in the class (they also stop matching the native cell version). It might cause some backward compatibility problems (e.g. in the case of ‘custom_lstms.py’ the ‘test_script…’ functions stopped working).
import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.jit as jit
from typing import List, Tuple
from torch import Tensor
class LSTMCell(jit.ScriptModule):
def __init__(self, input_size, hidden_size):
super(LSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
self.bias_ih = Parameter(torch.randn(4 * hidden_size))
self.bias_hh = Parameter(torch.randn(4 * hidden_size))
@jit.script_method
def forward(self, input, state):
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tensor
return input
class LSTMCellNoJit(nn.Module):
def __init__(self, input_size, hidden_size):
super(LSTMCellNoJit, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
self.bias_ih = Parameter(torch.randn(4 * hidden_size))
self.bias_hh = Parameter(torch.randn(4 * hidden_size))
def forward(self, input, state):
return input
net = LSTMCell(2, 10)
pars = [p for p in net.parameters()]
for i in range(len(pars)):
print(pars[i].shape)
Output:
torch.Size([40, 10])
torch.Size([40])
torch.Size([40])
torch.Size([40, 2])
net = LSTMCellNoJit(2, 10)
pars = [p for p in net.parameters()]
for i in range(len(pars)):
print(pars[i].shape)
Output:
torch.Size([40, 2])
torch.Size([40, 10])
torch.Size([40])
torch.Size([40])