Like the title says, I am looking for a way to be able to constrain the weights leading out of a given neuron to sum to 1. Is there a way to do this as I am getting an error for changing the values needed for gradient computation.
Yeah you can use the nn functional module.
Lets assume the fully connected layer has weight matrix of size mxn and input has dimension n.
So basically you want sum of all columns of matrix to 1 as far as I could understand from your question (sum of weights coming out of neuron)
Basically instead of m*n parameters you will have (m-1)*n parameters, last row of the weight matrix can be found (as sum of weights in column is 1)
You’ll need to define weight matrix with parameters (include the sum to 1 condition on your own ) everytime in forward pass.
Get back if you need help
I actually need the rows of the tensor to sum to 1. I need the synapses leading out of each neuron to sum to 1 because I need to be able to do something with the network after training.
Oh okay, so maybe make a parameter tensor of size mx(n-1) and copy it into a mxn tensor. Manually make the last row such that weights sum to 1 and use that tensor in nn functional interface.
Start by creating a nn Module, store the mxn-1 tensor in init function, in each forward function create the bigger mxn weight matrix and you should be fine
Now that @Naman-ntc summoned me:
The easy way is in
self.eps = 1e-7 weight_raw = torch.nn.Parameter(torch.randn(n,m))
weight = self.weight_raw / self.weight_raw.sum(1, keepdim=True).clamp(min=self.eps)
in forward. There are much more elaborate ways let’s see if this works for you. For example, you can construct a wrapper that does this, see torch.nn.utils.spectral_norm, but it probably overkill unless you’re sure you need it.
P.S.: Maybe add an “s” to the topic, in case you want to search for it later.
I am getting an issue that the tensors do not match for size and using the transpose gives an error as well.
Ah, sorry, use keepdim=True.
What are the benefits/risks to constraining your weights to add to a number?