I want to apply transfer learning by loading the state dict of one model to another. However, in the final layer which is a linear layer of 2D dim, i would like to randomly initalise it partially. Meaning say the linear layer is of size [3,256], how do i apply torch.nn.init.xavier_uniform_
to only 20% of the parameters in the linear layer?
Hi Wei Jie!
You could use a number of approaches.
I would probably use multinomial()
to choose which elements of
the Linear
to partially initialize. Then, really just for convenience, I
would apply xavier_uniform_()
to a temporary tensor of the same
shape as the Linear.weight
in question. Lastly, I would index into
the weight
tensor to update the desired elements inplace.
For example:
>>> import torch
>>> print (torch.__version__)
1.10.2
>>>
>>> _ = torch.manual_seed (2022)
>>>
>>> pretrained_linear = torch.nn.Linear (3, 5)
>>> xavier_weight = torch.nn.init.xavier_uniform_ (torch.empty_like (pretrained_linear.weight))
>>>
>>> pretrained_linear.weight
Parameter containing:
tensor([[-0.1203, 0.4872, 0.2988],
[-0.1372, -0.5471, -0.1624],
[ 0.3387, 0.3245, -0.0412],
[ 0.1535, 0.1951, -0.3115],
[ 0.3467, 0.2717, 0.2861]], requires_grad=True)
>>> xavier_weight
tensor([[ 5.6547e-01, -4.9935e-01, 6.1908e-01],
[-1.2431e-01, -1.7949e-01, -6.1666e-01],
[-8.6022e-01, 7.8006e-01, -4.8605e-01],
[-3.6499e-01, 3.0906e-01, -5.1421e-04],
[ 7.1781e-01, -3.7528e-01, 1.2801e-01]])
>>>
>>> mask = torch.multinomial (torch.tensor ([0.8, 0.2]), pretrained_linear.weight.nelement(), replacement = True).reshape_as (pretrained_linear.weight)
>>> mask
tensor([[1, 1, 0],
[0, 0, 0],
[1, 0, 0],
[1, 1, 0],
[0, 0, 0]])
>>>
>>> inds = torch.nonzero (mask)
>>> inds
tensor([[0, 0],
[0, 1],
[2, 0],
[3, 0],
[3, 1]])
>>>
>>> with torch.no_grad():
... pretrained_linear.weight[inds[:, 0], inds[:, 1]] = xavier_weight[inds[:, 0], inds[:, 1]]
...
>>> pretrained_linear.weight
Parameter containing:
tensor([[ 0.5655, -0.4994, 0.2988],
[-0.1372, -0.5471, -0.1624],
[-0.8602, 0.3245, -0.0412],
[-0.3650, 0.3091, -0.3115],
[ 0.3467, 0.2717, 0.2861]], requires_grad=True)
Best.
K. Frank
1 Like
Thanks frank! that work out well.