EDIT already solved it at the bottom
I have a nn.Model that has multiple other nn.Models as attribute.
I want to access the weights of the underlying models (they contain e.g. conv2d layer)
Dummy model here:
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.module_list = create_modules()
def create_modules():
module_list = nn.ModuleList()
for i in range(3):
modules = nn.Sequential()
if i == 0:
modules.add_module('Conv2d', nn.Conv2d(in_channels=1,
out_channels=1,
kernel_size=3))
modules.add_module('BatchNorm2d', nn.BatchNorm2d(1, momentum=0.03, eps=1E-4))
modules.add_module('activation', nn.LeakyReLU(0.1, inplace=True))
if i == 1:
maxpool = nn.MaxPool2d(kernel_size=(2, 2))
modules = maxpool
if i == 2:
modules.add_module('Conv2d', nn.Conv2d(in_channels=1,
out_channels=1,
kernel_size=3))
module_list.append(modules)
return module_list
model = Net()
for params in model.parameters():
print(params)
This prints
Parameter containing:
tensor([[[[ 0.0781, -0.2562, -0.2447],
[-0.2310, -0.2009, -0.1493],
[ 0.1241, 0.0419, -0.2956]]]], requires_grad=True)
Parameter containing:
tensor([-0.1773], requires_grad=True)
Parameter containing:
tensor([1.], requires_grad=True)
Parameter containing:
tensor([0.], requires_grad=True)
Parameter containing:
tensor([[[[ 0.0511, -0.0956, -0.0024],
[-0.0377, 0.1948, -0.1699],
[-0.3166, -0.0658, -0.0656]]]], requires_grad=True)
Parameter containing:
tensor([0.2309], requires_grad=True)
As you can see the conv2d layers are randomally initialized.
Now I want to set the weights of these two layers as torch.ones()
But this has to happen after the model is created.
So in my dummy code after
model = net()
For this I need to overwrite the parameters of my model with torch.ones()
But I am unsure on how to do that.
After the overwrite
for param in model.parameters()
should show
[1,1,1],[1,1,1],[1,1,1]
for my conv layers.
So how do I access the weights of a layer that is inside a model?
EDIT fixed it:
Inplace operation can do it:
for param in model.parameters():
param.data.fill_(1.0)
works