How to iterate through nn.ModuleDict without ordering keys

I want to define an nn.ModuleDict() and iterate through it preserving the order of keys as defined. That is, I would like to define the dict as follows:

D = nn.ModuleDict( 
    { 
        'b': nn.Linear(2,4), 
        'a': nn.Linear(4,8) 
    } 
) 

so as

for k, v in D.items(): 
    print(k, v)
                                                                                                                                       
b Linear(in_features=2, out_features=4, bias=True)
a Linear(in_features=4, out_features=8, bias=True)

However, by default, PyTorch will iterate through the dict after sorting alphanumerically its keys:

import torch
import torch.nn as nn

D = nn.ModuleDict( 
    { 
        'b': nn.Linear(2,4), 
        'a': nn.Linear(4,8) 
    } 
) 

D
ModuleDict(
  (a): Linear(in_features=4, out_features=8, bias=True)
  (b): Linear(in_features=2, out_features=4, bias=True)
)

for k, v in D.items(): 
    print(k, v)
                                                                                                                                       
a Linear(in_features=4, out_features=8, bias=True)
b Linear(in_features=2, out_features=4, bias=True)

On the contrary, if we use update(), it works as intended. That is

D = nn.ModuleDict()
D.update({'b': nn.Linear(2, 4)})                                                                                                            
D.update({'a': nn.Linear(4, 8)})

D
ModuleDict(
    (b): Linear(in_features=2, out_features=4, bias=True)
    (a): Linear(in_features=4, out_features=8, bias=True)
)

for k, v in D.items(): 
    print(k, v)
                                                                                                                                       
b Linear(in_features=2, out_features=4, bias=True)
a Linear(in_features=4, out_features=8, bias=True)

Is there any way of iterating through the dict using the order of keys given during dict’s definition? One solution would be using keys that are ordered in the first place, but I would like to know if this is possible in the general case.

If you pass in an ordered dict, the ordering will be preserved:

nn.ModuleDict(OrderedDict({ 
   'b': nn.Linear(2,4), 
   'a': nn.Linear(4,8) 
} ))

The crux is that Python 2 does not preserve order in dict (and for early Python 3.x it’s an implementation detail), so in order to have a deterministic dict->OrderedDict conversion, the key are sorted.

Alternatively, you can also pass in an iterator {...}.items() - looks funny, probably deserves an explanatory comment in the code, but works.

Best regards

Thomas

Hi @tom, many thanks for your quick response!

I see. That’s not bad solution of course, but especially in the case of a “Module Dictionary”, I cannot see why this isn’t the default behavior. I mean, when someone defines a modules dictionary, like I did above, then they would also expect to iterate it in the order it’s been defined.

Many thanks again!

Also if you use python 3.7, it should preserve the order of the keys for a regular dict.

Hi @dhpollack, I’m using python 3.7.3, but it doesn’t.