I want to make weight of certain layer to have a norm of 1, so how can I implement it? Will the following work properly?
class mylin(nn.Module):
def __init__(self):
super(mylin, self).__init__()
self.lin= nn.Parameter(torch.Tensor(20,1,5))
nn.init.xavier_uniform_(self.lin)
def forward(self, x):
setattr(self.lin, 'data', self.lin / self.lin.norm(dim=(1,2)))
x = torch.bmm(self.lin, x)
return x