I want to define an nn.ModuleDict()
and iterate through it preserving the order of keys as defined. That is, I would like to define the dict as follows:
D = nn.ModuleDict(
{
'b': nn.Linear(2,4),
'a': nn.Linear(4,8)
}
)
so as
for k, v in D.items():
print(k, v)
b Linear(in_features=2, out_features=4, bias=True)
a Linear(in_features=4, out_features=8, bias=True)
However, by default, PyTorch will iterate through the dict after sorting alphanumerically its keys:
import torch
import torch.nn as nn
D = nn.ModuleDict(
{
'b': nn.Linear(2,4),
'a': nn.Linear(4,8)
}
)
D
ModuleDict(
(a): Linear(in_features=4, out_features=8, bias=True)
(b): Linear(in_features=2, out_features=4, bias=True)
)
for k, v in D.items():
print(k, v)
a Linear(in_features=4, out_features=8, bias=True)
b Linear(in_features=2, out_features=4, bias=True)
On the contrary, if we use update()
, it works as intended. That is
D = nn.ModuleDict()
D.update({'b': nn.Linear(2, 4)})
D.update({'a': nn.Linear(4, 8)})
D
ModuleDict(
(b): Linear(in_features=2, out_features=4, bias=True)
(a): Linear(in_features=4, out_features=8, bias=True)
)
for k, v in D.items():
print(k, v)
b Linear(in_features=2, out_features=4, bias=True)
a Linear(in_features=4, out_features=8, bias=True)
Is there any way of iterating through the dict using the order of keys given during dict’s definition? One solution would be using keys that are ordered in the first place, but I would like to know if this is possible in the general case.