manozzm
(Manoj)
July 28, 2017, 4:41am
1
I have recently come across this note, which proposes better initialization scheme for sub-pixel/pixel-shuffle
https://arxiv.org/abs/1707.02937
Here is a sample implementation I found in Lasagne (https://github.com/Lasagne/Lasagne/issues/862 )
How can we implement the same in pytorch ?
I’m new to pytorch. I have gone through forums and found that we can custom initialize weights with,
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.copy_(some_weights)
It would be great if someone can provide a sample implementation the authors have suggested in the paper.
Thanks
1 Like
Kaixhin
(Kai Arulkumaran)
July 28, 2017, 11:47am
2
@alykhantejani is this something that you can be contribute to nn.init
?
Sorry - just saw this (I don’t check the forums much). Is this still something that’s desired?
Kaixhin
(Kai Arulkumaran)
December 1, 2017, 3:18pm
4
If it’s not too difficult, it would be nice to have in nn.init
. It’ll make it much easier for people to drop in and see if it helps for their problems, and we can also link PixelShuffle
to it so that people are aware that there’s a principled weight initialisation available.
mohit117
(Mohit Lamba)
June 3, 2020, 3:43am
5
https://github.com/pytorch/pytorch/pull/5429/files has a code for it. A slightly cleaned up form if it to do the required transformation I am posting below
import torch
import torch.nn as nn
def ICNR(tensor, upscale_factor=2, inizializer=nn.init.kaiming_normal_):
new_shape = [int(tensor.shape[0] / (upscale_factor ** 2))] + list(tensor.shape[1:])
subkernel = torch.zeros(new_shape)
subkernel = inizializer(subkernel)
subkernel = subkernel.transpose(0, 1)
subkernel = subkernel.contiguous().view(subkernel.shape[0],
subkernel.shape[1], -1)
kernel = subkernel.repeat(1, 1, upscale_factor ** 2)
transposed_shape = [tensor.shape[1]] + [tensor.shape[0]] + list(tensor.shape[2:])
kernel = kernel.contiguous().view(transposed_shape)
kernel = kernel.transpose(0, 1)
return kernel
upscale = 2
num_classes = 1
previous_layer_features = torch.Tensor(1, 2, 6, 6)
conv_shuffle = nn.Conv2d(2, num_classes * (upscale ** 2), 3, padding=1, bias=0)
print('initial weights 2 ip and 4 op{}'.format(conv_shuffle.weight))
ps = nn.PixelShuffle(upscale)
kernel = ICNR(conv_shuffle.weight, upscale)
conv_shuffle.weight.data.copy_(kernel)
print('\n ICNR weights {}'.format(conv_shuffle.weight))
output = ps(conv_shuffle(previous_layer_features))
print(output.shape)
1 Like