Appropriate method for weight sharing in convolutional layers?

I’m imagining a scenario where I want to apply a learnable convolution layer to multiple Tensor inputs in a module. I would want the layer to learn from all inputs using the average of the gradients of the shared convolution filter w.r.t. each input.

I see this question has been asked before, so let me expand on it a bit.

Say I have a convolution module that shares weights like this:

class SharedFilterConv2d(nn.Module):
  def __init__(self):
    super(SharedFilterConv2d, self).__init__()
    self.conv = nn.Conv2d(3, 3, 3)
    self.shared_weight = nn.Parameter(self.conv.weight)

  def forward(self, inputs):
    outputs = []
    for x in inputs:
      outputs.append(F.conv2d(x, self.shared_weight, ...))
    return torch.cat(outputs, dim=1)

The output of this operation is passed on through the network, and eventually .backward() is called on some value that depends on them. Since the weights need gradients calculated with respect to multiple inputs in one backward pass, what happens to the gradients? Are they averaged when optimizer.step() is called? Should they be averaged? Do the gradients accumulate for all inputs in the input list, even if the input list varies in size?
I am trying to get a grasp on how autograd handles this.

No, optimizer.step() will just update the passed parameters using the .grad attribute of each parameter.
The gradients will be accumulated in the .grad attribute if you are reusing the shared_weight parameter.

Thanks @ptrblck .
Since the gradients are added up, in order to effectively average the update across the different forward passes, would dividing the gradients by the number of forward passes shared weights are used in, using a backward hook, be a good way to do this?

I see in this topic, @albanD recommended sharing weights by .clone()ing the same weight into the weight attribute of different modules on the forward pass. Are there reasons that would better or more memory efficient than using F.conv2d and passing in the weight tensor to a functional call?

Both approaches should work and the difference might be the coding style you prefer (functional API vs. modules).

Yes, dividing the gradients by the number of steps (or the loss) should work, too.

1 Like