Parameter order changes when using TorchScript

Hello,

After updating PyTorch 1.0.1 -> 1.4.0 I encountered some unwanted behavior. When creating a ScriptModule, the order of its Parameters changes. I’ll attach an example based on the https://github.com/pytorch/benchmark/blob/master/rnns/fastrnns/custom_lstms.py file.

Two classes are defined. The only difference between the two is that one is a ScriptModule and the other an nn.Module. By plotting their parameter sizes we see that they are now stored in a different order than declared in the class (they also stop matching the native cell version). It might cause some backward compatibility problems (e.g. in the case of ‘custom_lstms.py’ the ‘test_script…’ functions stopped working).

import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.jit as jit
from typing import List, Tuple
from torch import Tensor
class LSTMCell(jit.ScriptModule):
    def __init__(self, input_size, hidden_size):
        super(LSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
        self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
        self.bias_ih = Parameter(torch.randn(4 * hidden_size))
        self.bias_hh = Parameter(torch.randn(4 * hidden_size))

    @jit.script_method
    def forward(self, input, state):
        # type: (Tensor, Tuple[Tensor, Tensor]) -> Tensor
        return input
    

class LSTMCellNoJit(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(LSTMCellNoJit, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
        self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
        self.bias_ih = Parameter(torch.randn(4 * hidden_size))
        self.bias_hh = Parameter(torch.randn(4 * hidden_size))

    def forward(self, input, state):
        return input
net = LSTMCell(2, 10)

pars = [p for p in net.parameters()]
for i in range(len(pars)):
    print(pars[i].shape)

Output:
torch.Size([40, 10])
torch.Size([40])
torch.Size([40])
torch.Size([40, 2])

net = LSTMCellNoJit(2, 10)

pars = [p for p in net.parameters()]
for i in range(len(pars)):
    print(pars[i].shape)

Output:
torch.Size([40, 2])
torch.Size([40, 10])
torch.Size([40])
torch.Size([40])

Hey thanks for the report / repro! This is an issue that’s been open for a while, we’re planning to fix it soon