Custom weight initialization

Dear experienced ones,

What would be the right way to implement a custom weight initialization method?
I believe I can’t directly add any method to torch.nn.init but wish to initialize my model’s weights with my own proprietary method.

Thanks!

4 Likes

Have a look at the other init implementations: code.
It’s basically a vanilla Python function setting the values.
Just create a new function with your proprietary method and call it with model.apply(my_init).

1 Like

Lets say this is my nn


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

and I want my weights, be like this
W = torch.Tensor([[1 ,0, -1],[2, 0 ,-2], [1, 0 ,-1]])

Is there anyway to do that?

1 Like

and I want my weights, be like this
W = torch.Tensor([[1 ,0, -1],[2, 0 ,-2], [1, 0 ,-1]])

Is there anyway to do that?

you don’t even have to implement a special PyTorch function for that. E.g. you can assign these values to

with torch.no_grad():
    self.fc1.weight = ... 

where ... would be replaced by a tensor (e.g. W in your case) or a function that creates a tensor with the right dimensions (assuming that in a real usecase, it’s probably to tedious to type out the whole tensor).

2 Likes

So for example for self.conv1 can i do

with torch.no_grad():
    self.conv1.weight = ... 

???

yes, that should work.

sorry for my stupid question,
but i get this error:

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3,3, 3)
        K = torch.Tensor([[1 ,0, -1],[2, 0 ,-2], [1, 0 ,-1]])
#        I think I should make the shape/size like this?
        K = torch.unsqueeze(torch.unsqueeze(K,0),0)
        with torch.no_grad():
            self.conv1.weight = K
        
    def forward(self, x):
        x = self.conv1(x)
        return x
    
    


net = Net()
net(rand(4,3,10,10))

TypeError: cannot assign 'torch.FloatTensor' as parameter 'weight' (torch.nn.Parameter or None expected)
1 Like

Ah sorry about that, I forgot to mention torch.nn.Parameter, which basically makes the weight recognizable as a parameter when you then do sth like

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  

So, in your case, the following should fix it:

    with torch.no_grad():
        self.conv1.weight = torch.nn.Parameter(K)

EDIT: I am actually not sure if it even requires no_grad, because this is still in the __init__ part and not in the forward call.

2 Likes

@isalirezag
Did @rasbt 's suggestion work for you?

Because for me, it doesnt. Torch 0.4.

In [22]: emb.weight
Out[22]: 
Parameter containing:
tensor([[-0.5733,  0.8611,  0.0699],
        [ 1.7558,  0.7413, -0.4029],
        [-0.7413, -1.0902,  0.0550],
        [-0.9363, -0.6691, -0.4870],
        [ 0.0080, -0.4598,  0.9987]])

In [23]: K = torch.Tensor([[1 ,0, -1],[2, 0 ,-2], [1, 0 ,-1],[3,4,5],[6,7,8]])

In [24]: K
Out[24]: 
tensor([[ 1.,  0., -1.],
        [ 2.,  0., -2.],
        [ 1.,  0., -1.],
        [ 3.,  4.,  5.],
        [ 6.,  7.,  8.]])

In [25]: emb.weight = emb.weight+ nn.Parameter(K)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-25-9f658db7039f> in <module>()
----> 1 emb.weight = emb.weight+ nn.Parameter(K)

~/anaconda3/envs/san/lib/python3.6/site-packages/torch/nn/modules/module.py in __setattr__(self, name, value)
    549                 raise TypeError("cannot assign '{}' as parameter '{}' "
    550                                 "(torch.nn.Parameter or None expected)"
--> 551                                 .format(torch.typename(value), name))
    552             self.register_parameter(name, value)
    553         else:

TypeError: cannot assign 'torch.FloatTensor' as parameter 'weight' (torch.nn.Parameter or None expected)

How do I add arbitrary tensor to embedding’s weight?

emb.weight.data = emb.weight.data+ nn.Parameter(K)

weight.data worked.

But i think since you have used .data, it is unneccesary to convert that K to Parameter. You can remove the with torch.no_grad() block and remove the nn.Parameter.

So, the revised code example will be as follows:

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3,3, 3)
        K = torch.Tensor([[1 ,0, -1],[2, 0 ,-2], [1, 0 ,-1]])
#        I think I should make the shape/size like this?
        K = torch.unsqueeze(torch.unsqueeze(K,0),0)
        #with torch.no_grad():
        self.conv1.weight.data = self.conv1.weight.data + K

    def forward(self, x):
        x = self.conv1(x)
        return x

net = Net()
net(torch.randn(4,3,10,10))
4 Likes

What if you want to optimize the weight?

1 Like

The weight will be optimized. It’s just that the initial values have changed as the question is how to use the custom initialization.

Vahid is right that in the case of his example

self.conv1.weight.data = self.conv1.weight.data + K

this will work because “weight” is already a parameter, and you are just modifying its value. But if you want to assign a completely new tensor to “weight” you would need wrap Parameter around that to get correct behavior.

2 Likes

But there is still a problem what if we do use nn.sequential to bulid our model,how colud we initial our model with this method?Just like this:

class Net(nn.Module):
    def __init__(self):                                             
        super(Net,self).__init__()
        #卷积层
        self.conv1 =torch.nn.Sequential(
            torch.nn.Conv1d(4,64,kernel_size=3,stride=1,padding=1),

            torch.nn.ReLU(),
            torch.nn.Conv1d(64,128,kernel_size=3,stride=1,padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool1d(stride=2,kernel_size=2)
        )
        #全连接层
        self.dense=torch.nn.Sequential(
            torch.nn.Linear(640,320),
            torch.nn.ReLU(),
            torch.nn.Dropout(p=0.5),
            torch.nn.Linear(320,10)
        )
    def forward(self,x):
        x=self.conv1(x)
        x=x.view(-1,640)
        x=self.dense(x)
        return x  

You can still do that. I created an instance of the class Net:

>>> net = Net()
>>> net
Net(
  (conv1): Sequential(
    (0): Conv1d(4, 64, kernel_size=(3,), stride=(1,), padding=(1,))
    (1): ReLU()
    (2): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,))
    (3): ReLU()
    (4): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (dense): Sequential(
    (0): Linear(in_features=640, out_features=320, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5)
    (3): Linear(in_features=320, out_features=10, bias=True)
  )
)

So, we can access the layers of this object by their names and their index:

>>> net.conv1[0]
Conv1d(4, 64, kernel_size=(3,), stride=(1,), padding=(1,))

So, we can assign the weights of each layer similarly:

net.conv1[0].weight.data = net.conv1[0].weight.data + K
3 Likes

It is so great,thank you so much!!




litianfudt

邮箱:litianfudt@163.com

Signature is customized by Netease Mail Master

Sure, I am happy to help!